├── LICENSE ├── README.md ├── __init__.py ├── captions.txt ├── docs ├── Makefile ├── build │ ├── doctrees │ │ ├── environment.pickle │ │ ├── index.doctree │ │ └── minimagen.doctree │ └── html │ │ ├── .buildinfo │ │ ├── _images │ │ ├── posterior_mean.png │ │ ├── posterior_variance.png │ │ ├── q_posterior.png │ │ └── q_sample.png │ │ ├── _sources │ │ ├── index.rst.txt │ │ └── minimagen.rst.txt │ │ ├── _static │ │ ├── _sphinx_javascript_frameworks_compat.js │ │ ├── basic.css │ │ ├── clf_free_guidance.png │ │ ├── css │ │ │ ├── badge_only.css │ │ │ ├── fonts │ │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ │ ├── fontawesome-webfont.eot │ │ │ │ ├── fontawesome-webfont.svg │ │ │ │ ├── fontawesome-webfont.ttf │ │ │ │ ├── fontawesome-webfont.woff │ │ │ │ ├── fontawesome-webfont.woff2 │ │ │ │ ├── lato-bold-italic.woff │ │ │ │ ├── lato-bold-italic.woff2 │ │ │ │ ├── lato-bold.woff │ │ │ │ ├── lato-bold.woff2 │ │ │ │ ├── lato-normal-italic.woff │ │ │ │ ├── lato-normal-italic.woff2 │ │ │ │ ├── lato-normal.woff │ │ │ │ └── lato-normal.woff2 │ │ │ └── theme.css │ │ ├── doctools.js │ │ ├── documentation_options.js │ │ ├── file.png │ │ ├── jquery-3.6.0.js │ │ ├── jquery.js │ │ ├── js │ │ │ ├── badge_only.js │ │ │ ├── html5shiv-printshiv.min.js │ │ │ ├── html5shiv.min.js │ │ │ └── theme.js │ │ ├── language_data.js │ │ ├── minus.png │ │ ├── plus.png │ │ ├── posterior_mean.png │ │ ├── posterior_mean_coeffs.png │ │ ├── posterior_variance.png │ │ ├── posterior_variance_box.png │ │ ├── pygments.css │ │ ├── q_posterior.png │ │ ├── q_sample.png │ │ ├── q_sample_reparam.png │ │ ├── searchtools.js │ │ ├── underscore-1.13.1.js │ │ ├── underscore.js │ │ └── x_tm1.png │ │ ├── genindex.html │ │ ├── index.html │ │ ├── minimagen.html │ │ ├── objects.inv │ │ ├── py-modindex.html │ │ ├── search.html │ │ └── searchindex.js ├── make.bat ├── requirements.txt └── source │ ├── _static │ ├── clf_free_guidance.png │ ├── file.png │ ├── minus.png │ ├── plus.png │ ├── posterior_mean.png │ ├── posterior_mean_coeffs.png │ ├── posterior_variance.png │ ├── posterior_variance_box.png │ ├── q_posterior.png │ ├── q_sample.png │ ├── q_sample_reparam.png │ └── x_tm1.png │ ├── conf.py │ ├── index.rst │ └── minimagen.rst ├── images ├── clf_free_guidance.png ├── conditioning_diagram.png ├── dynamic_threshold.mp4 ├── model_structure.png ├── posterior_mean.png ├── posterior_mean_coeffs.png ├── posterior_variance.png ├── posterior_variance_box.png ├── q_posterior.png ├── q_sample.png ├── q_sample_reparam.png ├── transformer_full.png └── x_tm1.png ├── inference.py ├── main.py ├── minimagen ├── Imagen.py ├── Unet.py ├── __init__.py ├── diffusion_model.py ├── generate.py ├── helpers.py ├── layers.py ├── t5.py └── training.py ├── parameters ├── imagen_params_20220816_165729.json ├── training_parameters_20220816_165729.txt ├── unet_0_params_20220816_165729.json └── unet_1_params_20220816_165729.json ├── requirements.txt ├── setup.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 AssemblyAI 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MinImagen 2 | ### A Minimal implementation of the [Imagen](https://imagen.research.google/) text-to-image model. 3 | 4 |
5 | 6 |

