├── .gitignore
├── LICENSE
├── README.md
└── guided_diffusion
├── configs
├── base_model.yaml
├── cifar10_model.yaml
└── clip_laion.yaml
├── denoiser.py
├── diffuser.py
├── img_load_utils.py
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Text to Image Model in Keras
2 |
3 |
4 | #### NEW: Brand new Repo using Pytorch to train a latent diffusion models using transformers. See [here](https://github.com/apapiu/transformer_latent_diffusion/tree/main). New Model is better, faster, and at 4x the resolution.
5 |
6 | #### [Kaggle notebook](https://www.kaggle.com/code/apapiu/train-latent-diffusion-in-keras-from-scratch) that trains a 128*128 Latent Diffusion model on the Kaggle kernel hardware (P100 GPU). This is should be similar to the code for Stable diffusion.
7 |
8 | Codebase to train a CLIP conditioned Text to Image Diffusion model on Colab in Keras. See below for notebooks and examples with prompts.
9 |
10 |
11 | Images generated for the prompt: `A small village in the Alps, spring, sunset`
12 |
13 |
14 |
15 | Images generated for the prompt: `Portrait of a young woman with curly red hair, photograph`
16 |
17 |
18 |
19 |
20 | (more exampes below - try with your own inputs in Colab here: [](https://colab.research.google.com/drive/123iljowP_b5o-_6RjHZK8vbgBkrNn8-c?usp=sharing) )
21 |
22 |
23 | ## Table of Contents:
24 | - [Into](#intro)
25 | - [Notebooks](#notebooks)
26 | - [Usage](#usage)
27 | - [Model Setup](#model-setup)
28 | - [Data](#data)
29 | - [Examples](#examples)
30 |
31 | ## Intro
32 |
33 | The goal of this repo is to provide a simple, self-contained codebase for Text to Image Diffusion that can be trained in Colab in a
34 | reasonable amount of time.
35 |
36 | While there are a lot of great resources around the math and usage of diffusion models I haven't found many specifically
37 | focused on _training_ text to img diffusion models.
38 | Particularly the idea of training a Dall-E 2 or Stable Diffusion like model feels like a daunting task requiring immense
39 | computational resources and data. Turns out you can accomplish quite a lot with little resources and without having a PhD in thermodynamics!
40 | The easiest way to get aquainted with the code is thru the notebooks below.
41 |
42 | #### Credits/Resources
43 |
44 | - Original Unet implementation in this [excellent blog post](https://keras.io/examples/generative/ddim/) - most of the code and Unet architecture in `denoiser.py` is based on this. I have added
45 | additional text/CLIP/masking embeddings/inputs and cross/self attention.
46 | - [Conditional CIFAR Model in Pytorch](https://colab.research.google.com/drive/1IJkrrV-D7boSCLVKhi7t5docRYqORtm3#scrollTo=TAUwPLG92r89)
47 | - [Laion Aesthetics 6.5+ Dataset](https://laion.ai/blog/laion-aesthetics/) - The 625K image-text pairs with predicted aesthetics scores of 6.5 or higher was used for training.
48 | - [Text 2 img package](https://github.com/hmiladhia/img2text)
49 | - [Variational Diffusion Models (Paper)](https://arxiv.org/abs/2107.00630)
50 | - [DDIM Paper](https://arxiv.org/abs/2010.02502)
51 |
52 | ## Notebooks
53 |
54 | If you are just starting out I recommend trying out the next two notebook in order. The first should be able to get you
55 | recognizable images on the Fashion Mnist dataset within minutes!
56 |
57 | - Train Class Conditional Fashion MNIST/CIFAR [](https://colab.research.google.com/drive/16rJUyPn72-C30mZRUr-Oo6ZjYS89Z3yH?usp=sharing)
58 | - To try out the code you can use the notebook above in colab. It is set to train on the fashion mnist dataset.
59 | You should be able to see reasonable image generations withing 5 epochs (5-20 minutes depending on GPU)!
60 | - For CIFAR 10/100 - you just have to change the `file_name`. You can get reasonable results after 25 epochs for CIFAR 10 and 40 epochs for CIFAR 100.
61 | Training 50-100 epochs is even better.
62 |
63 | - Train CLIP Conditioned Text to Img Model on 115k 64x64 images+prompts sampled from the Laion Aesthetics 6.5+ dataset. [](https://colab.research.google.com/drive/1EoGdyZTGVeOrEnieWyzjItusBSes_1Ef?usp=sharing)
64 | - You can get recognizable results after ~15 epochs
65 | ~ 10 minutes per epoch (V100)
66 | - Test Prompts on a model trained for about 60 epochs (~60 hours on 1 V100) on entire 600k Laion Aesthetics 6.5+. [](https://colab.research.google.com/drive/123iljowP_b5o-_6RjHZK8vbgBkrNn8-c?usp=sharing)
67 | - This model has about 40 million parameters (150MB) and can be downloaded from [here](https://huggingface.co/apapiu/diffusion_model_aesthetic_keras/blob/main/model_64_65_epochs.h5)
68 | - The examples in this repo use this model
69 |
70 | ## Usage
71 |
72 | The model architecture, training parameters, and generation parameters are specified in a yaml file see [here](https://github.com/apapiu/guided-diffusion-keras/tree/main/guided_diffusion/configs) for examples. If unsure you can use the base_model. The get_train_data is built to work with various known datasets. If you have
73 | your own dataset you can just use that instead. `train_label_embeddings` is expected to be a matrix of embedding the model conditions on (usually some embedding of text but could be anything).
74 |
75 |
76 | ```python
77 | config_path = "guided-diffusion-keras/guided_diffusion/configs/base_model.yaml"
78 |
79 | trainer = Trainer(config_path)
80 | print(trainer.__dict__)
81 |
82 | train_data, train_label_embeddings = get_train_data(trainer.file_name) #OR get your own images and label embeddings in matrix form.
83 | trainer.preprocess_data(train_data, train_label_embeddings)
84 | trainer.initialize_model()
85 | trainer.data_checks(train_data)
86 | trainer.train()
87 | ```
88 |
89 | ## Model Setup
90 |
91 | The setup is fairly simple.
92 |
93 | We train a denoising U-NET neural net that takes the following three inputs:
94 | - `noise_level` (sampled from 0 to 1 with more values concentrated close to 0)
95 | - image (x) corrupted with a level of random noise
96 | - for a given `noise_level` between 0 and 1 the corruption is as follows:
97 | - `x_noisy = x*(1-noise_level) + eps*noise_level where eps ~ np.random.normal`
98 | - CLIP embeddings of a text prompt
99 | - You can think of this as a numerical representation of a text prompt (this is the only pretrained model we use).
100 |
101 | The output is a prediction of the denoised image - call it `f(x_noisy)`.
102 |
103 | The model is trained to minimize the absolute error `|f(x_noisy) - x|` between the prediction and actual image
104 | (you can also use squared error here). Note that I don't reparametrize the loss in terms of the noise here to keep things simple.
105 |
106 | Using this model we then iteratively generate an image from random noise as follows:
107 |
108 | for i in range(len(self.noise_levels) - 1):
109 |
110 | curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]
111 |
112 | # predict original denoised image:
113 | x0_pred = self.predict_x_zero(new_img, label, curr_noise)
114 |
115 | # new image at next_noise level is a weighted average of old image and predicted x0:
116 | new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise
117 |
118 | The `predict_x_zero` method uses classifier free guidance by combining the conditional and unconditional
119 | prediction: `x0_pred = class_guidance * x0_pred_conditional + (1 - class_guidance) * x0_pred_unconditional`
120 |
121 | A bit of math: The approach above falls within the VDM parametrization see 3.1 in [Kingma et al.](https://arxiv.org/pdf/2107.00630.pdf):
122 |
123 | $$ z_t = \alpha_t*x + \sigma_t*\epsilon, \epsilon ~ n(0,1)$$
124 |
125 | Where $z_t$ is the noisy version of $x$ at time $t$.
126 |
127 | generally $\alpha_t$ is chosen to be $\sqrt{1-\sigma_t^2}$ so that the process is variance preserving. Here I chose $\alpha_t=1-\sigma_t$ so that we
128 | linearly interpolate between the image and random noise. Why? Honestly I just wondered if it was going to work :) also it simplifies the updating equation quite a bit and it's easier to understand what the noise to signal ratio will look like. The updating equation above is the DDIM model for this parametrization which simplifies to a simple weighted average. Note that the DDIM model deterministically maps random normal noise to images - this has two benefits: we can interpolate in the random normal latent space, it takes fewer steps generaly to get decent image quality.
129 |
130 |
131 |
132 | Note that I use a lot of unorthodox choices in the modelling. Since I am fairly new to generative models I found this to be a
133 | great way to learn what is crucial vs. what is nice to have. I generally did not see any divergences in training which supports
134 | the notion that **diffusion models are stable to train and are fairly robust to model choices**. The flipside of
135 | this is that if you introduce subtle bugs in your code (of which I am sure there are many in this repo) they are pretty hard
136 | to spot.
137 |
138 |
139 | Architecture: TODO - add cross-attention description.
140 |
141 |
142 | ## Data
143 |
144 | The text-to-img models use the [Laion 6.5+ ](https://laion.ai/blog/laion-aesthetics/) datasets. You can see
145 | some samples [here](http://captions.christoph-schuhmann.de/2B-en-6.5.html). As you can
146 | see this dataset is _very_ biased towards landscapes and portraits. Accordingly, the model
147 | does best at prompts related to art/landscapes/paintings/portraits/architecture.
148 |
149 | The script `img_load_utils.py` contains some code to use the img2dataset package to
150 | download and store images, texts, and their corresponding embeddings. The Laion datasets are still
151 | quite messy with a lot of duplicates, bad descriptions etc.
152 |
153 |
154 | #### 115k Laion Sample: I have uploaded a 115k 64x64 pixels sample from the Laion 6.5+ dataset+their corresponding prompts and CLIP embeddings to huggingface [here](https://huggingface.co/apapiu/diffusion_model_aesthetic_keras/tree/main).
155 | This can be used to quickly prototype new generative models. This
156 | dataset is also used in the notebook above.
157 |
158 | TODO: add more info and script on how to preprocess the data and link to huggingface repo.
159 | Talk about data quality issues.
160 |
161 |
162 | ## Training Process and Colab Hints:
163 |
164 | If you want to train the img-to-text model I highly recommend getting at least the Colab Pro or even the Colab
165 | Pro+ - it's going to be hard to train the model on a K80 GPU, unfortunately. NOTE: Colab will
166 | change its setup and introduce credits at the end of September - I will update this.
167 |
168 | Setting this training workflow on Google Colab wasn't too bad. My approach has been
169 | very low tech and Google Drive played a large role. Basically at the end of every epoch I save the model and the
170 | generated images on a small validation set (50-100 images) to Drive.
171 |
172 | This has a few advantages:
173 | - If I get kicked off my colab instance the model is not lost and I just need to restart the instance
174 | - I can keep record of image generation quality by epoch and go back to check and compare models
175 | - Important point here - make sure to use the _same_ random seed for every epoch - this controls for the randomness
176 | - Drive saves the past 100 versions of a file so I can always use past model version within the past 100 epochs.
177 | - This is important since there is some variability in image quality from epoch to epoch
178 | - It's low tech and you don't have to learn a new platform like wandb.
179 | - Reading and saving data in from Drive in Colab is _very_ fast.
180 |
181 | I have slowly moved some data/models on huggingface but this is WIP.
182 |
183 | #### GPU Speed:
184 | In terms of speed the GPUs go as follows:
185 | `A100>V100>P100>T4>K80` with the A100 being the fastest and every subsequent GPU being roughly twice as slow as
186 | the one before it for training (e.g. P100 is about 4x slower than A100). While I did get the A100 a few times
187 | the sweet spot was really V100/P100 on Colab Pro+ since the risk of being time-outed decreased. With colab PRO+ ($50/month) I managed to train on V100/P100 continuously for 12-24 hours at a time.
188 |
189 |
190 |
191 | ### Validation:
192 |
193 | I'm not an expert here but generally the validation of generative models
194 | is still an open question. There are metrics like Inception Score, FID, and KID that measure
195 | whether the distribution of generated images is "close" to the training distribution in some way.
196 | The main issue with all of these metrics however is that a model that simply memorizes the training data
197 | will have a perfect score - so they don't account for overfitting. They are also fairly hard to understand, need large
198 | sample sizes, and are computationally intensive. For all these reasons I have chosen not to use them for now.
199 |
200 | Instead I have focused on analyze the visual quality of generated images by uhm.. looking at them. This can quickly
201 | devolved into a tea-lead reading exercise however. To combat this one come up with different strategies to test for
202 | quality and diversity. For example sampling from both generated and ground truth images and looking at them together
203 | - either side by side or permuted is a reasonable way to check test for sample quality.
204 |
205 | To test for generalization I have mostly focused on interpolations in both the CLIP space and the
206 | random normal latent space. Ideally as you move from embedding to embedding you want to generated images
207 | along the path to be meaningful in some way.
208 |
209 | CLIP interpolation: "A lake in the forest in the summer" -> "A lake in the forest in the winter"
210 |
211 |
212 |
213 |
214 | Does the model memorize the training data? This is an important question that has
215 | lots of implications. First of all the models above don't have the capacity to memorize _all_
216 | of the training data. For example: the model is about 150 MB but is trained on about
217 | 8GB of data. Second of all it might not be in the model's best interest to memorize things.
218 | After digging a bit around the predictions on the training data I did find _one_ example where
219 | the model shamelessly copies a training example. Note this is because the image appears many times
220 | in the training data.
221 |
222 |
223 |
224 | ### Examples:
225 |
226 | Prompt: `An Italian Villaga Painted by Picasso`
227 |
228 |
229 |
230 | `City at night`
231 |
232 |
233 |
234 | `Photograph of young woman in a field of flowers, bokeh`
235 |
236 |
237 |
238 | `Street on an island in Greece`
239 |
240 |
241 |
242 | `A Mountain Lake in the spring at sunset`
243 |
244 |
245 |
246 | `A man in a suit in the field in wintertime`
247 |
248 |
249 |
250 |
251 | CLIP interpolation: "A minimalist living room" -> "A Field in springtime, painting"
252 |
253 |
254 |
255 | CLIP interpolation: "A lake in the forest in the summer" -> "A lake in the forest in the winter"
256 |
257 |
258 |
259 |
260 |
--------------------------------------------------------------------------------
/guided_diffusion/configs/base_model.yaml:
--------------------------------------------------------------------------------
1 | train_model: True
2 | epochs: 50
3 | file_name: "fashion_mnist"
4 |
5 | channels: 64
6 | channel_multiplier: [1, 2, 3]
7 | block_depth: 2
8 | emb_size: 512
9 | num_classes: 12
10 | attention_levels: [0, 1, 0]
11 | embedding_dims: 32
12 | embedding_max_frequency: 1000.0
13 |
14 | precomputed_embedding: False
15 | save_in_drive: False
16 |
17 | batch_size: 64
18 | num_imgs: 64
19 | row: 8
20 | learning_rate: 0.0003
21 |
22 | class_guidance: 1
23 | MODEL_NAME: "model_test"
24 | from_scratch: True
25 | data_dir: "/content/diffusion_model_aesthetic_keras"
--------------------------------------------------------------------------------
/guided_diffusion/configs/cifar10_model.yaml:
--------------------------------------------------------------------------------
1 | train_model: True
2 | epochs: 50
3 | file_name: "cifar10"
4 |
5 | channels: 64
6 | channel_multiplier: [1, 2, 3]
7 | block_depth: 2
8 | emb_size: 512
9 | num_classes: 12
10 | attention_levels: [0, 1, 0]
11 | embedding_dims: 32
12 | embedding_max_frequency: 1000.0
13 |
14 | precomputed_embedding: False
15 | save_in_drive: False
16 |
17 | batch_size: 64
18 | num_imgs: 64
19 | row: 8
20 | learning_rate: 0.0003
21 |
22 | class_guidance: 3
23 | MODEL_NAME: "model_cifar"
24 | from_scratch: True
25 | data_dir: "/content/diffusion_model_aesthetic_keras"
--------------------------------------------------------------------------------
/guided_diffusion/configs/clip_laion.yaml:
--------------------------------------------------------------------------------
1 | train_model: True
2 | epochs: 3
3 | file_name: "from_huggingface" #data stored in huggingface
4 |
5 | channels: 96
6 | channel_multiplier: [1, 2, 3, 4]
7 | block_depth: 2
8 | emb_size: 512
9 | num_classes: 12
10 | attention_levels: [0, 0, 1, 0]
11 | embedding_dims: 32
12 | embedding_max_frequency: 1000.0
13 |
14 | precomputed_embedding: True
15 | save_in_drive: False
16 |
17 | batch_size: 64
18 | num_imgs: 36
19 | learning_rate: 0.0003
20 |
21 | class_guidance: 4
22 | MODEL_NAME: "model_test_aesthetic"
23 | from_scratch: True
24 | data_dir: "/content/diffusion_model_aesthetic_keras"
25 |
--------------------------------------------------------------------------------
/guided_diffusion/denoiser.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import tensorflow as tf
4 | from tensorflow import keras
5 | from tensorflow.keras import layers
6 | from tensorflow.keras.models import Model
7 |
8 |
9 | def attention(qkv):
10 |
11 | q, k, v = qkv
12 | # should we scale this?
13 | s = tf.matmul(k, q, transpose_b=True) # [bs, h*w, h*w]
14 | beta = tf.nn.softmax(s) # attention map
15 | o = tf.matmul(beta, v) # [bs, h*w, C]
16 | return o
17 |
18 |
19 | def spatial_attention(img):
20 | # Attention implementation a combination of:
21 | # https://github.com/taki0112/Self-Attention-GAN-Tensorflow and
22 | # https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/attention.py
23 |
24 | filters = img.shape[3]
25 | orig_shape = ((img.shape[1], img.shape[2], img.shape[3]))
26 | print(orig_shape)
27 | img = layers.BatchNormalization()(img)
28 |
29 | # projections:
30 | q = layers.Conv2D(filters // 8, kernel_size=1, padding="same")(img)
31 | k = layers.Conv2D(filters // 8, kernel_size=1, padding="same")(img)
32 | v = layers.Conv2D(filters, kernel_size=1, padding="same")(img)
33 | k = layers.Reshape((k.shape[1] * k.shape[2], k.shape[3],))(k)
34 |
35 | q = layers.Reshape((q.shape[1] * q.shape[2], q.shape[3]))(q)
36 | v = layers.Reshape((v.shape[1] * v.shape[2], v.shape[3],))(v)
37 |
38 | img = layers.Lambda(attention)([q, k, v])
39 | img = layers.Reshape(orig_shape)(img)
40 |
41 | # out_projection:
42 | img = layers.Conv2D(filters, kernel_size=1, padding="same")(img)
43 | img = layers.BatchNormalization()(img)
44 |
45 | return img
46 |
47 |
48 | def cross_attention(img, text):
49 | filters = img.shape[3]
50 | orig_shape = ((img.shape[1], img.shape[2], img.shape[3]))
51 | print(orig_shape)
52 | img = layers.BatchNormalization()(img)
53 | text = layers.BatchNormalization()(text)
54 |
55 | # projections:
56 | q = layers.Conv2D(filters // 8, kernel_size=1, padding="same")(text)
57 | k = layers.Conv2D(filters // 8, kernel_size=1, padding="same")(img)
58 | v = layers.Conv2D(filters, kernel_size=1, padding="same")(text)
59 |
60 | q = layers.Reshape((q.shape[1] * q.shape[2], q.shape[3]))(q)
61 | k = layers.Reshape((k.shape[1] * k.shape[2], k.shape[3],))(k)
62 | v = layers.Reshape((v.shape[1] * v.shape[2], v.shape[3],))(v)
63 |
64 | img = layers.Lambda(attention)([q, k, v])
65 | img = layers.Reshape(orig_shape)(img)
66 |
67 | # out_projection:
68 | img = layers.Conv2D(filters, kernel_size=1, padding="same")(img)
69 | img = layers.BatchNormalization()(img)
70 |
71 | return img
72 |
73 | ### The sinusoidal_embedding, ResidualBlock, Down/UP Block taken from
74 | ### https://github.com/keras-team/keras-io/blob/master/examples/generative/ddim.py
75 | ### Only change is adding self/cross attention:
76 |
77 | def sinusoidal_embedding(x):
78 | #TODO: remove the hardcoded values here:
79 | embedding_min_frequency = 1.0
80 | embedding_max_frequency = 1000.0
81 | embedding_dims = 32
82 | frequencies = tf.exp(
83 | tf.linspace(
84 | tf.math.log(embedding_min_frequency),
85 | tf.math.log(embedding_max_frequency),
86 | embedding_dims // 2,
87 | )
88 | )
89 | angular_speeds = 2.0 * math.pi * frequencies
90 | embeddings = tf.concat(
91 | [tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3
92 | )
93 | return embeddings
94 |
95 |
96 | def ResidualBlock(width):
97 | def apply(x):
98 | input_width = x.shape[3]
99 | if input_width == width:
100 | residual = x
101 | else:
102 | residual = layers.Conv2D(width, kernel_size=1)(x)
103 | x = layers.BatchNormalization(center=False, scale=False)(x)
104 | x = layers.Conv2D(
105 | width, kernel_size=3, padding="same", activation=keras.activations.swish
106 | )(x)
107 | # intermediate layer ... add mlp embedding of class plus timestamp here?
108 | x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
109 | x = layers.Add()([x, residual])
110 | return x
111 |
112 | return apply
113 |
114 |
115 | def DownBlock(width, block_depth, use_self_attention=False):
116 | def apply(x):
117 | x, skips, emb_and_noise = x
118 | for _ in range(block_depth):
119 | x = ResidualBlock(width)(x)
120 |
121 | if use_self_attention:
122 | o = spatial_attention(x)
123 | x = layers.Add()([x, o])
124 | cross_att = cross_attention(x, emb_and_noise)
125 | x = layers.Add()([x, cross_att])
126 |
127 | skips.append(x)
128 | x = layers.AveragePooling2D(pool_size=2)(x)
129 | return x
130 |
131 | return apply
132 |
133 |
134 | def UpBlock(width, block_depth, use_self_attention=False):
135 | def apply(x):
136 | x, skips, emb_and_noise = x
137 | x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
138 | for _ in range(block_depth):
139 | x = layers.Concatenate()([x, skips.pop()])
140 | x = ResidualBlock(width)(x)
141 |
142 | if use_self_attention:
143 | o = spatial_attention(x)
144 | x = layers.Add()([x, o])
145 | cross_att = cross_attention(x, emb_and_noise)
146 | x = layers.Add()([x, cross_att])
147 |
148 | return x
149 |
150 | return apply
151 |
152 |
153 | def get_network(image_size, widths, block_depth, num_classes,
154 | num_channels, emb_size, attention_levels,
155 | precomputed_embedding=True):
156 | noisy_images = keras.Input(shape=(image_size, image_size, num_channels))
157 |
158 | noise_variances = keras.Input(shape=(1, 1, 1))
159 | e = layers.Lambda(sinusoidal_embedding)(noise_variances)
160 | e = layers.UpSampling2D(size=image_size, interpolation="nearest")(e)
161 |
162 | if precomputed_embedding:
163 | input_label = layers.Input(shape=emb_size) # CLIP/glove embedding.
164 | emb_label = layers.Dense(emb_size // 2)(input_label)
165 | emb_label = layers.Reshape((1, 1, emb_size // 2))(emb_label)
166 | else:
167 | input_label = layers.Input(shape=1) # label/word - integer encoded
168 | emb_label = layers.Embedding(input_dim=num_classes, output_dim=emb_size)(input_label)
169 | emb_label = layers.Reshape((1, 1, emb_size))(emb_label)
170 |
171 | emb_label = layers.UpSampling2D(size=image_size, interpolation="nearest")(emb_label)
172 |
173 | x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
174 | x = layers.Concatenate()([x, e, emb_label])
175 |
176 | emb_and_noise = layers.Concatenate()([emb_label, e])
177 | emb_and_noise = layers.BatchNormalization()(emb_and_noise)
178 |
179 | skips = []
180 | level = 0
181 | for width in widths[:-1]:
182 | use_self_attention = bool(attention_levels[level])
183 | x = DownBlock(width, block_depth, use_self_attention)([x, skips, emb_and_noise])
184 |
185 | emb_and_noise = layers.AveragePooling2D()(emb_and_noise)
186 | level += 1
187 |
188 | for _ in range(block_depth):
189 | x = ResidualBlock(widths[-1])(x)
190 | if bool(attention_levels[level]):
191 | o = spatial_attention(x)
192 | x = layers.Add()([x, o])
193 | cross_att = cross_attention(x, emb_and_noise)
194 | x = layers.Add()([x, cross_att])
195 |
196 | for width in reversed(widths[:-1]):
197 | level -= 1
198 |
199 | emb_and_noise = layers.UpSampling2D(size=2, interpolation="bilinear")(emb_and_noise)
200 | use_self_attention = bool(attention_levels[level])
201 | x = UpBlock(width, block_depth, use_self_attention)([x, skips, emb_and_noise])
202 |
203 | x = layers.Conv2D(num_channels, kernel_size=1, kernel_initializer="zeros",
204 | activation="linear"
205 | )(x)
206 |
207 | return Model([noisy_images, noise_variances, input_label], x, name="residual_unet")
--------------------------------------------------------------------------------
/guided_diffusion/diffuser.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import matplotlib.pyplot as plt
4 | from utils import imshow, plot_images
5 |
6 |
7 | def dynamic_thresholding(img, perc=99.5):
8 | s = np.percentile(np.abs(img.ravel()), perc)
9 | s = np.max([s, 1])
10 | img = img.clip(-s, s) / s
11 |
12 | return img
13 |
14 |
15 | class Diffuser:
16 |
17 | def __init__(self, denoiser, class_guidance, diffusion_steps, perc_thresholding=99.5, batch_size=64):
18 | self.denoiser = denoiser
19 | self.class_guidance = class_guidance
20 | self.diffusion_steps = diffusion_steps
21 | #TODO: parametrize this better:
22 | self.noise_levels = 1 - np.power(np.arange(0.0001, 0.99, 1 / self.diffusion_steps), 1 / 3)
23 | self.noise_levels[-1] = 0.01
24 | self.perc_thresholding = perc_thresholding
25 | self.batch_size = batch_size
26 |
27 | def predict_x_zero(self, x_t, label, noise_level):
28 | """Predict original image based on noisy image (or matrix of noisy images) at noise level plus conditional label"""
29 |
30 | # we use 0 for the unconditional embedding:
31 | num_imgs = len(x_t)
32 | label_empty_ohe = np.zeros(shape=label.shape)
33 |
34 | # predict x0:
35 | noise_in = np.array([noise_level] * num_imgs)[:, None, None, None]
36 |
37 | # TODO: can we do some of this in tensorflow?
38 | # concatenate the conditional and unconditional inputs to speed inference:
39 | nn_inputs = [np.vstack([x_t, x_t]),
40 | np.vstack([noise_in, noise_in]),
41 | np.vstack([label, label_empty_ohe])]
42 |
43 | x0_pred = self.denoiser.predict(nn_inputs, batch_size=self.batch_size)
44 |
45 | x0_pred_label = x0_pred[:num_imgs]
46 | x0_pred_no_label = x0_pred[num_imgs:]
47 |
48 | # classifier free guidance:
49 | x0_pred = self.class_guidance * x0_pred_label + (1 - self.class_guidance) * x0_pred_no_label
50 |
51 | if self.perc_thresholding:
52 | # clip the prediction using dynamic thresholding a la Imagen:
53 | x0_pred = dynamic_thresholding(x0_pred, perc=self.perc_thresholding)
54 |
55 | return x0_pred
56 |
57 | def reverse_diffusion(self, seeds, label, show_img=False):
58 | """Reverse Guided Diffusion on a matrix of random images (seeds). Returns generated images"""
59 |
60 | new_img = seeds
61 |
62 | for i in tqdm(range(len(self.noise_levels) - 1)):
63 |
64 | curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]
65 |
66 | # predict original denoised image:
67 | x0_pred = self.predict_x_zero(new_img, label, curr_noise)
68 |
69 | # new image at next_noise level is a weighted average of old image and predicted x0:
70 | new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise
71 |
72 | if show_img:
73 | plot_images(x0_pred, nrows=np.sqrt(len(new_img)),
74 | save_name=str(i),
75 | size=12)
76 | plt.show()
77 |
78 | return x0_pred
79 |
80 | def reverse_diffusion_masked(self, seeds, label_ohe, show_img=False, masked_imgs=None, mask=None, u=1):
81 | """Reverse Guided Diffusion on a matrix of random images (seeds) with a mask. Can be used for in/outpainting
82 | Based on the algorithm from Repaint: https://github.com/andreas128/RePaint
83 | """
84 |
85 | new_img = seeds
86 | num_imgs = len(new_img)
87 |
88 | for i in tqdm(range(len(self.noise_levels) - 1)):
89 |
90 | curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1] # alpha1, alpha2
91 |
92 | for j in range(1, u + 1):
93 |
94 | # predict original denoised image:
95 | x0_pred = self.predict_x_zero(new_img, label_ohe, curr_noise)
96 |
97 | # new image at next_noise level is a weighted average of old image and predicted x0:
98 | new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise
99 |
100 | if masked_imgs is not None:
101 | # let the last 20% of steps be done without masking.
102 | if i <= int(self.diffusion_steps * 0.8):
103 | ####masked part:
104 | new_img_known = (1 - next_noise) * masked_imgs + np.random.normal(0, 1,
105 | size=new_img.shape) * next_noise
106 |
107 | # mask here is empty/unknown part
108 | new_img = new_img_known * mask + new_img * (1 - mask)
109 | else:
110 | break
111 |
112 | if j != u:
113 | print(j)
114 | ### noise the image back to the previous step:
115 | s = (1 - curr_noise) / (1 - next_noise)
116 | new_img = s * new_img + np.sqrt(curr_noise ** 2 - s ** 2 * next_noise ** 2) * np.random.normal(0, 1,
117 | size=new_img.shape)
118 |
119 | if show_img:
120 | plot_images(x0_pred, nrows=np.sqrt(len(new_img)),
121 | save_name=str(i),
122 | size=12)
123 | plt.show()
124 |
125 | return x0_pred
--------------------------------------------------------------------------------
/guided_diffusion/img_load_utils.py:
--------------------------------------------------------------------------------
1 | # !pip install img2dataset
2 | # !pip install datasets
3 |
4 | from datasets import load_dataset
5 | import pandas as pd
6 | import numpy as np
7 | from sklearn.feature_extraction.text import CountVectorizer
8 | from img2dataset import download
9 | import os
10 | import matplotlib.pyplot as plt
11 | import glob
12 | import cv2
13 | import matplotlib.pyplot as plt
14 | from tqdm import tqdm
15 |
16 |
17 | def clean_df(df):
18 |
19 | df["pixels"] = df["HEIGHT"]*df["WIDTH"]
20 | df["ratio"] = df["HEIGHT"]/df["WIDTH"]
21 | x = df["similarity"]
22 | x[x<1].hist(bins=100)
23 | plt.show()
24 | df["ratio"].quantile(np.arange(0, 1, 0.01)).plot()
25 | plt.show()
26 | df["pixels"].quantile(np.arange(0, 1, 0.01)).plot()
27 | plt.show()
28 |
29 | df = df[df["similarity"] >= 0.3]
30 | print(df.shape)
31 |
32 | #### only pick the first URL - there are a lot of duplicate images:
33 | df = df.groupby("URL").first().reset_index()
34 |
35 | df = df.drop_duplicates(subset = ["TEXT", "WIDTH", "HEIGHT"])
36 | print(df.shape)
37 |
38 | #remove huge images for faster download:
39 | df = df[df["pixels"] <= 1024*1024]
40 | print(df.shape)
41 |
42 | # remove images that aren't close to being square - otherwise faces get cropped.
43 | df = df[df["ratio"] > 0.3]
44 | print(df.shape)
45 | df = df[df["ratio"] < 2]
46 | print(df.shape)
47 |
48 | df = df[df["AESTHETIC_SCORE"] > 5.5]
49 | print(df.shape)
50 |
51 | vectorizer = CountVectorizer(min_df=25, stop_words="english")
52 | text = df["TEXT"]
53 | vectorizer.fit(text)
54 |
55 | word_counts = vectorizer.transform(text)
56 | x = word_counts.sum(1)
57 | df["word_counts"] = pd.DataFrame(x)[0].values
58 |
59 | df["word_counts"].value_counts().sort_index()[:30].plot(kind="bar")
60 |
61 | df = df[df["word_counts"] >= 1.0]
62 | print(df.shape)
63 | df = df[df["word_counts"] <= 35.0]
64 | print(df.shape)
65 |
66 | return df
67 |
68 | def download_data(url_path="df"):
69 | download(
70 | processes_count=8,
71 | thread_count=16,
72 | url_list=url_path,
73 | image_size=64,
74 | output_folder=output_dir,
75 | output_format="files",
76 | input_format="parquet",
77 | url_col="URL",
78 | caption_col="TEXT",
79 | enable_wandb=False,
80 | number_sample_per_shard=10000,
81 | distributor="multiprocessing",
82 | resize_mode="center_crop",
83 | )
84 |
85 |
86 | def get_imgs_and_captions():
87 | #TODO: This is sketchy...and error prone:
88 | imgs_files = glob.glob("/content/output_dir/*/*.jpg")
89 | imgs_files.sort()
90 | text_files = glob.glob("/content/output_dir/*/*.txt")
91 | text_files.sort()
92 |
93 | caption_list = []
94 | for txt_path in tqdm(text_files):
95 | with open(txt_path) as f:
96 | lines = f.readlines()
97 | caption_list.append(lines[0])
98 |
99 | img_list = []
100 | for img_path in tqdm(imgs_files):
101 | img = cv2.imread(img_path)
102 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
103 | img_list.append(img)
104 |
105 | img_array = np.array(img_list)
106 | del img_list
107 | return img_array, caption_list
108 |
109 | def get_dataset_from_huggingface(hugging_face_dataset_name):
110 | dataset = load_dataset(hugging_face_dataset_name, split="train")
111 | df = dataset.to_pandas()
112 |
113 | return df
114 |
115 | def build_dataset_and_save(hugging_face_dataset_name):
116 | sample_size = 120000
117 | seed = 1
118 | df = get_dataset_from_huggingface(hugging_face_dataset_name)
119 |
120 | df = clean_df(df)
121 | np.random.seed(seed)
122 | df.sample(sample_size).to_parquet("df") #note that this keeps the original indexes from the data.
123 | output_dir = os.path.abspath("output_dir")
124 | download_data(url_path="df")
125 | train_data, caption_list = get_imgs_and_captions()
126 | pd.Series(caption_list).to_csv("/content/diffusion_model_aesthetic_keras/captions.csv", index=None)
127 | np.save("/content/diffusion_model_aesthetic_keras/imgs.npy", train_data)
--------------------------------------------------------------------------------
/guided_diffusion/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from tensorflow.keras.datasets import mnist, fashion_mnist, cifar10, cifar100
7 | from keras.utils import plot_model
8 | from tensorflow import keras
9 | from matplotlib import pyplot as plt
10 |
11 | from denoiser import get_network
12 | from utils import batch_generator, plot_images, get_data, preprocess
13 | from diffuser import Diffuser
14 |
15 |
16 | class Trainer:
17 | def __init__(self, config_file):
18 | # Load YAML config
19 | with open(config_file, 'r') as f:
20 | config = yaml.safe_load(f)
21 |
22 | # Unpack config into instance attributes
23 | for key, value in config.items():
24 | setattr(self, key, value)
25 |
26 | # Additional setups
27 | self._setup_paths_and_dirs()
28 |
29 | def _setup_paths_and_dirs(self):
30 | self.widths = [c * self.channels for c in self.channel_multiplier]
31 | self.captions_path = os.path.join(self.data_dir, "captions.csv")
32 | self.imgs_path = os.path.join(self.data_dir, "imgs.npy")
33 | self.embedding_path = os.path.join(self.data_dir, "embeddings.npy")
34 |
35 | if self.save_in_drive:
36 | from google.colab import drive
37 | drive.mount('/content/drive')
38 | drive_path = '/content/drive/MyDrive/'
39 | self.home_dir = os.path.join(drive_path, self.MODEL_NAME)
40 | else:
41 | self.home_dir = self.MODEL_NAME
42 |
43 | if not os.path.exists(self.home_dir):
44 | os.mkdir(self.home_dir)
45 |
46 | self.model_path = os.path.join(self.home_dir, self.MODEL_NAME + ".h5")
47 |
48 | def preprocess_data(self, train_data, train_label_embeddings):
49 | print(train_data.shape)
50 | self.train_data = train_data
51 | self.train_label_embeddings = train_label_embeddings
52 | self.image_size = train_data.shape[1]
53 | self.num_channels = train_data.shape[-1]
54 | self.row = int(np.sqrt(self.num_imgs))
55 | self.labels = self._get_labels(train_data, train_label_embeddings)
56 |
57 | def _get_labels(self, train_data, train_label_embeddings):
58 | if self.precomputed_embedding:
59 | return train_label_embeddings[:self.num_imgs]
60 | else:
61 | row_labels = np.array([[i] * self.row for i in np.arange(self.row)]).flatten()[:, None]
62 | return row_labels + 1
63 |
64 | def initialize_model(self):
65 |
66 | if self.from_scratch:
67 | self.autoencoder = get_network(self.image_size,
68 | self.widths,
69 | self.block_depth,
70 | num_classes=self.num_classes,
71 | attention_levels=self.attention_levels,
72 | emb_size=self.emb_size,
73 | num_channels=self.num_channels,
74 | precomputed_embedding=self.precomputed_embedding)
75 |
76 | self.autoencoder.compile(optimizer="adam", loss="mae")
77 | else:
78 | self.autoencoder = keras.models.load_model(self.model_path)
79 |
80 | def data_checks(self, train_data):
81 | print("Number of parameters is {0}".format(self.autoencoder.count_params()))
82 | pd.Series(train_data[:1000].ravel()).hist(bins=80)
83 | plt.show()
84 | print("Original Images:")
85 | plot_images(preprocess(train_data[:self.num_imgs]), nrows=int(np.sqrt(self.num_imgs)))
86 | plot_model(self.autoencoder, to_file=os.path.join(self.home_dir, 'model_plot.png'),
87 | show_shapes=True, show_layer_names=True)
88 | print("Generating Images below:")
89 |
90 | def train(self):
91 | np.random.seed(100)
92 | self.rand_image = np.random.normal(0, 1, (self.num_imgs, self.image_size, self.image_size, self.num_channels))
93 |
94 | self.diffuser = Diffuser(self.autoencoder,
95 | class_guidance=self.class_guidance,
96 | diffusion_steps=35)
97 |
98 | if self.train_model:
99 | train_generator = batch_generator(self.autoencoder,
100 | self.model_path,
101 | self.train_data,
102 | self.train_label_embeddings,
103 | self.epochs,
104 | self.batch_size,
105 | self.rand_image,
106 | self.labels,
107 | self.home_dir,
108 | self.diffuser)
109 |
110 | self.autoencoder.optimizer.learning_rate.assign(self.learning_rate)
111 |
112 | self.eval_nums = self.autoencoder.fit(
113 | x=train_generator,
114 | epochs=self.epochs
115 | )
116 |
117 |
118 | def get_train_data(file_name, captions_path=None, imgs_path=None, embedding_path=None):
119 | dataset_loaders = {
120 | "cifar10": cifar10.load_data,
121 | "cifar100": cifar100.load_data,
122 | "fashion_mnist": fashion_mnist.load_data,
123 | "mnist": mnist.load_data
124 | }
125 |
126 | if file_name in dataset_loaders:
127 | (train_data, train_label_embeddings), (_, _) = dataset_loaders[file_name]()
128 |
129 | # Add unconditional embedding
130 | train_label_embeddings = train_label_embeddings + 1
131 |
132 | if file_name in ["fashion_mnist", "mnist"]:
133 | train_data = train_data[:, :, :, None]
134 | train_label_embeddings = train_label_embeddings[:, None]
135 |
136 | else:
137 | captions = pd.read_csv(captions_path)
138 | train_data, train_label_embeddings = np.load(imgs_path), np.load(embedding_path)
139 |
140 | return train_data, train_label_embeddings
141 | #train_data, train_label_embeddings = get_data(npz_file_name=file_name, prop=0.6, captions=False)
142 |
143 |
144 | if __name__=='__main__':
145 |
146 | trainer = Trainer('guided_diffusion/configs/base_model.yaml')
147 |
148 | train_data, train_label_embeddings = get_train_data(trainer.file_name)
149 | trainer.preprocess_data(train_data, train_label_embeddings)
150 | trainer.initialize_model()
151 | trainer.data_checks(train_data)
152 | trainer.train()
153 |
--------------------------------------------------------------------------------
/guided_diffusion/utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from matplotlib import pyplot as plt
4 | import os
5 | from numpy.linalg import norm
6 |
7 | import torch
8 | from torch.utils.data import DataLoader
9 |
10 | def get_text_encodings(prompts, model):
11 | import clip
12 | #model, preprocess = clip.load("ViT-B/32")
13 | data_loader = DataLoader(prompts, batch_size=256)
14 |
15 | all_encodings = []
16 | for labels in data_loader:
17 | text_tokens = clip.tokenize(labels, truncate=True).cuda()
18 |
19 | with torch.no_grad():
20 | text_encoding = model.encode_text(text_tokens)
21 |
22 | all_encodings.append(text_encoding)
23 |
24 | all_encodings = torch.cat(all_encodings).cpu().numpy()
25 |
26 | return all_encodings
27 |
28 |
29 | def preprocess(array):
30 | """
31 | Normalizes the supplied array and reshapes it into the appropriate format.
32 | """
33 |
34 | array = array.astype("float32") / 255.0
35 | array = array * 2 - 1
36 |
37 | return array
38 |
39 |
40 | #######
41 | # TRAIN UTILS:
42 | #######
43 |
44 | def add_noise(array, mu=0, std=1):
45 | #TODO: have a better way to parametrize this:
46 | x = np.abs(np.random.normal(0, std, 2 * len(array)))
47 | x = x[x < 3]
48 | x = x / 3
49 | x = x[:len(array)]
50 | noise_levels = x
51 |
52 | signal_levels = 1 - noise_levels #OR: np.sqrt(1-np.square(noise_levels))
53 |
54 | # reshape so that the multiplication makes sense:
55 | noise_level_reshape = noise_levels[:, None, None, None]
56 | signal_level_reshape = signal_levels[:, None, None, None]
57 |
58 | pure_noise = np.random.normal(0, 1, size=array.shape).astype("float32")
59 | noisy_data = array * signal_level_reshape + pure_noise * noise_level_reshape
60 |
61 | return noisy_data, noise_levels
62 |
63 |
64 | def slerp(p0, p1, t):
65 | """spherical interpolation"""
66 | omega = np.arccos(np.dot(p0 / norm(p0), p1 / norm(p1)))
67 | so = np.sin(omega)
68 | return np.sin((1.0 - t) * omega) / so * p0 + np.sin(t * omega) / so * p1
69 |
70 |
71 | def interpolate_two_points(num_points=100):
72 | pA = np.random.normal(0, 1, (image_size * image_size * num_channels))
73 | pB = np.random.normal(0, 1, (image_size * image_size * num_channels))
74 |
75 | ps = np.array([slerp(pA, pB, t) for t in np.arange(0.0, 1.0, 1 / num_points)])
76 | rand_image = ps.reshape(len(ps), image_size, image_size, num_channels)
77 |
78 | return rand_image
79 |
80 |
81 | def imshow(img):
82 | def norm_0_1(img):
83 | return (img + 1) / 2
84 |
85 | # img here is betweeen -1 and 1:
86 | if img.shape[-1] == 1:
87 | img = img.reshape(img.shape[0], img.shape[1])
88 | img = np.clip(img, -1, 1)
89 | plt.imshow(norm_0_1(img))
90 |
91 |
92 | def plot_images(imgs, size=16, nrows=8, save_name=None):
93 | plt.rcParams["figure.figsize"] = (size, size)
94 |
95 | for i in range(len(imgs)):
96 | ax = plt.subplot(nrows, nrows, i + 1)
97 | imshow(imgs[i])
98 | plt.gray()
99 | ax.get_xaxis().set_visible(False)
100 | ax.get_yaxis().set_visible(False)
101 | if save_name:
102 | save_to = "{0}.png".format(save_name)
103 | plt.savefig(save_to, dpi=200)
104 | plt.show()
105 |
106 |
107 | def get_data(npz_file_path, prop=0.6, captions=False):
108 | data = np.load(npz_file_path)
109 |
110 | if captions:
111 | train_data, train_label_embeddings, caption_list = data["arr_0"], data["arr_1"], data["arr_2"]
112 | else:
113 | train_data, train_label_embeddings = data["arr_0"], data["arr_1"]
114 |
115 | # eliminate if perc white pixels > 60% - not really used.
116 | white_pixels = (train_data >= 254).mean(axis=(1, 2, 3))
117 | mask = (white_pixels < prop)
118 |
119 | if captions:
120 | train_data, train_label_embeddings, caption_list = train_data[mask], train_label_embeddings[mask], caption_list[mask]
121 | return train_data, train_label_embeddings, caption_list
122 | else:
123 | train_data, train_label_embeddings = train_data[mask], train_label_embeddings[mask]
124 | return train_data, train_label_embeddings
125 |
126 |
127 | def batch_generator(model, model_path, train_data, train_label_embeddings, epochs,
128 | batch_size, rand_image, labels, home_dir, diffuser):
129 | indices = np.arange(len(train_data))
130 | batch = []
131 | epoch = 0
132 | print("Training for {0}".format(epochs))
133 | while epoch < epochs:
134 | print("saving model:")
135 | model.save(model_path)
136 |
137 | print(" Generating images:")
138 | diffuser.denoiser = model
139 | imgs = diffuser.reverse_diffusion(rand_image, labels)
140 | img_path = os.path.join(home_dir, str(epoch))
141 | plot_images(imgs, save_name=img_path, nrows=int(np.sqrt(len(imgs))))
142 |
143 | print("new epoch {0}".format(epoch))
144 | # it might be a good idea to shuffle your data before each epoch
145 | np.random.shuffle(indices)
146 | for i in indices:
147 | batch.append(i)
148 | if len(batch) == batch_size:
149 | tr_batch = train_data[batch].copy()
150 | tr_batch = preprocess(tr_batch)
151 |
152 | # random dropout for CFG:
153 | s = np.random.binomial(1, 0.15, size=batch_size).astype("bool")
154 | train_label_dropout = train_label_embeddings[batch].copy()
155 | train_label_dropout[s] = np.zeros(shape=train_label_embeddings.shape[1])
156 |
157 | # Add Noise to Images:
158 | noisy_train_data, noise_level_train = add_noise(tr_batch, mu=0, std=1)
159 | noise_level_train = noise_level_train[:, None, None, None] # for correct dim
160 |
161 | yield [noisy_train_data, noise_level_train, train_label_dropout], tr_batch
162 | batch = []
163 | epoch += 1
164 |
165 |
166 | def get_common_words(n):
167 | from sklearn.feature_extraction.text import CountVectorizer
168 | vectorizer = CountVectorizer(min_df=10, stop_words="english")
169 | text = caption_list
170 | vectorizer.fit(text)
171 |
172 | word_counts = vectorizer.transform(text)
173 | word_counts = (word_counts.sum(0))
174 | word_counts = pd.DataFrame(word_counts.T, index=vectorizer.get_feature_names_out())
175 | return word_counts.sort_values(0, ascending=False).head(n)
176 |
--------------------------------------------------------------------------------