├── .gitignore ├── LICENSE ├── README.md └── guided_diffusion ├── configs ├── base_model.yaml ├── cifar10_model.yaml └── clip_laion.yaml ├── denoiser.py ├── diffuser.py ├── img_load_utils.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text to Image Model in Keras 2 | 3 | 4 | #### NEW: Brand new Repo using Pytorch to train a latent diffusion models using transformers. See [here](https://github.com/apapiu/transformer_latent_diffusion/tree/main). New Model is better, faster, and at 4x the resolution. 5 | 6 | #### [Kaggle notebook](https://www.kaggle.com/code/apapiu/train-latent-diffusion-in-keras-from-scratch) that trains a 128*128 Latent Diffusion model on the Kaggle kernel hardware (P100 GPU). This is should be similar to the code for Stable diffusion. 7 | 8 | Codebase to train a CLIP conditioned Text to Image Diffusion model on Colab in Keras. See below for notebooks and examples with prompts. 9 | 10 | 11 | Images generated for the prompt: `A small village in the Alps, spring, sunset` 12 | 13 | image 14 | 15 | Images generated for the prompt: `Portrait of a young woman with curly red hair, photograph` 16 | 17 | image 18 | 19 | 20 | (more exampes below - try with your own inputs in Colab here: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/123iljowP_b5o-_6RjHZK8vbgBkrNn8-c?usp=sharing) ) 21 | 22 | 23 | ## Table of Contents: 24 | - [Into](#intro) 25 | - [Notebooks](#notebooks) 26 | - [Usage](#usage) 27 | - [Model Setup](#model-setup) 28 | - [Data](#data) 29 | - [Examples](#examples) 30 | 31 | ## Intro 32 | 33 | The goal of this repo is to provide a simple, self-contained codebase for Text to Image Diffusion that can be trained in Colab in a 34 | reasonable amount of time. 35 | 36 | While there are a lot of great resources around the math and usage of diffusion models I haven't found many specifically 37 | focused on _training_ text to img diffusion models. 38 | Particularly the idea of training a Dall-E 2 or Stable Diffusion like model feels like a daunting task requiring immense 39 | computational resources and data. Turns out you can accomplish quite a lot with little resources and without having a PhD in thermodynamics! 40 | The easiest way to get aquainted with the code is thru the notebooks below. 41 | 42 | #### Credits/Resources 43 | 44 | - Original Unet implementation in this [excellent blog post](https://keras.io/examples/generative/ddim/) - most of the code and Unet architecture in `denoiser.py` is based on this. I have added 45 | additional text/CLIP/masking embeddings/inputs and cross/self attention. 46 | - [Conditional CIFAR Model in Pytorch](https://colab.research.google.com/drive/1IJkrrV-D7boSCLVKhi7t5docRYqORtm3#scrollTo=TAUwPLG92r89) 47 | - [Laion Aesthetics 6.5+ Dataset](https://laion.ai/blog/laion-aesthetics/) - The 625K image-text pairs with predicted aesthetics scores of 6.5 or higher was used for training. 48 | - [Text 2 img package](https://github.com/hmiladhia/img2text) 49 | - [Variational Diffusion Models (Paper)](https://arxiv.org/abs/2107.00630) 50 | - [DDIM Paper](https://arxiv.org/abs/2010.02502) 51 | 52 | ## Notebooks 53 | 54 | If you are just starting out I recommend trying out the next two notebook in order. The first should be able to get you 55 | recognizable images on the Fashion Mnist dataset within minutes! 56 | 57 | - Train Class Conditional Fashion MNIST/CIFAR [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/16rJUyPn72-C30mZRUr-Oo6ZjYS89Z3yH?usp=sharing) 58 | - To try out the code you can use the notebook above in colab. It is set to train on the fashion mnist dataset. 59 | You should be able to see reasonable image generations withing 5 epochs (5-20 minutes depending on GPU)! 60 | - For CIFAR 10/100 - you just have to change the `file_name`. You can get reasonable results after 25 epochs for CIFAR 10 and 40 epochs for CIFAR 100. 61 | Training 50-100 epochs is even better. 62 | 63 | - Train CLIP Conditioned Text to Img Model on 115k 64x64 images+prompts sampled from the Laion Aesthetics 6.5+ dataset. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EoGdyZTGVeOrEnieWyzjItusBSes_1Ef?usp=sharing) 64 | - You can get recognizable results after ~15 epochs 65 | ~ 10 minutes per epoch (V100) 66 | - Test Prompts on a model trained for about 60 epochs (~60 hours on 1 V100) on entire 600k Laion Aesthetics 6.5+. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/123iljowP_b5o-_6RjHZK8vbgBkrNn8-c?usp=sharing) 67 | - This model has about 40 million parameters (150MB) and can be downloaded from [here](https://huggingface.co/apapiu/diffusion_model_aesthetic_keras/blob/main/model_64_65_epochs.h5) 68 | - The examples in this repo use this model 69 | 70 | ## Usage 71 | 72 | The model architecture, training parameters, and generation parameters are specified in a yaml file see [here](https://github.com/apapiu/guided-diffusion-keras/tree/main/guided_diffusion/configs) for examples. If unsure you can use the base_model. The get_train_data is built to work with various known datasets. If you have 73 | your own dataset you can just use that instead. `train_label_embeddings` is expected to be a matrix of embedding the model conditions on (usually some embedding of text but could be anything). 74 | 75 | 76 | ```python 77 | config_path = "guided-diffusion-keras/guided_diffusion/configs/base_model.yaml" 78 | 79 | trainer = Trainer(config_path) 80 | print(trainer.__dict__) 81 | 82 | train_data, train_label_embeddings = get_train_data(trainer.file_name) #OR get your own images and label embeddings in matrix form. 83 | trainer.preprocess_data(train_data, train_label_embeddings) 84 | trainer.initialize_model() 85 | trainer.data_checks(train_data) 86 | trainer.train() 87 | ``` 88 | 89 | ## Model Setup 90 | 91 | The setup is fairly simple. 92 | 93 | We train a denoising U-NET neural net that takes the following three inputs: 94 | - `noise_level` (sampled from 0 to 1 with more values concentrated close to 0) 95 | - image (x) corrupted with a level of random noise 96 | - for a given `noise_level` between 0 and 1 the corruption is as follows: 97 | - `x_noisy = x*(1-noise_level) + eps*noise_level where eps ~ np.random.normal` 98 | - CLIP embeddings of a text prompt 99 | - You can think of this as a numerical representation of a text prompt (this is the only pretrained model we use). 100 | 101 | The output is a prediction of the denoised image - call it `f(x_noisy)`. 102 | 103 | The model is trained to minimize the absolute error `|f(x_noisy) - x|` between the prediction and actual image 104 | (you can also use squared error here). Note that I don't reparametrize the loss in terms of the noise here to keep things simple. 105 | 106 | Using this model we then iteratively generate an image from random noise as follows: 107 | 108 | for i in range(len(self.noise_levels) - 1): 109 | 110 | curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1] 111 | 112 | # predict original denoised image: 113 | x0_pred = self.predict_x_zero(new_img, label, curr_noise) 114 | 115 | # new image at next_noise level is a weighted average of old image and predicted x0: 116 | new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise 117 | 118 | The `predict_x_zero` method uses classifier free guidance by combining the conditional and unconditional 119 | prediction: `x0_pred = class_guidance * x0_pred_conditional + (1 - class_guidance) * x0_pred_unconditional` 120 | 121 | A bit of math: The approach above falls within the VDM parametrization see 3.1 in [Kingma et al.](https://arxiv.org/pdf/2107.00630.pdf): 122 | 123 | $$ z_t = \alpha_t*x + \sigma_t*\epsilon, \epsilon ~ n(0,1)$$ 124 | 125 | Where $z_t$ is the noisy version of $x$ at time $t$. 126 | 127 | generally $\alpha_t$ is chosen to be $\sqrt{1-\sigma_t^2}$ so that the process is variance preserving. Here I chose $\alpha_t=1-\sigma_t$ so that we 128 | linearly interpolate between the image and random noise. Why? Honestly I just wondered if it was going to work :) also it simplifies the updating equation quite a bit and it's easier to understand what the noise to signal ratio will look like. The updating equation above is the DDIM model for this parametrization which simplifies to a simple weighted average. Note that the DDIM model deterministically maps random normal noise to images - this has two benefits: we can interpolate in the random normal latent space, it takes fewer steps generaly to get decent image quality. 129 | 130 | 131 | 132 | Note that I use a lot of unorthodox choices in the modelling. Since I am fairly new to generative models I found this to be a 133 | great way to learn what is crucial vs. what is nice to have. I generally did not see any divergences in training which supports 134 | the notion that **diffusion models are stable to train and are fairly robust to model choices**. The flipside of 135 | this is that if you introduce subtle bugs in your code (of which I am sure there are many in this repo) they are pretty hard 136 | to spot. 137 | 138 | 139 | Architecture: TODO - add cross-attention description. 140 | 141 | 142 | ## Data 143 | 144 | The text-to-img models use the [Laion 6.5+ ](https://laion.ai/blog/laion-aesthetics/) datasets. You can see 145 | some samples [here](http://captions.christoph-schuhmann.de/2B-en-6.5.html). As you can 146 | see this dataset is _very_ biased towards landscapes and portraits. Accordingly, the model 147 | does best at prompts related to art/landscapes/paintings/portraits/architecture. 148 | 149 | The script `img_load_utils.py` contains some code to use the img2dataset package to 150 | download and store images, texts, and their corresponding embeddings. The Laion datasets are still 151 | quite messy with a lot of duplicates, bad descriptions etc. 152 | 153 | 154 | #### 115k Laion Sample: I have uploaded a 115k 64x64 pixels sample from the Laion 6.5+ dataset+their corresponding prompts and CLIP embeddings to huggingface [here](https://huggingface.co/apapiu/diffusion_model_aesthetic_keras/tree/main). 155 | This can be used to quickly prototype new generative models. This 156 | dataset is also used in the notebook above. 157 | 158 | TODO: add more info and script on how to preprocess the data and link to huggingface repo. 159 | Talk about data quality issues. 160 | 161 | 162 | ## Training Process and Colab Hints: 163 | 164 | If you want to train the img-to-text model I highly recommend getting at least the Colab Pro or even the Colab 165 | Pro+ - it's going to be hard to train the model on a K80 GPU, unfortunately. NOTE: Colab will 166 | change its setup and introduce credits at the end of September - I will update this. 167 | 168 | Setting this training workflow on Google Colab wasn't too bad. My approach has been 169 | very low tech and Google Drive played a large role. Basically at the end of every epoch I save the model and the 170 | generated images on a small validation set (50-100 images) to Drive. 171 | 172 | This has a few advantages: 173 | - If I get kicked off my colab instance the model is not lost and I just need to restart the instance 174 | - I can keep record of image generation quality by epoch and go back to check and compare models 175 | - Important point here - make sure to use the _same_ random seed for every epoch - this controls for the randomness 176 | - Drive saves the past 100 versions of a file so I can always use past model version within the past 100 epochs. 177 | - This is important since there is some variability in image quality from epoch to epoch 178 | - It's low tech and you don't have to learn a new platform like wandb. 179 | - Reading and saving data in from Drive in Colab is _very_ fast. 180 | 181 | I have slowly moved some data/models on huggingface but this is WIP. 182 | 183 | #### GPU Speed: 184 | In terms of speed the GPUs go as follows: 185 | `A100>V100>P100>T4>K80` with the A100 being the fastest and every subsequent GPU being roughly twice as slow as 186 | the one before it for training (e.g. P100 is about 4x slower than A100). While I did get the A100 a few times 187 | the sweet spot was really V100/P100 on Colab Pro+ since the risk of being time-outed decreased. With colab PRO+ ($50/month) I managed to train on V100/P100 continuously for 12-24 hours at a time. 188 | 189 | 190 | 191 | ### Validation: 192 | 193 | I'm not an expert here but generally the validation of generative models 194 | is still an open question. There are metrics like Inception Score, FID, and KID that measure 195 | whether the distribution of generated images is "close" to the training distribution in some way. 196 | The main issue with all of these metrics however is that a model that simply memorizes the training data 197 | will have a perfect score - so they don't account for overfitting. They are also fairly hard to understand, need large 198 | sample sizes, and are computationally intensive. For all these reasons I have chosen not to use them for now. 199 | 200 | Instead I have focused on analyze the visual quality of generated images by uhm.. looking at them. This can quickly 201 | devolved into a tea-lead reading exercise however. To combat this one come up with different strategies to test for 202 | quality and diversity. For example sampling from both generated and ground truth images and looking at them together 203 | - either side by side or permuted is a reasonable way to check test for sample quality. 204 | 205 | To test for generalization I have mostly focused on interpolations in both the CLIP space and the 206 | random normal latent space. Ideally as you move from embedding to embedding you want to generated images 207 | along the path to be meaningful in some way. 208 | 209 | CLIP interpolation: "A lake in the forest in the summer" -> "A lake in the forest in the winter" 210 | 211 | image 212 | 213 | 214 | Does the model memorize the training data? This is an important question that has 215 | lots of implications. First of all the models above don't have the capacity to memorize _all_ 216 | of the training data. For example: the model is about 150 MB but is trained on about 217 | 8GB of data. Second of all it might not be in the model's best interest to memorize things. 218 | After digging a bit around the predictions on the training data I did find _one_ example where 219 | the model shamelessly copies a training example. Note this is because the image appears many times 220 | in the training data. 221 | 222 | 223 | 224 | ### Examples: 225 | 226 | Prompt: `An Italian Villaga Painted by Picasso` 227 | 228 | image 229 | 230 | `City at night` 231 | 232 | image 233 | 234 | `Photograph of young woman in a field of flowers, bokeh` 235 | 236 | image 237 | 238 | `Street on an island in Greece` 239 | 240 | image 241 | 242 | `A Mountain Lake in the spring at sunset` 243 | 244 | image 245 | 246 | `A man in a suit in the field in wintertime` 247 | 248 | image 249 | 250 | 251 | CLIP interpolation: "A minimalist living room" -> "A Field in springtime, painting" 252 | 253 | image 254 | 255 | CLIP interpolation: "A lake in the forest in the summer" -> "A lake in the forest in the winter" 256 | 257 | image 258 | 259 | 260 | -------------------------------------------------------------------------------- /guided_diffusion/configs/base_model.yaml: -------------------------------------------------------------------------------- 1 | train_model: True 2 | epochs: 50 3 | file_name: "fashion_mnist" 4 | 5 | channels: 64 6 | channel_multiplier: [1, 2, 3] 7 | block_depth: 2 8 | emb_size: 512 9 | num_classes: 12 10 | attention_levels: [0, 1, 0] 11 | embedding_dims: 32 12 | embedding_max_frequency: 1000.0 13 | 14 | precomputed_embedding: False 15 | save_in_drive: False 16 | 17 | batch_size: 64 18 | num_imgs: 64 19 | row: 8 20 | learning_rate: 0.0003 21 | 22 | class_guidance: 1 23 | MODEL_NAME: "model_test" 24 | from_scratch: True 25 | data_dir: "/content/diffusion_model_aesthetic_keras" -------------------------------------------------------------------------------- /guided_diffusion/configs/cifar10_model.yaml: -------------------------------------------------------------------------------- 1 | train_model: True 2 | epochs: 50 3 | file_name: "cifar10" 4 | 5 | channels: 64 6 | channel_multiplier: [1, 2, 3] 7 | block_depth: 2 8 | emb_size: 512 9 | num_classes: 12 10 | attention_levels: [0, 1, 0] 11 | embedding_dims: 32 12 | embedding_max_frequency: 1000.0 13 | 14 | precomputed_embedding: False 15 | save_in_drive: False 16 | 17 | batch_size: 64 18 | num_imgs: 64 19 | row: 8 20 | learning_rate: 0.0003 21 | 22 | class_guidance: 3 23 | MODEL_NAME: "model_cifar" 24 | from_scratch: True 25 | data_dir: "/content/diffusion_model_aesthetic_keras" -------------------------------------------------------------------------------- /guided_diffusion/configs/clip_laion.yaml: -------------------------------------------------------------------------------- 1 | train_model: True 2 | epochs: 3 3 | file_name: "from_huggingface" #data stored in huggingface 4 | 5 | channels: 96 6 | channel_multiplier: [1, 2, 3, 4] 7 | block_depth: 2 8 | emb_size: 512 9 | num_classes: 12 10 | attention_levels: [0, 0, 1, 0] 11 | embedding_dims: 32 12 | embedding_max_frequency: 1000.0 13 | 14 | precomputed_embedding: True 15 | save_in_drive: False 16 | 17 | batch_size: 64 18 | num_imgs: 36 19 | learning_rate: 0.0003 20 | 21 | class_guidance: 4 22 | MODEL_NAME: "model_test_aesthetic" 23 | from_scratch: True 24 | data_dir: "/content/diffusion_model_aesthetic_keras" 25 | -------------------------------------------------------------------------------- /guided_diffusion/denoiser.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorflow as tf 4 | from tensorflow import keras 5 | from tensorflow.keras import layers 6 | from tensorflow.keras.models import Model 7 | 8 | 9 | def attention(qkv): 10 | 11 | q, k, v = qkv 12 | # should we scale this? 13 | s = tf.matmul(k, q, transpose_b=True) # [bs, h*w, h*w] 14 | beta = tf.nn.softmax(s) # attention map 15 | o = tf.matmul(beta, v) # [bs, h*w, C] 16 | return o 17 | 18 | 19 | def spatial_attention(img): 20 | # Attention implementation a combination of: 21 | # https://github.com/taki0112/Self-Attention-GAN-Tensorflow and 22 | # https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/attention.py 23 | 24 | filters = img.shape[3] 25 | orig_shape = ((img.shape[1], img.shape[2], img.shape[3])) 26 | print(orig_shape) 27 | img = layers.BatchNormalization()(img) 28 | 29 | # projections: 30 | q = layers.Conv2D(filters // 8, kernel_size=1, padding="same")(img) 31 | k = layers.Conv2D(filters // 8, kernel_size=1, padding="same")(img) 32 | v = layers.Conv2D(filters, kernel_size=1, padding="same")(img) 33 | k = layers.Reshape((k.shape[1] * k.shape[2], k.shape[3],))(k) 34 | 35 | q = layers.Reshape((q.shape[1] * q.shape[2], q.shape[3]))(q) 36 | v = layers.Reshape((v.shape[1] * v.shape[2], v.shape[3],))(v) 37 | 38 | img = layers.Lambda(attention)([q, k, v]) 39 | img = layers.Reshape(orig_shape)(img) 40 | 41 | # out_projection: 42 | img = layers.Conv2D(filters, kernel_size=1, padding="same")(img) 43 | img = layers.BatchNormalization()(img) 44 | 45 | return img 46 | 47 | 48 | def cross_attention(img, text): 49 | filters = img.shape[3] 50 | orig_shape = ((img.shape[1], img.shape[2], img.shape[3])) 51 | print(orig_shape) 52 | img = layers.BatchNormalization()(img) 53 | text = layers.BatchNormalization()(text) 54 | 55 | # projections: 56 | q = layers.Conv2D(filters // 8, kernel_size=1, padding="same")(text) 57 | k = layers.Conv2D(filters // 8, kernel_size=1, padding="same")(img) 58 | v = layers.Conv2D(filters, kernel_size=1, padding="same")(text) 59 | 60 | q = layers.Reshape((q.shape[1] * q.shape[2], q.shape[3]))(q) 61 | k = layers.Reshape((k.shape[1] * k.shape[2], k.shape[3],))(k) 62 | v = layers.Reshape((v.shape[1] * v.shape[2], v.shape[3],))(v) 63 | 64 | img = layers.Lambda(attention)([q, k, v]) 65 | img = layers.Reshape(orig_shape)(img) 66 | 67 | # out_projection: 68 | img = layers.Conv2D(filters, kernel_size=1, padding="same")(img) 69 | img = layers.BatchNormalization()(img) 70 | 71 | return img 72 | 73 | ### The sinusoidal_embedding, ResidualBlock, Down/UP Block taken from 74 | ### https://github.com/keras-team/keras-io/blob/master/examples/generative/ddim.py 75 | ### Only change is adding self/cross attention: 76 | 77 | def sinusoidal_embedding(x): 78 | #TODO: remove the hardcoded values here: 79 | embedding_min_frequency = 1.0 80 | embedding_max_frequency = 1000.0 81 | embedding_dims = 32 82 | frequencies = tf.exp( 83 | tf.linspace( 84 | tf.math.log(embedding_min_frequency), 85 | tf.math.log(embedding_max_frequency), 86 | embedding_dims // 2, 87 | ) 88 | ) 89 | angular_speeds = 2.0 * math.pi * frequencies 90 | embeddings = tf.concat( 91 | [tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3 92 | ) 93 | return embeddings 94 | 95 | 96 | def ResidualBlock(width): 97 | def apply(x): 98 | input_width = x.shape[3] 99 | if input_width == width: 100 | residual = x 101 | else: 102 | residual = layers.Conv2D(width, kernel_size=1)(x) 103 | x = layers.BatchNormalization(center=False, scale=False)(x) 104 | x = layers.Conv2D( 105 | width, kernel_size=3, padding="same", activation=keras.activations.swish 106 | )(x) 107 | # intermediate layer ... add mlp embedding of class plus timestamp here? 108 | x = layers.Conv2D(width, kernel_size=3, padding="same")(x) 109 | x = layers.Add()([x, residual]) 110 | return x 111 | 112 | return apply 113 | 114 | 115 | def DownBlock(width, block_depth, use_self_attention=False): 116 | def apply(x): 117 | x, skips, emb_and_noise = x 118 | for _ in range(block_depth): 119 | x = ResidualBlock(width)(x) 120 | 121 | if use_self_attention: 122 | o = spatial_attention(x) 123 | x = layers.Add()([x, o]) 124 | cross_att = cross_attention(x, emb_and_noise) 125 | x = layers.Add()([x, cross_att]) 126 | 127 | skips.append(x) 128 | x = layers.AveragePooling2D(pool_size=2)(x) 129 | return x 130 | 131 | return apply 132 | 133 | 134 | def UpBlock(width, block_depth, use_self_attention=False): 135 | def apply(x): 136 | x, skips, emb_and_noise = x 137 | x = layers.UpSampling2D(size=2, interpolation="bilinear")(x) 138 | for _ in range(block_depth): 139 | x = layers.Concatenate()([x, skips.pop()]) 140 | x = ResidualBlock(width)(x) 141 | 142 | if use_self_attention: 143 | o = spatial_attention(x) 144 | x = layers.Add()([x, o]) 145 | cross_att = cross_attention(x, emb_and_noise) 146 | x = layers.Add()([x, cross_att]) 147 | 148 | return x 149 | 150 | return apply 151 | 152 | 153 | def get_network(image_size, widths, block_depth, num_classes, 154 | num_channels, emb_size, attention_levels, 155 | precomputed_embedding=True): 156 | noisy_images = keras.Input(shape=(image_size, image_size, num_channels)) 157 | 158 | noise_variances = keras.Input(shape=(1, 1, 1)) 159 | e = layers.Lambda(sinusoidal_embedding)(noise_variances) 160 | e = layers.UpSampling2D(size=image_size, interpolation="nearest")(e) 161 | 162 | if precomputed_embedding: 163 | input_label = layers.Input(shape=emb_size) # CLIP/glove embedding. 164 | emb_label = layers.Dense(emb_size // 2)(input_label) 165 | emb_label = layers.Reshape((1, 1, emb_size // 2))(emb_label) 166 | else: 167 | input_label = layers.Input(shape=1) # label/word - integer encoded 168 | emb_label = layers.Embedding(input_dim=num_classes, output_dim=emb_size)(input_label) 169 | emb_label = layers.Reshape((1, 1, emb_size))(emb_label) 170 | 171 | emb_label = layers.UpSampling2D(size=image_size, interpolation="nearest")(emb_label) 172 | 173 | x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images) 174 | x = layers.Concatenate()([x, e, emb_label]) 175 | 176 | emb_and_noise = layers.Concatenate()([emb_label, e]) 177 | emb_and_noise = layers.BatchNormalization()(emb_and_noise) 178 | 179 | skips = [] 180 | level = 0 181 | for width in widths[:-1]: 182 | use_self_attention = bool(attention_levels[level]) 183 | x = DownBlock(width, block_depth, use_self_attention)([x, skips, emb_and_noise]) 184 | 185 | emb_and_noise = layers.AveragePooling2D()(emb_and_noise) 186 | level += 1 187 | 188 | for _ in range(block_depth): 189 | x = ResidualBlock(widths[-1])(x) 190 | if bool(attention_levels[level]): 191 | o = spatial_attention(x) 192 | x = layers.Add()([x, o]) 193 | cross_att = cross_attention(x, emb_and_noise) 194 | x = layers.Add()([x, cross_att]) 195 | 196 | for width in reversed(widths[:-1]): 197 | level -= 1 198 | 199 | emb_and_noise = layers.UpSampling2D(size=2, interpolation="bilinear")(emb_and_noise) 200 | use_self_attention = bool(attention_levels[level]) 201 | x = UpBlock(width, block_depth, use_self_attention)([x, skips, emb_and_noise]) 202 | 203 | x = layers.Conv2D(num_channels, kernel_size=1, kernel_initializer="zeros", 204 | activation="linear" 205 | )(x) 206 | 207 | return Model([noisy_images, noise_variances, input_label], x, name="residual_unet") -------------------------------------------------------------------------------- /guided_diffusion/diffuser.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import matplotlib.pyplot as plt 4 | from utils import imshow, plot_images 5 | 6 | 7 | def dynamic_thresholding(img, perc=99.5): 8 | s = np.percentile(np.abs(img.ravel()), perc) 9 | s = np.max([s, 1]) 10 | img = img.clip(-s, s) / s 11 | 12 | return img 13 | 14 | 15 | class Diffuser: 16 | 17 | def __init__(self, denoiser, class_guidance, diffusion_steps, perc_thresholding=99.5, batch_size=64): 18 | self.denoiser = denoiser 19 | self.class_guidance = class_guidance 20 | self.diffusion_steps = diffusion_steps 21 | #TODO: parametrize this better: 22 | self.noise_levels = 1 - np.power(np.arange(0.0001, 0.99, 1 / self.diffusion_steps), 1 / 3) 23 | self.noise_levels[-1] = 0.01 24 | self.perc_thresholding = perc_thresholding 25 | self.batch_size = batch_size 26 | 27 | def predict_x_zero(self, x_t, label, noise_level): 28 | """Predict original image based on noisy image (or matrix of noisy images) at noise level plus conditional label""" 29 | 30 | # we use 0 for the unconditional embedding: 31 | num_imgs = len(x_t) 32 | label_empty_ohe = np.zeros(shape=label.shape) 33 | 34 | # predict x0: 35 | noise_in = np.array([noise_level] * num_imgs)[:, None, None, None] 36 | 37 | # TODO: can we do some of this in tensorflow? 38 | # concatenate the conditional and unconditional inputs to speed inference: 39 | nn_inputs = [np.vstack([x_t, x_t]), 40 | np.vstack([noise_in, noise_in]), 41 | np.vstack([label, label_empty_ohe])] 42 | 43 | x0_pred = self.denoiser.predict(nn_inputs, batch_size=self.batch_size) 44 | 45 | x0_pred_label = x0_pred[:num_imgs] 46 | x0_pred_no_label = x0_pred[num_imgs:] 47 | 48 | # classifier free guidance: 49 | x0_pred = self.class_guidance * x0_pred_label + (1 - self.class_guidance) * x0_pred_no_label 50 | 51 | if self.perc_thresholding: 52 | # clip the prediction using dynamic thresholding a la Imagen: 53 | x0_pred = dynamic_thresholding(x0_pred, perc=self.perc_thresholding) 54 | 55 | return x0_pred 56 | 57 | def reverse_diffusion(self, seeds, label, show_img=False): 58 | """Reverse Guided Diffusion on a matrix of random images (seeds). Returns generated images""" 59 | 60 | new_img = seeds 61 | 62 | for i in tqdm(range(len(self.noise_levels) - 1)): 63 | 64 | curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1] 65 | 66 | # predict original denoised image: 67 | x0_pred = self.predict_x_zero(new_img, label, curr_noise) 68 | 69 | # new image at next_noise level is a weighted average of old image and predicted x0: 70 | new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise 71 | 72 | if show_img: 73 | plot_images(x0_pred, nrows=np.sqrt(len(new_img)), 74 | save_name=str(i), 75 | size=12) 76 | plt.show() 77 | 78 | return x0_pred 79 | 80 | def reverse_diffusion_masked(self, seeds, label_ohe, show_img=False, masked_imgs=None, mask=None, u=1): 81 | """Reverse Guided Diffusion on a matrix of random images (seeds) with a mask. Can be used for in/outpainting 82 | Based on the algorithm from Repaint: https://github.com/andreas128/RePaint 83 | """ 84 | 85 | new_img = seeds 86 | num_imgs = len(new_img) 87 | 88 | for i in tqdm(range(len(self.noise_levels) - 1)): 89 | 90 | curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1] # alpha1, alpha2 91 | 92 | for j in range(1, u + 1): 93 | 94 | # predict original denoised image: 95 | x0_pred = self.predict_x_zero(new_img, label_ohe, curr_noise) 96 | 97 | # new image at next_noise level is a weighted average of old image and predicted x0: 98 | new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise 99 | 100 | if masked_imgs is not None: 101 | # let the last 20% of steps be done without masking. 102 | if i <= int(self.diffusion_steps * 0.8): 103 | ####masked part: 104 | new_img_known = (1 - next_noise) * masked_imgs + np.random.normal(0, 1, 105 | size=new_img.shape) * next_noise 106 | 107 | # mask here is empty/unknown part 108 | new_img = new_img_known * mask + new_img * (1 - mask) 109 | else: 110 | break 111 | 112 | if j != u: 113 | print(j) 114 | ### noise the image back to the previous step: 115 | s = (1 - curr_noise) / (1 - next_noise) 116 | new_img = s * new_img + np.sqrt(curr_noise ** 2 - s ** 2 * next_noise ** 2) * np.random.normal(0, 1, 117 | size=new_img.shape) 118 | 119 | if show_img: 120 | plot_images(x0_pred, nrows=np.sqrt(len(new_img)), 121 | save_name=str(i), 122 | size=12) 123 | plt.show() 124 | 125 | return x0_pred -------------------------------------------------------------------------------- /guided_diffusion/img_load_utils.py: -------------------------------------------------------------------------------- 1 | # !pip install img2dataset 2 | # !pip install datasets 3 | 4 | from datasets import load_dataset 5 | import pandas as pd 6 | import numpy as np 7 | from sklearn.feature_extraction.text import CountVectorizer 8 | from img2dataset import download 9 | import os 10 | import matplotlib.pyplot as plt 11 | import glob 12 | import cv2 13 | import matplotlib.pyplot as plt 14 | from tqdm import tqdm 15 | 16 | 17 | def clean_df(df): 18 | 19 | df["pixels"] = df["HEIGHT"]*df["WIDTH"] 20 | df["ratio"] = df["HEIGHT"]/df["WIDTH"] 21 | x = df["similarity"] 22 | x[x<1].hist(bins=100) 23 | plt.show() 24 | df["ratio"].quantile(np.arange(0, 1, 0.01)).plot() 25 | plt.show() 26 | df["pixels"].quantile(np.arange(0, 1, 0.01)).plot() 27 | plt.show() 28 | 29 | df = df[df["similarity"] >= 0.3] 30 | print(df.shape) 31 | 32 | #### only pick the first URL - there are a lot of duplicate images: 33 | df = df.groupby("URL").first().reset_index() 34 | 35 | df = df.drop_duplicates(subset = ["TEXT", "WIDTH", "HEIGHT"]) 36 | print(df.shape) 37 | 38 | #remove huge images for faster download: 39 | df = df[df["pixels"] <= 1024*1024] 40 | print(df.shape) 41 | 42 | # remove images that aren't close to being square - otherwise faces get cropped. 43 | df = df[df["ratio"] > 0.3] 44 | print(df.shape) 45 | df = df[df["ratio"] < 2] 46 | print(df.shape) 47 | 48 | df = df[df["AESTHETIC_SCORE"] > 5.5] 49 | print(df.shape) 50 | 51 | vectorizer = CountVectorizer(min_df=25, stop_words="english") 52 | text = df["TEXT"] 53 | vectorizer.fit(text) 54 | 55 | word_counts = vectorizer.transform(text) 56 | x = word_counts.sum(1) 57 | df["word_counts"] = pd.DataFrame(x)[0].values 58 | 59 | df["word_counts"].value_counts().sort_index()[:30].plot(kind="bar") 60 | 61 | df = df[df["word_counts"] >= 1.0] 62 | print(df.shape) 63 | df = df[df["word_counts"] <= 35.0] 64 | print(df.shape) 65 | 66 | return df 67 | 68 | def download_data(url_path="df"): 69 | download( 70 | processes_count=8, 71 | thread_count=16, 72 | url_list=url_path, 73 | image_size=64, 74 | output_folder=output_dir, 75 | output_format="files", 76 | input_format="parquet", 77 | url_col="URL", 78 | caption_col="TEXT", 79 | enable_wandb=False, 80 | number_sample_per_shard=10000, 81 | distributor="multiprocessing", 82 | resize_mode="center_crop", 83 | ) 84 | 85 | 86 | def get_imgs_and_captions(): 87 | #TODO: This is sketchy...and error prone: 88 | imgs_files = glob.glob("/content/output_dir/*/*.jpg") 89 | imgs_files.sort() 90 | text_files = glob.glob("/content/output_dir/*/*.txt") 91 | text_files.sort() 92 | 93 | caption_list = [] 94 | for txt_path in tqdm(text_files): 95 | with open(txt_path) as f: 96 | lines = f.readlines() 97 | caption_list.append(lines[0]) 98 | 99 | img_list = [] 100 | for img_path in tqdm(imgs_files): 101 | img = cv2.imread(img_path) 102 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 103 | img_list.append(img) 104 | 105 | img_array = np.array(img_list) 106 | del img_list 107 | return img_array, caption_list 108 | 109 | def get_dataset_from_huggingface(hugging_face_dataset_name): 110 | dataset = load_dataset(hugging_face_dataset_name, split="train") 111 | df = dataset.to_pandas() 112 | 113 | return df 114 | 115 | def build_dataset_and_save(hugging_face_dataset_name): 116 | sample_size = 120000 117 | seed = 1 118 | df = get_dataset_from_huggingface(hugging_face_dataset_name) 119 | 120 | df = clean_df(df) 121 | np.random.seed(seed) 122 | df.sample(sample_size).to_parquet("df") #note that this keeps the original indexes from the data. 123 | output_dir = os.path.abspath("output_dir") 124 | download_data(url_path="df") 125 | train_data, caption_list = get_imgs_and_captions() 126 | pd.Series(caption_list).to_csv("/content/diffusion_model_aesthetic_keras/captions.csv", index=None) 127 | np.save("/content/diffusion_model_aesthetic_keras/imgs.npy", train_data) -------------------------------------------------------------------------------- /guided_diffusion/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from tensorflow.keras.datasets import mnist, fashion_mnist, cifar10, cifar100 7 | from keras.utils import plot_model 8 | from tensorflow import keras 9 | from matplotlib import pyplot as plt 10 | 11 | from denoiser import get_network 12 | from utils import batch_generator, plot_images, get_data, preprocess 13 | from diffuser import Diffuser 14 | 15 | 16 | class Trainer: 17 | def __init__(self, config_file): 18 | # Load YAML config 19 | with open(config_file, 'r') as f: 20 | config = yaml.safe_load(f) 21 | 22 | # Unpack config into instance attributes 23 | for key, value in config.items(): 24 | setattr(self, key, value) 25 | 26 | # Additional setups 27 | self._setup_paths_and_dirs() 28 | 29 | def _setup_paths_and_dirs(self): 30 | self.widths = [c * self.channels for c in self.channel_multiplier] 31 | self.captions_path = os.path.join(self.data_dir, "captions.csv") 32 | self.imgs_path = os.path.join(self.data_dir, "imgs.npy") 33 | self.embedding_path = os.path.join(self.data_dir, "embeddings.npy") 34 | 35 | if self.save_in_drive: 36 | from google.colab import drive 37 | drive.mount('/content/drive') 38 | drive_path = '/content/drive/MyDrive/' 39 | self.home_dir = os.path.join(drive_path, self.MODEL_NAME) 40 | else: 41 | self.home_dir = self.MODEL_NAME 42 | 43 | if not os.path.exists(self.home_dir): 44 | os.mkdir(self.home_dir) 45 | 46 | self.model_path = os.path.join(self.home_dir, self.MODEL_NAME + ".h5") 47 | 48 | def preprocess_data(self, train_data, train_label_embeddings): 49 | print(train_data.shape) 50 | self.train_data = train_data 51 | self.train_label_embeddings = train_label_embeddings 52 | self.image_size = train_data.shape[1] 53 | self.num_channels = train_data.shape[-1] 54 | self.row = int(np.sqrt(self.num_imgs)) 55 | self.labels = self._get_labels(train_data, train_label_embeddings) 56 | 57 | def _get_labels(self, train_data, train_label_embeddings): 58 | if self.precomputed_embedding: 59 | return train_label_embeddings[:self.num_imgs] 60 | else: 61 | row_labels = np.array([[i] * self.row for i in np.arange(self.row)]).flatten()[:, None] 62 | return row_labels + 1 63 | 64 | def initialize_model(self): 65 | 66 | if self.from_scratch: 67 | self.autoencoder = get_network(self.image_size, 68 | self.widths, 69 | self.block_depth, 70 | num_classes=self.num_classes, 71 | attention_levels=self.attention_levels, 72 | emb_size=self.emb_size, 73 | num_channels=self.num_channels, 74 | precomputed_embedding=self.precomputed_embedding) 75 | 76 | self.autoencoder.compile(optimizer="adam", loss="mae") 77 | else: 78 | self.autoencoder = keras.models.load_model(self.model_path) 79 | 80 | def data_checks(self, train_data): 81 | print("Number of parameters is {0}".format(self.autoencoder.count_params())) 82 | pd.Series(train_data[:1000].ravel()).hist(bins=80) 83 | plt.show() 84 | print("Original Images:") 85 | plot_images(preprocess(train_data[:self.num_imgs]), nrows=int(np.sqrt(self.num_imgs))) 86 | plot_model(self.autoencoder, to_file=os.path.join(self.home_dir, 'model_plot.png'), 87 | show_shapes=True, show_layer_names=True) 88 | print("Generating Images below:") 89 | 90 | def train(self): 91 | np.random.seed(100) 92 | self.rand_image = np.random.normal(0, 1, (self.num_imgs, self.image_size, self.image_size, self.num_channels)) 93 | 94 | self.diffuser = Diffuser(self.autoencoder, 95 | class_guidance=self.class_guidance, 96 | diffusion_steps=35) 97 | 98 | if self.train_model: 99 | train_generator = batch_generator(self.autoencoder, 100 | self.model_path, 101 | self.train_data, 102 | self.train_label_embeddings, 103 | self.epochs, 104 | self.batch_size, 105 | self.rand_image, 106 | self.labels, 107 | self.home_dir, 108 | self.diffuser) 109 | 110 | self.autoencoder.optimizer.learning_rate.assign(self.learning_rate) 111 | 112 | self.eval_nums = self.autoencoder.fit( 113 | x=train_generator, 114 | epochs=self.epochs 115 | ) 116 | 117 | 118 | def get_train_data(file_name, captions_path=None, imgs_path=None, embedding_path=None): 119 | dataset_loaders = { 120 | "cifar10": cifar10.load_data, 121 | "cifar100": cifar100.load_data, 122 | "fashion_mnist": fashion_mnist.load_data, 123 | "mnist": mnist.load_data 124 | } 125 | 126 | if file_name in dataset_loaders: 127 | (train_data, train_label_embeddings), (_, _) = dataset_loaders[file_name]() 128 | 129 | # Add unconditional embedding 130 | train_label_embeddings = train_label_embeddings + 1 131 | 132 | if file_name in ["fashion_mnist", "mnist"]: 133 | train_data = train_data[:, :, :, None] 134 | train_label_embeddings = train_label_embeddings[:, None] 135 | 136 | else: 137 | captions = pd.read_csv(captions_path) 138 | train_data, train_label_embeddings = np.load(imgs_path), np.load(embedding_path) 139 | 140 | return train_data, train_label_embeddings 141 | #train_data, train_label_embeddings = get_data(npz_file_name=file_name, prop=0.6, captions=False) 142 | 143 | 144 | if __name__=='__main__': 145 | 146 | trainer = Trainer('guided_diffusion/configs/base_model.yaml') 147 | 148 | train_data, train_label_embeddings = get_train_data(trainer.file_name) 149 | trainer.preprocess_data(train_data, train_label_embeddings) 150 | trainer.initialize_model() 151 | trainer.data_checks(train_data) 152 | trainer.train() 153 | -------------------------------------------------------------------------------- /guided_diffusion/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | import os 5 | from numpy.linalg import norm 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | def get_text_encodings(prompts, model): 11 | import clip 12 | #model, preprocess = clip.load("ViT-B/32") 13 | data_loader = DataLoader(prompts, batch_size=256) 14 | 15 | all_encodings = [] 16 | for labels in data_loader: 17 | text_tokens = clip.tokenize(labels, truncate=True).cuda() 18 | 19 | with torch.no_grad(): 20 | text_encoding = model.encode_text(text_tokens) 21 | 22 | all_encodings.append(text_encoding) 23 | 24 | all_encodings = torch.cat(all_encodings).cpu().numpy() 25 | 26 | return all_encodings 27 | 28 | 29 | def preprocess(array): 30 | """ 31 | Normalizes the supplied array and reshapes it into the appropriate format. 32 | """ 33 | 34 | array = array.astype("float32") / 255.0 35 | array = array * 2 - 1 36 | 37 | return array 38 | 39 | 40 | ####### 41 | # TRAIN UTILS: 42 | ####### 43 | 44 | def add_noise(array, mu=0, std=1): 45 | #TODO: have a better way to parametrize this: 46 | x = np.abs(np.random.normal(0, std, 2 * len(array))) 47 | x = x[x < 3] 48 | x = x / 3 49 | x = x[:len(array)] 50 | noise_levels = x 51 | 52 | signal_levels = 1 - noise_levels #OR: np.sqrt(1-np.square(noise_levels)) 53 | 54 | # reshape so that the multiplication makes sense: 55 | noise_level_reshape = noise_levels[:, None, None, None] 56 | signal_level_reshape = signal_levels[:, None, None, None] 57 | 58 | pure_noise = np.random.normal(0, 1, size=array.shape).astype("float32") 59 | noisy_data = array * signal_level_reshape + pure_noise * noise_level_reshape 60 | 61 | return noisy_data, noise_levels 62 | 63 | 64 | def slerp(p0, p1, t): 65 | """spherical interpolation""" 66 | omega = np.arccos(np.dot(p0 / norm(p0), p1 / norm(p1))) 67 | so = np.sin(omega) 68 | return np.sin((1.0 - t) * omega) / so * p0 + np.sin(t * omega) / so * p1 69 | 70 | 71 | def interpolate_two_points(num_points=100): 72 | pA = np.random.normal(0, 1, (image_size * image_size * num_channels)) 73 | pB = np.random.normal(0, 1, (image_size * image_size * num_channels)) 74 | 75 | ps = np.array([slerp(pA, pB, t) for t in np.arange(0.0, 1.0, 1 / num_points)]) 76 | rand_image = ps.reshape(len(ps), image_size, image_size, num_channels) 77 | 78 | return rand_image 79 | 80 | 81 | def imshow(img): 82 | def norm_0_1(img): 83 | return (img + 1) / 2 84 | 85 | # img here is betweeen -1 and 1: 86 | if img.shape[-1] == 1: 87 | img = img.reshape(img.shape[0], img.shape[1]) 88 | img = np.clip(img, -1, 1) 89 | plt.imshow(norm_0_1(img)) 90 | 91 | 92 | def plot_images(imgs, size=16, nrows=8, save_name=None): 93 | plt.rcParams["figure.figsize"] = (size, size) 94 | 95 | for i in range(len(imgs)): 96 | ax = plt.subplot(nrows, nrows, i + 1) 97 | imshow(imgs[i]) 98 | plt.gray() 99 | ax.get_xaxis().set_visible(False) 100 | ax.get_yaxis().set_visible(False) 101 | if save_name: 102 | save_to = "{0}.png".format(save_name) 103 | plt.savefig(save_to, dpi=200) 104 | plt.show() 105 | 106 | 107 | def get_data(npz_file_path, prop=0.6, captions=False): 108 | data = np.load(npz_file_path) 109 | 110 | if captions: 111 | train_data, train_label_embeddings, caption_list = data["arr_0"], data["arr_1"], data["arr_2"] 112 | else: 113 | train_data, train_label_embeddings = data["arr_0"], data["arr_1"] 114 | 115 | # eliminate if perc white pixels > 60% - not really used. 116 | white_pixels = (train_data >= 254).mean(axis=(1, 2, 3)) 117 | mask = (white_pixels < prop) 118 | 119 | if captions: 120 | train_data, train_label_embeddings, caption_list = train_data[mask], train_label_embeddings[mask], caption_list[mask] 121 | return train_data, train_label_embeddings, caption_list 122 | else: 123 | train_data, train_label_embeddings = train_data[mask], train_label_embeddings[mask] 124 | return train_data, train_label_embeddings 125 | 126 | 127 | def batch_generator(model, model_path, train_data, train_label_embeddings, epochs, 128 | batch_size, rand_image, labels, home_dir, diffuser): 129 | indices = np.arange(len(train_data)) 130 | batch = [] 131 | epoch = 0 132 | print("Training for {0}".format(epochs)) 133 | while epoch < epochs: 134 | print("saving model:") 135 | model.save(model_path) 136 | 137 | print(" Generating images:") 138 | diffuser.denoiser = model 139 | imgs = diffuser.reverse_diffusion(rand_image, labels) 140 | img_path = os.path.join(home_dir, str(epoch)) 141 | plot_images(imgs, save_name=img_path, nrows=int(np.sqrt(len(imgs)))) 142 | 143 | print("new epoch {0}".format(epoch)) 144 | # it might be a good idea to shuffle your data before each epoch 145 | np.random.shuffle(indices) 146 | for i in indices: 147 | batch.append(i) 148 | if len(batch) == batch_size: 149 | tr_batch = train_data[batch].copy() 150 | tr_batch = preprocess(tr_batch) 151 | 152 | # random dropout for CFG: 153 | s = np.random.binomial(1, 0.15, size=batch_size).astype("bool") 154 | train_label_dropout = train_label_embeddings[batch].copy() 155 | train_label_dropout[s] = np.zeros(shape=train_label_embeddings.shape[1]) 156 | 157 | # Add Noise to Images: 158 | noisy_train_data, noise_level_train = add_noise(tr_batch, mu=0, std=1) 159 | noise_level_train = noise_level_train[:, None, None, None] # for correct dim 160 | 161 | yield [noisy_train_data, noise_level_train, train_label_dropout], tr_batch 162 | batch = [] 163 | epoch += 1 164 | 165 | 166 | def get_common_words(n): 167 | from sklearn.feature_extraction.text import CountVectorizer 168 | vectorizer = CountVectorizer(min_df=10, stop_words="english") 169 | text = caption_list 170 | vectorizer.fit(text) 171 | 172 | word_counts = vectorizer.transform(text) 173 | word_counts = (word_counts.sum(0)) 174 | word_counts = pd.DataFrame(word_counts.T, index=vectorizer.get_feature_names_out()) 175 | return word_counts.sort_values(0, ascending=False).head(n) 176 | --------------------------------------------------------------------------------