7 | 8 |
9 | 10 | ### See [Build Your Own Imagen Text-to-Image Model](https://www.assemblyai.com/blog/build-your-own-imagen-text-to-image-model/) for a tutorial on how to build MinImagen. 11 | 12 | ### See [How Imagen Actually Works](https://www.assemblyai.com/blog/how-imagen-actually-works/) for a detailed explanation of Imagen's operating principles. 13 | 14 |
15 | 16 | Given a caption of an image, the text-to-image model **Imagen** will generate an image that reflects the scene described by the caption. The model is a [cascading diffusion model](https://arxiv.org/abs/2106.15282), using a [T5 text encoder](https://arxiv.org/abs/1910.10683) to generate a caption encoding which conditions a base image generator and then a sequence of super-resolution models through which the output of the base image generator is passed. 17 | 18 | In particular, two notable contributions are the developments of: 19 | 1. [**Noise Conditioning Augmentation**](https://www.assemblyai.com/blog/how-imagen-actually-works/#robust-cascaded-diffusion-models), which noises low-resolution conditioning images in the super-resolution models, and 20 | 2. [**Dynamic Thresholding**](https://www.assemblyai.com/blog/how-imagen-actually-works/#dynamic-thresholding) which helps prevent image saturation at high [classifier-free guidance](https://www.assemblyai.com/blog/how-imagen-actually-works/#classifier-free-guidance) weights. 21 | 22 |
23 | 24 | **N.B. - This project is intended only for educational purposes to demonstrate how Diffusion Models are implemented and incorporated into text-to-image models. Many components of the network that are not essential for these educational purposes have been stripped off for simplicity. For a full-fledged implementation, check out Phil Wang's repo (see attribution note below)** 25 | 26 |
27 | 28 | ## Table of Contents 29 | - [Attribution Note](#attribution-note) 30 | - [Installation](#installation) 31 | - [Documentation](#documentation) 32 | - [Usage - Command Line](#usage---command-line) 33 | - [`main.py`](#mainpy) - training and image generation in sequence 34 | - [`train.py`](#trainpy) - training a MinImagen instance 35 | - [`inference.py`](#inferencepy) - generated images using a MinImagen instance 36 | - [Usage - Package](#usage---package) 37 | - [Training](#training) 38 | - [Image Generation](#image-generation) 39 | - [Modifying the Source Code](#modifying-the-source-code) 40 | - [Additional Resources](#additional-resources) 41 | - [Socials](#socials) 42 | 43 |
44 | 45 | ## Attribution Note 46 | This implementation is largely based on Phil Wang's [Imagen implementation](https://github.com/lucidrains/imagen-pytorch). 47 | 48 |
49 | 50 | ## Installation 51 | To install MinImagen, run the following command in the terminal: 52 | ```bash 53 | $ pip install minimagen 54 | ``` 55 | **Note that MinImagen requires Python3.9 or higher** 56 | 57 |
58 | 59 | ## Documentation 60 | See the [MinImagen Documentation](https://assemblyai-examples.github.io/MinImagen/) to learn more about the package. 61 | 62 |
63 | 64 | ## Usage - Command Line 65 | If you have cloned this repo (as opposed to just installing the `minimagen` package), you can use the provided scripts to get started with MinImagen. This repo can be cloned by running the following command in the terminal: 66 | 67 | ```bash 68 | $ git clone https://github.com/AssemblyAI-Examples/MinImagen.git 69 | ``` 70 | 71 |
72 | 73 | ### `main.py` 74 | For the most basic usage, simply enter the MinImagen directory and run the following in the terminal: 75 | ```bash 76 | $ python main.py 77 | ``` 78 | This will create a small MinImagen instance and train it on a tiny amount of data, and then use this MinImagen instance to generate an image. 79 | 80 | After running the script, you will see a directory called `training_`. 81 | 1. This directory is called a *Training Directory* and is generated when training a MinImagen instance. 82 | 2. It contains information about the configuration (`parameters` subdirectory), and contains the model checkpoints (`state_dicts` and `tmp` directories). 83 | 3. It also contains a `training_progress.txt` file that records training progress. 84 | 85 | You will also see a directory called `generated_images_`. 86 | 1. This directory contains a folder of images generated by the model (`generated_images`). 87 | 2. It also contains `captions.txt` files, which documents the captions that were input to get the images (where the line index of a given caption corresponds to the image number in the `generated_iamges` folder). 88 | 3. Finally, this directory also contains `imagen_training_directory.txt`, which specifies the name of the Training Directory used to load the MinImagen instance / generate images. 89 | 90 |
91 | 92 | ### `train.py` 93 | 94 | `main.py` simply runs `train.py` and `inference.py` in series, the former to train the model and the latter to generate the image. 95 | 96 | To train a model, simply run `train.py` and specify relevant command line arguments. The [possible arguments](https://github.com/AssemblyAI-Examples/MinImagen/blob/d7de8350db17713fb630e127c010020820953872/minimagen/training.py#L178) are: 97 | 98 | - `--PARAMETERS` or `-p`, which specifies a directory that specifies the MinImagen configuration to use. It should be structured like a `parameters` subdirectory within a Training Directory (example in [`parameters`](https://github.com/AssemblyAI-Examples/MinImagen/tree/main/parameters)). 99 | - `--NUM_WORKERS"` or `-n`, which specifies the number of workers to use for the DataLoaders. 100 | - `--BATCH_SIZE` or `-b`, which specifies the batch size to use during training. 101 | - `--MAX_NUM_WORDS` or `-mw`, which specifies the maximum number of words allowed in a caption. 102 | - `--IMG_SIDE_LEN` or `-s`, specifies the final side length of the square images the MinImagen will output. 103 | - `--EPOCHS` or `-e`, which specifies the number of training epochs. 104 | - `--T5_NAME` `-t5`, which specifies the name of T5 encoder to use. 105 | - `--TRAIN_VALID_FRAC` or `-f`, which specifies the fraction of dataset to use for training (vs. validation). 106 | - `--TIMESTEPS` or `-t`, which specifies the number of timesteps in Diffusion Process. 107 | - `--OPTIM_LR` or `-lr`, which specifies the learning rate for Adam optimizer. 108 | - `--ACCUM_ITER` or `-ai`, which specifies the number of batches to accumulate for gradient accumulation. 109 | - `--CHCKPT_NUM` or `-cn`, which specifies the interval of batches to create a temporary model checkpoint at during training. 110 | - `--VALID_NUM` or `-vn`, which specifies the number of validation images to use. If None, uses full amount from train/valid split. The reason for including this is that, even with an e.g. 0.99 `--TRAIN_VALID_FRAC`, a prohibitively large number of images could still be left for validation for very large datasets. 111 | - `--RESTART_DIRECTORY` or `-rd`, training directory to load MinImagen instance from if resuming training. A new Training Directory will be created for the training, leaving the previous Training Directory from which the checkpoint is loaded unperturbed. 112 | - `--TESTING` or `-test`, which is used to run the script with a small MinImagen instance and small dataset for testing. 113 | 114 | For example, to run a small training using the provided example [`parameters`](https://github.com/AssemblyAI-Examples/MinImagen/tree/main/parameters) folder, run the following in the terminal: 115 | 116 | ```bash 117 | python train.py --PARAMETERS ./parameters --BATCH_SIZE 2 --TIMESTEPS 25 --TESTING 118 | ``` 119 | After execution, you will see a new `training_` [Training Directory](#training-directory) that contains the files as [listed above](#training-directory) from the training. 120 | 121 |
122 | 123 | ### `inference.py` 124 | 125 | To generate images using a model from a [Training Directory](#training-directory), we can use `inference.py`. Simply run `inference.py` and specify relevant command line arguments. The possible arguments are: 126 | 127 | - `--TRAINING_DIRECTORY"` or `-d`, which specifies the training directory from which to load the MinImagen instance for inference. 128 | - `--CAPTIONS` or `-c`, which specifies either (a) a single caption to generate an image for, or (b) a filepath to a `.txt` file that contains a list of captions to generate images for, where each caption is on a new line. 129 | 130 | For example, to generate images for the example captions provided in [`captions.txt`](https://github.com/AssemblyAI-Examples/MinImagen/blob/main/captions.txt) using the model generated from the above training line, simply run 131 | 132 | ```bash 133 | python inference.py -CAPTIONS captions.txt --TRAINING_DIRECTORY training_ 134 | ``` 135 | 136 | where `TIMESTAMP` is replaced with the appropriate value from your training. 137 | 138 |
139 | 140 | ## Usage - Package 141 | 142 | ### Training 143 | 144 | A minimal training script using the `minimagen` package is shown below. See [`train.py`](https://github.com/AssemblyAI-Examples/MinImagen/blob/main/train.py) for a more built-up version of the below code. 145 | 146 | ```python 147 | import os 148 | from datetime import datetime 149 | 150 | import torch.utils.data 151 | from torch import optim 152 | 153 | from minimagen.Imagen import Imagen 154 | from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest 155 | from minimagen.generate import load_minimagen, load_params 156 | from minimagen.t5 import get_encoded_dim 157 | from minimagen.training import get_minimagen_parser, ConceptualCaptions, get_minimagen_dl_opts, \ 158 | create_directory, get_model_size, save_training_info, get_default_args, MinimagenTrain, \ 159 | load_testing_parameters 160 | 161 | # Get device 162 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 163 | 164 | # Command line argument parser 165 | parser = get_minimagen_parser() 166 | args = parser.parse_args() 167 | 168 | # Create training directory 169 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 170 | dir_path = f"./training_{timestamp}" 171 | training_dir = create_directory(dir_path) 172 | 173 | # Replace some cmd line args to lower computational load. 174 | args = load_testing_parameters(args) 175 | 176 | # Load subset of Conceptual Captions dataset. 177 | train_dataset, valid_dataset = ConceptualCaptions(args, smalldata=True) 178 | 179 | # Create dataloaders 180 | dl_opts = {**get_minimagen_dl_opts(device), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS} 181 | train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts) 182 | valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts) 183 | 184 | # Use small U-Nets to lower computational load. 185 | unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)] 186 | unets = [Unet(**unet_params).to(device) for unet_params in unets_params] 187 | 188 | # Specify MinImagen parameters 189 | imagen_params = dict( 190 | image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN), 191 | timesteps=args.TIMESTEPS, 192 | cond_drop_prob=0.15, 193 | text_encoder_name=args.T5_NAME 194 | ) 195 | 196 | # Create MinImagen from UNets with specified imagen parameters 197 | imagen = Imagen(unets=unets, **imagen_params).to(device) 198 | 199 | # Fill in unspecified arguments with defaults to record complete config (parameters) file 200 | unets_params = [{**get_default_args(Unet), **i} for i in unets_params] 201 | imagen_params = {**get_default_args(Imagen), **imagen_params} 202 | 203 | # Get the size of the Imagen model in megabytes 204 | model_size_MB = get_model_size(imagen) 205 | 206 | # Save all training info (config files, model size, etc.) 207 | save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir) 208 | 209 | # Create optimizer 210 | optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR) 211 | 212 | # Train the MinImagen instance 213 | MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30) 214 | ``` 215 | 216 | ### Image Generation 217 | 218 | A minimal inference script using the `minimagen` package is shown below. See [`inference.py`](https://github.com/AssemblyAI-Examples/MinImagen/blob/main/inference.py) for a more built-up version of the below code. 219 | 220 | ```python 221 | from argparse import ArgumentParser 222 | from minimagen.generate import load_minimagen, sample_and_save 223 | 224 | # Command line argument parser 225 | parser = ArgumentParser() 226 | parser.add_argument("-d", "--TRAINING_DIRECTORY", dest="TRAINING_DIRECTORY", help="Training directory to use for inference", type=str) 227 | args = parser.parse_args() 228 | 229 | # Specify the caption(s) to generate images for 230 | captions = ['a happy dog'] 231 | 232 | # Use `sample_and_save` to generate and save the iamges 233 | sample_and_save(captions, training_directory=args.TRAINING_DIRECTORY) 234 | 235 | 236 | 237 | # Alternatively, rather than specifying a Training Directory, you can input just a MinImagen instance to use for image generation. 238 | # In this case, information about the MinImagen instance used to generate the images will not be saved. 239 | minimagen = load_minimagen(args.TRAINING_DIRECTORY) 240 | sample_and_save(captions, minimagen=minimagen) 241 | ``` 242 | 243 | To see more of what MinImagen has to offer, or to get additional details on the scripts above, check out the [MinImagen Documentation](https://assemblyai-examples.github.io/MinImagen/) 244 | 245 |
246 | 247 | ## Modifying the Source Code 248 | If you want to make modifications to the source code (rather than use the `minimagen` package), first clone this repository and navigate into it: 249 | 250 | ```bash 251 | $ git clone https://github.com/AssemblyAI-Examples/MinImagen.git 252 | $ cd MinImagen 253 | ``` 254 | 255 | After that, create a virtual environment: 256 | ```bash 257 | $ pip install virtualenv 258 | $ virtualenv venv 259 | ``` 260 | 261 | Then activate the virtual environment and install all dependencies: 262 | ```bash 263 | $ .\venv\Scripts\activate.bat # Windows 264 | $ source venv/bin/activate # MacOS/Linux 265 | $ pip install -r requirements.txt 266 | ``` 267 | 268 | Now you can modify the source code and the changes will be reflected when running any of the [included scripts](#usage---command-line) (as long as the virtual environment created above is active). 269 | 270 | 271 |
272 | 273 | ## Additional Resources 274 | 275 | - For a step-by-step guide on how to build the version of Imagen in this repository, see [Build Your Own Imagen Text-to-Image Model](https://www.assemblyai.com/blog/build-your-own-imagen-text-to-image-model/). 276 | - For an deep-dive into how Imagen works, see [How Imagen Actually Works](https://www.assemblyai.com/blog/how-imagen-actually-works/). 277 | - For a deep-dive into Diffusion Models, see our [Introduction to Diffusion Models for Machine Learning](https://www.assemblyai.com/blog/diffusion-models-for-machine-learning-introduction/) guide. 278 | - For additional learning resources on Machine Learning and Deep Learning, check out our [Blog](https://www.assemblyai.com/blog/) and [YouTube channel](https://www.youtube.com/c/AssemblyAI). 279 | - Read the original Imagen paper [here](https://arxiv.org/abs/2205.11487). 280 | 281 | ## Socials 282 | - Follow us on [Twitter](https://twitter.com/AssemblyAI) for more Deep Learning content. 283 | - [Follow our newsletter](https://assemblyai.us17.list-manage.com/subscribe?u=cb9db7b18b274c2d402a56c5f&id=2116bf7c68) to stay up to date on our recent content. 284 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .minimagen import Imagen 2 | from .minimagen import t5 3 | from .minimagen import Unet 4 | from .minimagen import diffusion_model -------------------------------------------------------------------------------- /captions.txt: -------------------------------------------------------------------------------- 1 | a happy dog 2 | a big red house -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/build/doctrees/environment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/doctrees/environment.pickle -------------------------------------------------------------------------------- /docs/build/doctrees/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/doctrees/index.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/minimagen.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/doctrees/minimagen.doctree -------------------------------------------------------------------------------- /docs/build/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 7fd9cf5bdbe1c20ffb2db5da0c8aa23e 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/build/html/_images/posterior_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_images/posterior_mean.png -------------------------------------------------------------------------------- /docs/build/html/_images/posterior_variance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_images/posterior_variance.png -------------------------------------------------------------------------------- /docs/build/html/_images/q_posterior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_images/q_posterior.png -------------------------------------------------------------------------------- /docs/build/html/_images/q_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_images/q_sample.png -------------------------------------------------------------------------------- /docs/build/html/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | .. MinImagen documentation master file, created by 2 | sphinx-quickstart on Mon Aug 15 18:23:24 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to MinImagen's documentation! 7 | ===================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | minimagen 14 | 15 | 16 | 17 | Indices and tables 18 | ================== 19 | 20 | * :ref:`genindex` 21 | * :ref:`modindex` 22 | * :ref:`search` 23 | -------------------------------------------------------------------------------- /docs/build/html/_sources/minimagen.rst.txt: -------------------------------------------------------------------------------- 1 | minimagen 2 | ================== 3 | 4 | Imagen 5 | ------------------------ 6 | 7 | .. automodule:: minimagen.Imagen 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | Unet 13 | ---------------------- 14 | 15 | .. automodule:: minimagen.Unet 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | Diffusion Model 21 | ---------------------------------- 22 | 23 | .. automodule:: minimagen.diffusion_model 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | T5 29 | -------------------- 30 | 31 | .. automodule:: minimagen.t5 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Training 37 | -------------------- 38 | 39 | .. automodule:: minimagen.training 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | Generate 45 | -------------------- 46 | 47 | .. automodule:: minimagen.generate 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: -------------------------------------------------------------------------------- /docs/build/html/_static/_sphinx_javascript_frameworks_compat.js: -------------------------------------------------------------------------------- 1 | /* 2 | * _sphinx_javascript_frameworks_compat.js 3 | * ~~~~~~~~~~ 4 | * 5 | * Compatability shim for jQuery and underscores.js. 6 | * 7 | * WILL BE REMOVED IN Sphinx 6.0 8 | * xref RemovedInSphinx60Warning 9 | * 10 | */ 11 | 12 | /** 13 | * select a different prefix for underscore 14 | */ 15 | $u = _.noConflict(); 16 | 17 | 18 | /** 19 | * small helper function to urldecode strings 20 | * 21 | * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/decodeURIComponent#Decoding_query_parameters_from_a_URL 22 | */ 23 | jQuery.urldecode = function(x) { 24 | if (!x) { 25 | return x 26 | } 27 | return decodeURIComponent(x.replace(/\+/g, ' ')); 28 | }; 29 | 30 | /** 31 | * small helper function to urlencode strings 32 | */ 33 | jQuery.urlencode = encodeURIComponent; 34 | 35 | /** 36 | * This function returns the parsed url parameters of the 37 | * current request. Multiple values per key are supported, 38 | * it will always return arrays of strings for the value parts. 39 | */ 40 | jQuery.getQueryParameters = function(s) { 41 | if (typeof s === 'undefined') 42 | s = document.location.search; 43 | var parts = s.substr(s.indexOf('?') + 1).split('&'); 44 | var result = {}; 45 | for (var i = 0; i < parts.length; i++) { 46 | var tmp = parts[i].split('=', 2); 47 | var key = jQuery.urldecode(tmp[0]); 48 | var value = jQuery.urldecode(tmp[1]); 49 | if (key in result) 50 | result[key].push(value); 51 | else 52 | result[key] = [value]; 53 | } 54 | return result; 55 | }; 56 | 57 | /** 58 | * highlight a given string on a jquery object by wrapping it in 59 | * span elements with the given class name. 60 | */ 61 | jQuery.fn.highlightText = function(text, className) { 62 | function highlight(node, addItems) { 63 | if (node.nodeType === 3) { 64 | var val = node.nodeValue; 65 | var pos = val.toLowerCase().indexOf(text); 66 | if (pos >= 0 && 67 | !jQuery(node.parentNode).hasClass(className) && 68 | !jQuery(node.parentNode).hasClass("nohighlight")) { 69 | var span; 70 | var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg"); 71 | if (isInSVG) { 72 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 73 | } else { 74 | span = document.createElement("span"); 75 | span.className = className; 76 | } 77 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 78 | node.parentNode.insertBefore(span, node.parentNode.insertBefore( 79 | document.createTextNode(val.substr(pos + text.length)), 80 | node.nextSibling)); 81 | node.nodeValue = val.substr(0, pos); 82 | if (isInSVG) { 83 | var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); 84 | var bbox = node.parentElement.getBBox(); 85 | rect.x.baseVal.value = bbox.x; 86 | rect.y.baseVal.value = bbox.y; 87 | rect.width.baseVal.value = bbox.width; 88 | rect.height.baseVal.value = bbox.height; 89 | rect.setAttribute('class', className); 90 | addItems.push({ 91 | "parent": node.parentNode, 92 | "target": rect}); 93 | } 94 | } 95 | } 96 | else if (!jQuery(node).is("button, select, textarea")) { 97 | jQuery.each(node.childNodes, function() { 98 | highlight(this, addItems); 99 | }); 100 | } 101 | } 102 | var addItems = []; 103 | var result = this.each(function() { 104 | highlight(this, addItems); 105 | }); 106 | for (var i = 0; i < addItems.length; ++i) { 107 | jQuery(addItems[i].parent).before(addItems[i].target); 108 | } 109 | return result; 110 | }; 111 | 112 | /* 113 | * backward compatibility for jQuery.browser 114 | * This will be supported until firefox bug is fixed. 115 | */ 116 | if (!jQuery.browser) { 117 | jQuery.uaMatch = function(ua) { 118 | ua = ua.toLowerCase(); 119 | 120 | var match = /(chrome)[ \/]([\w.]+)/.exec(ua) || 121 | /(webkit)[ \/]([\w.]+)/.exec(ua) || 122 | /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) || 123 | /(msie) ([\w.]+)/.exec(ua) || 124 | ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) || 125 | []; 126 | 127 | return { 128 | browser: match[ 1 ] || "", 129 | version: match[ 2 ] || "0" 130 | }; 131 | }; 132 | jQuery.browser = {}; 133 | jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true; 134 | } 135 | -------------------------------------------------------------------------------- /docs/build/html/_static/clf_free_guidance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/clf_free_guidance.png -------------------------------------------------------------------------------- /docs/build/html/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/doctools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * doctools.js 3 | * ~~~~~~~~~~~ 4 | * 5 | * Base JavaScript utilities for all Sphinx HTML documentation. 6 | * 7 | * :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | "use strict"; 12 | 13 | const _ready = (callback) => { 14 | if (document.readyState !== "loading") { 15 | callback(); 16 | } else { 17 | document.addEventListener("DOMContentLoaded", callback); 18 | } 19 | }; 20 | 21 | /** 22 | * highlight a given string on a node by wrapping it in 23 | * span elements with the given class name. 24 | */ 25 | const _highlight = (node, addItems, text, className) => { 26 | if (node.nodeType === Node.TEXT_NODE) { 27 | const val = node.nodeValue; 28 | const parent = node.parentNode; 29 | const pos = val.toLowerCase().indexOf(text); 30 | if ( 31 | pos >= 0 && 32 | !parent.classList.contains(className) && 33 | !parent.classList.contains("nohighlight") 34 | ) { 35 | let span; 36 | 37 | const closestNode = parent.closest("body, svg, foreignObject"); 38 | const isInSVG = closestNode && closestNode.matches("svg"); 39 | if (isInSVG) { 40 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 41 | } else { 42 | span = document.createElement("span"); 43 | span.classList.add(className); 44 | } 45 | 46 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 47 | parent.insertBefore( 48 | span, 49 | parent.insertBefore( 50 | document.createTextNode(val.substr(pos + text.length)), 51 | node.nextSibling 52 | ) 53 | ); 54 | node.nodeValue = val.substr(0, pos); 55 | 56 | if (isInSVG) { 57 | const rect = document.createElementNS( 58 | "http://www.w3.org/2000/svg", 59 | "rect" 60 | ); 61 | const bbox = parent.getBBox(); 62 | rect.x.baseVal.value = bbox.x; 63 | rect.y.baseVal.value = bbox.y; 64 | rect.width.baseVal.value = bbox.width; 65 | rect.height.baseVal.value = bbox.height; 66 | rect.setAttribute("class", className); 67 | addItems.push({ parent: parent, target: rect }); 68 | } 69 | } 70 | } else if (node.matches && !node.matches("button, select, textarea")) { 71 | node.childNodes.forEach((el) => _highlight(el, addItems, text, className)); 72 | } 73 | }; 74 | const _highlightText = (thisNode, text, className) => { 75 | let addItems = []; 76 | _highlight(thisNode, addItems, text, className); 77 | addItems.forEach((obj) => 78 | obj.parent.insertAdjacentElement("beforebegin", obj.target) 79 | ); 80 | }; 81 | 82 | /** 83 | * Small JavaScript module for the documentation. 84 | */ 85 | const Documentation = { 86 | init: () => { 87 | Documentation.highlightSearchWords(); 88 | Documentation.initDomainIndexTable(); 89 | Documentation.initOnKeyListeners(); 90 | }, 91 | 92 | /** 93 | * i18n support 94 | */ 95 | TRANSLATIONS: {}, 96 | PLURAL_EXPR: (n) => (n === 1 ? 0 : 1), 97 | LOCALE: "unknown", 98 | 99 | // gettext and ngettext don't access this so that the functions 100 | // can safely bound to a different name (_ = Documentation.gettext) 101 | gettext: (string) => { 102 | const translated = Documentation.TRANSLATIONS[string]; 103 | switch (typeof translated) { 104 | case "undefined": 105 | return string; // no translation 106 | case "string": 107 | return translated; // translation exists 108 | default: 109 | return translated[0]; // (singular, plural) translation tuple exists 110 | } 111 | }, 112 | 113 | ngettext: (singular, plural, n) => { 114 | const translated = Documentation.TRANSLATIONS[singular]; 115 | if (typeof translated !== "undefined") 116 | return translated[Documentation.PLURAL_EXPR(n)]; 117 | return n === 1 ? singular : plural; 118 | }, 119 | 120 | addTranslations: (catalog) => { 121 | Object.assign(Documentation.TRANSLATIONS, catalog.messages); 122 | Documentation.PLURAL_EXPR = new Function( 123 | "n", 124 | `return (${catalog.plural_expr})` 125 | ); 126 | Documentation.LOCALE = catalog.locale; 127 | }, 128 | 129 | /** 130 | * highlight the search words provided in the url in the text 131 | */ 132 | highlightSearchWords: () => { 133 | const highlight = 134 | new URLSearchParams(window.location.search).get("highlight") || ""; 135 | const terms = highlight.toLowerCase().split(/\s+/).filter(x => x); 136 | if (terms.length === 0) return; // nothing to do 137 | 138 | // There should never be more than one element matching "div.body" 139 | const divBody = document.querySelectorAll("div.body"); 140 | const body = divBody.length ? divBody[0] : document.querySelector("body"); 141 | window.setTimeout(() => { 142 | terms.forEach((term) => _highlightText(body, term, "highlighted")); 143 | }, 10); 144 | 145 | const searchBox = document.getElementById("searchbox"); 146 | if (searchBox === null) return; 147 | searchBox.appendChild( 148 | document 149 | .createRange() 150 | .createContextualFragment( 151 | '
" 155 | ) 156 | ); 157 | }, 158 | 159 | /** 160 | * helper function to hide the search marks again 161 | */ 162 | hideSearchWords: () => { 163 | document 164 | .querySelectorAll("#searchbox .highlight-link") 165 | .forEach((el) => el.remove()); 166 | document 167 | .querySelectorAll("span.highlighted") 168 | .forEach((el) => el.classList.remove("highlighted")); 169 | const url = new URL(window.location); 170 | url.searchParams.delete("highlight"); 171 | window.history.replaceState({}, "", url); 172 | }, 173 | 174 | /** 175 | * helper function to focus on search bar 176 | */ 177 | focusSearchBar: () => { 178 | document.querySelectorAll("input[name=q]")[0]?.focus(); 179 | }, 180 | 181 | /** 182 | * Initialise the domain index toggle buttons 183 | */ 184 | initDomainIndexTable: () => { 185 | const toggler = (el) => { 186 | const idNumber = el.id.substr(7); 187 | const toggledRows = document.querySelectorAll(`tr.cg-${idNumber}`); 188 | if (el.src.substr(-9) === "minus.png") { 189 | el.src = `${el.src.substr(0, el.src.length - 9)}plus.png`; 190 | toggledRows.forEach((el) => (el.style.display = "none")); 191 | } else { 192 | el.src = `${el.src.substr(0, el.src.length - 8)}minus.png`; 193 | toggledRows.forEach((el) => (el.style.display = "")); 194 | } 195 | }; 196 | 197 | const togglerElements = document.querySelectorAll("img.toggler"); 198 | togglerElements.forEach((el) => 199 | el.addEventListener("click", (event) => toggler(event.currentTarget)) 200 | ); 201 | togglerElements.forEach((el) => (el.style.display = "")); 202 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) togglerElements.forEach(toggler); 203 | }, 204 | 205 | initOnKeyListeners: () => { 206 | // only install a listener if it is really needed 207 | if ( 208 | !DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS && 209 | !DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS 210 | ) 211 | return; 212 | 213 | const blacklistedElements = new Set([ 214 | "TEXTAREA", 215 | "INPUT", 216 | "SELECT", 217 | "BUTTON", 218 | ]); 219 | document.addEventListener("keydown", (event) => { 220 | if (blacklistedElements.has(document.activeElement.tagName)) return; // bail for input elements 221 | if (event.altKey || event.ctrlKey || event.metaKey) return; // bail with special keys 222 | 223 | if (!event.shiftKey) { 224 | switch (event.key) { 225 | case "ArrowLeft": 226 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; 227 | 228 | const prevLink = document.querySelector('link[rel="prev"]'); 229 | if (prevLink && prevLink.href) { 230 | window.location.href = prevLink.href; 231 | event.preventDefault(); 232 | } 233 | break; 234 | case "ArrowRight": 235 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; 236 | 237 | const nextLink = document.querySelector('link[rel="next"]'); 238 | if (nextLink && nextLink.href) { 239 | window.location.href = nextLink.href; 240 | event.preventDefault(); 241 | } 242 | break; 243 | case "Escape": 244 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) break; 245 | Documentation.hideSearchWords(); 246 | event.preventDefault(); 247 | } 248 | } 249 | 250 | // some keyboard layouts may need Shift to get / 251 | switch (event.key) { 252 | case "/": 253 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) break; 254 | Documentation.focusSearchBar(); 255 | event.preventDefault(); 256 | } 257 | }); 258 | }, 259 | }; 260 | 261 | // quick alias for translations 262 | const _ = Documentation.gettext; 263 | 264 | _ready(Documentation.init); 265 | -------------------------------------------------------------------------------- /docs/build/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '2022', 4 | LANGUAGE: 'en', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false, 12 | SHOW_SEARCH_SUMMARY: true, 13 | ENABLE_SEARCH_SHORTCUTS: false, 14 | }; -------------------------------------------------------------------------------- /docs/build/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/file.png -------------------------------------------------------------------------------- /docs/build/html/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /docs/build/html/_static/js/html5shiv-printshiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3-pre | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=y.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=y.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),y.elements=c+" "+a,j(b)}function f(a){var b=x[a[v]];return b||(b={},w++,a[v]=w,x[w]=b),b}function g(a,c,d){if(c||(c=b),q)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():u.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||t.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),q)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return y.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(y,b.frag)}function j(a){a||(a=b);var d=f(a);return!y.shivCSS||p||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),q||i(a,d),a}function k(a){for(var b,c=a.getElementsByTagName("*"),e=c.length,f=RegExp("^(?:"+d().join("|")+")$","i"),g=[];e--;)b=c[e],f.test(b.nodeName)&&g.push(b.applyElement(l(b)));return g}function l(a){for(var b,c=a.attributes,d=c.length,e=a.ownerDocument.createElement(A+":"+a.nodeName);d--;)b=c[d],b.specified&&e.setAttribute(b.nodeName,b.nodeValue);return e.style.cssText=a.style.cssText,e}function m(a){for(var b,c=a.split("{"),e=c.length,f=RegExp("(^|[\\s,>+~])("+d().join("|")+")(?=[[\\s,>+~#.:]|$)","gi"),g="$1"+A+"\\:$2";e--;)b=c[e]=c[e].split("}"),b[b.length-1]=b[b.length-1].replace(f,g),c[e]=b.join("}");return c.join("{")}function n(a){for(var b=a.length;b--;)a[b].removeNode()}function o(a){function b(){clearTimeout(g._removeSheetTimer),d&&d.removeNode(!0),d=null}var d,e,g=f(a),h=a.namespaces,i=a.parentWindow;return!B||a.printShived?a:("undefined"==typeof h[A]&&h.add(A),i.attachEvent("onbeforeprint",function(){b();for(var f,g,h,i=a.styleSheets,j=[],l=i.length,n=Array(l);l--;)n[l]=i[l];for(;h=n.pop();)if(!h.disabled&&z.test(h.media)){try{f=h.imports,g=f.length}catch(o){g=0}for(l=0;g>l;l++)n.push(f[l]);try{j.push(h.cssText)}catch(o){}}j=m(j.reverse().join("")),e=k(a),d=c(a,j)}),i.attachEvent("onafterprint",function(){n(e),clearTimeout(g._removeSheetTimer),g._removeSheetTimer=setTimeout(b,500)}),a.printShived=!0,a)}var p,q,r="3.7.3",s=a.html5||{},t=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,u=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,v="_html5shiv",w=0,x={};!function(){try{var a=b.createElement("a");a.innerHTML="",p="hidden"in a,q=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){p=!0,q=!0}}();var y={elements:s.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:r,shivCSS:s.shivCSS!==!1,supportsUnknownElements:q,shivMethods:s.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=y,j(b);var z=/^$|\b(?:all|print)\b/,A="html5shiv",B=!q&&function(){var c=b.documentElement;return!("undefined"==typeof b.namespaces||"undefined"==typeof b.parentWindow||"undefined"==typeof c.applyElement||"undefined"==typeof c.removeNode||"undefined"==typeof a.attachEvent)}();y.type+=" print",y.shivPrint=o,o(b),"object"==typeof module&&module.exports&&(module.exports=y)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/build/html/_static/js/html5shiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/build/html/_static/js/theme.js: -------------------------------------------------------------------------------- 1 | !function(n){var e={};function t(i){if(e[i])return e[i].exports;var o=e[i]={i:i,l:!1,exports:{}};return n[i].call(o.exports,o,o.exports,t),o.l=!0,o.exports}t.m=n,t.c=e,t.d=function(n,e,i){t.o(n,e)||Object.defineProperty(n,e,{enumerable:!0,get:i})},t.r=function(n){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(n,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(n,"__esModule",{value:!0})},t.t=function(n,e){if(1&e&&(n=t(n)),8&e)return n;if(4&e&&"object"==typeof n&&n&&n.__esModule)return n;var i=Object.create(null);if(t.r(i),Object.defineProperty(i,"default",{enumerable:!0,value:n}),2&e&&"string"!=typeof n)for(var o in n)t.d(i,o,function(e){return n[e]}.bind(null,o));return i},t.n=function(n){var e=n&&n.__esModule?function(){return n.default}:function(){return n};return t.d(e,"a",e),e},t.o=function(n,e){return Object.prototype.hasOwnProperty.call(n,e)},t.p="",t(t.s=0)}([function(n,e,t){t(1),n.exports=t(3)},function(n,e,t){(function(){var e="undefined"!=typeof window?window.jQuery:t(2);n.exports.ThemeNav={navBar:null,win:null,winScroll:!1,winResize:!1,linkScroll:!1,winPosition:0,winHeight:null,docHeight:null,isRunning:!1,enable:function(n){var t=this;void 0===n&&(n=!0),t.isRunning||(t.isRunning=!0,e((function(e){t.init(e),t.reset(),t.win.on("hashchange",t.reset),n&&t.win.on("scroll",(function(){t.linkScroll||t.winScroll||(t.winScroll=!0,requestAnimationFrame((function(){t.onScroll()})))})),t.win.on("resize",(function(){t.winResize||(t.winResize=!0,requestAnimationFrame((function(){t.onResize()})))})),t.onResize()})))},enableSticky:function(){this.enable(!0)},init:function(n){n(document);var e=this;this.navBar=n("div.wy-side-scroll:first"),this.win=n(window),n(document).on("click","[data-toggle='wy-nav-top']",(function(){n("[data-toggle='wy-nav-shift']").toggleClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift")})).on("click",".wy-menu-vertical .current ul li a",(function(){var t=n(this);n("[data-toggle='wy-nav-shift']").removeClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift"),e.toggleCurrent(t),e.hashChange()})).on("click","[data-toggle='rst-current-version']",(function(){n("[data-toggle='rst-versions']").toggleClass("shift-up")})),n("table.docutils:not(.field-list,.footnote,.citation)").wrap("
"),n("table.docutils.footnote").wrap("
"),n("table.docutils.citation").wrap("
"),n(".wy-menu-vertical ul").not(".simple").siblings("a").each((function(){var t=n(this);expand=n(''),expand.on("click",(function(n){return e.toggleCurrent(t),n.stopPropagation(),!1})),t.prepend(expand)}))},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),t=e.find('[href="'+n+'"]');if(0===t.length){var i=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(t=e.find('[href="#'+i.attr("id")+'"]')).length&&(t=e.find('[href="#"]'))}if(t.length>0){$(".wy-menu-vertical .current").removeClass("current").attr("aria-expanded","false"),t.addClass("current").attr("aria-expanded","true"),t.closest("li.toctree-l1").parent().addClass("current").attr("aria-expanded","true");for(let n=1;n<=10;n++)t.closest("li.toctree-l"+n).addClass("current").attr("aria-expanded","true");t[0].scrollIntoView()}}catch(n){console.log("Error expanding nav for anchor",n)}},onScroll:function(){this.winScroll=!1;var n=this.win.scrollTop(),e=n+this.winHeight,t=this.navBar.scrollTop()+(n-this.winPosition);n<0||e>this.docHeight||(this.navBar.scrollTop(t),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",(function(){this.linkScroll=!1}))},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current").attr("aria-expanded","false"),e.siblings().find("li.current").removeClass("current").attr("aria-expanded","false");var t=e.find("> ul li");t.length&&(t.removeClass("current").attr("aria-expanded","false"),e.toggleClass("current").attr("aria-expanded",(function(n,e){return"true"==e?"false":"true"})))}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:n.exports.ThemeNav,StickyNav:n.exports.ThemeNav}),function(){for(var n=0,e=["ms","moz","webkit","o"],t=0;t0 63 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 64 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 65 | var s_v = "^(" + C + ")?" + v; // vowel in stem 66 | 67 | this.stemWord = function (w) { 68 | var stem; 69 | var suffix; 70 | var firstch; 71 | var origword = w; 72 | 73 | if (w.length < 3) 74 | return w; 75 | 76 | var re; 77 | var re2; 78 | var re3; 79 | var re4; 80 | 81 | firstch = w.substr(0,1); 82 | if (firstch == "y") 83 | w = firstch.toUpperCase() + w.substr(1); 84 | 85 | // Step 1a 86 | re = /^(.+?)(ss|i)es$/; 87 | re2 = /^(.+?)([^s])s$/; 88 | 89 | if (re.test(w)) 90 | w = w.replace(re,"$1$2"); 91 | else if (re2.test(w)) 92 | w = w.replace(re2,"$1$2"); 93 | 94 | // Step 1b 95 | re = /^(.+?)eed$/; 96 | re2 = /^(.+?)(ed|ing)$/; 97 | if (re.test(w)) { 98 | var fp = re.exec(w); 99 | re = new RegExp(mgr0); 100 | if (re.test(fp[1])) { 101 | re = /.$/; 102 | w = w.replace(re,""); 103 | } 104 | } 105 | else if (re2.test(w)) { 106 | var fp = re2.exec(w); 107 | stem = fp[1]; 108 | re2 = new RegExp(s_v); 109 | if (re2.test(stem)) { 110 | w = stem; 111 | re2 = /(at|bl|iz)$/; 112 | re3 = new RegExp("([^aeiouylsz])\\1$"); 113 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 114 | if (re2.test(w)) 115 | w = w + "e"; 116 | else if (re3.test(w)) { 117 | re = /.$/; 118 | w = w.replace(re,""); 119 | } 120 | else if (re4.test(w)) 121 | w = w + "e"; 122 | } 123 | } 124 | 125 | // Step 1c 126 | re = /^(.+?)y$/; 127 | if (re.test(w)) { 128 | var fp = re.exec(w); 129 | stem = fp[1]; 130 | re = new RegExp(s_v); 131 | if (re.test(stem)) 132 | w = stem + "i"; 133 | } 134 | 135 | // Step 2 136 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; 137 | if (re.test(w)) { 138 | var fp = re.exec(w); 139 | stem = fp[1]; 140 | suffix = fp[2]; 141 | re = new RegExp(mgr0); 142 | if (re.test(stem)) 143 | w = stem + step2list[suffix]; 144 | } 145 | 146 | // Step 3 147 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; 148 | if (re.test(w)) { 149 | var fp = re.exec(w); 150 | stem = fp[1]; 151 | suffix = fp[2]; 152 | re = new RegExp(mgr0); 153 | if (re.test(stem)) 154 | w = stem + step3list[suffix]; 155 | } 156 | 157 | // Step 4 158 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; 159 | re2 = /^(.+?)(s|t)(ion)$/; 160 | if (re.test(w)) { 161 | var fp = re.exec(w); 162 | stem = fp[1]; 163 | re = new RegExp(mgr1); 164 | if (re.test(stem)) 165 | w = stem; 166 | } 167 | else if (re2.test(w)) { 168 | var fp = re2.exec(w); 169 | stem = fp[1] + fp[2]; 170 | re2 = new RegExp(mgr1); 171 | if (re2.test(stem)) 172 | w = stem; 173 | } 174 | 175 | // Step 5 176 | re = /^(.+?)e$/; 177 | if (re.test(w)) { 178 | var fp = re.exec(w); 179 | stem = fp[1]; 180 | re = new RegExp(mgr1); 181 | re2 = new RegExp(meq1); 182 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 183 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) 184 | w = stem; 185 | } 186 | re = /ll$/; 187 | re2 = new RegExp(mgr1); 188 | if (re.test(w) && re2.test(w)) { 189 | re = /.$/; 190 | w = w.replace(re,""); 191 | } 192 | 193 | // and turn initial Y back to y 194 | if (firstch == "y") 195 | w = firstch.toLowerCase() + w.substr(1); 196 | return w; 197 | } 198 | } 199 | 200 | -------------------------------------------------------------------------------- /docs/build/html/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/minus.png -------------------------------------------------------------------------------- /docs/build/html/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/plus.png -------------------------------------------------------------------------------- /docs/build/html/_static/posterior_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/posterior_mean.png -------------------------------------------------------------------------------- /docs/build/html/_static/posterior_mean_coeffs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/posterior_mean_coeffs.png -------------------------------------------------------------------------------- /docs/build/html/_static/posterior_variance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/posterior_variance.png -------------------------------------------------------------------------------- /docs/build/html/_static/posterior_variance_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/posterior_variance_box.png -------------------------------------------------------------------------------- /docs/build/html/_static/pygments.css: -------------------------------------------------------------------------------- 1 | pre { line-height: 125%; } 2 | td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 3 | span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 4 | td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 5 | span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 6 | .highlight .hll { background-color: #ffffcc } 7 | .highlight { background: #f8f8f8; } 8 | .highlight .c { color: #3D7B7B; font-style: italic } /* Comment */ 9 | .highlight .err { border: 1px solid #FF0000 } /* Error */ 10 | .highlight .k { color: #008000; font-weight: bold } /* Keyword */ 11 | .highlight .o { color: #666666 } /* Operator */ 12 | .highlight .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */ 13 | .highlight .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */ 14 | .highlight .cp { color: #9C6500 } /* Comment.Preproc */ 15 | .highlight .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */ 16 | .highlight .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */ 17 | .highlight .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */ 18 | .highlight .gd { color: #A00000 } /* Generic.Deleted */ 19 | .highlight .ge { font-style: italic } /* Generic.Emph */ 20 | .highlight .gr { color: #E40000 } /* Generic.Error */ 21 | .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 22 | .highlight .gi { color: #008400 } /* Generic.Inserted */ 23 | .highlight .go { color: #717171 } /* Generic.Output */ 24 | .highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */ 25 | .highlight .gs { font-weight: bold } /* Generic.Strong */ 26 | .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 27 | .highlight .gt { color: #0044DD } /* Generic.Traceback */ 28 | .highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */ 29 | .highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */ 30 | .highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */ 31 | .highlight .kp { color: #008000 } /* Keyword.Pseudo */ 32 | .highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */ 33 | .highlight .kt { color: #B00040 } /* Keyword.Type */ 34 | .highlight .m { color: #666666 } /* Literal.Number */ 35 | .highlight .s { color: #BA2121 } /* Literal.String */ 36 | .highlight .na { color: #687822 } /* Name.Attribute */ 37 | .highlight .nb { color: #008000 } /* Name.Builtin */ 38 | .highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */ 39 | .highlight .no { color: #880000 } /* Name.Constant */ 40 | .highlight .nd { color: #AA22FF } /* Name.Decorator */ 41 | .highlight .ni { color: #717171; font-weight: bold } /* Name.Entity */ 42 | .highlight .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */ 43 | .highlight .nf { color: #0000FF } /* Name.Function */ 44 | .highlight .nl { color: #767600 } /* Name.Label */ 45 | .highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */ 46 | .highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */ 47 | .highlight .nv { color: #19177C } /* Name.Variable */ 48 | .highlight .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */ 49 | .highlight .w { color: #bbbbbb } /* Text.Whitespace */ 50 | .highlight .mb { color: #666666 } /* Literal.Number.Bin */ 51 | .highlight .mf { color: #666666 } /* Literal.Number.Float */ 52 | .highlight .mh { color: #666666 } /* Literal.Number.Hex */ 53 | .highlight .mi { color: #666666 } /* Literal.Number.Integer */ 54 | .highlight .mo { color: #666666 } /* Literal.Number.Oct */ 55 | .highlight .sa { color: #BA2121 } /* Literal.String.Affix */ 56 | .highlight .sb { color: #BA2121 } /* Literal.String.Backtick */ 57 | .highlight .sc { color: #BA2121 } /* Literal.String.Char */ 58 | .highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */ 59 | .highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ 60 | .highlight .s2 { color: #BA2121 } /* Literal.String.Double */ 61 | .highlight .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */ 62 | .highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */ 63 | .highlight .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */ 64 | .highlight .sx { color: #008000 } /* Literal.String.Other */ 65 | .highlight .sr { color: #A45A77 } /* Literal.String.Regex */ 66 | .highlight .s1 { color: #BA2121 } /* Literal.String.Single */ 67 | .highlight .ss { color: #19177C } /* Literal.String.Symbol */ 68 | .highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */ 69 | .highlight .fm { color: #0000FF } /* Name.Function.Magic */ 70 | .highlight .vc { color: #19177C } /* Name.Variable.Class */ 71 | .highlight .vg { color: #19177C } /* Name.Variable.Global */ 72 | .highlight .vi { color: #19177C } /* Name.Variable.Instance */ 73 | .highlight .vm { color: #19177C } /* Name.Variable.Magic */ 74 | .highlight .il { color: #666666 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /docs/build/html/_static/q_posterior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/q_posterior.png -------------------------------------------------------------------------------- /docs/build/html/_static/q_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/q_sample.png -------------------------------------------------------------------------------- /docs/build/html/_static/q_sample_reparam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/_static/q_sample_reparam.png -------------------------------------------------------------------------------- /docs/build/html/_static/searchtools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * searchtools.js 3 | * ~~~~~~~~~~~~~~~~ 4 | * 5 | * Sphinx JavaScript utilities for the full-text search. 6 | * 7 | * :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | "use strict"; 12 | 13 | /** 14 | * Simple result scoring code. 15 | */ 16 | if (typeof Scorer === "undefined") { 17 | var Scorer = { 18 | // Implement the following function to further tweak the score for each result 19 | // The function takes a result array [docname, title, anchor, descr, score, filename] 20 | // and returns the new score. 21 | /* 22 | score: result => { 23 | const [docname, title, anchor, descr, score, filename] = result 24 | return score 25 | }, 26 | */ 27 | 28 | // query matches the full name of an object 29 | objNameMatch: 11, 30 | // or matches in the last dotted part of the object name 31 | objPartialMatch: 6, 32 | // Additive scores depending on the priority of the object 33 | objPrio: { 34 | 0: 15, // used to be importantResults 35 | 1: 5, // used to be objectResults 36 | 2: -5, // used to be unimportantResults 37 | }, 38 | // Used when the priority is not in the mapping. 39 | objPrioDefault: 0, 40 | 41 | // query found in title 42 | title: 15, 43 | partialTitle: 7, 44 | // query found in terms 45 | term: 5, 46 | partialTerm: 2, 47 | }; 48 | } 49 | 50 | const _removeChildren = (element) => { 51 | while (element && element.lastChild) element.removeChild(element.lastChild); 52 | }; 53 | 54 | /** 55 | * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Regular_Expressions#escaping 56 | */ 57 | const _escapeRegExp = (string) => 58 | string.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string 59 | 60 | const _displayItem = (item, highlightTerms, searchTerms) => { 61 | const docBuilder = DOCUMENTATION_OPTIONS.BUILDER; 62 | const docUrlRoot = DOCUMENTATION_OPTIONS.URL_ROOT; 63 | const docFileSuffix = DOCUMENTATION_OPTIONS.FILE_SUFFIX; 64 | const docLinkSuffix = DOCUMENTATION_OPTIONS.LINK_SUFFIX; 65 | const showSearchSummary = DOCUMENTATION_OPTIONS.SHOW_SEARCH_SUMMARY; 66 | 67 | const [docName, title, anchor, descr] = item; 68 | 69 | let listItem = document.createElement("li"); 70 | let requestUrl; 71 | let linkUrl; 72 | if (docBuilder === "dirhtml") { 73 | // dirhtml builder 74 | let dirname = docName + "/"; 75 | if (dirname.match(/\/index\/$/)) 76 | dirname = dirname.substring(0, dirname.length - 6); 77 | else if (dirname === "index/") dirname = ""; 78 | requestUrl = docUrlRoot + dirname; 79 | linkUrl = requestUrl; 80 | } else { 81 | // normal html builders 82 | requestUrl = docUrlRoot + docName + docFileSuffix; 83 | linkUrl = docName + docLinkSuffix; 84 | } 85 | const params = new URLSearchParams(); 86 | params.set("highlight", [...highlightTerms].join(" ")); 87 | let linkEl = listItem.appendChild(document.createElement("a")); 88 | linkEl.href = linkUrl + "?" + params.toString() + anchor; 89 | linkEl.innerHTML = title; 90 | if (descr) 91 | listItem.appendChild(document.createElement("span")).innerText = 92 | " (" + descr + ")"; 93 | else if (showSearchSummary) 94 | fetch(requestUrl) 95 | .then((responseData) => responseData.text()) 96 | .then((data) => { 97 | if (data) 98 | listItem.appendChild( 99 | Search.makeSearchSummary(data, searchTerms, highlightTerms) 100 | ); 101 | }); 102 | Search.output.appendChild(listItem); 103 | }; 104 | const _finishSearch = (resultCount) => { 105 | Search.stopPulse(); 106 | Search.title.innerText = _("Search Results"); 107 | if (!resultCount) 108 | Search.status.innerText = Documentation.gettext( 109 | "Your search did not match any documents. Please make sure that all words are spelled correctly and that you've selected enough categories." 110 | ); 111 | else 112 | Search.status.innerText = _( 113 | `Search finished, found ${resultCount} page(s) matching the search query.` 114 | ); 115 | }; 116 | const _displayNextItem = ( 117 | results, 118 | resultCount, 119 | highlightTerms, 120 | searchTerms 121 | ) => { 122 | // results left, load the summary and display it 123 | // this is intended to be dynamic (don't sub resultsCount) 124 | if (results.length) { 125 | _displayItem(results.pop(), highlightTerms, searchTerms); 126 | setTimeout( 127 | () => _displayNextItem(results, resultCount, highlightTerms, searchTerms), 128 | 5 129 | ); 130 | } 131 | // search finished, update title and status message 132 | else _finishSearch(resultCount); 133 | }; 134 | 135 | /** 136 | * Default splitQuery function. Can be overridden in ``sphinx.search`` with a 137 | * custom function per language. 138 | * 139 | * The regular expression works by splitting the string on consecutive characters 140 | * that are not Unicode letters, numbers, underscores, or emoji characters. 141 | * This is the same as ``\W+`` in Python, preserving the surrogate pair area. 142 | */ 143 | if (typeof splitQuery === "undefined") { 144 | var splitQuery = (query) => query 145 | .split(/[^\p{Letter}\p{Number}_\p{Emoji_Presentation}]+/gu) 146 | .filter(term => term) // remove remaining empty strings 147 | } 148 | 149 | /** 150 | * Search Module 151 | */ 152 | const Search = { 153 | _index: null, 154 | _queued_query: null, 155 | _pulse_status: -1, 156 | 157 | htmlToText: (htmlString) => { 158 | const htmlElement = document 159 | .createRange() 160 | .createContextualFragment(htmlString); 161 | _removeChildren(htmlElement.querySelectorAll(".headerlink")); 162 | const docContent = htmlElement.querySelector('[role="main"]'); 163 | if (docContent !== undefined) return docContent.textContent; 164 | console.warn( 165 | "Content block not found. Sphinx search tries to obtain it via '[role=main]'. Could you check your theme or template." 166 | ); 167 | return ""; 168 | }, 169 | 170 | init: () => { 171 | const query = new URLSearchParams(window.location.search).get("q"); 172 | document 173 | .querySelectorAll('input[name="q"]') 174 | .forEach((el) => (el.value = query)); 175 | if (query) Search.performSearch(query); 176 | }, 177 | 178 | loadIndex: (url) => 179 | (document.body.appendChild(document.createElement("script")).src = url), 180 | 181 | setIndex: (index) => { 182 | Search._index = index; 183 | if (Search._queued_query !== null) { 184 | const query = Search._queued_query; 185 | Search._queued_query = null; 186 | Search.query(query); 187 | } 188 | }, 189 | 190 | hasIndex: () => Search._index !== null, 191 | 192 | deferQuery: (query) => (Search._queued_query = query), 193 | 194 | stopPulse: () => (Search._pulse_status = -1), 195 | 196 | startPulse: () => { 197 | if (Search._pulse_status >= 0) return; 198 | 199 | const pulse = () => { 200 | Search._pulse_status = (Search._pulse_status + 1) % 4; 201 | Search.dots.innerText = ".".repeat(Search._pulse_status); 202 | if (Search._pulse_status >= 0) window.setTimeout(pulse, 500); 203 | }; 204 | pulse(); 205 | }, 206 | 207 | /** 208 | * perform a search for something (or wait until index is loaded) 209 | */ 210 | performSearch: (query) => { 211 | // create the required interface elements 212 | const searchText = document.createElement("h2"); 213 | searchText.textContent = _("Searching"); 214 | const searchSummary = document.createElement("p"); 215 | searchSummary.classList.add("search-summary"); 216 | searchSummary.innerText = ""; 217 | const searchList = document.createElement("ul"); 218 | searchList.classList.add("search"); 219 | 220 | const out = document.getElementById("search-results"); 221 | Search.title = out.appendChild(searchText); 222 | Search.dots = Search.title.appendChild(document.createElement("span")); 223 | Search.status = out.appendChild(searchSummary); 224 | Search.output = out.appendChild(searchList); 225 | 226 | const searchProgress = document.getElementById("search-progress"); 227 | // Some themes don't use the search progress node 228 | if (searchProgress) { 229 | searchProgress.innerText = _("Preparing search..."); 230 | } 231 | Search.startPulse(); 232 | 233 | // index already loaded, the browser was quick! 234 | if (Search.hasIndex()) Search.query(query); 235 | else Search.deferQuery(query); 236 | }, 237 | 238 | /** 239 | * execute search (requires search index to be loaded) 240 | */ 241 | query: (query) => { 242 | // stem the search terms and add them to the correct list 243 | const stemmer = new Stemmer(); 244 | const searchTerms = new Set(); 245 | const excludedTerms = new Set(); 246 | const highlightTerms = new Set(); 247 | const objectTerms = new Set(splitQuery(query.toLowerCase().trim())); 248 | splitQuery(query.trim()).forEach((queryTerm) => { 249 | const queryTermLower = queryTerm.toLowerCase(); 250 | 251 | // maybe skip this "word" 252 | // stopwords array is from language_data.js 253 | if ( 254 | stopwords.indexOf(queryTermLower) !== -1 || 255 | queryTerm.match(/^\d+$/) 256 | ) 257 | return; 258 | 259 | // stem the word 260 | let word = stemmer.stemWord(queryTermLower); 261 | // select the correct list 262 | if (word[0] === "-") excludedTerms.add(word.substr(1)); 263 | else { 264 | searchTerms.add(word); 265 | highlightTerms.add(queryTermLower); 266 | } 267 | }); 268 | 269 | // console.debug("SEARCH: searching for:"); 270 | // console.info("required: ", [...searchTerms]); 271 | // console.info("excluded: ", [...excludedTerms]); 272 | 273 | // array of [docname, title, anchor, descr, score, filename] 274 | let results = []; 275 | _removeChildren(document.getElementById("search-progress")); 276 | 277 | // lookup as object 278 | objectTerms.forEach((term) => 279 | results.push(...Search.performObjectSearch(term, objectTerms)) 280 | ); 281 | 282 | // lookup as search terms in fulltext 283 | results.push(...Search.performTermsSearch(searchTerms, excludedTerms)); 284 | 285 | // let the scorer override scores with a custom scoring function 286 | if (Scorer.score) results.forEach((item) => (item[4] = Scorer.score(item))); 287 | 288 | // now sort the results by score (in opposite order of appearance, since the 289 | // display function below uses pop() to retrieve items) and then 290 | // alphabetically 291 | results.sort((a, b) => { 292 | const leftScore = a[4]; 293 | const rightScore = b[4]; 294 | if (leftScore === rightScore) { 295 | // same score: sort alphabetically 296 | const leftTitle = a[1].toLowerCase(); 297 | const rightTitle = b[1].toLowerCase(); 298 | if (leftTitle === rightTitle) return 0; 299 | return leftTitle > rightTitle ? -1 : 1; // inverted is intentional 300 | } 301 | return leftScore > rightScore ? 1 : -1; 302 | }); 303 | 304 | // remove duplicate search results 305 | // note the reversing of results, so that in the case of duplicates, the highest-scoring entry is kept 306 | let seen = new Set(); 307 | results = results.reverse().reduce((acc, result) => { 308 | let resultStr = result.slice(0, 4).concat([result[5]]).map(v => String(v)).join(','); 309 | if (!seen.has(resultStr)) { 310 | acc.push(result); 311 | seen.add(resultStr); 312 | } 313 | return acc; 314 | }, []); 315 | 316 | results = results.reverse(); 317 | 318 | // for debugging 319 | //Search.lastresults = results.slice(); // a copy 320 | // console.info("search results:", Search.lastresults); 321 | 322 | // print the results 323 | _displayNextItem(results, results.length, highlightTerms, searchTerms); 324 | }, 325 | 326 | /** 327 | * search for object names 328 | */ 329 | performObjectSearch: (object, objectTerms) => { 330 | const filenames = Search._index.filenames; 331 | const docNames = Search._index.docnames; 332 | const objects = Search._index.objects; 333 | const objNames = Search._index.objnames; 334 | const titles = Search._index.titles; 335 | 336 | const results = []; 337 | 338 | const objectSearchCallback = (prefix, match) => { 339 | const name = match[4] 340 | const fullname = (prefix ? prefix + "." : "") + name; 341 | const fullnameLower = fullname.toLowerCase(); 342 | if (fullnameLower.indexOf(object) < 0) return; 343 | 344 | let score = 0; 345 | const parts = fullnameLower.split("."); 346 | 347 | // check for different match types: exact matches of full name or 348 | // "last name" (i.e. last dotted part) 349 | if (fullnameLower === object || parts.slice(-1)[0] === object) 350 | score += Scorer.objNameMatch; 351 | else if (parts.slice(-1)[0].indexOf(object) > -1) 352 | score += Scorer.objPartialMatch; // matches in last name 353 | 354 | const objName = objNames[match[1]][2]; 355 | const title = titles[match[0]]; 356 | 357 | // If more than one term searched for, we require other words to be 358 | // found in the name/title/description 359 | const otherTerms = new Set(objectTerms); 360 | otherTerms.delete(object); 361 | if (otherTerms.size > 0) { 362 | const haystack = `${prefix} ${name} ${objName} ${title}`.toLowerCase(); 363 | if ( 364 | [...otherTerms].some((otherTerm) => haystack.indexOf(otherTerm) < 0) 365 | ) 366 | return; 367 | } 368 | 369 | let anchor = match[3]; 370 | if (anchor === "") anchor = fullname; 371 | else if (anchor === "-") anchor = objNames[match[1]][1] + "-" + fullname; 372 | 373 | const descr = objName + _(", in ") + title; 374 | 375 | // add custom score for some objects according to scorer 376 | if (Scorer.objPrio.hasOwnProperty(match[2])) 377 | score += Scorer.objPrio[match[2]]; 378 | else score += Scorer.objPrioDefault; 379 | 380 | results.push([ 381 | docNames[match[0]], 382 | fullname, 383 | "#" + anchor, 384 | descr, 385 | score, 386 | filenames[match[0]], 387 | ]); 388 | }; 389 | Object.keys(objects).forEach((prefix) => 390 | objects[prefix].forEach((array) => 391 | objectSearchCallback(prefix, array) 392 | ) 393 | ); 394 | return results; 395 | }, 396 | 397 | /** 398 | * search for full-text terms in the index 399 | */ 400 | performTermsSearch: (searchTerms, excludedTerms) => { 401 | // prepare search 402 | const terms = Search._index.terms; 403 | const titleTerms = Search._index.titleterms; 404 | const docNames = Search._index.docnames; 405 | const filenames = Search._index.filenames; 406 | const titles = Search._index.titles; 407 | 408 | const scoreMap = new Map(); 409 | const fileMap = new Map(); 410 | 411 | // perform the search on the required terms 412 | searchTerms.forEach((word) => { 413 | const files = []; 414 | const arr = [ 415 | { files: terms[word], score: Scorer.term }, 416 | { files: titleTerms[word], score: Scorer.title }, 417 | ]; 418 | // add support for partial matches 419 | if (word.length > 2) { 420 | const escapedWord = _escapeRegExp(word); 421 | Object.keys(terms).forEach((term) => { 422 | if (term.match(escapedWord) && !terms[word]) 423 | arr.push({ files: terms[term], score: Scorer.partialTerm }); 424 | }); 425 | Object.keys(titleTerms).forEach((term) => { 426 | if (term.match(escapedWord) && !titleTerms[word]) 427 | arr.push({ files: titleTerms[word], score: Scorer.partialTitle }); 428 | }); 429 | } 430 | 431 | // no match but word was a required one 432 | if (arr.every((record) => record.files === undefined)) return; 433 | 434 | // found search word in contents 435 | arr.forEach((record) => { 436 | if (record.files === undefined) return; 437 | 438 | let recordFiles = record.files; 439 | if (recordFiles.length === undefined) recordFiles = [recordFiles]; 440 | files.push(...recordFiles); 441 | 442 | // set score for the word in each file 443 | recordFiles.forEach((file) => { 444 | if (!scoreMap.has(file)) scoreMap.set(file, {}); 445 | scoreMap.get(file)[word] = record.score; 446 | }); 447 | }); 448 | 449 | // create the mapping 450 | files.forEach((file) => { 451 | if (fileMap.has(file) && fileMap.get(file).indexOf(word) === -1) 452 | fileMap.get(file).push(word); 453 | else fileMap.set(file, [word]); 454 | }); 455 | }); 456 | 457 | // now check if the files don't contain excluded terms 458 | const results = []; 459 | for (const [file, wordList] of fileMap) { 460 | // check if all requirements are matched 461 | 462 | // as search terms with length < 3 are discarded 463 | const filteredTermCount = [...searchTerms].filter( 464 | (term) => term.length > 2 465 | ).length; 466 | if ( 467 | wordList.length !== searchTerms.size && 468 | wordList.length !== filteredTermCount 469 | ) 470 | continue; 471 | 472 | // ensure that none of the excluded terms is in the search result 473 | if ( 474 | [...excludedTerms].some( 475 | (term) => 476 | terms[term] === file || 477 | titleTerms[term] === file || 478 | (terms[term] || []).includes(file) || 479 | (titleTerms[term] || []).includes(file) 480 | ) 481 | ) 482 | break; 483 | 484 | // select one (max) score for the file. 485 | const score = Math.max(...wordList.map((w) => scoreMap.get(file)[w])); 486 | // add result to the result list 487 | results.push([ 488 | docNames[file], 489 | titles[file], 490 | "", 491 | null, 492 | score, 493 | filenames[file], 494 | ]); 495 | } 496 | return results; 497 | }, 498 | 499 | /** 500 | * helper function to return a node containing the 501 | * search summary for a given text. keywords is a list 502 | * of stemmed words, highlightWords is the list of normal, unstemmed 503 | * words. the first one is used to find the occurrence, the 504 | * latter for highlighting it. 505 | */ 506 | makeSearchSummary: (htmlText, keywords, highlightWords) => { 507 | const text = Search.htmlToText(htmlText).toLowerCase(); 508 | if (text === "") return null; 509 | 510 | const actualStartPosition = [...keywords] 511 | .map((k) => text.indexOf(k.toLowerCase())) 512 | .filter((i) => i > -1) 513 | .slice(-1)[0]; 514 | const startWithContext = Math.max(actualStartPosition - 120, 0); 515 | 516 | const top = startWithContext === 0 ? "" : "..."; 517 | const tail = startWithContext + 240 < text.length ? "..." : ""; 518 | 519 | let summary = document.createElement("div"); 520 | summary.classList.add("context"); 521 | summary.innerText = top + text.substr(startWithContext, 240).trim() + tail; 522 | 523 | highlightWords.forEach((highlightWord) => 524 | _highlightText(summary, highlightWord, "highlighted") 525 | ); 526 | 527 | return summary; 528 | }, 529 | }; 530 | 531 | _ready(Search.init); 532 | -------------------------------------------------------------------------------- /docs/build/html/_static/underscore.js: -------------------------------------------------------------------------------- 1 | !function(n,r){"object"==typeof exports&&"undefined"!=typeof module?module.exports=r():"function"==typeof define&&define.amd?define("underscore",r):(n="undefined"!=typeof globalThis?globalThis:n||self,function(){var t=n._,e=n._=r();e.noConflict=function(){return n._=t,e}}())}(this,(function(){ 2 | // Underscore.js 1.13.1 3 | // https://underscorejs.org 4 | // (c) 2009-2021 Jeremy Ashkenas, Julian Gonggrijp, and DocumentCloud and Investigative Reporters & Editors 5 | // Underscore may be freely distributed under the MIT license. 6 | var n="1.13.1",r="object"==typeof self&&self.self===self&&self||"object"==typeof global&&global.global===global&&global||Function("return this")()||{},t=Array.prototype,e=Object.prototype,u="undefined"!=typeof Symbol?Symbol.prototype:null,o=t.push,i=t.slice,a=e.toString,f=e.hasOwnProperty,c="undefined"!=typeof ArrayBuffer,l="undefined"!=typeof DataView,s=Array.isArray,p=Object.keys,v=Object.create,h=c&&ArrayBuffer.isView,y=isNaN,d=isFinite,g=!{toString:null}.propertyIsEnumerable("toString"),b=["valueOf","isPrototypeOf","toString","propertyIsEnumerable","hasOwnProperty","toLocaleString"],m=Math.pow(2,53)-1;function j(n,r){return r=null==r?n.length-1:+r,function(){for(var t=Math.max(arguments.length-r,0),e=Array(t),u=0;u=0&&t<=m}}function J(n){return function(r){return null==r?void 0:r[n]}}var G=J("byteLength"),H=K(G),Q=/\[object ((I|Ui)nt(8|16|32)|Float(32|64)|Uint8Clamped|Big(I|Ui)nt64)Array\]/;var X=c?function(n){return h?h(n)&&!q(n):H(n)&&Q.test(a.call(n))}:C(!1),Y=J("length");function Z(n,r){r=function(n){for(var r={},t=n.length,e=0;e":">",'"':""","'":"'","`":"`"},Cn=Ln($n),Kn=Ln(_n($n)),Jn=tn.templateSettings={evaluate:/<%([\s\S]+?)%>/g,interpolate:/<%=([\s\S]+?)%>/g,escape:/<%-([\s\S]+?)%>/g},Gn=/(.)^/,Hn={"'":"'","\\":"\\","\r":"r","\n":"n","\u2028":"u2028","\u2029":"u2029"},Qn=/\\|'|\r|\n|\u2028|\u2029/g;function Xn(n){return"\\"+Hn[n]}var Yn=/^\s*(\w|\$)+\s*$/;var Zn=0;function nr(n,r,t,e,u){if(!(e instanceof r))return n.apply(t,u);var o=Mn(n.prototype),i=n.apply(o,u);return _(i)?i:o}var rr=j((function(n,r){var t=rr.placeholder,e=function(){for(var u=0,o=r.length,i=Array(o),a=0;a1)ur(a,r-1,t,e),u=e.length;else for(var f=0,c=a.length;f0&&(t=r.apply(this,arguments)),n<=1&&(r=null),t}}var lr=rr(cr,2);function sr(n,r,t){r=qn(r,t);for(var e,u=nn(n),o=0,i=u.length;o0?0:u-1;o>=0&&o0?a=o>=0?o:Math.max(o+f,a):f=o>=0?Math.min(o+1,f):o+f+1;else if(t&&o&&f)return e[o=t(e,u)]===u?o:-1;if(u!=u)return(o=r(i.call(e,a,f),$))>=0?o+a:-1;for(o=n>0?a:f-1;o>=0&&o0?0:i-1;for(u||(e=r[o?o[a]:a],a+=n);a>=0&&a=3;return r(n,Fn(t,u,4),e,o)}}var Ar=wr(1),xr=wr(-1);function Sr(n,r,t){var e=[];return r=qn(r,t),jr(n,(function(n,t,u){r(n,t,u)&&e.push(n)})),e}function Or(n,r,t){r=qn(r,t);for(var e=!er(n)&&nn(n),u=(e||n).length,o=0;o=0}var Br=j((function(n,r,t){var e,u;return D(r)?u=r:(r=Nn(r),e=r.slice(0,-1),r=r[r.length-1]),_r(n,(function(n){var o=u;if(!o){if(e&&e.length&&(n=In(n,e)),null==n)return;o=n[r]}return null==o?o:o.apply(n,t)}))}));function Nr(n,r){return _r(n,Rn(r))}function Ir(n,r,t){var e,u,o=-1/0,i=-1/0;if(null==r||"number"==typeof r&&"object"!=typeof n[0]&&null!=n)for(var a=0,f=(n=er(n)?n:jn(n)).length;ao&&(o=e);else r=qn(r,t),jr(n,(function(n,t,e){((u=r(n,t,e))>i||u===-1/0&&o===-1/0)&&(o=n,i=u)}));return o}function Tr(n,r,t){if(null==r||t)return er(n)||(n=jn(n)),n[Wn(n.length-1)];var e=er(n)?En(n):jn(n),u=Y(e);r=Math.max(Math.min(r,u),0);for(var o=u-1,i=0;i1&&(e=Fn(e,r[1])),r=an(n)):(e=qr,r=ur(r,!1,!1),n=Object(n));for(var u=0,o=r.length;u1&&(t=r[1])):(r=_r(ur(r,!1,!1),String),e=function(n,t){return!Er(r,t)}),Ur(n,e,t)}));function zr(n,r,t){return i.call(n,0,Math.max(0,n.length-(null==r||t?1:r)))}function Lr(n,r,t){return null==n||n.length<1?null==r||t?void 0:[]:null==r||t?n[0]:zr(n,n.length-r)}function $r(n,r,t){return i.call(n,null==r||t?1:r)}var Cr=j((function(n,r){return r=ur(r,!0,!0),Sr(n,(function(n){return!Er(r,n)}))})),Kr=j((function(n,r){return Cr(n,r)}));function Jr(n,r,t,e){A(r)||(e=t,t=r,r=!1),null!=t&&(t=qn(t,e));for(var u=[],o=[],i=0,a=Y(n);ir?(e&&(clearTimeout(e),e=null),a=c,i=n.apply(u,o),e||(u=o=null)):e||!1===t.trailing||(e=setTimeout(f,l)),i};return c.cancel=function(){clearTimeout(e),a=0,e=u=o=null},c},debounce:function(n,r,t){var e,u,o,i,a,f=function(){var c=zn()-u;r>c?e=setTimeout(f,r-c):(e=null,t||(i=n.apply(a,o)),e||(o=a=null))},c=j((function(c){return a=this,o=c,u=zn(),e||(e=setTimeout(f,r),t&&(i=n.apply(a,o))),i}));return c.cancel=function(){clearTimeout(e),e=o=a=null},c},wrap:function(n,r){return rr(r,n)},negate:fr,compose:function(){var n=arguments,r=n.length-1;return function(){for(var t=r,e=n[r].apply(this,arguments);t--;)e=n[t].call(this,e);return e}},after:function(n,r){return function(){if(--n<1)return r.apply(this,arguments)}},before:cr,once:lr,findKey:sr,findIndex:vr,findLastIndex:hr,sortedIndex:yr,indexOf:gr,lastIndexOf:br,find:mr,detect:mr,findWhere:function(n,r){return mr(n,Dn(r))},each:jr,forEach:jr,map:_r,collect:_r,reduce:Ar,foldl:Ar,inject:Ar,reduceRight:xr,foldr:xr,filter:Sr,select:Sr,reject:function(n,r,t){return Sr(n,fr(qn(r)),t)},every:Or,all:Or,some:Mr,any:Mr,contains:Er,includes:Er,include:Er,invoke:Br,pluck:Nr,where:function(n,r){return Sr(n,Dn(r))},max:Ir,min:function(n,r,t){var e,u,o=1/0,i=1/0;if(null==r||"number"==typeof r&&"object"!=typeof n[0]&&null!=n)for(var a=0,f=(n=er(n)?n:jn(n)).length;ae||void 0===t)return 1;if(t 2 | 3 | 4 | 5 | 6 | Index — MinImagen 2022 documentation 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 |
25 | 46 | 47 |
51 | 52 |
53 |
54 |
55 |
    56 |
  • »
  • 57 |
  • Index
  • 58 |
  • 59 |
  • 60 |
61 |
62 |
63 |
64 |
65 | 66 | 67 |

Index

68 | 69 |
70 | B 71 | | D 72 | | F 73 | | G 74 | | I 75 | | L 76 | | M 77 | | P 78 | | Q 79 | | S 80 | | T 81 | | U 82 | 83 |
84 |

B

85 | 86 | 90 | 94 |
95 | 96 |

D

97 | 98 | 110 | 114 |
115 | 116 |

F

117 | 118 | 126 | 130 |
131 | 132 |

G

133 | 134 | 138 | 142 |
143 | 144 |

I

145 | 146 | 150 |
151 | 152 |

L

153 | 154 | 158 | 164 |
165 | 166 |

M

167 | 168 | 198 | 222 |
    169 |
  • 170 | minimagen.diffusion_model 171 | 172 |
  • 176 |
  • 177 | minimagen.generate 178 | 179 |
  • 183 |
  • 184 | minimagen.Imagen 185 | 186 |
  • 190 |
  • 191 | minimagen.t5 192 | 193 |
  • 197 |
223 | 224 |

P

225 | 226 | 230 |
231 | 232 |

Q

233 | 234 | 238 | 242 |
243 | 244 |

S

245 | 246 | 252 | 260 |
261 | 262 |

T

263 | 264 | 268 |
269 | 270 |

U

271 | 272 | 276 |
277 | 278 | 279 | 280 |
281 |
282 |
283 | 284 |
285 | 286 |
287 |

© Copyright 2022, AssemblyAI.

288 |
289 | 290 | Built with Sphinx using a 291 | theme 292 | provided by Read the Docs. 293 | 294 | 295 |
296 |
297 |
298 |
299 |
300 | 305 | 306 | 307 | -------------------------------------------------------------------------------- /docs/build/html/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Welcome to MinImagen’s documentation! — MinImagen 2022 documentation 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 48 | 49 |
53 | 54 |
55 |
56 |
57 |
    58 |
  • »
  • 59 |
  • Welcome to MinImagen’s documentation!
  • 60 |
  • 61 | View page source 62 |
  • 63 |
64 |
65 |
66 |
67 |
68 | 69 |
70 |

Welcome to MinImagen’s documentation!

71 |
72 |

Contents:

73 | 84 |
85 |
86 |
87 |

Indices and tables

88 | 93 |
94 | 95 | 96 |
97 |
98 |
101 | 102 |
103 | 104 |
105 |

© Copyright 2022, AssemblyAI.

106 |
107 | 108 | Built with Sphinx using a 109 | theme 110 | provided by Read the Docs. 111 | 112 | 113 |
114 |
115 |
116 |
117 |
118 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /docs/build/html/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/build/html/objects.inv -------------------------------------------------------------------------------- /docs/build/html/py-modindex.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Python Module Index — MinImagen 2022 documentation 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 | 49 | 50 |
54 | 55 |
56 |
57 |
58 |
    59 |
  • »
  • 60 |
  • Python Module Index
  • 61 |
  • 62 |
  • 63 |
64 |
65 |
66 |
67 |
68 | 69 | 70 |

Python Module Index

71 | 72 |
73 | m 74 |
75 | 76 | 77 | 78 | 80 | 81 | 83 | 86 | 87 | 88 | 91 | 92 | 93 | 96 | 97 | 98 | 101 | 102 | 103 | 106 | 107 | 108 | 111 |
 
79 | m
84 | minimagen 85 |
    89 | minimagen.diffusion_model 90 |
    94 | minimagen.generate 95 |
    99 | minimagen.Imagen 100 |
    104 | minimagen.t5 105 |
    109 | minimagen.Unet 110 |
112 | 113 | 114 |
115 |
116 |
117 | 118 |
119 | 120 |
121 |

© Copyright 2022, AssemblyAI.

122 |
123 | 124 | Built with Sphinx using a 125 | theme 126 | provided by Read the Docs. 127 | 128 | 129 |
130 |
131 |
132 |
133 |
134 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /docs/build/html/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Search — MinImagen 2022 documentation 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 | 49 | 50 |
54 | 55 |
56 |
57 |
58 |
    59 |
  • »
  • 60 |
  • Search
  • 61 |
  • 62 |
  • 63 |
64 |
65 |
66 |
67 |
68 | 69 | 76 | 77 | 78 |
79 | 80 |
81 | 82 |
83 |
84 |
85 | 86 |
87 | 88 |
89 |

© Copyright 2022, AssemblyAI.

90 |
91 | 92 | Built with Sphinx using a 93 | theme 94 | provided by Read the Docs. 95 | 96 | 97 |
98 |
99 |
100 |
101 |
102 | 107 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /docs/build/html/searchindex.js: -------------------------------------------------------------------------------- 1 | Search.setIndex({"docnames": ["index", "minimagen"], "filenames": ["index.rst", "minimagen.rst"], "titles": ["Welcome to MinImagen\u2019s documentation!", "minimagen"], "terms": {"imagen": 0, "unet": 0, "diffus": 0, "model": 0, "t5": 0, "train": 0, "gener": 0, "index": 0, "modul": [0, 1], "search": 0, "page": 0, "class": 1, "union": 1, "list": 1, "tupl": 1, "text_encoder_nam": 1, "str": 1, "image_s": 1, "int": 1, "text_embed_dim": 1, "option": 1, "none": 1, "channel": 1, "3": 1, "timestep": 1, "1000": 1, "cond_drop_prob": 1, "float": 1, "0": 1, "1": 1, "loss_typ": 1, "liter": 1, "l1": 1, "l2": 1, "huber": 1, "lowres_sample_noise_level": 1, "2": 1, "auto_normalize_img": 1, "bool": 1, "true": 1, "dynamic_thresholding_percentil": 1, "9": 1, "only_train_unet_numb": 1, "base": 1, "minim": 1, "implement": 1, "properti": 1, "devic": 1, "forward": 1, "imag": 1, "text": 1, "text_emb": 1, "tensor": 1, "text_mask": 1, "unet_numb": 1, "pass": 1, "nois": 1, "calcul": 1, "loss": 1, "from": 1, "u": 1, "net": 1, "predict": 1, "paramet": 1, "oper": 1, "shape": 1, "b": 1, "c": 1, "s": 1, "caption": 1, "condit": 1, "length": 1, "embed": 1, "us": 1, "mask": 1, "which": 1, "number": 1, "ar": 1, "multipl": 1, "return": 1, "load_state_dict": 1, "arg": 1, "kwarg": 1, "overrid": 1, "place": 1, "all": 1, "instanc": 1, "one": 1, "when": 1, "call": 1, "sampl": 1, "state_dict": 1, "default": 1, "argument": 1, "origin": 1, "dim": 1, "512": 1, "dim_mult": 1, "4": 1, "num_resnet_block": 1, "layer_attn": 1, "fals": 1, "layer_cross_attn": 1, "memory_effici": 1, "basetest": 1, "intend": 1, "test": 1, "8": 1, "super": 1, "resolut": 1, "128": 1, "supertest": 1, "channels_out": 1, "cond_dim": 1, "attn_head": 1, "lowres_cond": 1, "attend_at_middl": 1, "denois": 1, "via": 1, "see": 1, "also": 1, "diffusion_model": 1, "gaussiandiffus": 1, "x": 1, "time": 1, "lowres_cond_img": 1, "lowres_noise_tim": 1, "input": 1, "each": 1, "upsampl": 1, "low": 1, "re": 1, "augment": 1, "size": 1, "256": 1, "embedding_dim": 1, "t5_encode_text": 1, "probabl": 1, "drop": 1, "info": 1, "classifi": 1, "free": 1, "guidanc": 1, "rang": 1, "forward_with_cond_scal": 1, "cond_scal": 1, "add": 1, "scale": 1, "uncondit": 1, "standard": 1, "larg": 1, "weight": 1, "improv": 1, "qualiti": 1, "fidel": 1, "cost": 1, "divers": 1, "here": 1, "more": 1, "inform": 1, "keyword": 1, "guid": 1, "predict_start_from_nois": 1, "x_t": 1, "t": 1, "given": 1, "its": 1, "compon": 1, "unnois": 1, "x_0": 1, "un": 1, "q_posterior": 1, "x_start": 1, "built": 1, "method": 1, "type": 1, "object": 1, "0x00007ffa641ea7f0": 1, "start": 1, "where": 1, "mean": 1, "And": 1, "varianc": 1, "prefactor": 1, "h": 1, "w": 1, "current": 1, "posterior": 1, "clip": 1, "log": 1, "q_sampl": 1, "q": 1, "valu": 1, "batch": 1, "suppli": 1, "gaussian": 1, "get_encoded_dim": 1, "name": 1, "get": 1, "encod": 1, "dimension": 1, "t5_base": 1, "max_length": 1, "sequenc": 1, "t5_small": 1, "24": 1, "gb": 1, "89": 1, "768": 1, "t5_larg": 1, "75": 1, "1024": 1, "t5_3b": 1, "10": 1, "6": 1, "t5_11b": 1, "42": 1, "attent": 1, "element": 1, "i": 1, "j": 1, "k": 1, "final": 1, "correspond": 1, "th": 1, "token": 1, "load_minimagen": 1, "directori": 1, "load": 1, "structur": 1, "accord": 1, "create_directori": 1, "readi": 1, "infer": 1, "load_param": 1, "path": 1, "unets_param": 1, "imagen_param": 1, "whose": 1, "sample_and_sav": 1, "training_directori": 1, "sample_arg": 1, "dict": 1, "save_directori": 1, "filetyp": 1, "png": 1, "save": 1, "generated_imag": 1, "image_": 1, "caption_index": 1, "string": 1, "e": 1, "must": 1, "specifi": 1, "addit": 1, "function": 1, "do": 1, "includ": 1, "code": 1, "return_pil_imag": 1, "thi": 1, "dictionari": 1, "datetim": 1, "stamp": 1}, "objects": {"minimagen": [[1, 0, 0, "-", "Imagen"], [1, 0, 0, "-", "Unet"], [1, 0, 0, "-", "diffusion_model"], [1, 0, 0, "-", "generate"], [1, 0, 0, "-", "t5"]], "minimagen.Imagen": [[1, 1, 1, "", "Imagen"]], "minimagen.Imagen.Imagen": [[1, 2, 1, "", "device"], [1, 3, 1, "", "forward"], [1, 3, 1, "", "load_state_dict"], [1, 3, 1, "", "sample"], [1, 3, 1, "", "state_dict"]], "minimagen.Unet": [[1, 1, 1, "", "Base"], [1, 1, 1, "", "BaseTest"], [1, 1, 1, "", "Super"], [1, 1, 1, "", "SuperTest"], [1, 1, 1, "", "Unet"]], "minimagen.Unet.Base": [[1, 4, 1, "", "defaults"]], "minimagen.Unet.BaseTest": [[1, 4, 1, "", "defaults"]], "minimagen.Unet.Super": [[1, 4, 1, "", "defaults"]], "minimagen.Unet.SuperTest": [[1, 4, 1, "", "defaults"]], "minimagen.Unet.Unet": [[1, 3, 1, "", "forward"], [1, 3, 1, "", "forward_with_cond_scale"]], "minimagen.diffusion_model": [[1, 1, 1, "", "GaussianDiffusion"]], "minimagen.diffusion_model.GaussianDiffusion": [[1, 3, 1, "", "predict_start_from_noise"], [1, 3, 1, "", "q_posterior"], [1, 3, 1, "", "q_sample"]], "minimagen.generate": [[1, 5, 1, "", "load_minimagen"], [1, 5, 1, "", "load_params"], [1, 5, 1, "", "sample_and_save"]], "minimagen.t5": [[1, 5, 1, "", "get_encoded_dim"], [1, 5, 1, "", "t5_encode_text"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:property", "3": "py:method", "4": "py:attribute", "5": "py:function"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "property", "Python property"], "3": ["py", "method", "Python method"], "4": ["py", "attribute", "Python attribute"], "5": ["py", "function", "Python function"]}, "titleterms": {"welcom": 0, "minimagen": [0, 1], "s": 0, "document": 0, "content": 0, "indic": 0, "tabl": 0, "imagen": 1, "unet": 1, "diffus": 1, "model": 1, "t5": 1, "train": 1, "gener": 1}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 56}}) -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/requirements.txt -------------------------------------------------------------------------------- /docs/source/_static/clf_free_guidance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/clf_free_guidance.png -------------------------------------------------------------------------------- /docs/source/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/file.png -------------------------------------------------------------------------------- /docs/source/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/minus.png -------------------------------------------------------------------------------- /docs/source/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/plus.png -------------------------------------------------------------------------------- /docs/source/_static/posterior_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/posterior_mean.png -------------------------------------------------------------------------------- /docs/source/_static/posterior_mean_coeffs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/posterior_mean_coeffs.png -------------------------------------------------------------------------------- /docs/source/_static/posterior_variance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/posterior_variance.png -------------------------------------------------------------------------------- /docs/source/_static/posterior_variance_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/posterior_variance_box.png -------------------------------------------------------------------------------- /docs/source/_static/q_posterior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/q_posterior.png -------------------------------------------------------------------------------- /docs/source/_static/q_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/q_sample.png -------------------------------------------------------------------------------- /docs/source/_static/q_sample_reparam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/q_sample_reparam.png -------------------------------------------------------------------------------- /docs/source/_static/x_tm1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/docs/source/_static/x_tm1.png -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('..\\..\\venv\\Lib\\site-packages')) 16 | sys.path.insert(0, os.path.abspath('..\\..')) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'MinImagen' 22 | copyright = '2022, AssemblyAI' 23 | author = 'AssemblyAI' 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = '2022' 27 | 28 | 29 | # -- General configuration --------------------------------------------------- 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = ["sphinx.ext.autodoc" 35 | ] 36 | 37 | # Add any paths that contain templates here, relative to this directory. 38 | templates_path = ['_templates'] 39 | 40 | # List of patterns, relative to source directory, that match files and 41 | # directories to ignore when looking for source files. 42 | # This pattern also affects html_static_path and html_extra_path. 43 | exclude_patterns = [] 44 | 45 | 46 | # -- Options for HTML output ------------------------------------------------- 47 | 48 | # The theme to use for HTML and HTML Help pages. See the documentation for 49 | # a list of builtin themes. 50 | # 51 | html_theme = 'sphinx_rtd_theme' 52 | 53 | # Add any paths that contain custom static files (such as style sheets) here, 54 | # relative to this directory. They are copied after the builtin static files, 55 | # so a file named "default.css" will overwrite the builtin "default.css". 56 | html_static_path = ['_static'] 57 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. MinImagen documentation master file, created by 2 | sphinx-quickstart on Mon Aug 15 18:23:24 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to MinImagen's documentation! 7 | ===================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | minimagen 14 | 15 | 16 | 17 | Indices and tables 18 | ================== 19 | 20 | * :ref:`genindex` 21 | * :ref:`modindex` 22 | * :ref:`search` 23 | -------------------------------------------------------------------------------- /docs/source/minimagen.rst: -------------------------------------------------------------------------------- 1 | minimagen 2 | ================== 3 | 4 | Imagen 5 | ------------------------ 6 | 7 | .. automodule:: minimagen.Imagen 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | Unet 13 | ---------------------- 14 | 15 | .. automodule:: minimagen.Unet 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | Diffusion Model 21 | ---------------------------------- 22 | 23 | .. automodule:: minimagen.diffusion_model 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | T5 29 | -------------------- 30 | 31 | .. automodule:: minimagen.t5 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Training 37 | -------------------- 38 | 39 | .. automodule:: minimagen.training 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | Generate 45 | -------------------- 46 | 47 | .. automodule:: minimagen.generate 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: -------------------------------------------------------------------------------- /images/clf_free_guidance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/clf_free_guidance.png -------------------------------------------------------------------------------- /images/conditioning_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/conditioning_diagram.png -------------------------------------------------------------------------------- /images/dynamic_threshold.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/dynamic_threshold.mp4 -------------------------------------------------------------------------------- /images/model_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/model_structure.png -------------------------------------------------------------------------------- /images/posterior_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/posterior_mean.png -------------------------------------------------------------------------------- /images/posterior_mean_coeffs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/posterior_mean_coeffs.png -------------------------------------------------------------------------------- /images/posterior_variance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/posterior_variance.png -------------------------------------------------------------------------------- /images/posterior_variance_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/posterior_variance_box.png -------------------------------------------------------------------------------- /images/q_posterior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/q_posterior.png -------------------------------------------------------------------------------- /images/q_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/q_sample.png -------------------------------------------------------------------------------- /images/q_sample_reparam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/q_sample_reparam.png -------------------------------------------------------------------------------- /images/transformer_full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/transformer_full.png -------------------------------------------------------------------------------- /images/x_tm1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/images/x_tm1.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from minimagen.generate import load_minimagen, sample_and_save 3 | 4 | # Command line argument parser 5 | parser = ArgumentParser() 6 | parser.add_argument("-c", "--CAPTIONS", dest="CAPTIONS", help="Single caption to generate for or filepath for .txt " 7 | "file of captions to generate for", default=None, type=str) 8 | parser.add_argument("-d", "--TRAINING_DIRECTORY", dest="TRAINING_DIRECTORY", help="Training directory to use for inference", type=str) 9 | args = parser.parse_args() 10 | 11 | minimagen = load_minimagen(args.TRAINING_DIRECTORY) 12 | 13 | if args.CAPTIONS is None: 14 | print("\nNo caption supplied - using the default of \"a happy dog\".\n") 15 | captions = ['a happy dog'] 16 | elif not args.CAPTIONS.endswith(".txt"): 17 | captions = [args.CAPTIONS] 18 | elif args.CAPTIONS.endswith(".txt"): 19 | with open(args.CAPTIONS, 'r') as f: 20 | lines = f.readlines() 21 | captions = [line[:-1] if line.endswith('\n') else line for line in lines] 22 | else: 23 | raise ValueError("Please input a valid argument for --CAPTIONS") 24 | 25 | # Can supply a training dictionary to load from for inference 26 | sample_and_save(captions, training_directory=args.TRAINING_DIRECTORY, sample_args={'cond_scale':3.}) 27 | 28 | # Otherwise, can supply a MinImagen instance itself. In this case, information about the instance will not be saved. 29 | # sample_and_save(captions, minimagen=minimagen, sample_args={'cond_scale':3.}) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from datetime import datetime 3 | 4 | # Get timestamp for training 5 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 6 | 7 | # Run training on small test Imagen 8 | subprocess.call(["venv/Scripts/python", "train.py", "-test", "-ts", timestamp]) 9 | 10 | # Use small test Imagen to generate image 11 | subprocess.call(["venv/Scripts/python", "inference.py", "-d", f"training_{timestamp}"]) -------------------------------------------------------------------------------- /minimagen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/minimagen/__init__.py -------------------------------------------------------------------------------- /minimagen/diffusion_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from .helpers import default, log, extract 6 | 7 | 8 | class GaussianDiffusion(nn.Module): 9 | """ 10 | `Diffusion Model `_. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | *, 16 | timesteps: int 17 | ): 18 | """ 19 | :param timesteps: Number of timesteps in the Diffusion Process. 20 | """ 21 | super().__init__() 22 | 23 | # Timesteps < 20 => scale > 50 => beta_end > 1 => alphas[-1] < 0 => sqrt_alphas_cumprod[-1] is NaN 24 | assert not timesteps < 20, f'timsteps must be at least 20' 25 | self.num_timesteps = timesteps 26 | 27 | # Create variance schedule. 28 | scale = 1000 / timesteps 29 | beta_start = scale * 0.0001 30 | beta_end = scale * 0.02 31 | betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) 32 | 33 | # Diffusion model constants/buffers. See https://arxiv.org/pdf/2006.11239.pdf 34 | alphas = 1. - betas 35 | alphas_cumprod = torch.cumprod(alphas, axis=0) 36 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) 37 | 38 | # register buffer helper function 39 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32), persistent=False) 40 | 41 | # Register variance schedule related buffers 42 | register_buffer('betas', betas) 43 | register_buffer('alphas_cumprod', alphas_cumprod) 44 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 45 | 46 | # Buffer for diffusion calculations q(x_t | x_{t-1}) and others 47 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 48 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 49 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 50 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 51 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 52 | 53 | # calculations for posterior q(x_{t-1} | x_t, x_0) 54 | 55 | # Posterior variance: 56 | # https://github.com/AssemblyAI-Examples/build-your-own-imagen/blob/main/images/posterior_variance.png 57 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 58 | register_buffer('posterior_variance', posterior_variance) 59 | 60 | # Clipped because posterior variance is 0 at the beginning of the diffusion chain 61 | register_buffer('posterior_log_variance_clipped', log(posterior_variance, eps=1e-20)) 62 | 63 | # Buffers for calculating the q_posterior mean $\~{\mu}$. See 64 | # https://github.com/oconnoob/minimal_imagen/blob/minimal/images/posterior_mean_coeffs.png 65 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 66 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) 67 | 68 | def _get_times(self, batch_size: int, noise_level: float, *, device: torch.device) -> torch.tensor: 69 | return torch.full((batch_size,), int(self.num_timesteps * noise_level), device=device, dtype=torch.long) 70 | 71 | def _sample_random_times(self, batch_size: int, *, device: torch.device) -> torch.tensor: 72 | """ 73 | Randomly sample `batch_size` timestep values uniformly from [0, 1, ..., `self.num_timesteps`] 74 | 75 | :param batch_size: Number of images in the batch. 76 | :param device: Device on which to place the return tensor. 77 | :return: Tensor of integers (`dtype=torch.long`) of shape `(batch_size,)` 78 | """ 79 | return torch.randint(0, self.num_timesteps, (batch_size,), device=device, dtype=torch.long) 80 | 81 | def _get_sampling_timesteps(self, batch: int, *, device: torch.device) -> list[torch.tensor]: 82 | time_transitions = [] 83 | 84 | for i in reversed(range(self.num_timesteps)): 85 | time_transitions.append((torch.full((batch,), i, device=device, dtype=torch.long))) 86 | 87 | return time_transitions 88 | 89 | def q_posterior(self, x_start: torch.tensor, x_t: torch.tensor, t: torch.tensor) -> tuple[torch.tensor, 90 | torch.tensor, 91 | torch.tensor]: 92 | """ 93 | Calculates q_posterior parameters given a starting image :code:`x_start` (x_0) and a noised image :code:`x_t`. 94 | 95 | .. figure:: _static/q_posterior.png 96 | :alt: q posterior formula 97 | 98 | Where the mean is 99 | 100 | .. image:: _static/posterior_mean.png 101 | 102 | And the variance prefactor is 103 | 104 | .. image:: _static/posterior_variance.png 105 | :width: 300px 106 | 107 | :param x_start: Original input images x_0. Shape (b, c, h, w) 108 | :param x_t: Images at current time x_t. Shape (b, c, h, w) 109 | :param t: Current time. Shape (b,) 110 | :return: Tuple of 111 | 112 | - **posterior mean**, shape (b, c, s, s), 113 | 114 | - **posterior variance**, shape (b, 1, 1, 1), 115 | 116 | - **clipped log of the posterior variance**, shape (b, 1, 1, 1) 117 | """ 118 | posterior_mean = ( 119 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 120 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 121 | ) 122 | # Extract the value corresponding to the current time from the buffers, and then reshape to (b, 1, 1, 1) 123 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 124 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 125 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 126 | 127 | def q_sample(self, x_start: torch.tensor, t: torch.tensor, noise: torch.tensor = None) -> torch.tensor: 128 | """ 129 | Sample from q at a given timestep: 130 | 131 | .. image:: _static/q_sample.png 132 | 133 | 134 | :param x_start: Original input images. Shape (b, c, h, w). 135 | :param t: Timestep value for each image in the batch. Shape (b,). 136 | :param noise: Optionally supply noise to use. Defaults to Gaussian. Shape (b, c, s, s). 137 | :return: Noised image. Shape (b, c, h, w). 138 | 139 | """ 140 | noise = default(noise, lambda: torch.randn_like(x_start)) 141 | 142 | noised = ( 143 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 144 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 145 | ) 146 | 147 | return noised 148 | 149 | def predict_start_from_noise(self, x_t: torch.tensor, t: torch.tensor, noise: torch.tensor) -> torch.tensor: 150 | """ 151 | Given a noised image and its noise component, calculated the unnoised image :code:`x_0`. 152 | 153 | :param x_t: Noised images. Shape (b, c, s, s). 154 | :param t: Timestep for each image. Shape (b,). 155 | :param noise: Noise component for each image. Shape (b, c, s, s). 156 | :return: Un-noised images. Shape (b, c, s, s). 157 | 158 | """ 159 | return ( 160 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 161 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 162 | ) 163 | -------------------------------------------------------------------------------- /minimagen/generate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from contextlib import contextmanager 4 | from datetime import datetime 5 | 6 | import torch 7 | 8 | from minimagen.Unet import Unet 9 | from minimagen.Imagen import Imagen 10 | 11 | 12 | def _create_directory(dir_path): 13 | """ 14 | Creates a directory at the given path if it does not exist already and returns a context manager that allows user 15 | to temporarily enter the directory. Similar to `minimagen.training.create_dictionary` but for generated_images 16 | folder rather than training folder. 17 | """ 18 | original_dir = os.getcwd() 19 | img_path = os.path.join(original_dir, dir_path, "generated_images") 20 | if not os.path.exists(img_path): 21 | os.makedirs(img_path) 22 | elif not len(os.listdir(img_path)) == 0: 23 | raise FileExistsError(f"The directory {os.path.join(original_dir, img_path)} already exists and is nonempty") 24 | 25 | @contextmanager 26 | def cm(subdir=""): 27 | os.chdir(os.path.join(original_dir, dir_path, subdir)) 28 | yield 29 | os.chdir(original_dir) 30 | return cm 31 | 32 | 33 | def _get_best_state_dict(unet_number, files): 34 | """ Gets the filename for the state_dict with lowest validation accuracy for given unet number""" 35 | # Filter out files not for current unet 36 | filt_list = list(filter(lambda x: x.startswith(f"unet_{unet_number}"), files)) 37 | # Get validation loss of best state_dict for this unet 38 | min_val = min([i.split("_")[-1].split(".pth")[0] for i in filt_list]) 39 | # Get the filename for the best state_dict for this unet 40 | return list(filter(lambda x: x.endswith(f"{min_val}.pth"), filt_list))[0] 41 | 42 | 43 | def _read_params(directory, filename): 44 | """Returns dictionary from JSON config file in the parameters folder of a training directory""" 45 | with open(os.path.join(directory, "parameters", filename), 'r') as _file: 46 | return json.loads(_file.read()) 47 | 48 | 49 | def load_params(directory): 50 | """ 51 | Loads Unets and Imagen parameters from a training directory 52 | 53 | :param directory: Path of training directory generated by training 54 | :return: (unets_params, imagen_params) where unets_params is a list whose i-th element are the parameters of the 55 | i-th Unet in the Imagen instance. 56 | """ 57 | # Files in parameters directory 58 | files = os.listdir(os.path.join(directory, "parameters")) 59 | 60 | # Filter only param files for U-Nets 61 | unets_params_files = sorted(list(filter(lambda x: x.startswith("unet_", ), files)), 62 | key=lambda x: int(x.split("_")[1])) 63 | 64 | # Load U-Nets / MinImagen parameters 65 | unets_params = [_read_params(directory, f) for f in unets_params_files] 66 | imagen_params_files = _read_params(directory, list(filter(lambda x: x.startswith("imagen_"), files))[0]) 67 | return unets_params, imagen_params_files 68 | 69 | 70 | def _instatiate_minimagen(directory): 71 | # TODO: When restarted training, parameters folder only has the cmd line args, not the unet/imagen params. 72 | # had to copy from training folder this one was restarted from. Fix this so it copies. 73 | """ Instantiate an Imagen model with given parameters """ 74 | unets_params, imagen_params_files = load_params(directory) 75 | 76 | return Imagen(unets=[Unet(**params) for params in unets_params], **imagen_params_files) 77 | 78 | 79 | def load_minimagen(directory): 80 | """ 81 | Load a :obj:`MinImagen <.minimagen.Imagen.Imagen>`. instance from a training directory. 82 | 83 | :param directory: MinImagen training directory as structured according to :func:`.minimagen.training.create_directory`. 84 | :return: :obj:`MinImagen <.minimagen.Imagen.Imagen>` instance (ready for inference). 85 | """ 86 | map_location = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 87 | 88 | minimagen = _instatiate_minimagen(directory) 89 | 90 | # Filepaths for all statedicts 91 | files = os.listdir(os.path.join(directory, "state_dicts")) 92 | 93 | # Use tmp folder if state_dicts empty 94 | if files != []: 95 | subdir = "state_dicts" 96 | num_unets = int(max(set([i.split("_")[1] for i in list(filter(lambda x: x.startswith("unet_"), files))]))) + 1 97 | 98 | # Load best state for each unet in the minimagen instance 99 | unet_state_dicts = [list(filter(lambda x: x.startswith(f"unet_{i}"), files))[0] for i in range(num_unets)] 100 | for idx, file in enumerate(unet_state_dicts): 101 | pth = os.path.join(directory, f'{subdir}', file) 102 | minimagen.unets[idx].load_state_dict(torch.load(pth, map_location=map_location)) 103 | 104 | else: 105 | subdir = "tmp" 106 | print(f"\n\"state_dicts\" folder in {directory} is empty, using the most recent checkpoint from \"tmp\".\n") 107 | files = os.listdir(os.path.join(directory, f"{subdir}")) 108 | 109 | if files == []: 110 | raise ValueError(f"Both \"/state_dicts\" and \"/tmp\" in {directory} are empty. Train the model to acquire state dictionaries for inference. ") 111 | 112 | num_unets = int(max(set([i.split("_")[1] for i in list(filter(lambda x: x.startswith("unet_"), files))]))) + 1 113 | 114 | 115 | # Load best state for each unet in the minimagen instance 116 | unet_state_dicts = [list(filter(lambda x: x.startswith(f"unet_{i}"), files))[0] for i in range(num_unets)] 117 | for idx, file in enumerate(unet_state_dicts): 118 | pth = os.path.join(directory, f'{subdir}', file) 119 | minimagen.unets[idx].load_state_dict(torch.load(pth, map_location=map_location)) 120 | 121 | return minimagen 122 | 123 | 124 | def sample_and_save(captions: list, 125 | *, 126 | minimagen: Imagen = None, 127 | training_directory: str = None, 128 | sample_args: dict = {}, 129 | save_directory: str = None, 130 | filetype: str = "png"): 131 | """ 132 | Generate and save images for a list of captions using a :obj:`MinImagen <.minimagen.Imagen.Imagen>` instance. 133 | Images are saved into a "generated_images" directory as "image_." 134 | 135 | :param captions: List of captions (strings) to generate images for. 136 | :param minimagen: :obj:`MinImagen <.minimagen.Imagen.Imagen>` instance to use for sampling (i.e. 137 | generating images). Must specify one of :code:`minimagen` or :code:`training_directory`. 138 | :param training_directory: Training directory of MinImagen instance to use for inference. Must specify one 139 | of :code:`minimagen` or :code:`training_directory`. 140 | :param sample_args: Additional keyword arguments to pass for :obj:`Imagen.sample <.minimagen.Imagen.Imagen.sample>` 141 | function. Do not include :code:`texts` or code:`return_pil_images` in this dictionary. 142 | :param save_directory: Directory to save images to. Defaults to datetime-stamped directory if not specified. 143 | :param filetype: Filetype of saved images. 144 | :return: 145 | """ 146 | assert not (minimagen is None and training_directory is None), \ 147 | "Must supply either a training directory or MinImagen instance." 148 | 149 | assert (minimagen != None) ^ (training_directory != None), \ 150 | "Cannot supply both a MinImagen instance and a training directory" 151 | 152 | if save_directory is None: 153 | save_directory = datetime.now().strftime("generated_images_%Y%m%d_%H%M%S") 154 | 155 | cm = _create_directory(save_directory) 156 | 157 | with cm(): 158 | with open('captions.txt', 'w') as f: 159 | for caption in captions: 160 | f.write(f"{caption}\n") 161 | if training_directory is not None: 162 | with open('imagen_training_directory.txt', 'w') as f: 163 | f.write(training_directory) 164 | 165 | if training_directory is not None: 166 | minimagen = load_minimagen(training_directory) 167 | 168 | 169 | images = minimagen.sample(texts=captions, return_pil_images=True, **sample_args) 170 | 171 | with cm("generated_images"): 172 | for idx, img in enumerate(images): 173 | img.save(f'image_{idx}.{filetype}') 174 | -------------------------------------------------------------------------------- /minimagen/helpers.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from functools import wraps 3 | 4 | import torch 5 | from typing import Literal, Callable 6 | from resize_right import resize 7 | 8 | 9 | def cast_tuple(val, length: int = None) -> tuple: 10 | ''' 11 | Casts input to a tuple. If the input is a list, converts it to a tuple. If input a single value, casts it to a 12 | tuple of length `length`, which is 1 if not provided. 13 | ''' 14 | if isinstance(val, list): 15 | val = tuple(val) 16 | 17 | output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) 18 | 19 | if exists(length): 20 | assert len(output) == length 21 | 22 | return output 23 | 24 | 25 | def default(val, d): 26 | """ 27 | Returns the input value `val` unless it is `None`, in which case the default `d` is returned if it is a value or 28 | `d()` is returned if it is a callable. 29 | """ 30 | if exists(val): 31 | return val 32 | return d() if callable(d) else d 33 | 34 | 35 | def eval_decorator(fn): 36 | """ 37 | Decorator for sampling from Imagen. Temporarily sets the model in evaluation mode if it was training. 38 | """ 39 | def inner(model, *args, **kwargs): 40 | was_training = model.training 41 | model.eval() 42 | out = fn(model, *args, **kwargs) 43 | model.train(was_training) 44 | return out 45 | 46 | return inner 47 | 48 | 49 | def exists(val) -> bool: 50 | """ 51 | Checks to see if a value is not `None` 52 | """ 53 | return val is not None 54 | 55 | 56 | def extract(a: torch.tensor, t: torch.tensor, x_shape: torch.Size) -> torch.tensor: 57 | """ 58 | Extracts values from `a` using `t` as indices 59 | 60 | :param a: 1D tensor of length L. 61 | :param t: 1D tensor of length b. 62 | :param x_shape: Tensor of size (b, c, h, w). 63 | :return: Tensor of shape (b, 1, 1, 1) that selects elements of a, using t as indices of selection. 64 | """ 65 | b, *_ = t.shape 66 | out = a.gather(-1, t) 67 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 68 | 69 | 70 | def identity(t, *args, **kwargs): 71 | return t 72 | 73 | 74 | def log(t: torch.tensor, eps: float = 1e-12) -> torch.tensor: 75 | """ 76 | Calculates the natural logarithm of a torch tensor, clamping values to a minimum of `eps`. 77 | """ 78 | return torch.log(t.clamp(min=eps)) 79 | 80 | 81 | def maybe(fn: Callable) -> Callable: 82 | """ 83 | Returns a new function that simply applies the input function in all cases where the input is not `None`. If the 84 | input is `None`, `maybe` returns `None`. 85 | 86 | Passes through function name, docstring, etc. with [functools.wraps](https://docs.python.org/3/library/functools.html#functools.wraps) 87 | """ 88 | 89 | @wraps(fn) 90 | def inner(x): 91 | if not exists(x): 92 | return x 93 | return fn(x) 94 | 95 | return inner 96 | 97 | 98 | def module_device(module: torch.nn.Module) -> torch.device: 99 | """ 100 | Returns the device on which a Module's parameters lie 101 | """ 102 | return next(module.parameters()).device 103 | 104 | 105 | def normalize_neg_one_to_one(img: torch.tensor) -> torch.tensor: 106 | """ 107 | Normalizes an image in the range (0., 1.) to be in the range (-1., 1.). Inverse of 108 | :func:`.unnormalize_zero_to_one` 109 | """ 110 | return img * 2 - 1 111 | 112 | 113 | @contextmanager 114 | def null_context(*args, **kwargs): 115 | """ 116 | A placeholder null context manager that does nothing. 117 | """ 118 | yield 119 | 120 | 121 | def prob_mask_like(shape: tuple, prob: float, device: torch.device) -> torch.Tensor: 122 | """ 123 | For classifier free guidance. Creates a boolean mask for given input shape and probability of `True`. 124 | 125 | :param shape: Shape of mask. 126 | :param prob: Probability of True. In interval [0., 1.]. 127 | :param device: Device to put the mask on. Should be the same as that of the tensor which it will be used on. 128 | :return: The mask. 129 | """ 130 | if prob == 1: 131 | return torch.ones(shape, device=device, dtype=torch.bool) 132 | elif prob == 0: 133 | return torch.zeros(shape, device=device, dtype=torch.bool) 134 | else: 135 | return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob 136 | 137 | 138 | def resize_image_to(image: torch.tensor, 139 | target_image_size: int, 140 | clamp_range: tuple = None, 141 | pad_mode: Literal['constant', 'edge', 'reflect', 'symmetric'] = 'reflect' 142 | ) -> torch.tensor: 143 | """ 144 | Resizes image to desired size. 145 | 146 | :param image: Images to resize. Shape (b, c, s, s) 147 | :param target_image_size: Edge length to resize to. 148 | :param clamp_range: Range to clamp values to. Tuple of length 2. 149 | :param pad_mode: `constant`, `edge`, `reflect`, `symmetric`. 150 | See [TorchVision documentation](https://pytorch.org/vision/main/generated/torchvision.transforms.functional.pad.html) for additional details 151 | :return: Resized image. Shape (b, c, target_image_size, target_image_size) 152 | """ 153 | orig_image_size = image.shape[-1] 154 | 155 | if orig_image_size == target_image_size: 156 | return image 157 | 158 | scale_factors = target_image_size / orig_image_size 159 | out = resize(image, scale_factors=scale_factors, pad_mode=pad_mode) 160 | 161 | if exists(clamp_range): 162 | out = out.clamp(*clamp_range) 163 | 164 | return out 165 | 166 | 167 | def right_pad_dims_to(x: torch.tensor, t: torch.tensor) -> torch.tensor: 168 | """ 169 | Pads `t` with empty dimensions to the number of dimensions `x` has. If `t` does not have fewer dimensions than `x` 170 | it is returned without change. 171 | """ 172 | padding_dims = x.ndim - t.ndim 173 | if padding_dims <= 0: 174 | return t 175 | return t.view(*t.shape, *((1,) * padding_dims)) 176 | 177 | 178 | def unnormalize_zero_to_one(normed_img): 179 | """ 180 | Un-normalizes an image in the range (-1., 1.) to be in the range (-1., 1.). Inverse of 181 | :func:`.normalize_neg_one_to_one`. 182 | """ 183 | return (normed_img + 1) * 0.5 184 | -------------------------------------------------------------------------------- /minimagen/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from einops import rearrange 4 | from einops_exts import rearrange_many, repeat_many 5 | from einops_exts.torch import EinopsToAndFrom 6 | import math 7 | import torch 8 | from torch import nn, einsum 9 | import torch.nn.functional as F 10 | 11 | from .helpers import default, exists 12 | 13 | 14 | class Attention(nn.Module): 15 | """ 16 | Multi-headed attention 17 | """ 18 | 19 | def __init__( 20 | self, 21 | dim: int, 22 | *, 23 | dim_head: int = 64, 24 | heads: int = 8, 25 | context_dim: int = None 26 | ): 27 | """ 28 | :param dim: Input dimensionality. 29 | :param dim_head: Dimensionality for each attention head. 30 | :param heads: Number of attention heads. 31 | :param context_dim: Context dimensionality. 32 | """ 33 | super().__init__() 34 | self.scale = dim_head ** -0.5 35 | self.heads = heads 36 | inner_dim = dim_head * heads 37 | 38 | self.norm = LayerNorm(dim) 39 | 40 | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) 41 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 42 | self.to_kv = nn.Linear(dim, dim_head * 2, bias=False) 43 | 44 | self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists( 45 | context_dim) else None 46 | 47 | self.to_out = nn.Sequential( 48 | nn.Linear(inner_dim, dim, bias=False), 49 | LayerNorm(dim) 50 | ) 51 | 52 | def forward(self, x: torch.tensor, context: torch.tensor = None, mask: torch.tensor = None, 53 | attn_bias: torch.tensor = None) -> torch.tensor: 54 | 55 | b, n, device = *x.shape[:2], x.device 56 | 57 | x = self.norm(x) 58 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) 59 | 60 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads) 61 | q = q * self.scale 62 | 63 | # add null key / value for classifier free guidance in prior net 64 | 65 | nk, nv = repeat_many(self.null_kv.unbind(dim=-2), 'd -> b 1 d', b=b) 66 | k = torch.cat((nk, k), dim=-2) 67 | v = torch.cat((nv, v), dim=-2) 68 | 69 | # add text conditioning, if present 70 | 71 | if exists(context): 72 | assert exists(self.to_context) 73 | ck, cv = self.to_context(context).chunk(2, dim=-1) 74 | k = torch.cat((ck, k), dim=-2) 75 | v = torch.cat((cv, v), dim=-2) 76 | 77 | # calculate query / key similarities 78 | 79 | sim = einsum('b h i d, b j d -> b h i j', q, k) 80 | 81 | # relative positional encoding (T5 style) 82 | 83 | if exists(attn_bias): 84 | sim = sim + attn_bias 85 | 86 | # masking 87 | 88 | max_neg_value = -torch.finfo(sim.dtype).max 89 | 90 | if exists(mask): 91 | mask = F.pad(mask, (1, 0), value=True) 92 | mask = rearrange(mask, 'b j -> b 1 1 j') 93 | sim = sim.masked_fill(~mask, max_neg_value) 94 | 95 | # attention 96 | 97 | attn = sim.softmax(dim=-1, dtype=torch.float32) 98 | 99 | # aggregate values 100 | 101 | out = einsum('b h i j, b j d -> b h i d', attn, v) 102 | 103 | out = rearrange(out, 'b h n d -> b n (h d)') 104 | return self.to_out(out) 105 | 106 | 107 | class Block(nn.Module): 108 | """ 109 | Sub-block for :class:`.ResnetBlock`. GroupNorm/Identity, SiLU, and Conv2D in sequence, with the potential for 110 | scale-shift incorporation of timestep information. 111 | """ 112 | 113 | def __init__( 114 | self, 115 | dim: int, 116 | dim_out: int, 117 | groups: int = 8, 118 | norm: bool = True 119 | ): 120 | """ 121 | :param dim: Input number of channels. 122 | :param dim_out: Output number of channels. 123 | :param groups: Number of GroupNorm groups. 124 | :param norm: Whether to use GroupNorm or Identity. 125 | """ 126 | super().__init__() 127 | self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() 128 | self.activation = nn.SiLU() 129 | self.project = nn.Conv2d(dim, dim_out, 3, padding=1) 130 | 131 | def forward(self, x: torch.tensor, scale_shift: tuple[torch.tensor, torch.tensor] = None) -> torch.tensor: 132 | """ 133 | Forward pass 134 | 135 | :param x: Input images. 136 | :param scale_shift: Tensors to use for scale-shift. 137 | """ 138 | x = self.groupnorm(x) 139 | 140 | if exists(scale_shift): 141 | scale, shift = scale_shift 142 | x = x * (scale + 1) + shift 143 | 144 | x = self.activation(x) 145 | return self.project(x) 146 | 147 | 148 | def ChanFeedForward(dim: int, 149 | mult: int = 2) -> torch.nn.Sequential: # in paper, it seems for self attention layers they did feedforwards with twice channel width 150 | """ 151 | MLP for :class:`.TransformerBlock`. Maps images to a multiple of the number of channels and then back with 152 | convolutions, with layernorms before each convolution a GELU between them. 153 | """ 154 | hidden_dim = int(dim * mult) 155 | return nn.Sequential( 156 | ChanLayerNorm(dim), 157 | nn.Conv2d(dim, hidden_dim, 1, bias=False), 158 | nn.GELU(), 159 | ChanLayerNorm(hidden_dim), 160 | nn.Conv2d(hidden_dim, dim, 1, bias=False) 161 | ) 162 | 163 | 164 | class ChanLayerNorm(nn.Module): 165 | """ 166 | LayerNorm for :class:`.ChanFeedForward`. 167 | """ 168 | 169 | def __init__(self, dim: int, eps: float = 1e-5): 170 | super().__init__() 171 | self.eps = eps 172 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 173 | 174 | def forward(self, x: torch.tensor) -> torch.tensor: 175 | var = torch.var(x, dim=1, unbiased=False, keepdim=True) 176 | mean = torch.mean(x, dim=1, keepdim=True) 177 | return (x - mean) / (var + self.eps).sqrt() * self.g 178 | 179 | 180 | class CrossAttention(nn.Module): 181 | """ 182 | Multi-headed cross attention. 183 | """ 184 | 185 | def __init__( 186 | self, 187 | dim: int, 188 | *, 189 | context_dim: int = None, 190 | dim_head: int = 64, 191 | heads: int = 8, 192 | norm_context: bool = False 193 | ): 194 | """ 195 | :param dim: Input dimensionality. 196 | :param context_dim: Context dimensionality. 197 | :param dim_head: Dimensionality for each attention head. 198 | :param heads: Number of attention heads. 199 | :param norm_context: Whether to LayerNorm the context. 200 | """ 201 | super().__init__() 202 | self.scale = dim_head ** -0.5 203 | self.heads = heads 204 | inner_dim = dim_head * heads 205 | 206 | context_dim = default(context_dim, dim) 207 | 208 | self.norm = LayerNorm(dim) 209 | self.norm_context = LayerNorm(context_dim) if norm_context else Identity() 210 | 211 | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) 212 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 213 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) 214 | 215 | self.to_out = nn.Sequential( 216 | nn.Linear(inner_dim, dim, bias=False), 217 | LayerNorm(dim) 218 | ) 219 | 220 | def forward(self, x: torch.tensor, context: torch.tensor, mask: torch.tensor = None) -> torch.tensor: 221 | b, n, device = *x.shape[:2], x.device 222 | 223 | x = self.norm(x) 224 | context = self.norm_context(context) 225 | 226 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) 227 | 228 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h=self.heads) 229 | 230 | # add null key / value for classifier free guidance in prior net 231 | 232 | nk, nv = repeat_many(self.null_kv.unbind(dim=-2), 'd -> b h 1 d', h=self.heads, b=b) 233 | 234 | k = torch.cat((nk, k), dim=-2) 235 | v = torch.cat((nv, v), dim=-2) 236 | 237 | q = q * self.scale 238 | 239 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 240 | max_neg_value = -torch.finfo(sim.dtype).max 241 | 242 | if exists(mask): 243 | mask = F.pad(mask, (1, 0), value=True) 244 | mask = rearrange(mask, 'b j -> b 1 1 j') 245 | sim = sim.masked_fill(~mask, max_neg_value) 246 | 247 | attn = sim.softmax(dim=-1, dtype=torch.float32) 248 | 249 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 250 | out = rearrange(out, 'b h n d -> b n (h d)') 251 | return self.to_out(out) 252 | 253 | 254 | class CrossEmbedLayer(nn.Module): 255 | ''' 256 | Module that performs cross embedding on an input image (essentially an Inception module) which maintains channel 257 | depth. 258 | 259 | E.g. If input a 64x64 image with 128 channels and use kernel_sizes = (3, 7, 15) and stride=1, then 3 convolutions 260 | will be performed: 261 | 262 | 1: 64 filters, (3x3) kernel, stride=(1x1), padding=(1x1) -> 64x64 output 263 | 2: 32 filters, (7x7) kernel, stride=(1x1), padding=(3x3) -> 64x64 output 264 | 3: 32 filters, (15x15) kernel, stride=(1x1), padding=(7x7) -> 64x64 output 265 | 266 | Concatenate them for a resulting 64x64 image with 128 output channels 267 | ''' 268 | 269 | def __init__( 270 | self, 271 | dim_in: int, 272 | kernel_sizes: tuple[int, ...], 273 | dim_out: int = None, 274 | stride: int = 2 275 | ): 276 | """ 277 | :param dim_in: Number of channels in the input image. 278 | :param kernel_sizes: Tuple of kernel sizes to use for convolutions. 279 | :param dim_out: Number of channels in output image. Defaults to `dim_in`. 280 | :param stride: Stride of convolutions. 281 | """ 282 | super().__init__() 283 | # Ensures stride and all kernels are either all odd or all even 284 | assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) 285 | 286 | # Set output dimensionality to be same as input if not provided 287 | dim_out = default(dim_out, dim_in) 288 | 289 | # Sort the kernels by size and determine number of kernels 290 | kernel_sizes = sorted(kernel_sizes) 291 | num_scales = len(kernel_sizes) 292 | 293 | # Determine number of filters for each kernel. They will sum to dim_out and be descending with kernel size 294 | dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] 295 | dim_scales = [*dim_scales, dim_out - sum(dim_scales)] 296 | 297 | # Create the convolution objects 298 | self.convs = nn.ModuleList([]) 299 | for kernel, dim_scale in zip(kernel_sizes, dim_scales): 300 | self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride=stride, padding=(kernel - stride) // 2)) 301 | 302 | def forward(self, x: torch.tensor) -> torch.tensor: 303 | # Perform each convolution and then concatenate the results along the channel dim. 304 | fmaps = tuple(map(lambda conv: conv(x), self.convs)) 305 | return torch.cat(fmaps, dim=1) 306 | 307 | 308 | def Downsample(dim: int, dim_out: int = None) -> torch.nn.Conv2d: 309 | """ 310 | Return a convolution layer that cuts the spatial dimensions of an image in half and potentially modifies the 311 | number of channels 312 | 313 | :param dim: Input dimensionality of the image 314 | :param dim_out: Output dimensionality of the image. Defaults to `dim`. 315 | :return: Convolution layer. 316 | """ 317 | 318 | dim_out = default(dim_out, dim) 319 | return nn.Conv2d(dim, dim_out, kernel_size=4, stride=2, padding=1) 320 | 321 | 322 | class Identity(nn.Module): 323 | """ 324 | Identity module - forward pass returns input. 325 | """ 326 | def __init__(self, *args, **kwargs): 327 | super().__init__() 328 | 329 | def forward(self, x: torch.tensor, *args, **kwargs) -> torch.tensor: 330 | return x 331 | 332 | 333 | class LayerNorm(nn.Module): 334 | """ 335 | LayerNorm 336 | """ 337 | def __init__(self, dim: int): 338 | super().__init__() 339 | self.gamma = nn.Parameter(torch.ones(dim)) 340 | self.register_buffer('beta', torch.zeros(dim)) 341 | 342 | def forward(self, x: torch.tensor) -> torch.tensor: 343 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 344 | 345 | 346 | class Parallel(nn.Module): 347 | """ 348 | Passes input through parallel functions and then sums the result. 349 | """ 350 | def __init__(self, *fns: tuple[Callable, ...]): 351 | super().__init__() 352 | self.fns = nn.ModuleList(fns) 353 | 354 | def forward(self, x: torch.tensor) -> torch.tensor: 355 | outputs = [fn(x) for fn in self.fns] 356 | return sum(outputs) 357 | 358 | 359 | class Residual(nn.Module): 360 | """ 361 | Residual module. Passes input through a function and then adds the result to the input. 362 | """ 363 | def __init__(self, fn: callable): 364 | super().__init__() 365 | self.fn = fn 366 | 367 | def forward(self, x: torch.tensor, **kwargs) -> torch.tensor: 368 | return self.fn(x, **kwargs) + x 369 | 370 | 371 | class ResnetBlock(nn.Module): 372 | """ 373 | ResNet Block. 374 | """ 375 | def __init__( 376 | self, 377 | dim: int, 378 | dim_out: int, 379 | *, 380 | cond_dim: int = None, 381 | time_cond_dim: int = None, 382 | groups: int = 8, 383 | ): 384 | """ 385 | :param dim: Number of channels in the input. 386 | :param dim_out: Number of channels in the output. 387 | :param cond_dim: Dimension of the conditioning tokens on which to perform cross attention with the input. 388 | :param time_cond_dim: Dimension of the time conditioning tensor. 389 | :param groups: Number of groups to use in the GroupNorms. See :class:`.Block`. 390 | """ 391 | super().__init__() 392 | 393 | self.time_mlp = None 394 | 395 | if exists(time_cond_dim): 396 | self.time_mlp = nn.Sequential( 397 | nn.SiLU(), 398 | nn.Linear(time_cond_dim, dim_out * 2) 399 | ) 400 | 401 | self.cross_attn = None 402 | if exists(cond_dim): 403 | self.cross_attn = EinopsToAndFrom( 404 | 'b c h w', 405 | 'b (h w) c', 406 | CrossAttention( 407 | dim=dim_out, 408 | context_dim=cond_dim 409 | ) 410 | ) 411 | 412 | self.block1 = Block(dim, dim_out, groups=groups) 413 | self.block2 = Block(dim_out, dim_out, groups=groups) 414 | 415 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else Identity() 416 | 417 | def forward(self, x: torch.tensor, time_emb: torch.tensor = None, cond: torch.tensor = None) -> torch.tensor: 418 | """ 419 | :param x: Input image. Shape (b, c, s, s). 420 | :param time_emb: Time conditioning tensor. Shape (b, c2). 421 | :param cond: Main conditioning tensor. Shape (b, c3). 422 | :return: Output image. Shape (b, c, s, s) 423 | """ 424 | 425 | scale_shift = None 426 | if exists(self.time_mlp) and exists(time_emb): 427 | time_emb = self.time_mlp(time_emb) 428 | time_emb = rearrange(time_emb, 'b c -> b c 1 1') 429 | scale_shift = time_emb.chunk(2, dim=1) 430 | 431 | h = self.block1(x) 432 | 433 | if exists(self.cross_attn): 434 | assert exists(cond) 435 | h = self.cross_attn(h, context=cond) + h 436 | 437 | h = self.block2(h, scale_shift=scale_shift) 438 | 439 | return h + self.res_conv(x) 440 | 441 | 442 | class SinusoidalPosEmb(nn.Module): 443 | ''' 444 | Generates sinusoidal positional embedding tensor. In this case, position corresponds to time. For more information 445 | on sinusoidal embeddings, see ["Positional Encoding - Additional Details"](https://www.assemblyai.com/blog/how-imagen-actually-works/#timestep-conditioning). 446 | ''' 447 | 448 | def __init__(self, dim: int): 449 | """ 450 | :param dim: Dimensionality of the embedding space 451 | """ 452 | super().__init__() 453 | self.dim = dim 454 | 455 | def forward(self, x: torch.tensor) -> torch.tensor: 456 | """ 457 | :param x: Tensor of positions (i.e. times) to generate embeddings for. 458 | :return: T x D tensor where T is the number of positions/times and D is the dimensionality of the embedding 459 | space 460 | """ 461 | half_dim = self.dim // 2 462 | emb = math.log(10000) / (half_dim - 1) 463 | emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb) 464 | emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') 465 | return torch.cat((emb.sin(), emb.cos()), dim=-1) 466 | 467 | 468 | class TransformerBlock(nn.Module): 469 | """ 470 | Transformer encoder block. Responsible for applying attention at the end of a chain of :class:`.ResnetBlock`s at 471 | each layer in the U-Met. 472 | """ 473 | def __init__( 474 | self, 475 | dim: int, 476 | *, 477 | heads: int = 8, 478 | dim_head: int = 32, 479 | ff_mult: int = 2, 480 | context_dim: int = None 481 | ): 482 | """ 483 | 484 | :param dim: Number of channels in the input. 485 | :param heads: Number of attention heads for multi-headed :class:`.Attention`. 486 | :param dim_head: Dimensionality for each attention head in multi-headed :class:`.Attention`. 487 | :param ff_mult: Channel depth multiplier for the :class:`.ChanFeedForward` MLP applied after multi-headed 488 | attention. 489 | :param context_dim: Dimensionality of the context. 490 | """ 491 | super().__init__() 492 | self.attn = EinopsToAndFrom('b c h w', 'b (h w) c', 493 | Attention(dim=dim, heads=heads, dim_head=dim_head, context_dim=context_dim)) 494 | self.ff = ChanFeedForward(dim=dim, mult=ff_mult) 495 | 496 | def forward(self, x: torch.tensor, context: torch.tensor = None) -> torch.tensor: 497 | x = self.attn(x, context=context) + x 498 | x = self.ff(x) + x 499 | return x 500 | 501 | 502 | def Upsample(dim: int, dim_out: int = None) -> torch.nn.Sequential: 503 | """ 504 | Returns Sequential module that upsamples to twice the spatial width with an [Upsample](https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html) 505 | followed by a [Conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html). 506 | 507 | :param dim: Number of channels in the input. 508 | :param dim_out: Number of channels in the output. Defaults to `dim`. 509 | """ 510 | dim_out = default(dim_out, dim) 511 | 512 | return nn.Sequential( 513 | nn.Upsample(scale_factor=2, mode='nearest'), 514 | nn.Conv2d(dim, dim_out, 3, padding=1) 515 | ) 516 | -------------------------------------------------------------------------------- /minimagen/t5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from transformers import T5Tokenizer, T5EncoderModel 4 | 5 | MAX_LENGTH = 256 6 | 7 | DEFAULT_T5_NAME = 't5_small' 8 | 9 | # Variants: https://huggingface.co/docs/transformers/model_doc/t5v1.1. 1.1 versions must be finetuned. 10 | T5_VERSIONS = { 11 | 't5_small': {'tokenizer': None, 'model': None, 'handle': 't5-small', 'dim': 512, 'size': .24}, 12 | 't5_base': {'tokenizer': None, 'model': None, 'handle': 't5-base', 'dim': 768, 'size': .890}, 13 | 't5_large': {'tokenizer': None, 'model': None, 'handle': 't5-large', 'dim': 1024, 'size': 2.75}, 14 | 't5_3b': {'tokenizer': None, 'model': None, 'handle': 't5-3b', 'dim': 1024, 'size': 10.6}, 15 | 't5_11b': {'tokenizer': None, 'model': None, 'handle': 't5-11b', 'dim': 1024, 'size': 42.1}, 16 | 'small1.1': {'tokenizer': None, 'model': None, 'handle': 'google/t5-v1_1-small', 'dim': 512, 'size': .3}, 17 | 'base1.1': {'tokenizer': None, 'model': None, 'handle': 'google/t5-v1_1-base', 'dim': 768, 'size': .99}, 18 | 'large1.1': {'tokenizer': None, 'model': None, 'handle': 'google/t5-v1_1-large', 'dim': 1024, 'size': 3.13}, 19 | 'xl1.1': {'tokenizer': None, 'model': None, 'handle': 'google/t5-v1_1-xl', 'dim': 2048, 'size': 11.4}, 20 | 'xxl1.1': {'tokenizer': None, 'model': None, 'handle': 'google/t5-v1_1-xxl', 'dim': 4096, 'size': 44.5}, 21 | } 22 | 23 | # Fast tokenizers: https://huggingface.co/docs/transformers/main_classes/tokenizer 24 | def _check_downloads(name): 25 | if T5_VERSIONS[name]['tokenizer'] is None: 26 | T5_VERSIONS[name]['tokenizer'] = T5Tokenizer.from_pretrained(T5_VERSIONS[name]['handle']) 27 | if T5_VERSIONS[name]['model'] is None: 28 | T5_VERSIONS[name]['model'] = T5EncoderModel.from_pretrained(T5_VERSIONS[name]['handle']) 29 | 30 | 31 | def t5_encode_text(text, name: str = 't5_base', max_length=MAX_LENGTH): 32 | """ 33 | Encodes a sequence of text with a T5 text encoder. 34 | 35 | :param text: List of text to encode. 36 | :param name: Name of T5 model to use. Options are: 37 | 38 | - :code:`'t5_small'` (~0.24 GB, 512 encoding dim), 39 | 40 | - :code:`'t5_base'` (~0.89 GB, 768 encoding dim), 41 | 42 | - :code:`'t5_large'` (~2.75 GB, 1024 encoding dim), 43 | 44 | - :code:`'t5_3b'` (~10.6 GB, 1024 encoding dim), 45 | 46 | - :code:`'t5_11b'` (~42.1 GB, 1024 encoding dim), 47 | 48 | :return: Returns encodings and attention mask. Element **[i,j,k]** of the final encoding corresponds to the **k**-th 49 | encoding component of the **j**-th token in the **i**-th input list element. 50 | """ 51 | _check_downloads(name) 52 | tokenizer = T5_VERSIONS[name]['tokenizer'] 53 | model = T5_VERSIONS[name]['model'] 54 | 55 | # Move to cuda is available 56 | if torch.cuda.is_available(): 57 | device = torch.device('cuda') 58 | model = model.to(device) 59 | else: 60 | device = torch.device('cpu') 61 | 62 | # Tokenize text 63 | tokenized = tokenizer.batch_encode_plus( 64 | text, 65 | padding='longest', 66 | max_length=max_length, 67 | truncation=True, 68 | return_tensors="pt", # Returns torch.tensor instead of python integers 69 | ) 70 | 71 | input_ids = tokenized.input_ids.to(device) 72 | attention_mask = tokenized.attention_mask.to(device) 73 | 74 | model.eval() 75 | 76 | # Don't need gradient - T5 frozen during Imagen training 77 | with torch.no_grad(): 78 | t5_output = model(input_ids=input_ids, attention_mask=attention_mask) 79 | final_encoding = t5_output.last_hidden_state.detach() 80 | 81 | # Wherever the encoding is masked, make equal to zero 82 | final_encoding = final_encoding.masked_fill(~rearrange(attention_mask, '... -> ... 1').bool(), 0.) 83 | 84 | return final_encoding, attention_mask.bool() 85 | 86 | 87 | def get_encoded_dim(name: str) -> int: 88 | """ 89 | Gets the encoding dimensionality of a given T5 encoder. 90 | """ 91 | return T5_VERSIONS[name]['dim'] -------------------------------------------------------------------------------- /parameters/imagen_params_20220816_165729.json: -------------------------------------------------------------------------------- 1 | { 2 | "text_embed_dim": null, 3 | "channels": 3, 4 | "timesteps": 25, 5 | "cond_drop_prob": 0.15, 6 | "loss_type": "l2", 7 | "lowres_sample_noise_level": 0.2, 8 | "auto_normalize_img": true, 9 | "dynamic_thresholding_percentile": 0.9, 10 | "only_train_unet_number": null, 11 | "image_sizes": [ 12 | 64, 13 | 128 14 | ], 15 | "text_encoder_name": "t5_small" 16 | } -------------------------------------------------------------------------------- /parameters/training_parameters_20220816_165729.txt: -------------------------------------------------------------------------------- 1 | --PARAMETERS=None 2 | --NUM_WORKERS=0 3 | --BATCH_SIZE=2 4 | --MAX_NUM_WORDS=32 5 | --IMG_SIDE_LEN=128 6 | --EPOCHS=2 7 | --T5_NAME=t5_small 8 | --TRAIN_VALID_FRAC=0.5 9 | --TIMESTEPS=25 10 | --OPTIM_LR=0.0001 11 | --ACCUM_ITER=1 12 | --CHCKPT_NUM=500 13 | --VALID_NUM=None 14 | --RESTART_DIRECTORY=None 15 | --TESTING=True 16 | --timestamp=None 17 | -------------------------------------------------------------------------------- /parameters/unet_0_params_20220816_165729.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim": 8, 3 | "dim_mults": [ 4 | 1, 5 | 2 6 | ], 7 | "channels": 3, 8 | "channels_out": null, 9 | "cond_dim": null, 10 | "text_embed_dim": 512, 11 | "num_resnet_blocks": 1, 12 | "layer_attns": false, 13 | "layer_cross_attns": false, 14 | "attn_heads": 8, 15 | "lowres_cond": false, 16 | "memory_efficient": false, 17 | "attend_at_middle": false 18 | } -------------------------------------------------------------------------------- /parameters/unet_1_params_20220816_165729.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim": 8, 3 | "dim_mults": [ 4 | 1, 5 | 2 6 | ], 7 | "channels": 3, 8 | "channels_out": null, 9 | "cond_dim": null, 10 | "text_embed_dim": 512, 11 | "num_resnet_blocks": [ 12 | 1, 13 | 2 14 | ], 15 | "layer_attns": false, 16 | "layer_cross_attns": false, 17 | "attn_heads": 8, 18 | "lowres_cond": false, 19 | "memory_efficient": true, 20 | "attend_at_middle": false 21 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI-Community/MinImagen/502c6962fa55285a871995de716cdb0ed3e3d81e/requirements.txt -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | VERSION = '0.0.9' 4 | DESCRIPTION = 'Minimal Imagen text-to-image model implementation.' 5 | with open("README.md", "r") as f: 6 | LONG_DESCRIPTION = f.read() 7 | 8 | # Replace README local image paths with GitHub paths so images render on PyPi 9 | LONG_DESCRIPTION = LONG_DESCRIPTION.replace("./images/", "https://github.com/AssemblyAI-Examples/MinImagen/raw/main/images/") 10 | 11 | #"Minimal Imagen text-to-image model implementation. See the [GitHub repo](https://github.com/AssemblyAI-Examples/MinImagen) or the [how-to build guide](www.assemblyai.com/blog/build-your-own-imagen-text-to-image-model/) for more details" 12 | 13 | with open('requirements.txt', "r", encoding="utf-16") as f: 14 | required = f.read().splitlines() 15 | 16 | # Setting up 17 | setup( 18 | name="minimagen", 19 | version=VERSION, 20 | author="AssemblyAI", 21 | author_email="", 22 | description=DESCRIPTION, 23 | long_description=LONG_DESCRIPTION, 24 | long_description_content_type='text/markdown', 25 | packages=find_packages(), 26 | install_requires=required, 27 | keywords=[ 'imagen', 28 | 'text-to-image', 29 | 'diffusion model', 30 | 'super resolution', 31 | 'image generation', 32 | 'machine learning', 33 | 'deep learning', 34 | 'pytorch', 35 | 'python'], 36 | classifiers=[ 37 | "Intended Audience :: Developers", 38 | "Programming Language :: Python :: 3", 39 | "Operating System :: Unix", 40 | "Operating System :: MacOS :: MacOS X", 41 | "Operating System :: Microsoft :: Windows", 42 | ] 43 | ) 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import torch.utils.data 5 | from torch import optim 6 | 7 | 8 | from minimagen.Imagen import Imagen 9 | from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest 10 | from minimagen.generate import load_minimagen, load_params 11 | from minimagen.t5 import get_encoded_dim 12 | from minimagen.training import get_minimagen_parser, ConceptualCaptions, get_minimagen_dl_opts, \ 13 | create_directory, get_model_params, get_model_size, save_training_info, get_default_args, MinimagenTrain, \ 14 | load_restart_training_parameters, load_testing_parameters 15 | 16 | # Get device 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | 19 | # Command line argument parser. See `training.get_minimagen_parser()`. 20 | parser = get_minimagen_parser() 21 | # Add argument for when using `main.py` 22 | parser.add_argument("-ts", "--TIMESTAMP", dest="timestamp", help="Timestamp for training directory", type=str, 23 | default=None) 24 | args = parser.parse_args() 25 | timestamp = args.timestamp 26 | 27 | # Get training timestamp for when running train.py as main rather than via main.py 28 | if timestamp is None: 29 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 30 | 31 | # Create training directory 32 | dir_path = f"./training_{timestamp}" 33 | training_dir = create_directory(dir_path) 34 | 35 | # If loading from a parameters/training directory 36 | if args.RESTART_DIRECTORY is not None: 37 | args = load_restart_training_parameters(args) 38 | elif args.PARAMETERS is not None: 39 | args = load_restart_training_parameters(args, justparams=True) 40 | 41 | # If testing, lower parameter values to lower computational load and also to lower amount of data being used. 42 | if args.TESTING: 43 | args = load_testing_parameters(args) 44 | train_dataset, valid_dataset = ConceptualCaptions(args, smalldata=True) 45 | else: 46 | train_dataset, valid_dataset = ConceptualCaptions(args, smalldata=False) 47 | 48 | # Create dataloaders 49 | dl_opts = {**get_minimagen_dl_opts(device), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS} 50 | train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts) 51 | valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts) 52 | 53 | # Create Unets 54 | if args.RESTART_DIRECTORY is None: 55 | imagen_params = dict( 56 | image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN), 57 | timesteps=args.TIMESTEPS, 58 | cond_drop_prob=0.15, 59 | text_encoder_name=args.T5_NAME 60 | ) 61 | 62 | # If not loading a training from a checkpoint 63 | if args.TESTING: 64 | # If testing, use tiny MinImagen for low computational load 65 | unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)] 66 | 67 | # Else if not loading Unet/Imagen settings from a config (parameters) folder, use defaults 68 | elif not args.PARAMETERS: 69 | # If no parameters provided, use params from minimagen.Imagen.Base and minimagen.Imagen.Super built-in classes 70 | unets_params = [get_default_args(Base), get_default_args(Super)] 71 | 72 | # Else load unet/Imagen configs from config (parameters) folder (override imagen+params) 73 | else: 74 | # If parameters are provided, load them 75 | unets_params, imagen_params = get_model_params(args.PARAMETERS) 76 | 77 | # Create Unets accoridng to unets_params 78 | unets = [Unet(**unet_params).to(device) for unet_params in unets_params] 79 | 80 | # Create Imagen from UNets with specified imagen parameters 81 | imagen = Imagen(unets=unets, **imagen_params).to(device) 82 | else: 83 | # If training is being resumed from a previous one, load all relevant models/info (load config AND state dicts) 84 | orig_train_dir = os.path.join(os.getcwd(), args.RESTART_DIRECTORY) 85 | unets_params, imagen_params = load_params(orig_train_dir) 86 | imagen = load_minimagen(orig_train_dir).to(device) 87 | unets = imagen.unets 88 | 89 | # Fill in unspecified arguments with defaults for complete config (parameters) file 90 | unets_params = [{**get_default_args(Unet), **i} for i in unets_params] 91 | imagen_params = {**get_default_args(Imagen), **imagen_params} 92 | 93 | # Get the size of the Imagen model in megabytes 94 | model_size_MB = get_model_size(imagen) 95 | 96 | # Save all training info (config files, model size, etc.) 97 | save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir) 98 | 99 | # Create optimizer 100 | optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR) 101 | 102 | # Train the MinImagen instance 103 | MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30) --------------------------------------------------------------------------------