├── .github ├── FUNDING.yml └── workflows │ ├── ci.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── configs ├── README.md ├── train_decoder_config.example.json ├── train_decoder_config.test.json └── train_prior_config.example.json ├── dalle2.png ├── dalle2_pytorch ├── __init__.py ├── cli.py ├── dalle2_pytorch.py ├── data │ └── bpe_simple_vocab_16e6.txt ├── dataloaders │ ├── README.md │ ├── __init__.py │ ├── decoder_loader.py │ ├── prior_loader.py │ └── simple_image_only_dataloader.py ├── optimizer.py ├── tokenizer.py ├── trackers.py ├── train_configs.py ├── trainer.py ├── utils.py ├── version.py ├── vqgan_vae.py └── vqgan_vae_trainer.py ├── prior.md ├── samples └── oxford.png ├── setup.py ├── test_data ├── 0.tar ├── 1.tar ├── 2.tar ├── 3.tar ├── 4.tar ├── 5.tar ├── 6.tar ├── 7.tar ├── 8.tar └── 9.tar ├── train_decoder.py └── train_diffusion_prior.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [nousr, Veldrovive, lucidrains] 2 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Continuous integration 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | tests: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: [3.8] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install 25 | run: | 26 | python3 -m venv .env 27 | source .env/bin/activate 28 | make install 29 | - name: Tests 30 | run: | 31 | source .env/bin/activate 32 | make test 33 | 34 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # default experiment tracker data 2 | .tracker-data/ 3 | 4 | # Configuration Files 5 | configs/* 6 | !configs/*.example 7 | !configs/*_defaults.py 8 | !configs/README.md 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | .tracker_data 140 | *.pth 141 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include dalle2_pytorch *.txt 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: 2 | pip install -U pip 3 | pip install -e . 4 | 5 | test: 6 | CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json 7 | -------------------------------------------------------------------------------- /configs/README.md: -------------------------------------------------------------------------------- 1 | ## DALLE2 Training Configurations 2 | 3 | For more complex configuration, we provide the option of using a configuration file instead of command line arguments. 4 | 5 | ### Decoder Trainer 6 | 7 | The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json). 8 | 9 | **Unet:** 10 | 11 | This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets` 12 | 13 | | Option | Required | Default | Description | 14 | | ------ | -------- | ------- | ----------- | 15 | | `dim` | Yes | N/A | The starting channels of the unet. | 16 | | `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. | 17 | | `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. | 18 | 19 | Any parameter from the `Unet` constructor can also be given here. 20 | 21 | **Decoder:** 22 | 23 | Defines the configuration options for the decoder model. The unets defined above will automatically be inserted. 24 | | Option | Required | Default | Description | 25 | | ------ | -------- | ------- | ----------- | 26 | | `unets` | Yes | N/A | A list of unets, using the configuration above | 27 | | `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. | 28 | | `image_size` | Yes | N/A | Not used. Can be any number. | 29 | | `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. | 30 | | `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. | 31 | | `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. | 32 | | `learned_variance` | No | `True` | Whether to learn the variance. | 33 | | `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. | 34 | 35 | Any parameter from the `Decoder` constructor can also be given here. 36 | 37 | **Data:** 38 | 39 | Settings for creation of the dataloaders. 40 | | Option | Required | Default | Description | 41 | | ------ | -------- | ------- | ----------- | 42 | | `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. | 43 | | `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. | 44 | | `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. | 45 | | `num_workers` | No | `4` | The number of workers used in the dataloader. | 46 | | `batch_size` | No | `64` | The batch size. | 47 | | `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. | 48 | | `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. | 49 | | `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. | 50 | | `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. | 51 | | `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. | 52 | | `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. | 53 | | `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. | 54 | | `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. | 55 | 56 | [^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`. 57 | 58 | [^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5. 59 | 60 | [^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4. 61 | 62 | **Train:** 63 | 64 | Settings for controlling the training hyperparameters. 65 | | Option | Required | Default | Description | 66 | | ------ | -------- | ------- | ----------- | 67 | | `epochs` | No | `20` | The number of epochs in the training run. | 68 | | `lr` | No | `1e-4` | The learning rate. | 69 | | `wd` | No | `0.01` | The weight decay. | 70 | | `max_grad_norm`| No | `0.5` | The grad norm clipping. | 71 | | `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. | 72 | | `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. | 73 | | `device` | No | `cuda:0` | The device to train on. | 74 | | `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. | 75 | | `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. | 76 | | `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. | 77 | | `ema_beta` | No | `0.99` | The ema coefficient. | 78 | | `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. | 79 | 80 | **Evaluate:** 81 | 82 | Defines which evaluation metrics will be used to test the model. 83 | Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked. 84 | | Option | Required | Default | Description | 85 | | ------ | -------- | ------- | ----------- | 86 | | `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. | 87 | | `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric. 88 | | `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric. 89 | | `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. | 90 | | `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. | 91 | 92 | **Tracker:** 93 | 94 | Selects how the experiment will be tracked. 95 | | Option | Required | Default | Description | 96 | | ------ | -------- | ------- | ----------- | 97 | | `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. | 98 | | `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. | 99 | | `log` | Yes | N/A | Logging configuration. | 100 | | `load` | No | `None` | Checkpoint loading configuration. | 101 | | `save` | Yes | N/A | Checkpoint/Model saving configuration. | 102 | Tracking is split up into three sections: 103 | * Log: Where to save run metadata and image output. Options are `console` or `wandb`. 104 | * Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`. 105 | * Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`. 106 | 107 | **Logging:** 108 | 109 | All loggers have the following keys: 110 | | Option | Required | Default | Description | 111 | | ------ | -------- | ------- | ----------- | 112 | | `log_type` | Yes | N/A | The type of logger class to use. | 113 | | `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. | 114 | | `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. | 115 | 116 | If using `console` there is no further configuration than setting `log_type` to `console`. 117 | | Option | Required | Default | Description | 118 | | ------ | -------- | ------- | ----------- | 119 | | `log_type` | Yes | N/A | Must be `console`. | 120 | 121 | If using `wandb` 122 | | Option | Required | Default | Description | 123 | | ------ | -------- | ------- | ----------- | 124 | | `log_type` | Yes | N/A | Must be `wandb`. | 125 | | `wandb_entity` | Yes | N/A | The wandb entity to log to. | 126 | | `wandb_project` | Yes | N/A | The wandb project save the run to. | 127 | | `wandb_run_name` | No | `None` | The wandb run name. | 128 | | `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. | 129 | 130 | **Loading:** 131 | 132 | All loaders have the following keys: 133 | | Option | Required | Default | Description | 134 | | ------ | -------- | ------- | ----------- | 135 | | `load_from` | Yes | N/A | The type of loader class to use. | 136 | | `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. | 137 | 138 | If using `local` 139 | | Option | Required | Default | Description | 140 | | ------ | -------- | ------- | ----------- | 141 | | `load_from` | Yes | N/A | Must be `local`. | 142 | | `file_path` | Yes | N/A | The path to the checkpoint file. | 143 | 144 | If using `url` 145 | | Option | Required | Default | Description | 146 | | ------ | -------- | ------- | ----------- | 147 | | `load_from` | Yes | N/A | Must be `url`. | 148 | | `url` | Yes | N/A | The url of the checkpoint file. | 149 | 150 | If using `wandb` 151 | | Option | Required | Default | Description | 152 | | ------ | -------- | ------- | ----------- | 153 | | `load_from` | Yes | N/A | Must be `wandb`. | 154 | | `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. | 155 | | `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. | 156 | 157 | **Saving:** 158 | Unlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run. 159 | 160 | All save locations have these configuration options 161 | | Option | Required | Default | Description | 162 | | ------ | -------- | ------- | ----------- | 163 | | `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. | 164 | | `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. | 165 | | `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. | 166 | | `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. | 167 | | `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). | 168 | 169 | If using `local` 170 | | Option | Required | Default | Description | 171 | | ------ | -------- | ------- | ----------- | 172 | | `save_to` | Yes | N/A | Must be `local`. | 173 | 174 | If using `huggingface` 175 | | Option | Required | Default | Description | 176 | | ------ | -------- | ------- | ----------- | 177 | | `save_to` | Yes | N/A | Must be `huggingface`. | 178 | | `huggingface_repo` | Yes | N/A | The huggingface repository to save to. | 179 | | `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. | 180 | 181 | If using `wandb` 182 | | Option | Required | Default | Description | 183 | | ------ | -------- | ------- | ----------- | 184 | | `save_to` | Yes | N/A | Must be `wandb`. | 185 | | `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. | 186 | -------------------------------------------------------------------------------- /configs/train_decoder_config.example.json: -------------------------------------------------------------------------------- 1 | { 2 | "decoder": { 3 | "unets": [ 4 | { 5 | "dim": 128, 6 | "image_embed_dim": 768, 7 | "cond_dim": 64, 8 | "channels": 3, 9 | "dim_mults": [1, 2, 4, 8], 10 | "attn_dim_head": 32, 11 | "attn_heads": 16 12 | } 13 | ], 14 | "image_sizes": [64], 15 | "channels": 3, 16 | "timesteps": 1000, 17 | "loss_type": "l2", 18 | "beta_schedule": ["cosine"], 19 | "learned_variance": true 20 | }, 21 | "data": { 22 | "webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -", 23 | "img_embeddings_url": "s3://bucket/img_embeddings/path/", 24 | "num_workers": 4, 25 | "batch_size": 64, 26 | "start_shard": 0, 27 | "end_shard": 9999999, 28 | "shard_width": 6, 29 | "index_width": 4, 30 | "splits": { 31 | "train": 0.75, 32 | "val": 0.15, 33 | "test": 0.1 34 | }, 35 | "shuffle_train": true, 36 | "resample_train": false, 37 | "preprocessing": { 38 | "RandomResizedCrop": { 39 | "size": [128, 128], 40 | "scale": [0.75, 1.0], 41 | "ratio": [1.0, 1.0] 42 | }, 43 | "ToTensor": true 44 | } 45 | }, 46 | "train": { 47 | "epochs": 20, 48 | "lr": 1e-4, 49 | "wd": 0.01, 50 | "max_grad_norm": 0.5, 51 | "save_every_n_samples": 100000, 52 | "n_sample_images": 6, 53 | "device": "cuda:0", 54 | "epoch_samples": null, 55 | "validation_samples": null, 56 | "use_ema": true, 57 | "ema_beta": 0.99, 58 | "amp": false, 59 | "unet_training_mask": [true] 60 | }, 61 | "evaluate": { 62 | "n_evaluation_samples": 1000, 63 | "FID": { 64 | "feature": 64 65 | }, 66 | "IS": { 67 | "feature": 64, 68 | "splits": 10 69 | }, 70 | "KID": { 71 | "feature": 64, 72 | "subset_size": 10 73 | }, 74 | "LPIPS": { 75 | "net_type": "vgg", 76 | "reduction": "mean" 77 | } 78 | }, 79 | "tracker": { 80 | "overwrite_data_path": true, 81 | 82 | "log": { 83 | "log_type": "wandb", 84 | 85 | "wandb_entity": "your_wandb", 86 | "wandb_project": "your_project", 87 | 88 | "verbose": true 89 | }, 90 | 91 | "load": { 92 | "load_from": null 93 | }, 94 | 95 | "save": [{ 96 | "save_to": "wandb", 97 | "save_latest_to": "latest.pth" 98 | }, { 99 | "save_to": "huggingface", 100 | "huggingface_repo": "Veldrovive/test_model", 101 | 102 | "save_latest_to": "path/to/model_dir/latest.pth", 103 | "save_best_to": "path/to/model_dir/best.pth", 104 | "save_meta_to": "path/to/directory/for/assorted/files", 105 | 106 | "save_type": "model" 107 | }] 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /configs/train_decoder_config.test.json: -------------------------------------------------------------------------------- 1 | { 2 | "decoder": { 3 | "unets": [ 4 | { 5 | "dim": 16, 6 | "image_embed_dim": 768, 7 | "cond_dim": 16, 8 | "channels": 3, 9 | "dim_mults": [1, 2, 4, 8], 10 | "attn_dim_head": 16, 11 | "attn_heads": 4, 12 | "self_attn": [false, true, true, true] 13 | } 14 | ], 15 | "clip": { 16 | "make": "openai", 17 | "model": "ViT-L/14" 18 | }, 19 | 20 | "timesteps": 10, 21 | "image_sizes": [64], 22 | "channels": 3, 23 | "loss_type": "l2", 24 | "beta_schedule": ["cosine"], 25 | "learned_variance": true 26 | }, 27 | "data": { 28 | "webdataset_base_url": "test_data/{}.tar", 29 | "num_workers": 4, 30 | "batch_size": 4, 31 | "start_shard": 0, 32 | "end_shard": 9, 33 | "shard_width": 1, 34 | "index_width": 1, 35 | "splits": { 36 | "train": 0.75, 37 | "val": 0.15, 38 | "test": 0.1 39 | }, 40 | "shuffle_train": false, 41 | "resample_train": true, 42 | "preprocessing": { 43 | "RandomResizedCrop": { 44 | "size": [224, 224], 45 | "scale": [0.75, 1.0], 46 | "ratio": [1.0, 1.0] 47 | }, 48 | "ToTensor": true 49 | } 50 | }, 51 | "train": { 52 | "epochs": 1, 53 | "lr": 1e-16, 54 | "wd": 0.01, 55 | "max_grad_norm": 0.5, 56 | "save_every_n_samples": 100, 57 | "n_sample_images": 1, 58 | "device": "cpu", 59 | "epoch_samples": 50, 60 | "validation_samples": 5, 61 | "use_ema": true, 62 | "ema_beta": 0.99, 63 | "amp": false, 64 | "unet_training_mask": [true] 65 | }, 66 | "evaluate": { 67 | "n_evaluation_samples": 2, 68 | "FID": { 69 | "feature": 64 70 | }, 71 | "IS": { 72 | "feature": 64, 73 | "splits": 10 74 | }, 75 | "KID": { 76 | "feature": 64, 77 | "subset_size": 2 78 | }, 79 | "LPIPS": { 80 | "net_type": "vgg", 81 | "reduction": "mean" 82 | } 83 | }, 84 | "tracker": { 85 | "overwrite_data_path": true, 86 | 87 | "log": { 88 | "log_type": "console" 89 | }, 90 | 91 | "load": { 92 | "load_from": null 93 | }, 94 | 95 | "save": [{ 96 | "save_to": "local", 97 | "save_latest_to": "latest.pth" 98 | }] 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /configs/train_prior_config.example.json: -------------------------------------------------------------------------------- 1 | { 2 | "prior": { 3 | "clip": { 4 | "make": "openai", 5 | "model": "ViT-L/14" 6 | }, 7 | "net": { 8 | "dim": 768, 9 | "depth": 12, 10 | "num_timesteps": 1000, 11 | "max_text_len": 77, 12 | "num_time_embeds": 1, 13 | "num_image_embeds": 1, 14 | "num_text_embeds": 1, 15 | "dim_head": 64, 16 | "heads": 12, 17 | "ff_mult": 4, 18 | "norm_out": true, 19 | "attn_dropout": 0.05, 20 | "ff_dropout": 0.05, 21 | "final_proj": true, 22 | "normformer": true, 23 | "rotary_emb": true 24 | }, 25 | "image_embed_dim": 768, 26 | "image_size": 224, 27 | "image_channels": 3, 28 | "timesteps": 1000, 29 | "sample_timesteps": 64, 30 | "cond_drop_prob": 0.1, 31 | "loss_type": "l2", 32 | "predict_x_start": true, 33 | "beta_schedule": "cosine", 34 | "condition_on_text_encodings": true 35 | }, 36 | "data": { 37 | "batch_size": 128, 38 | "num_data_points": 100000, 39 | "eval_every_seconds": 1600, 40 | "image_url": "", 41 | "meta_url": "", 42 | "splits": { 43 | "train": 0.8, 44 | "val": 0.1, 45 | "test": 0.1 46 | } 47 | }, 48 | "train": { 49 | "epochs": 5, 50 | "lr": 1.1e-4, 51 | "wd": 6.02e-2, 52 | "max_grad_norm": 0.5, 53 | "use_ema": true, 54 | "ema_beta": 0.9999, 55 | "ema_update_after_step": 50, 56 | "warmup_steps": 50, 57 | "amp": false, 58 | "save_every_seconds": 3600, 59 | "eval_timesteps": [64, 1000], 60 | "random_seed": 84513 61 | }, 62 | "tracker": { 63 | "data_path": ".prior", 64 | "overwrite_data_path": true, 65 | "log": { 66 | "log_type": "wandb", 67 | "wandb_entity": "", 68 | "wandb_project": "prior_debugging", 69 | "wandb_resume": false, 70 | "verbose": true 71 | }, 72 | "save": [ 73 | { 74 | "save_to": "local", 75 | "save_type": "checkpoint", 76 | "save_latest_to": ".prior/latest_checkpoint.pth", 77 | "save_best_to": ".prior/best_checkpoint.pth" 78 | } 79 | ] 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /dalle2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/dalle2.png -------------------------------------------------------------------------------- /dalle2_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from dalle2_pytorch.version import __version__ 2 | from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder 3 | from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter 4 | from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer 5 | 6 | from dalle2_pytorch.vqgan_vae import VQGanVAE 7 | from x_clip import CLIP 8 | -------------------------------------------------------------------------------- /dalle2_pytorch/cli.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | import torchvision.transforms as T 4 | from functools import reduce 5 | from pathlib import Path 6 | 7 | from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior 8 | 9 | def safeget(dictionary, keys, default = None): 10 | return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary) 11 | 12 | def simple_slugify(text, max_length = 255): 13 | return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length] 14 | 15 | def get_pkg_version(): 16 | from pkg_resources import get_distribution 17 | return get_distribution('dalle2_pytorch').version 18 | 19 | def main(): 20 | pass 21 | 22 | @click.command() 23 | @click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model') 24 | @click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder') 25 | @click.argument('text') 26 | def dream( 27 | model, 28 | cond_scale, 29 | text 30 | ): 31 | model_path = Path(model) 32 | full_model_path = str(model_path.resolve()) 33 | assert model_path.exists(), f'model not found at {full_model_path}' 34 | loaded = torch.load(str(model_path)) 35 | 36 | version = safeget(loaded, 'version') 37 | print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}') 38 | 39 | prior_init_params = safeget(loaded, 'init_params.prior') 40 | decoder_init_params = safeget(loaded, 'init_params.decoder') 41 | model_params = safeget(loaded, 'model_params') 42 | 43 | prior = DiffusionPrior(**prior_init_params) 44 | decoder = Decoder(**decoder_init_params) 45 | 46 | dalle2 = DALLE2(prior, decoder) 47 | dalle2.load_state_dict(model_params) 48 | 49 | image = dalle2(text, cond_scale = cond_scale) 50 | 51 | pil_image = T.ToPILImage()(image) 52 | return pil_image.save(f'./{simple_slugify(text)}.png') 53 | -------------------------------------------------------------------------------- /dalle2_pytorch/dataloaders/README.md: -------------------------------------------------------------------------------- 1 | ## Dataloaders 2 | In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network. 3 | 4 | ### Decoder: Image Embedding Dataset 5 | When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509. 6 | 7 | Generating a dataset of this type: 8 | 1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset. 9 | 2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings. 10 | 3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format. 11 | 12 | Usage: 13 | ```python 14 | from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader 15 | 16 | # Create a dataloader directly. 17 | dataloader = create_image_embedding_dataloader( 18 | tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar 19 | embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise 20 | num_workers=4, 21 | batch_size=32, 22 | shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index 23 | shuffle_num=200, # Does a shuffle of the data with a buffer size of 200 24 | shuffle_shards=True, # Shuffle the order the shards are read in 25 | resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually 26 | ) 27 | for img, emb in dataloader: 28 | print(img.shape) # torch.Size([32, 3, 256, 256]) 29 | print(emb.shape) # torch.Size([32, 512]) 30 | # Train decoder only as shown above 31 | 32 | # Or create a dataset without a loader so you can configure it manually 33 | dataset = ImageEmbeddingDataset( 34 | urls="/path/or/url/to/webdataset/{0000..9999}.tar", 35 | embedding_folder_url="path/or/url/to/embeddings/folder", 36 | shard_width=4, 37 | shuffle_shards=True, 38 | resample=False 39 | ) 40 | ``` 41 | 42 | ### Diffusion Prior: Prior Embedding Dataset 43 | When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code. 44 | 45 | To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run. 46 | 47 | If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters. 48 | 49 | Usage: 50 | ```python 51 | from dalle2_pytorch.dataloaders import get_reader, make_splits 52 | 53 | # grab embeddings from some specified location 54 | IMG_URL = "data/img_emb/" 55 | META_URL = "data/meta/" 56 | 57 | reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL) 58 | 59 | # some config for training 60 | TRAIN_ARGS = { 61 | "world_size": 3, 62 | "text_conditioned": True, 63 | "start": 0, 64 | "num_data_points": 10000, 65 | "batch_size": 2, 66 | "train_split": 0.5, 67 | "eval_split": 0.25, 68 | "image_reader": reader, 69 | } 70 | 71 | # specifying a rank will handle allocation internally 72 | rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS) 73 | rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS) 74 | rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS) 75 | ``` 76 | -------------------------------------------------------------------------------- /dalle2_pytorch/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader 2 | from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset 3 | -------------------------------------------------------------------------------- /dalle2_pytorch/dataloaders/decoder_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import webdataset as wds 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | import fsspec 7 | import shutil 8 | 9 | def get_shard(filename): 10 | """ 11 | Filenames with shards in them have a consistent structure that we can take advantage of 12 | Standard structure: path/to/file/prefix_string_00001.ext 13 | """ 14 | try: 15 | return filename.split("_")[-1].split(".")[0] 16 | except ValueError: 17 | raise RuntimeError(f"Could not find shard for filename {filename}") 18 | 19 | def get_example_file(fs, path, file_format): 20 | """ 21 | Given a file system and a file extension, return the example file 22 | """ 23 | return fs.glob(os.path.join(path, f"*.{file_format}"))[0] 24 | 25 | def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception): 26 | """Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields""" 27 | previous_tar_url = None 28 | current_embeddings = None 29 | # Get a reference to an abstract file system where the embeddings are stored 30 | embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url) 31 | example_embedding_file = get_example_file(embeddings_fs, embeddings_path, "npy") 32 | example_embedding_shard = get_shard(example_embedding_file) 33 | emb_shard_width = len(example_embedding_shard) 34 | # Easier to get the basename without the shard once than search through for the correct file every time 35 | embedding_file_basename = '_'.join(example_embedding_file.split("_")[:-1]) + "_" 36 | 37 | def load_corresponding_embeds(tar_url): 38 | """Finds and reads the npy files that contains embeddings for the given webdataset tar""" 39 | shard = int(tar_url.split("/")[-1].split(".")[0]) 40 | embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy' 41 | with embeddings_fs.open(embedding_url) as f: 42 | data = np.load(f) 43 | return torch.from_numpy(data) 44 | 45 | for sample in samples: 46 | try: 47 | tar_url = sample["__url__"] 48 | key = sample["__key__"] 49 | if tar_url != previous_tar_url: 50 | # If the tar changed, we need to download new embeddings 51 | # This means if we shuffle before inserting it will load many more files than we expect and be very inefficient. 52 | previous_tar_url = tar_url 53 | current_embeddings = load_corresponding_embeds(tar_url) 54 | 55 | embedding_index = int(key[-index_width:]) 56 | embedding = current_embeddings[embedding_index] 57 | # We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop 58 | if torch.count_nonzero(embedding) == 0: 59 | raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}") 60 | sample[sample_key] = embedding 61 | yield sample 62 | except Exception as exn: # From wds implementation 63 | if handler(exn): 64 | continue 65 | else: 66 | break 67 | insert_embedding = wds.filters.pipelinefilter(embedding_inserter) 68 | 69 | def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception): 70 | """Finds if the is a corresponding embedding for the tarfile at { url: [URL] }""" 71 | embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url) 72 | embedding_files = embeddings_fs.ls(embeddings_path) 73 | get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0]) 74 | embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member 75 | 76 | get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0]) 77 | for tarfile in tarfiles: 78 | try: 79 | webdataset_shard = get_tar_shard(tarfile["url"]) 80 | # If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one 81 | if webdataset_shard in embedding_shards: 82 | yield tarfile 83 | except Exception as exn: # From wds implementation 84 | if handler(exn): 85 | continue 86 | else: 87 | break 88 | skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper) 89 | 90 | def join_embeddings(samples, handler=wds.handlers.reraise_exception): 91 | """ 92 | Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb } 93 | either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist 94 | """ 95 | for sample in samples: 96 | try: 97 | sample['emb'] = {} 98 | if 'text_emb' in sample: 99 | sample['emb']['text'] = sample['text_emb'] 100 | if 'img_emb' in sample: 101 | sample['emb']['img'] = sample['img_emb'] 102 | yield sample 103 | except Exception as exn: # From wds implementation 104 | if handler(exn): 105 | continue 106 | else: 107 | break 108 | 109 | def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception): 110 | """ 111 | Requires that both the image and embedding are present in the sample 112 | This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter. 113 | """ 114 | for sample in samples: 115 | try: 116 | for key in required_keys: 117 | assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}" 118 | yield sample 119 | except Exception as exn: # From wds implementation 120 | if handler(exn): 121 | continue 122 | else: 123 | break 124 | key_verifier = wds.filters.pipelinefilter(verify_keys) 125 | 126 | class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface): 127 | """ 128 | A fluid interface wrapper for DataPipline that returns image embedding pairs 129 | Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source. 130 | """ 131 | 132 | def __init__( 133 | self, 134 | urls, 135 | img_embedding_folder_url=None, 136 | text_embedding_folder_url=None, 137 | index_width=None, 138 | img_preproc=None, 139 | extra_keys=[], 140 | handler=wds.handlers.reraise_exception, 141 | resample=False, 142 | shuffle_shards=True 143 | ): 144 | """ 145 | Modeled directly off of the WebDataset constructor 146 | 147 | :param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar 148 | :param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset. 149 | Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros. 150 | :param index_width: The number of digits in the index. This is used to align the embedding index with the image index. 151 | For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width. 152 | :param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor. 153 | :param handler: A webdataset handler. 154 | :param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely. 155 | :param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true. 156 | 157 | 158 | """ 159 | super().__init__() 160 | keys = ["jpg", "emb"] + extra_keys 161 | # if img_embedding_folder_url is not None: 162 | # keys.append("img_emb") 163 | # if text_embedding_folder_url is not None: 164 | # keys.append("text_emb") 165 | # keys.extend(extra_keys) 166 | self.key_map = {key: i for i, key in enumerate(keys)} 167 | self.resampling = resample 168 | self.img_preproc = img_preproc 169 | # If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up 170 | if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])): 171 | # Then this has an s3 link for the webdataset and we need extra packages 172 | if shutil.which("s3cmd") is None: 173 | raise RuntimeError("s3cmd is required for s3 webdataset") 174 | if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url): 175 | # Then the embeddings are being loaded from s3 and fsspec requires s3fs 176 | try: 177 | import s3fs 178 | except ImportError: 179 | raise RuntimeError("s3fs is required to load embeddings from s3") 180 | # Add the shardList and randomize or resample if requested 181 | if resample: 182 | assert not shuffle_shards, "Cannot both resample and shuffle" 183 | self.append(wds.ResampledShards(urls)) 184 | else: 185 | self.append(wds.SimpleShardList(urls)) 186 | if shuffle_shards: 187 | self.append(wds.filters.shuffle(1000)) 188 | 189 | if img_embedding_folder_url is not None: 190 | # There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues. 191 | self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler)) 192 | if text_embedding_folder_url is not None: 193 | self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler)) 194 | 195 | self.append(wds.tarfile_to_samples(handler=handler)) 196 | self.append(wds.decode("pilrgb", handler=handler)) 197 | if img_embedding_folder_url is not None: 198 | # Then we are loading image embeddings for a remote source 199 | assert index_width is not None, "Reading embeddings separately requires index width length to be given" 200 | self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler)) 201 | if text_embedding_folder_url is not None: 202 | # Then we are loading image embeddings for a remote source 203 | assert index_width is not None, "Reading embeddings separately requires index width length to be given" 204 | self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler)) 205 | self.append(join_embeddings) 206 | self.append(key_verifier(required_keys=keys, handler=handler)) 207 | # Apply preprocessing 208 | self.append(wds.map(self.preproc)) 209 | self.append(wds.to_tuple(*keys)) 210 | 211 | def preproc(self, sample): 212 | """Applies the preprocessing for images""" 213 | if self.img_preproc is not None: 214 | sample["jpg"] = self.img_preproc(sample["jpg"]) 215 | return sample 216 | 217 | def create_image_embedding_dataloader( 218 | tar_url, 219 | num_workers, 220 | batch_size, 221 | img_embeddings_url=None, 222 | text_embeddings_url=None, 223 | index_width=None, 224 | shuffle_num = None, 225 | shuffle_shards = True, 226 | resample_shards = False, 227 | img_preproc=None, 228 | extra_keys=[], 229 | handler=wds.handlers.reraise_exception#warn_and_continue 230 | ): 231 | """ 232 | Convenience function to create an image embedding dataseta and dataloader in one line 233 | 234 | :param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar 235 | :param num_workers: The number of workers to use for the dataloader 236 | :param batch_size: The batch size to use for the dataloader 237 | :param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset. 238 | Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros. 239 | :param index_width: The number of digits in the index. This is used to align the embedding index with the image index. 240 | For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width. 241 | :param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling. 242 | :param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true. 243 | :param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely. 244 | :param handler: A webdataset handler. 245 | """ 246 | ds = ImageEmbeddingDataset( 247 | tar_url, 248 | img_embedding_folder_url=img_embeddings_url, 249 | text_embedding_folder_url=text_embeddings_url, 250 | index_width=index_width, 251 | shuffle_shards=shuffle_shards, 252 | resample=resample_shards, 253 | extra_keys=extra_keys, 254 | img_preproc=img_preproc, 255 | handler=handler 256 | ) 257 | if shuffle_num is not None and shuffle_num > 0: 258 | ds.shuffle(1000) 259 | return DataLoader( 260 | ds, 261 | num_workers=num_workers, 262 | batch_size=batch_size, 263 | prefetch_factor=2, # This might be good to have high so the next npy file is prefetched 264 | pin_memory=True, 265 | shuffle=False 266 | ) 267 | -------------------------------------------------------------------------------- /dalle2_pytorch/dataloaders/prior_loader.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from clip import tokenize 3 | from embedding_reader import EmbeddingReader 4 | from torch import from_numpy 5 | from torch.utils.data import IterableDataset, DataLoader 6 | 7 | 8 | class PriorEmbeddingDataset(IterableDataset): 9 | """ 10 | PriorEmbeddingDataset is a wrapper of EmbeddingReader. 11 | 12 | It enables one to simplify the logic necessary to yield samples from 13 | the different EmbeddingReader configurations available. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | text_conditioned: bool, 19 | batch_size: int, 20 | start: int, 21 | stop: int, 22 | image_reader, 23 | text_reader: EmbeddingReader = None, 24 | ) -> None: 25 | super(PriorEmbeddingDataset).__init__() 26 | 27 | self.text_conditioned = text_conditioned 28 | 29 | if not self.text_conditioned: 30 | self.text_reader = text_reader 31 | 32 | self.image_reader = image_reader 33 | self.start = start 34 | self.stop = stop 35 | self.batch_size = batch_size 36 | 37 | def __len__(self): 38 | return self.stop - self.start 39 | 40 | def __iter__(self): 41 | # D.R.Y loader args 42 | loader_args = dict( 43 | batch_size=self.batch_size, 44 | start=self.start, 45 | end=self.stop, 46 | show_progress=False, 47 | ) 48 | 49 | # if the data requested is text conditioned, only load images 50 | if self.text_conditioned: 51 | self.loader = self.image_reader(**loader_args) 52 | # otherwise, include text embeddings and bypass metadata 53 | else: 54 | self.loader = zip( 55 | self.image_reader(**loader_args), self.text_reader(**loader_args) 56 | ) 57 | 58 | # return the data loader in its formatted state 59 | return self 60 | 61 | def __next__(self): 62 | try: 63 | return self.get_sample() 64 | except StopIteration: 65 | raise StopIteration 66 | 67 | def __str__(self): 68 | return f"" 69 | 70 | def set_start(self, start): 71 | """ 72 | Adjust the starting point within the reader, useful for resuming an epoch 73 | """ 74 | self.start = start 75 | 76 | def get_start(self): 77 | return self.start 78 | 79 | def get_sample(self): 80 | """ 81 | pre-proocess data from either reader into a common format 82 | """ 83 | if self.text_conditioned: 84 | image_embedding, caption = next(self.loader) 85 | 86 | image_embedding = from_numpy(image_embedding) 87 | tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True) 88 | 89 | return image_embedding, tokenized_caption 90 | 91 | else: 92 | (image_embedding, _), (text_embedding, _) = next(self.loader) 93 | 94 | image_embedding = from_numpy(image_embedding) 95 | text_embedding = from_numpy(text_embedding) 96 | 97 | return image_embedding, text_embedding 98 | 99 | 100 | # helper functions 101 | 102 | 103 | def distribute_to_rank(start, stop, rank, world_size): 104 | """ 105 | Distribute data to each rank given the world size. 106 | 107 | Return: 108 | - New start and stop points for this rank. 109 | """ 110 | num_samples = int(stop - start) 111 | 112 | per_rank = int(ceil((num_samples) / float(world_size))) 113 | 114 | assert ( 115 | per_rank > 0 116 | ), f"Number of samples per rank must be larger than 0, (found: {per_rank})" 117 | 118 | rank_start = start + rank * per_rank 119 | 120 | rank_stop = min(rank_start + per_rank, stop) 121 | 122 | new_length = rank_stop - rank_start 123 | 124 | assert ( 125 | new_length > 0 126 | ), "Calculated start and stop points result in a length of zero for this rank." 127 | 128 | return rank_start, rank_stop 129 | 130 | 131 | def get_reader( 132 | text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None 133 | ): 134 | """ 135 | Create an EmbeddingReader object from the specified URLs 136 | 137 | get_reader() will always expect a url to image embeddings. 138 | 139 | If text-conditioned, it will also expect a meta_url for the captions. 140 | Otherwise, it will need txt_url for the matching text embeddings. 141 | 142 | Returns an image_reader object if text-conditioned. 143 | Otherwise it returns both an image_reader and a text_reader 144 | """ 145 | 146 | assert img_url is not None, "Must supply a image url" 147 | 148 | if text_conditioned: 149 | assert meta_url is not None, "Must supply meta url if text-conditioned" 150 | 151 | image_reader = EmbeddingReader( 152 | embeddings_folder=img_url, 153 | file_format="parquet_npy", 154 | # will assume the caption column exists and is the only one requested 155 | meta_columns=["caption"], 156 | metadata_folder=meta_url, 157 | ) 158 | 159 | return image_reader 160 | 161 | # otherwise we will require text embeddings as well and return two readers 162 | assert ( 163 | txt_url is not None 164 | ), "Must supply text embedding url if not text-conditioning" 165 | 166 | image_reader = EmbeddingReader(img_url, file_format="npy") 167 | text_reader = EmbeddingReader(txt_url, file_format="npy") 168 | 169 | return image_reader, text_reader 170 | 171 | 172 | def make_splits( 173 | text_conditioned: bool, 174 | batch_size: int, 175 | num_data_points: int, 176 | train_split: float, 177 | eval_split: float, 178 | image_reader: EmbeddingReader, 179 | text_reader: EmbeddingReader = None, 180 | start=0, 181 | rank=0, 182 | world_size=1, 183 | ): 184 | """ 185 | Split an embedding reader object as needed. 186 | 187 | NOTE: make_splits() will infer the test set size from your train and eval. 188 | 189 | Input: 190 | - text_conditioned: whether to prepare text-conditioned training data 191 | - batch_size: the batch size for a single gpu 192 | - num_data_points: the total number of data points you wish to train on 193 | - train_split: the percentage of data you wish to train on 194 | - eval_split: the percentage of data you wish to validate on 195 | - image_reader: the image_reader you wish to split 196 | - text_reader: the text_reader you want to split (if !text_conditioned) 197 | - start: the starting point within your dataset 198 | - rank: the rank of your worker 199 | - world_size: the total world size of your distributed training run 200 | 201 | Returns: 202 | - PyTorch Dataloaders that yield tuples of (img, txt) data. 203 | """ 204 | 205 | assert start < image_reader.count, "start position cannot exceed reader count." 206 | 207 | # verify that the num_data_points does not exceed the max points 208 | if num_data_points > (image_reader.count - start): 209 | print( 210 | "Specified count is larger than what's available...defaulting to reader's count." 211 | ) 212 | num_data_points = image_reader.count 213 | 214 | # compute split points 215 | train_set_size = int(train_split * num_data_points) 216 | eval_set_size = int(eval_split * num_data_points) 217 | eval_start = train_set_size 218 | eval_stop = int(eval_start + eval_set_size) 219 | 220 | assert ( 221 | train_split + eval_split 222 | ) < 1.0, "Specified train and eval split is too large to infer a test split." 223 | 224 | # distribute to rank 225 | rank_train_start, rank_train_stop = distribute_to_rank( 226 | start, train_set_size, rank, world_size 227 | ) 228 | rank_eval_start, rank_eval_stop = distribute_to_rank( 229 | train_set_size, eval_stop, rank, world_size 230 | ) 231 | rank_test_start, rank_test_stop = distribute_to_rank( 232 | eval_stop, num_data_points, rank, world_size 233 | ) 234 | 235 | # wrap up splits into a dict 236 | train_split_args = dict( 237 | start=rank_train_start, stop=rank_train_stop, batch_size=batch_size 238 | ) 239 | eval_split_args = dict( 240 | start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size 241 | ) 242 | test_split_args = dict( 243 | start=rank_test_start, stop=rank_test_stop, batch_size=batch_size 244 | ) 245 | 246 | if text_conditioned: 247 | # add the text-conditioned args to a unified dict 248 | reader_args = dict( 249 | text_conditioned=text_conditioned, 250 | image_reader=image_reader, 251 | ) 252 | 253 | train_split_args = dict(**reader_args, **train_split_args) 254 | eval_split_args = dict(**reader_args, **eval_split_args) 255 | test_split_args = dict(**reader_args, **test_split_args) 256 | 257 | train = PriorEmbeddingDataset(**train_split_args) 258 | val = PriorEmbeddingDataset(**eval_split_args) 259 | test = PriorEmbeddingDataset(**test_split_args) 260 | 261 | else: 262 | # add the non-conditioned args to a unified dict 263 | reader_args = dict( 264 | text_conditioned=text_conditioned, 265 | image_reader=image_reader, 266 | text_reader=text_reader, 267 | ) 268 | 269 | train_split_args = dict(**reader_args, **train_split_args) 270 | eval_split_args = dict(**reader_args, **eval_split_args) 271 | test_split_args = dict(**reader_args, **test_split_args) 272 | 273 | train = PriorEmbeddingDataset(**train_split_args) 274 | val = PriorEmbeddingDataset(**eval_split_args) 275 | test = PriorEmbeddingDataset(**test_split_args) 276 | 277 | # true batch size is specifed in the PriorEmbeddingDataset 278 | train_loader = DataLoader(train, batch_size=None) 279 | eval_loader = DataLoader(val, batch_size=None) 280 | test_loader = DataLoader(test, batch_size=None) 281 | 282 | return train_loader, eval_loader, test_loader 283 | -------------------------------------------------------------------------------- /dalle2_pytorch/dataloaders/simple_image_only_dataloader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from torch.utils import data 5 | from torchvision import transforms, utils 6 | 7 | from PIL import Image 8 | 9 | # helpers functions 10 | 11 | def cycle(dl): 12 | while True: 13 | for data in dl: 14 | yield data 15 | 16 | # dataset and dataloader 17 | 18 | class Dataset(data.Dataset): 19 | def __init__( 20 | self, 21 | folder, 22 | image_size, 23 | exts = ['jpg', 'jpeg', 'png'] 24 | ): 25 | super().__init__() 26 | self.folder = folder 27 | self.image_size = image_size 28 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] 29 | 30 | self.transform = transforms.Compose([ 31 | transforms.Resize(image_size), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.CenterCrop(image_size), 34 | transforms.ToTensor() 35 | ]) 36 | 37 | def __len__(self): 38 | return len(self.paths) 39 | 40 | def __getitem__(self, index): 41 | path = self.paths[index] 42 | img = Image.open(path) 43 | return self.transform(img) 44 | 45 | def get_images_dataloader( 46 | folder, 47 | *, 48 | batch_size, 49 | image_size, 50 | shuffle = True, 51 | cycle_dl = True, 52 | pin_memory = True 53 | ): 54 | ds = Dataset(folder, image_size) 55 | dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory) 56 | 57 | if cycle_dl: 58 | dl = cycle(dl) 59 | return dl 60 | -------------------------------------------------------------------------------- /dalle2_pytorch/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import AdamW, Adam 2 | 3 | def separate_weight_decayable_params(params): 4 | wd_params, no_wd_params = [], [] 5 | for param in params: 6 | param_list = no_wd_params if param.ndim < 2 else wd_params 7 | param_list.append(param) 8 | return wd_params, no_wd_params 9 | 10 | def get_optimizer( 11 | params, 12 | lr = 1e-4, 13 | wd = 1e-2, 14 | betas = (0.9, 0.99), 15 | eps = 1e-8, 16 | filter_by_requires_grad = False, 17 | group_wd_params = True, 18 | **kwargs 19 | ): 20 | if filter_by_requires_grad: 21 | params = list(filter(lambda t: t.requires_grad, params)) 22 | 23 | if wd == 0: 24 | return Adam(params, lr = lr, betas = betas, eps = eps) 25 | 26 | if group_wd_params: 27 | wd_params, no_wd_params = separate_weight_decayable_params(params) 28 | 29 | params = [ 30 | {'params': wd_params}, 31 | {'params': no_wd_params, 'weight_decay': 0}, 32 | ] 33 | 34 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) 35 | -------------------------------------------------------------------------------- /dalle2_pytorch/tokenizer.py: -------------------------------------------------------------------------------- 1 | # take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py 2 | # to give users a quick easy start to training DALL-E without doing BPE 3 | 4 | import torch 5 | 6 | import html 7 | import os 8 | import ftfy 9 | import regex as re 10 | from functools import lru_cache 11 | from pathlib import Path 12 | 13 | from dalle2_pytorch.utils import import_or_print_error 14 | 15 | # OpenAI simple tokenizer 16 | 17 | @lru_cache() 18 | def default_bpe(): 19 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt") 20 | 21 | @lru_cache() 22 | def bytes_to_unicode(): 23 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 24 | cs = bs[:] 25 | n = 0 26 | for b in range(2 ** 8): 27 | if b not in bs: 28 | bs.append(b) 29 | cs.append(2 ** 8 + n) 30 | n += 1 31 | cs = [chr(n) for n in cs] 32 | return dict(zip(bs, cs)) 33 | 34 | def get_pairs(word): 35 | pairs = set() 36 | prev_char = word[0] 37 | for char in word[1:]: 38 | pairs.add((prev_char, char)) 39 | prev_char = char 40 | return pairs 41 | 42 | def basic_clean(text): 43 | text = ftfy.fix_text(text) 44 | text = html.unescape(html.unescape(text)) 45 | return text.strip() 46 | 47 | def whitespace_clean(text): 48 | text = re.sub(r'\s+', ' ', text) 49 | text = text.strip() 50 | return text 51 | 52 | class SimpleTokenizer(object): 53 | def __init__(self, bpe_path = default_bpe()): 54 | self.byte_encoder = bytes_to_unicode() 55 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 56 | merges = Path(bpe_path).read_text(encoding='utf8').split('\n') 57 | merges = merges[1:49152 - 256 - 2 + 1] 58 | merges = [tuple(merge.split()) for merge in merges] 59 | vocab = list(bytes_to_unicode().values()) 60 | vocab = vocab + [v + '' for v in vocab] 61 | for merge in merges: 62 | vocab.append(''.join(merge)) 63 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 64 | 65 | self.vocab_size = 49408 66 | 67 | self.encoder = dict(zip(vocab, range(len(vocab)))) 68 | self.decoder = {v: k for k, v in self.encoder.items()} 69 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 70 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 71 | self.pat = re.compile( 72 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 73 | re.IGNORECASE) 74 | 75 | def bpe(self, token): 76 | if token in self.cache: 77 | return self.cache[token] 78 | word = tuple(token[:-1]) + (token[-1] + '',) 79 | pairs = get_pairs(word) 80 | 81 | if not pairs: 82 | return token + '' 83 | 84 | while True: 85 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 86 | if bigram not in self.bpe_ranks: 87 | break 88 | first, second = bigram 89 | new_word = [] 90 | i = 0 91 | while i < len(word): 92 | try: 93 | j = word.index(first, i) 94 | new_word.extend(word[i:j]) 95 | i = j 96 | except: 97 | new_word.extend(word[i:]) 98 | break 99 | 100 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 101 | new_word.append(first + second) 102 | i += 2 103 | else: 104 | new_word.append(word[i]) 105 | i += 1 106 | new_word = tuple(new_word) 107 | word = new_word 108 | if len(word) == 1: 109 | break 110 | else: 111 | pairs = get_pairs(word) 112 | word = ' '.join(word) 113 | self.cache[token] = word 114 | return word 115 | 116 | def encode(self, text): 117 | bpe_tokens = [] 118 | text = whitespace_clean(basic_clean(text)).lower() 119 | for token in re.findall(self.pat, text): 120 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 121 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 122 | return bpe_tokens 123 | 124 | def decode(self, tokens, remove_start_end = True, pad_tokens = set()): 125 | if torch.is_tensor(tokens): 126 | tokens = tokens.tolist() 127 | 128 | if remove_start_end: 129 | tokens = [token for token in tokens if token not in (49406, 40407, 0)] 130 | text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | 134 | def tokenize(self, texts, context_length = 256, truncate_text = False): 135 | if isinstance(texts, str): 136 | texts = [texts] 137 | 138 | all_tokens = [self.encode(text) for text in texts] 139 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 140 | 141 | for i, tokens in enumerate(all_tokens): 142 | if len(tokens) > context_length: 143 | if truncate_text: 144 | tokens = tokens[:context_length] 145 | else: 146 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 147 | result[i, :len(tokens)] = torch.tensor(tokens) 148 | 149 | return result 150 | 151 | tokenizer = SimpleTokenizer() 152 | 153 | # YTTM tokenizer 154 | 155 | class YttmTokenizer: 156 | def __init__(self, bpe_path = None): 157 | bpe_path = Path(bpe_path) 158 | assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist' 159 | 160 | self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`') 161 | 162 | tokenizer = self.yttm.BPE(model = str(bpe_path)) 163 | self.tokenizer = tokenizer 164 | self.vocab_size = tokenizer.vocab_size() 165 | 166 | def decode(self, tokens, pad_tokens = set()): 167 | if torch.is_tensor(tokens): 168 | tokens = tokens.tolist() 169 | 170 | return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0})) 171 | 172 | def encode(self, texts): 173 | encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID) 174 | return list(map(torch.tensor, encoded)) 175 | 176 | def tokenize(self, texts, context_length = 256, truncate_text = False): 177 | if isinstance(texts, str): 178 | texts = [texts] 179 | 180 | all_tokens = self.encode(texts) 181 | 182 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 183 | for i, tokens in enumerate(all_tokens): 184 | if len(tokens) > context_length: 185 | if truncate_text: 186 | tokens = tokens[:context_length] 187 | else: 188 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 189 | result[i, :len(tokens)] = torch.tensor(tokens) 190 | 191 | return result 192 | -------------------------------------------------------------------------------- /dalle2_pytorch/trackers.py: -------------------------------------------------------------------------------- 1 | import urllib.request 2 | import os 3 | import json 4 | from pathlib import Path 5 | import shutil 6 | from itertools import zip_longest 7 | from typing import Any, Optional, List, Union 8 | from pydantic import BaseModel 9 | 10 | import torch 11 | from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior 12 | from dalle2_pytorch.utils import import_or_print_error 13 | from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer 14 | from dalle2_pytorch.version import __version__ 15 | from packaging import version 16 | 17 | # constants 18 | 19 | DEFAULT_DATA_PATH = './.tracker-data' 20 | 21 | # helper functions 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | class BaseLogger: 27 | """ 28 | An abstract class representing an object that can log data. 29 | Parameters: 30 | data_path (str): A file path for storing temporary data. 31 | verbose (bool): Whether of not to always print logs to the console. 32 | """ 33 | def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs): 34 | self.data_path = Path(data_path) 35 | self.resume = resume 36 | self.auto_resume = auto_resume 37 | self.verbose = verbose 38 | 39 | def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None: 40 | """ 41 | Initializes the logger. 42 | Errors if the logger is invalid. 43 | full_config is the config file dict while extra_config is anything else from the script that is not defined the config file. 44 | """ 45 | raise NotImplementedError 46 | 47 | def log(self, log, **kwargs) -> None: 48 | raise NotImplementedError 49 | 50 | def log_images(self, images, captions=[], image_section="images", **kwargs) -> None: 51 | raise NotImplementedError 52 | 53 | def log_file(self, file_path, **kwargs) -> None: 54 | raise NotImplementedError 55 | 56 | def log_error(self, error_string, **kwargs) -> None: 57 | raise NotImplementedError 58 | 59 | def get_resume_data(self, **kwargs) -> dict: 60 | """ 61 | Sets tracker attributes that along with { "resume": True } will be used to resume training. 62 | It is assumed that after init is called this data will be complete. 63 | If the logger does not have any resume functionality, it should return an empty dict. 64 | """ 65 | raise NotImplementedError 66 | 67 | class ConsoleLogger(BaseLogger): 68 | def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None: 69 | print("Logging to console") 70 | 71 | def log(self, log, **kwargs) -> None: 72 | print(log) 73 | 74 | def log_images(self, images, captions=[], image_section="images", **kwargs) -> None: 75 | pass 76 | 77 | def log_file(self, file_path, **kwargs) -> None: 78 | pass 79 | 80 | def log_error(self, error_string, **kwargs) -> None: 81 | print(error_string) 82 | 83 | def get_resume_data(self, **kwargs) -> dict: 84 | return {} 85 | 86 | class WandbLogger(BaseLogger): 87 | """ 88 | Logs to a wandb run. 89 | Parameters: 90 | data_path (str): A file path for storing temporary data. 91 | wandb_entity (str): The wandb entity to log to. 92 | wandb_project (str): The wandb project to log to. 93 | wandb_run_id (str): The wandb run id to resume. 94 | wandb_run_name (str): The wandb run name to use. 95 | """ 96 | def __init__(self, 97 | data_path: str, 98 | wandb_entity: str, 99 | wandb_project: str, 100 | wandb_run_id: Optional[str] = None, 101 | wandb_run_name: Optional[str] = None, 102 | **kwargs 103 | ): 104 | super().__init__(data_path, **kwargs) 105 | self.entity = wandb_entity 106 | self.project = wandb_project 107 | self.run_id = wandb_run_id 108 | self.run_name = wandb_run_name 109 | 110 | def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None: 111 | assert self.entity is not None, "wandb_entity must be specified for wandb logger" 112 | assert self.project is not None, "wandb_project must be specified for wandb logger" 113 | self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger') 114 | os.environ["WANDB_SILENT"] = "true" 115 | # Initializes the wandb run 116 | init_object = { 117 | "entity": self.entity, 118 | "project": self.project, 119 | "config": {**full_config.dict(), **extra_config} 120 | } 121 | if self.run_name is not None: 122 | init_object['name'] = self.run_name 123 | if self.resume: 124 | assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True' 125 | if self.run_name is not None: 126 | print("You are renaming a run. I hope that is what you intended.") 127 | init_object['resume'] = 'must' 128 | init_object['id'] = self.run_id 129 | 130 | self.wandb.init(**init_object) 131 | print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}") 132 | 133 | def log(self, log, **kwargs) -> None: 134 | if self.verbose: 135 | print(log) 136 | self.wandb.log(log, **kwargs) 137 | 138 | def log_images(self, images, captions=[], image_section="images", **kwargs) -> None: 139 | """ 140 | Takes a tensor of images and a list of captions and logs them to wandb. 141 | """ 142 | wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)] 143 | self.wandb.log({ image_section: wandb_images }, **kwargs) 144 | 145 | def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None: 146 | if base_path is None: 147 | # Then we take the basepath as the parent of the file_path 148 | base_path = Path(file_path).parent 149 | self.wandb.save(str(file_path), base_path = str(base_path)) 150 | 151 | def log_error(self, error_string, step=None, **kwargs) -> None: 152 | if self.verbose: 153 | print(error_string) 154 | self.wandb.log({"error": error_string, **kwargs}, step=step) 155 | 156 | def get_resume_data(self, **kwargs) -> dict: 157 | # In order to resume, we need wandb_entity, wandb_project, and wandb_run_id 158 | return { 159 | "entity": self.entity, 160 | "project": self.project, 161 | "run_id": self.wandb.run.id 162 | } 163 | 164 | logger_type_map = { 165 | 'console': ConsoleLogger, 166 | 'wandb': WandbLogger, 167 | } 168 | def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger: 169 | if logger_type == 'custom': 170 | raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.') 171 | try: 172 | logger_class = logger_type_map[logger_type] 173 | except KeyError: 174 | raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}') 175 | return logger_class(data_path, **kwargs) 176 | 177 | class BaseLoader: 178 | """ 179 | An abstract class representing an object that can load a model checkpoint. 180 | Parameters: 181 | data_path (str): A file path for storing temporary data. 182 | """ 183 | def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs): 184 | self.data_path = Path(data_path) 185 | self.only_auto_resume = only_auto_resume 186 | 187 | def init(self, logger: BaseLogger, **kwargs) -> None: 188 | raise NotImplementedError 189 | 190 | def recall() -> dict: 191 | raise NotImplementedError 192 | 193 | class UrlLoader(BaseLoader): 194 | """ 195 | A loader that downloads the file from a url and loads it 196 | Parameters: 197 | data_path (str): A file path for storing temporary data. 198 | url (str): The url to download the file from. 199 | """ 200 | def __init__(self, data_path: str, url: str, **kwargs): 201 | super().__init__(data_path, **kwargs) 202 | self.url = url 203 | 204 | def init(self, logger: BaseLogger, **kwargs) -> None: 205 | # Makes sure the file exists to be downloaded 206 | pass # TODO: Actually implement that 207 | 208 | def recall(self) -> dict: 209 | # Download the file 210 | save_path = self.data_path / 'loaded_checkpoint.pth' 211 | urllib.request.urlretrieve(self.url, str(save_path)) 212 | # Load the file 213 | return torch.load(str(save_path), map_location='cpu') 214 | 215 | 216 | class LocalLoader(BaseLoader): 217 | """ 218 | A loader that loads a file from a local path 219 | Parameters: 220 | data_path (str): A file path for storing temporary data. 221 | file_path (str): The path to the file to load. 222 | """ 223 | def __init__(self, data_path: str, file_path: str, **kwargs): 224 | super().__init__(data_path, **kwargs) 225 | self.file_path = Path(file_path) 226 | 227 | def init(self, logger: BaseLogger, **kwargs) -> None: 228 | # Makes sure the file exists to be loaded 229 | if not self.file_path.exists() and not self.only_auto_resume: 230 | raise FileNotFoundError(f'Model not found at {self.file_path}') 231 | 232 | def recall(self) -> dict: 233 | # Load the file 234 | return torch.load(str(self.file_path), map_location='cpu') 235 | 236 | class WandbLoader(BaseLoader): 237 | """ 238 | A loader that loads a model from an existing wandb run 239 | """ 240 | def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs): 241 | super().__init__(data_path, **kwargs) 242 | self.run_path = wandb_run_path 243 | self.file_path = wandb_file_path 244 | 245 | def init(self, logger: BaseLogger, **kwargs) -> None: 246 | self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function') 247 | # Make sure the file can be downloaded 248 | if self.wandb.run is not None and self.run_path is None: 249 | self.run_path = self.wandb.run.path 250 | assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.' 251 | assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader' 252 | assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader' 253 | 254 | os.environ["WANDB_SILENT"] = "true" 255 | pass # TODO: Actually implement that 256 | 257 | def recall(self) -> dict: 258 | file_reference = self.wandb.restore(self.file_path, run_path=self.run_path) 259 | return torch.load(file_reference.name, map_location='cpu') 260 | 261 | loader_type_map = { 262 | 'url': UrlLoader, 263 | 'local': LocalLoader, 264 | 'wandb': WandbLoader, 265 | } 266 | def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader: 267 | if loader_type == 'custom': 268 | raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.') 269 | try: 270 | loader_class = loader_type_map[loader_type] 271 | except KeyError: 272 | raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}') 273 | return loader_class(data_path, **kwargs) 274 | 275 | class BaseSaver: 276 | def __init__(self, 277 | data_path: str, 278 | save_latest_to: Optional[Union[str, bool]] = None, 279 | save_best_to: Optional[Union[str, bool]] = None, 280 | save_meta_to: Optional[str] = None, 281 | save_type: str = 'checkpoint', 282 | **kwargs 283 | ): 284 | self.data_path = Path(data_path) 285 | self.save_latest_to = save_latest_to 286 | self.saving_latest = save_latest_to is not None and save_latest_to is not False 287 | self.save_best_to = save_best_to 288 | self.saving_best = save_best_to is not None and save_best_to is not False 289 | self.save_meta_to = save_meta_to 290 | self.saving_meta = save_meta_to is not None 291 | self.save_type = save_type 292 | assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`' 293 | assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified' 294 | 295 | def init(self, logger: BaseLogger, **kwargs) -> None: 296 | raise NotImplementedError 297 | 298 | def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None: 299 | """ 300 | Save a general file under save_meta_to 301 | """ 302 | raise NotImplementedError 303 | 304 | class LocalSaver(BaseSaver): 305 | def __init__(self, 306 | data_path: str, 307 | **kwargs 308 | ): 309 | super().__init__(data_path, **kwargs) 310 | 311 | def init(self, logger: BaseLogger, **kwargs) -> None: 312 | # Makes sure the directory exists to be saved to 313 | print(f"Saving {self.save_type} locally") 314 | if not self.data_path.exists(): 315 | self.data_path.mkdir(parents=True) 316 | 317 | def save_file(self, local_path: str, save_path: str, **kwargs) -> None: 318 | # Copy the file to save_path 319 | save_path_file_name = Path(save_path).name 320 | # Make sure parent directory exists 321 | save_path_parent = Path(save_path).parent 322 | if not save_path_parent.exists(): 323 | save_path_parent.mkdir(parents=True) 324 | print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}") 325 | shutil.copy(local_path, save_path) 326 | 327 | class WandbSaver(BaseSaver): 328 | def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs): 329 | super().__init__(data_path, **kwargs) 330 | self.run_path = wandb_run_path 331 | 332 | def init(self, logger: BaseLogger, **kwargs) -> None: 333 | self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger') 334 | os.environ["WANDB_SILENT"] = "true" 335 | # Makes sure that the user can upload tot his run 336 | if self.run_path is not None: 337 | entity, project, run_id = self.run_path.split("/") 338 | self.run = self.wandb.init(entity=entity, project=project, id=run_id) 339 | else: 340 | assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`' 341 | self.run = self.wandb.run 342 | # TODO: Now actually check if upload is possible 343 | print(f"Saving to wandb run {self.run.path}-{self.run.name}") 344 | 345 | def save_file(self, local_path: Path, save_path: str, **kwargs) -> None: 346 | # In order to log something in the correct place in wandb, we need to have the same file structure here 347 | save_path_file_name = Path(save_path).name 348 | print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}") 349 | save_path = Path(self.data_path) / save_path 350 | save_path.parent.mkdir(parents=True, exist_ok=True) 351 | shutil.copy(local_path, save_path) 352 | self.run.save(str(save_path), base_path = str(self.data_path), policy='now') 353 | 354 | class HuggingfaceSaver(BaseSaver): 355 | def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs): 356 | super().__init__(data_path, **kwargs) 357 | self.huggingface_repo = huggingface_repo 358 | self.token_path = token_path 359 | 360 | def init(self, logger: BaseLogger, **kwargs): 361 | # Makes sure this user can upload to the repo 362 | self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver') 363 | try: 364 | identity = self.hub.whoami() # Errors if not logged in 365 | # Then we are logged in 366 | except: 367 | # We are not logged in. Use the token_path to set the token. 368 | if not os.path.exists(self.token_path): 369 | raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.") 370 | with open(self.token_path, "r") as f: 371 | token = f.read().strip() 372 | self.hub.HfApi.set_access_token(token) 373 | identity = self.hub.whoami() 374 | print(f"Saving to huggingface repo {self.huggingface_repo}") 375 | 376 | def save_file(self, local_path: Path, save_path: str, **kwargs) -> None: 377 | # Saving to huggingface is easy, we just need to upload the file with the correct name 378 | save_path_file_name = Path(save_path).name 379 | print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}") 380 | self.hub.upload_file( 381 | path_or_fileobj=str(local_path), 382 | path_in_repo=str(save_path), 383 | repo_id=self.huggingface_repo 384 | ) 385 | 386 | saver_type_map = { 387 | 'local': LocalSaver, 388 | 'wandb': WandbSaver, 389 | 'huggingface': HuggingfaceSaver 390 | } 391 | def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver: 392 | if saver_type == 'custom': 393 | raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.') 394 | try: 395 | saver_class = saver_type_map[saver_type] 396 | except KeyError: 397 | raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}') 398 | return saver_class(data_path, **kwargs) 399 | 400 | 401 | class Tracker: 402 | def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False): 403 | self.data_path = Path(data_path) 404 | if not dummy_mode: 405 | if not overwrite_data_path: 406 | assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.' 407 | if not self.data_path.exists(): 408 | self.data_path.mkdir(parents=True) 409 | self.logger: BaseLogger = None 410 | self.loader: Optional[BaseLoader] = None 411 | self.savers: List[BaseSaver]= [] 412 | self.dummy_mode = dummy_mode 413 | 414 | def _load_auto_resume(self) -> bool: 415 | # If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run. 416 | if not self.auto_resume_path.exists(): 417 | if self.logger.auto_resume: 418 | print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.") 419 | return False 420 | 421 | # Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time 422 | if not self.logger.auto_resume: 423 | print(f'Removing auto_resume.json because auto_resume is not enabled in the config') 424 | self.auto_resume_path.unlink() 425 | return False 426 | 427 | # Otherwise we read the json into a dictionary will will override parts of logger.__dict__ 428 | with open(self.auto_resume_path, 'r') as f: 429 | auto_resume_dict = json.load(f) 430 | # Check if the logger is of the same type as the autoresume save 431 | if auto_resume_dict["logger_type"] != self.logger.__class__.__name__: 432 | raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.') 433 | # Then we are ready to override the logger with the autoresume save 434 | self.logger.__dict__["resume"] = True 435 | print(f"Updating {self.logger.__dict__} with {auto_resume_dict}") 436 | self.logger.__dict__.update(auto_resume_dict) 437 | return True 438 | 439 | def _save_auto_resume(self): 440 | # Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file 441 | auto_resume_dict = self.logger.get_resume_data() 442 | auto_resume_dict['logger_type'] = self.logger.__class__.__name__ 443 | with open(self.auto_resume_path, 'w') as f: 444 | json.dump(auto_resume_dict, f) 445 | 446 | def init(self, full_config: BaseModel, extra_config: dict): 447 | self.auto_resume_path = self.data_path / 'auto_resume.json' 448 | # Check for resuming the run 449 | self.did_auto_resume = self._load_auto_resume() 450 | if self.did_auto_resume: 451 | print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n') 452 | print(f"New logger config: {self.logger.__dict__}") 453 | 454 | self.save_metadata = dict( 455 | version = version.parse(__version__) 456 | ) # Data that will be saved alongside the checkpoint or model 457 | self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata 458 | 459 | assert self.logger is not None, '`logger` must be set before `init` is called' 460 | if self.dummy_mode: 461 | # The only thing we need is a loader 462 | if self.loader is not None: 463 | self.loader.init(self.logger) 464 | return 465 | assert len(self.savers) > 0, '`savers` must be set before `init` is called' 466 | 467 | self.logger.init(full_config, extra_config) 468 | if self.loader is not None: 469 | self.loader.init(self.logger) 470 | for saver in self.savers: 471 | saver.init(self.logger) 472 | 473 | if self.logger.auto_resume: 474 | # Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved. 475 | self._save_auto_resume() 476 | 477 | def add_logger(self, logger: BaseLogger): 478 | self.logger = logger 479 | 480 | def add_loader(self, loader: BaseLoader): 481 | self.loader = loader 482 | 483 | def add_saver(self, saver: BaseSaver): 484 | self.savers.append(saver) 485 | 486 | def log(self, *args, **kwargs): 487 | if self.dummy_mode: 488 | return 489 | self.logger.log(*args, **kwargs) 490 | 491 | def log_images(self, *args, **kwargs): 492 | if self.dummy_mode: 493 | return 494 | self.logger.log_images(*args, **kwargs) 495 | 496 | def log_file(self, *args, **kwargs): 497 | if self.dummy_mode: 498 | return 499 | self.logger.log_file(*args, **kwargs) 500 | 501 | def save_config(self, current_config_path: str, config_name = 'config.json'): 502 | if self.dummy_mode: 503 | return 504 | # Save the config under config_name in the root folder of data_path 505 | shutil.copy(current_config_path, self.data_path / config_name) 506 | for saver in self.savers: 507 | if saver.saving_meta: 508 | remote_path = Path(saver.save_meta_to) / config_name 509 | saver.save_file(current_config_path, str(remote_path)) 510 | 511 | def add_save_metadata(self, state_dict_key: str, metadata: Any): 512 | """ 513 | Adds a new piece of metadata that will be saved along with the model or decoder. 514 | """ 515 | self.save_metadata[state_dict_key] = metadata 516 | 517 | def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path: 518 | """ 519 | Gets the state dict to be saved and writes it to file_path. 520 | If save_type is 'checkpoint', we save the entire trainer state dict. 521 | If save_type is 'model', we save only the model state dict. 522 | """ 523 | assert save_type in ['checkpoint', 'model'] 524 | if save_type == 'checkpoint': 525 | # Create a metadata dict without the blacklisted keys so we do not error when we create the state dict 526 | metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys} 527 | trainer.save(file_path, overwrite=True, **kwargs, **metadata) 528 | elif save_type == 'model': 529 | if isinstance(trainer, DiffusionPriorTrainer): 530 | prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior 531 | prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior) 532 | # Remove CLIP if it is part of the model 533 | original_clip = prior.clip 534 | prior.clip = None 535 | model_state_dict = prior.state_dict() 536 | prior.clip = original_clip 537 | elif isinstance(trainer, DecoderTrainer): 538 | decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder) 539 | # Remove CLIP if it is part of the model 540 | original_clip = decoder.clip 541 | decoder.clip = None 542 | if trainer.use_ema: 543 | trainable_unets = decoder.unets 544 | decoder.unets = trainer.unets # Swap EMA unets in 545 | model_state_dict = decoder.state_dict() 546 | decoder.unets = trainable_unets # Swap back 547 | else: 548 | model_state_dict = decoder.state_dict() 549 | decoder.clip = original_clip 550 | else: 551 | raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?') 552 | state_dict = { 553 | **self.save_metadata, 554 | 'model': model_state_dict 555 | } 556 | torch.save(state_dict, file_path) 557 | return Path(file_path) 558 | 559 | def save(self, trainer, is_best: bool, is_latest: bool, **kwargs): 560 | if self.dummy_mode: 561 | return 562 | if not is_best and not is_latest: 563 | # Nothing to do 564 | return 565 | # Save the checkpoint and model to data_path 566 | checkpoint_path = self.data_path / 'checkpoint.pth' 567 | self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs) 568 | model_path = self.data_path / 'model.pth' 569 | self._save_state_dict(trainer, 'model', model_path, **kwargs) 570 | print("Saved cached models") 571 | # Call the save methods on the savers 572 | for saver in self.savers: 573 | local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path 574 | if saver.saving_latest and is_latest: 575 | latest_checkpoint_path = saver.save_latest_to.format(**kwargs) 576 | try: 577 | saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs) 578 | except Exception as e: 579 | self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs) 580 | print(f'Error saving checkpoint: {e}') 581 | if saver.saving_best and is_best: 582 | best_checkpoint_path = saver.save_best_to.format(**kwargs) 583 | try: 584 | saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs) 585 | except Exception as e: 586 | self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs) 587 | print(f'Error saving checkpoint: {e}') 588 | 589 | @property 590 | def can_recall(self): 591 | # Defines whether a recall can be performed. 592 | return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume) 593 | 594 | def recall(self): 595 | if self.can_recall: 596 | return self.loader.recall() 597 | else: 598 | raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.') 599 | 600 | 601 | -------------------------------------------------------------------------------- /dalle2_pytorch/train_configs.py: -------------------------------------------------------------------------------- 1 | import json 2 | from torchvision import transforms as T 3 | from pydantic import BaseModel, validator, model_validator 4 | from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar 5 | 6 | from x_clip import CLIP as XCLIP 7 | from open_clip import list_pretrained 8 | from coca_pytorch import CoCa 9 | 10 | from dalle2_pytorch.dalle2_pytorch import ( 11 | CoCaAdapter, 12 | OpenAIClipAdapter, 13 | OpenClipAdapter, 14 | Unet, 15 | Decoder, 16 | DiffusionPrior, 17 | DiffusionPriorNetwork, 18 | XClipAdapter 19 | ) 20 | from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver 21 | 22 | # helper functions 23 | 24 | def exists(val): 25 | return val is not None 26 | 27 | def default(val, d): 28 | return val if exists(val) else d 29 | 30 | InnerType = TypeVar('InnerType') 31 | ListOrTuple = Union[List[InnerType], Tuple[InnerType]] 32 | SingularOrIterable = Union[InnerType, ListOrTuple[InnerType]] 33 | 34 | # general pydantic classes 35 | 36 | class TrainSplitConfig(BaseModel): 37 | train: float = 0.75 38 | val: float = 0.15 39 | test: float = 0.1 40 | 41 | @model_validator(mode = 'after') 42 | def validate_all(self, m): 43 | actual_sum = sum([*dict(self).values()]) 44 | if actual_sum != 1.: 45 | raise ValueError(f'{dict(self).keys()} must sum to 1.0. Found: {actual_sum}') 46 | return self 47 | 48 | class TrackerLogConfig(BaseModel): 49 | log_type: str = 'console' 50 | resume: bool = False # For logs that are saved to unique locations, resume a previous run 51 | auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed 52 | verbose: bool = False 53 | 54 | class Config: 55 | # Each individual log type has it's own arguments that will be passed through the config 56 | extra = "allow" 57 | 58 | def create(self, data_path: str): 59 | kwargs = self.dict() 60 | return create_logger(self.log_type, data_path, **kwargs) 61 | 62 | 63 | class TrackerLoadConfig(BaseModel): 64 | load_from: Optional[str] = None 65 | only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming 66 | 67 | class Config: 68 | extra = "allow" 69 | 70 | def create(self, data_path: str): 71 | kwargs = self.dict() 72 | if self.load_from is None: 73 | return None 74 | return create_loader(self.load_from, data_path, **kwargs) 75 | 76 | class TrackerSaveConfig(BaseModel): 77 | save_to: str = 'local' 78 | save_all: bool = False 79 | save_latest: bool = True 80 | save_best: bool = True 81 | 82 | class Config: 83 | extra = "allow" 84 | 85 | def create(self, data_path: str): 86 | kwargs = self.dict() 87 | return create_saver(self.save_to, data_path, **kwargs) 88 | 89 | class TrackerConfig(BaseModel): 90 | data_path: str = '.tracker_data' 91 | overwrite_data_path: bool = False 92 | log: TrackerLogConfig 93 | load: Optional[TrackerLoadConfig] = None 94 | save: Union[List[TrackerSaveConfig], TrackerSaveConfig] 95 | 96 | def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker: 97 | tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path) 98 | # Add the logger 99 | tracker.add_logger(self.log.create(self.data_path)) 100 | # Add the loader 101 | if self.load is not None: 102 | tracker.add_loader(self.load.create(self.data_path)) 103 | # Add the saver or savers 104 | if isinstance(self.save, list): 105 | for save_config in self.save: 106 | tracker.add_saver(save_config.create(self.data_path)) 107 | else: 108 | tracker.add_saver(self.save.create(self.data_path)) 109 | # Initialize all the components and verify that all data is valid 110 | tracker.init(full_config, extra_config) 111 | return tracker 112 | 113 | # diffusion prior pydantic classes 114 | 115 | class AdapterConfig(BaseModel): 116 | make: str = "openai" 117 | model: str = "ViT-L/14" 118 | base_model_kwargs: Optional[Dict[str, Any]] = None 119 | 120 | def create(self): 121 | if self.make == "openai": 122 | return OpenAIClipAdapter(self.model) 123 | elif self.make == "open_clip": 124 | pretrained = dict(list_pretrained()) 125 | checkpoint = pretrained[self.model] 126 | return OpenClipAdapter(name=self.model, pretrained=checkpoint) 127 | elif self.make == "x-clip": 128 | return XClipAdapter(XCLIP(**self.base_model_kwargs)) 129 | elif self.make == "coca": 130 | return CoCaAdapter(CoCa(**self.base_model_kwargs)) 131 | else: 132 | raise AttributeError("No adapter with that name is available.") 133 | 134 | class DiffusionPriorNetworkConfig(BaseModel): 135 | dim: int 136 | depth: int 137 | max_text_len: Optional[int] = None 138 | num_timesteps: Optional[int] = None 139 | num_time_embeds: int = 1 140 | num_image_embeds: int = 1 141 | num_text_embeds: int = 1 142 | dim_head: int = 64 143 | heads: int = 8 144 | ff_mult: int = 4 145 | norm_in: bool = False 146 | norm_out: bool = True 147 | attn_dropout: float = 0. 148 | ff_dropout: float = 0. 149 | final_proj: bool = True 150 | normformer: bool = False 151 | rotary_emb: bool = True 152 | 153 | class Config: 154 | extra = "allow" 155 | 156 | def create(self): 157 | kwargs = self.dict() 158 | return DiffusionPriorNetwork(**kwargs) 159 | 160 | class DiffusionPriorConfig(BaseModel): 161 | clip: Optional[AdapterConfig] = None 162 | net: DiffusionPriorNetworkConfig 163 | image_embed_dim: int 164 | image_size: int 165 | image_channels: int = 3 166 | timesteps: int = 1000 167 | sample_timesteps: Optional[int] = None 168 | cond_drop_prob: float = 0. 169 | loss_type: str = 'l2' 170 | predict_x_start: bool = True 171 | beta_schedule: str = 'cosine' 172 | condition_on_text_encodings: bool = True 173 | 174 | class Config: 175 | extra = "allow" 176 | 177 | def create(self): 178 | kwargs = self.dict() 179 | 180 | has_clip = exists(kwargs.pop('clip')) 181 | kwargs.pop('net') 182 | 183 | clip = None 184 | if has_clip: 185 | clip = self.clip.create() 186 | 187 | diffusion_prior_network = self.net.create() 188 | return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs) 189 | 190 | class DiffusionPriorTrainConfig(BaseModel): 191 | epochs: int = 1 192 | lr: float = 1.1e-4 193 | wd: float = 6.02e-2 194 | max_grad_norm: float = 0.5 195 | use_ema: bool = True 196 | ema_beta: float = 0.99 197 | amp: bool = False 198 | warmup_steps: Optional[int] = None # number of warmup steps 199 | save_every_seconds: int = 3600 # how often to save 200 | eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with 201 | best_validation_loss: float = 1e9 # the current best valudation loss observed 202 | current_epoch: int = 0 # the current epoch 203 | num_samples_seen: int = 0 # the current number of samples seen 204 | random_seed: int = 0 # manual seed for torch 205 | 206 | class DiffusionPriorDataConfig(BaseModel): 207 | image_url: str # path to embeddings folder 208 | meta_url: str # path to metadata (captions) for images 209 | splits: TrainSplitConfig # define train, validation, test splits for your dataset 210 | batch_size: int # per-gpu batch size used to train the model 211 | num_data_points: int = 25e7 # total number of datapoints to train on 212 | eval_every_seconds: int = 3600 # validation statistics will be performed this often 213 | 214 | class TrainDiffusionPriorConfig(BaseModel): 215 | prior: DiffusionPriorConfig 216 | data: DiffusionPriorDataConfig 217 | train: DiffusionPriorTrainConfig 218 | tracker: TrackerConfig 219 | 220 | @classmethod 221 | def from_json_path(cls, json_path): 222 | with open(json_path) as f: 223 | config = json.load(f) 224 | return cls(**config) 225 | 226 | # decoder pydantic classes 227 | 228 | class UnetConfig(BaseModel): 229 | dim: int 230 | dim_mults: ListOrTuple[int] 231 | image_embed_dim: Optional[int] = None 232 | text_embed_dim: Optional[int] = None 233 | cond_on_text_encodings: Optional[bool] = None 234 | cond_dim: Optional[int] = None 235 | channels: int = 3 236 | self_attn: SingularOrIterable[bool] = False 237 | attn_dim_head: int = 32 238 | attn_heads: int = 16 239 | init_cross_embed: bool = True 240 | 241 | class Config: 242 | extra = "allow" 243 | 244 | class DecoderConfig(BaseModel): 245 | unets: ListOrTuple[UnetConfig] 246 | image_size: Optional[int] = None 247 | image_sizes: ListOrTuple[int] = None 248 | clip: Optional[AdapterConfig] = None # The clip model to use if embeddings are not provided 249 | channels: int = 3 250 | timesteps: int = 1000 251 | sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None 252 | loss_type: str = 'l2' 253 | beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine 254 | learned_variance: SingularOrIterable[bool] = True 255 | image_cond_drop_prob: float = 0.1 256 | text_cond_drop_prob: float = 0.5 257 | 258 | def create(self): 259 | decoder_kwargs = self.dict() 260 | 261 | unet_configs = decoder_kwargs.pop('unets') 262 | unets = [Unet(**config) for config in unet_configs] 263 | 264 | has_clip = exists(decoder_kwargs.pop('clip')) 265 | clip = None 266 | if has_clip: 267 | clip = self.clip.create() 268 | 269 | return Decoder(unets, clip=clip, **decoder_kwargs) 270 | 271 | @validator('image_sizes') 272 | def check_image_sizes(cls, image_sizes, values): 273 | if exists(values.get('image_size')) ^ exists(image_sizes): 274 | return image_sizes 275 | raise ValueError('either image_size or image_sizes is required, but not both') 276 | 277 | class Config: 278 | extra = "allow" 279 | 280 | class DecoderDataConfig(BaseModel): 281 | webdataset_base_url: str # path to a webdataset with jpg images 282 | img_embeddings_url: Optional[str] = None # path to .npy files with embeddings 283 | text_embeddings_url: Optional[str] = None # path to .npy files with embeddings 284 | num_workers: int = 4 285 | batch_size: int = 64 286 | start_shard: int = 0 287 | end_shard: int = 9999999 288 | shard_width: int = 6 289 | index_width: int = 4 290 | splits: TrainSplitConfig 291 | shuffle_train: bool = True 292 | resample_train: bool = False 293 | preprocessing: Dict[str, Any] = {'ToTensor': True} 294 | 295 | @property 296 | def img_preproc(self): 297 | def _get_transformation(transformation_name, **kwargs): 298 | if transformation_name == "RandomResizedCrop": 299 | return T.RandomResizedCrop(**kwargs) 300 | elif transformation_name == "RandomHorizontalFlip": 301 | return T.RandomHorizontalFlip() 302 | elif transformation_name == "ToTensor": 303 | return T.ToTensor() 304 | 305 | transforms = [] 306 | for transform_name, transform_kwargs_or_bool in self.preprocessing.items(): 307 | transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool 308 | transforms.append(_get_transformation(transform_name, **transform_kwargs)) 309 | return T.Compose(transforms) 310 | 311 | class DecoderTrainConfig(BaseModel): 312 | epochs: int = 20 313 | lr: SingularOrIterable[float] = 1e-4 314 | wd: SingularOrIterable[float] = 0.01 315 | warmup_steps: Optional[SingularOrIterable[int]] = None 316 | find_unused_parameters: bool = True 317 | static_graph: bool = True 318 | max_grad_norm: SingularOrIterable[float] = 0.5 319 | save_every_n_samples: int = 100000 320 | n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset 321 | cond_scale: Union[float, List[float]] = 1.0 322 | device: str = 'cuda:0' 323 | epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. 324 | validation_samples: Optional[int] = None # Same as above but for validation. 325 | save_immediately: bool = False 326 | use_ema: bool = True 327 | ema_beta: float = 0.999 328 | amp: bool = False 329 | unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets 330 | 331 | class DecoderEvaluateConfig(BaseModel): 332 | n_evaluation_samples: int = 1000 333 | FID: Optional[Dict[str, Any]] = None 334 | IS: Optional[Dict[str, Any]] = None 335 | KID: Optional[Dict[str, Any]] = None 336 | LPIPS: Optional[Dict[str, Any]] = None 337 | 338 | class TrainDecoderConfig(BaseModel): 339 | decoder: DecoderConfig 340 | data: DecoderDataConfig 341 | train: DecoderTrainConfig 342 | evaluate: DecoderEvaluateConfig 343 | tracker: TrackerConfig 344 | seed: int = 0 345 | 346 | @classmethod 347 | def from_json_path(cls, json_path): 348 | with open(json_path) as f: 349 | config = json.load(f) 350 | print(config) 351 | return cls(**config) 352 | 353 | @model_validator(mode = 'after') 354 | def check_has_embeddings(self, m): 355 | # Makes sure that enough information is provided to get the embeddings specified for training 356 | values = dict(self) 357 | 358 | data_config, decoder_config = values.get('data'), values.get('decoder') 359 | 360 | if not exists(data_config) or not exists(decoder_config): 361 | # Then something else errored and we should just pass through 362 | return values 363 | 364 | using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets]) 365 | using_clip = exists(decoder_config.clip) 366 | img_emb_url = data_config.img_embeddings_url 367 | text_emb_url = data_config.text_embeddings_url 368 | 369 | if using_text_embeddings: 370 | # Then we need some way to get the embeddings 371 | assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided' 372 | 373 | if using_clip: 374 | if using_text_embeddings: 375 | assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings' 376 | else: 377 | assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings' 378 | 379 | if text_emb_url: 380 | assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason." 381 | 382 | return m 383 | -------------------------------------------------------------------------------- /dalle2_pytorch/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | from pathlib import Path 4 | from math import ceil 5 | from functools import partial, wraps 6 | from contextlib import nullcontext 7 | from collections.abc import Iterable 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR 13 | from torch.cuda.amp import autocast, GradScaler 14 | 15 | from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior 16 | from dalle2_pytorch.optimizer import get_optimizer 17 | from dalle2_pytorch.version import __version__ 18 | from packaging import version 19 | 20 | import pytorch_warmup as warmup 21 | 22 | from ema_pytorch import EMA 23 | 24 | from accelerate import Accelerator, DistributedType 25 | 26 | import numpy as np 27 | 28 | # helper functions 29 | 30 | def exists(val): 31 | return val is not None 32 | 33 | def default(val, d): 34 | if exists(val): 35 | return val 36 | return d() if callable(d) else d 37 | 38 | def cast_tuple(val, length = 1): 39 | return val if isinstance(val, tuple) else ((val,) * length) 40 | 41 | def pick_and_pop(keys, d): 42 | values = list(map(lambda key: d.pop(key), keys)) 43 | return dict(zip(keys, values)) 44 | 45 | def group_dict_by_key(cond, d): 46 | return_val = [dict(),dict()] 47 | for key in d.keys(): 48 | match = bool(cond(key)) 49 | ind = int(not match) 50 | return_val[ind][key] = d[key] 51 | return (*return_val,) 52 | 53 | def string_begins_with(prefix, str): 54 | return str.startswith(prefix) 55 | 56 | def group_by_key_prefix(prefix, d): 57 | return group_dict_by_key(partial(string_begins_with, prefix), d) 58 | 59 | def groupby_prefix_and_trim(prefix, d): 60 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 61 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 62 | return kwargs_without_prefix, kwargs 63 | 64 | def num_to_groups(num, divisor): 65 | groups = num // divisor 66 | remainder = num % divisor 67 | arr = [divisor] * groups 68 | if remainder > 0: 69 | arr.append(remainder) 70 | return arr 71 | 72 | # decorators 73 | 74 | def cast_torch_tensor(fn): 75 | @wraps(fn) 76 | def inner(model, *args, **kwargs): 77 | device = kwargs.pop('_device', next(model.parameters()).device) 78 | cast_device = kwargs.pop('_cast_device', True) 79 | cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True) 80 | 81 | kwargs_keys = kwargs.keys() 82 | all_args = (*args, *kwargs.values()) 83 | split_kwargs_index = len(all_args) - len(kwargs_keys) 84 | all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args)) 85 | 86 | if cast_device: 87 | all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) 88 | 89 | if cast_deepspeed_precision: 90 | try: 91 | accelerator = model.accelerator 92 | if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED: 93 | cast_type_map = { 94 | "fp16": torch.half, 95 | "bf16": torch.bfloat16, 96 | "no": torch.float 97 | } 98 | precision_type = cast_type_map[accelerator.mixed_precision] 99 | all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) 100 | except AttributeError: 101 | # Then this model doesn't have an accelerator 102 | pass 103 | 104 | args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] 105 | kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) 106 | 107 | out = fn(model, *args, **kwargs) 108 | return out 109 | return inner 110 | 111 | # gradient accumulation functions 112 | 113 | def split_iterable(it, split_size): 114 | accum = [] 115 | for ind in range(ceil(len(it) / split_size)): 116 | start_index = ind * split_size 117 | accum.append(it[start_index: (start_index + split_size)]) 118 | return accum 119 | 120 | def split(t, split_size = None): 121 | if not exists(split_size): 122 | return t 123 | 124 | if isinstance(t, torch.Tensor): 125 | return t.split(split_size, dim = 0) 126 | 127 | if isinstance(t, Iterable): 128 | return split_iterable(t, split_size) 129 | 130 | return TypeError 131 | 132 | def find_first(cond, arr): 133 | for el in arr: 134 | if cond(el): 135 | return el 136 | return None 137 | 138 | def split_args_and_kwargs(*args, split_size = None, **kwargs): 139 | all_args = (*args, *kwargs.values()) 140 | len_all_args = len(all_args) 141 | first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args) 142 | assert exists(first_tensor) 143 | 144 | batch_size = len(first_tensor) 145 | split_size = default(split_size, batch_size) 146 | num_chunks = ceil(batch_size / split_size) 147 | 148 | dict_len = len(kwargs) 149 | dict_keys = kwargs.keys() 150 | split_kwargs_index = len_all_args - dict_len 151 | 152 | split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args] 153 | chunk_sizes = tuple(map(len, split_all_args[0])) 154 | 155 | for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): 156 | chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:] 157 | chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values))) 158 | chunk_size_frac = chunk_size / batch_size 159 | yield chunk_size_frac, (chunked_args, chunked_kwargs) 160 | 161 | # diffusion prior trainer 162 | 163 | def prior_sample_in_chunks(fn): 164 | @wraps(fn) 165 | def inner(self, *args, max_batch_size = None, **kwargs): 166 | if not exists(max_batch_size): 167 | return fn(self, *args, **kwargs) 168 | 169 | outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)] 170 | return torch.cat(outputs, dim = 0) 171 | return inner 172 | 173 | class DiffusionPriorTrainer(nn.Module): 174 | def __init__( 175 | self, 176 | diffusion_prior, 177 | accelerator = None, 178 | use_ema = True, 179 | lr = 3e-4, 180 | wd = 1e-2, 181 | eps = 1e-6, 182 | max_grad_norm = None, 183 | group_wd_params = True, 184 | warmup_steps = None, 185 | cosine_decay_max_steps = None, 186 | **kwargs 187 | ): 188 | super().__init__() 189 | assert isinstance(diffusion_prior, DiffusionPrior) 190 | 191 | ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) 192 | accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs) 193 | 194 | if not exists(accelerator): 195 | accelerator = Accelerator(**accelerator_kwargs) 196 | 197 | # assign some helpful member vars 198 | 199 | self.accelerator = accelerator 200 | self.text_conditioned = diffusion_prior.condition_on_text_encodings 201 | 202 | # setting the device 203 | 204 | self.device = accelerator.device 205 | diffusion_prior.to(self.device) 206 | 207 | # save model 208 | 209 | self.diffusion_prior = diffusion_prior 210 | 211 | # mixed precision checks 212 | 213 | if ( 214 | exists(self.accelerator) 215 | and self.accelerator.distributed_type == DistributedType.DEEPSPEED 216 | and self.diffusion_prior.clip is not None 217 | ): 218 | # Then we need to make sure clip is using the correct precision or else deepspeed will error 219 | cast_type_map = { 220 | "fp16": torch.half, 221 | "bf16": torch.bfloat16, 222 | "no": torch.float 223 | } 224 | precision_type = cast_type_map[accelerator.mixed_precision] 225 | assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip" 226 | self.diffusion_prior.clip.to(precision_type) 227 | 228 | # optimizer stuff 229 | 230 | self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params) 231 | 232 | self.optimizer = get_optimizer( 233 | self.diffusion_prior.parameters(), 234 | **self.optim_kwargs, 235 | **kwargs 236 | ) 237 | 238 | if exists(cosine_decay_max_steps): 239 | self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps) 240 | else: 241 | self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0) 242 | 243 | self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None 244 | 245 | # distribute the model if using HFA 246 | 247 | self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler) 248 | 249 | # exponential moving average stuff 250 | 251 | self.use_ema = use_ema 252 | 253 | if self.use_ema: 254 | self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs) 255 | 256 | # gradient clipping if needed 257 | 258 | self.max_grad_norm = max_grad_norm 259 | 260 | # track steps internally 261 | 262 | self.register_buffer('step', torch.tensor([0], device = self.device)) 263 | 264 | # utility 265 | 266 | def save(self, path, overwrite = True, **kwargs): 267 | 268 | # only save on the main process 269 | if self.accelerator.is_main_process: 270 | print(f"Saving checkpoint at step: {self.step.item()}") 271 | path = Path(path) 272 | assert not (path.exists() and not overwrite) 273 | path.parent.mkdir(parents = True, exist_ok = True) 274 | 275 | # FIXME: LambdaLR can't be saved due to pickling issues 276 | save_obj = dict( 277 | optimizer = self.optimizer.state_dict(), 278 | scheduler = self.scheduler.state_dict(), 279 | warmup_scheduler = self.warmup_scheduler, 280 | model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(), 281 | version = version.parse(__version__), 282 | step = self.step, 283 | **kwargs 284 | ) 285 | 286 | if self.use_ema: 287 | save_obj = { 288 | **save_obj, 289 | 'ema': self.ema_diffusion_prior.state_dict(), 290 | 'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload 291 | } 292 | 293 | torch.save(save_obj, str(path)) 294 | 295 | def load(self, path_or_state, overwrite_lr = True, strict = True): 296 | """ 297 | Load a checkpoint of a diffusion prior trainer. 298 | 299 | Will load the entire trainer, including the optimizer and EMA. 300 | 301 | Params: 302 | - path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file 303 | - overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer 304 | - strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match 305 | 306 | Returns: 307 | loaded_obj (dict): The loaded checkpoint dictionary 308 | """ 309 | 310 | # all processes need to load checkpoint. no restriction here 311 | if isinstance(path_or_state, str): 312 | path = Path(path_or_state) 313 | assert path.exists() 314 | loaded_obj = torch.load(str(path), map_location=self.device) 315 | 316 | elif isinstance(path_or_state, dict): 317 | loaded_obj = path_or_state 318 | 319 | if version.parse(__version__) != loaded_obj['version']: 320 | print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}') 321 | 322 | # unwrap the model when loading from checkpoint 323 | self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) 324 | self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device)) 325 | 326 | self.optimizer.load_state_dict(loaded_obj['optimizer']) 327 | self.scheduler.load_state_dict(loaded_obj['scheduler']) 328 | 329 | # set warmupstep 330 | if exists(self.warmup_scheduler): 331 | self.warmup_scheduler.last_step = self.step.item() 332 | 333 | # ensure new lr is used if different from old one 334 | if overwrite_lr: 335 | new_lr = self.optim_kwargs["lr"] 336 | 337 | for group in self.optimizer.param_groups: 338 | group["lr"] = new_lr if group["lr"] > 0.0 else 0.0 339 | 340 | if self.use_ema: 341 | assert 'ema' in loaded_obj 342 | self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict) 343 | # below might not be necessary, but I had a suspicion that this wasn't being loaded correctly 344 | self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"]) 345 | 346 | return loaded_obj 347 | 348 | # model functionality 349 | 350 | def update(self): 351 | 352 | if exists(self.max_grad_norm): 353 | self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm) 354 | 355 | self.optimizer.step() 356 | self.optimizer.zero_grad() 357 | 358 | # accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy" 359 | if not self.accelerator.optimizer_step_was_skipped: 360 | sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext 361 | with sched_context(): 362 | self.scheduler.step() 363 | 364 | if self.use_ema: 365 | self.ema_diffusion_prior.update() 366 | 367 | self.step += 1 368 | 369 | @torch.no_grad() 370 | @cast_torch_tensor 371 | @prior_sample_in_chunks 372 | def p_sample_loop(self, *args, **kwargs): 373 | model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior 374 | return model.p_sample_loop(*args, **kwargs) 375 | 376 | @torch.no_grad() 377 | @cast_torch_tensor 378 | @prior_sample_in_chunks 379 | def sample(self, *args, **kwargs): 380 | model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior 381 | return model.sample(*args, **kwargs) 382 | 383 | @torch.no_grad() 384 | def sample_batch_size(self, *args, **kwargs): 385 | model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior 386 | return model.sample_batch_size(*args, **kwargs) 387 | 388 | @torch.no_grad() 389 | @cast_torch_tensor 390 | @prior_sample_in_chunks 391 | def embed_text(self, *args, **kwargs): 392 | return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs) 393 | 394 | @cast_torch_tensor 395 | def forward( 396 | self, 397 | *args, 398 | max_batch_size = None, 399 | **kwargs 400 | ): 401 | total_loss = 0. 402 | 403 | for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): 404 | with self.accelerator.autocast(): 405 | loss = self.diffusion_prior(*chunked_args, **chunked_kwargs) 406 | loss = loss * chunk_size_frac 407 | 408 | total_loss += loss.item() 409 | 410 | if self.training: 411 | self.accelerator.backward(loss) 412 | 413 | return total_loss 414 | 415 | # decoder trainer 416 | 417 | def decoder_sample_in_chunks(fn): 418 | @wraps(fn) 419 | def inner(self, *args, max_batch_size = None, **kwargs): 420 | if not exists(max_batch_size): 421 | return fn(self, *args, **kwargs) 422 | 423 | if self.decoder.unconditional: 424 | batch_size = kwargs.get('batch_size') 425 | batch_sizes = num_to_groups(batch_size, max_batch_size) 426 | outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes] 427 | else: 428 | outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)] 429 | 430 | return torch.cat(outputs, dim = 0) 431 | return inner 432 | 433 | class DecoderTrainer(nn.Module): 434 | def __init__( 435 | self, 436 | decoder, 437 | accelerator = None, 438 | dataloaders = None, 439 | use_ema = True, 440 | lr = 1e-4, 441 | wd = 1e-2, 442 | eps = 1e-8, 443 | warmup_steps = None, 444 | cosine_decay_max_steps = None, 445 | max_grad_norm = 0.5, 446 | amp = False, 447 | group_wd_params = True, 448 | **kwargs 449 | ): 450 | super().__init__() 451 | assert isinstance(decoder, Decoder) 452 | ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) 453 | 454 | self.accelerator = default(accelerator, Accelerator) 455 | 456 | self.num_unets = len(decoder.unets) 457 | 458 | self.use_ema = use_ema 459 | self.ema_unets = nn.ModuleList([]) 460 | 461 | self.amp = amp 462 | 463 | # be able to finely customize learning rate, weight decay 464 | # per unet 465 | 466 | lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps)) 467 | 468 | assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4' 469 | 470 | optimizers = [] 471 | schedulers = [] 472 | warmup_schedulers = [] 473 | 474 | for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps): 475 | if isinstance(unet, nn.Identity): 476 | optimizers.append(None) 477 | schedulers.append(None) 478 | warmup_schedulers.append(None) 479 | else: 480 | optimizer = get_optimizer( 481 | unet.parameters(), 482 | lr = unet_lr, 483 | wd = unet_wd, 484 | eps = unet_eps, 485 | group_wd_params = group_wd_params, 486 | **kwargs 487 | ) 488 | 489 | optimizers.append(optimizer) 490 | 491 | if exists(unet_cosine_decay_max_steps): 492 | scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) 493 | else: 494 | scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) 495 | 496 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None 497 | warmup_schedulers.append(warmup_scheduler) 498 | 499 | schedulers.append(scheduler) 500 | 501 | if self.use_ema: 502 | self.ema_unets.append(EMA(unet, **ema_kwargs)) 503 | 504 | # gradient clipping if needed 505 | 506 | self.max_grad_norm = max_grad_norm 507 | 508 | self.register_buffer('steps', torch.tensor([0] * self.num_unets)) 509 | 510 | if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None: 511 | # Then we need to make sure clip is using the correct precision or else deepspeed will error 512 | cast_type_map = { 513 | "fp16": torch.half, 514 | "bf16": torch.bfloat16, 515 | "no": torch.float 516 | } 517 | precision_type = cast_type_map[accelerator.mixed_precision] 518 | assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip" 519 | clip = decoder.clip 520 | clip.to(precision_type) 521 | 522 | decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) 523 | 524 | self.decoder = decoder 525 | 526 | # prepare dataloaders 527 | 528 | train_loader = val_loader = None 529 | if exists(dataloaders): 530 | train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"]) 531 | 532 | self.train_loader = train_loader 533 | self.val_loader = val_loader 534 | 535 | # store optimizers 536 | 537 | for opt_ind, optimizer in zip(range(len(optimizers)), optimizers): 538 | setattr(self, f'optim{opt_ind}', optimizer) 539 | 540 | # store schedulers 541 | 542 | for sched_ind, scheduler in zip(range(len(schedulers)), schedulers): 543 | setattr(self, f'sched{sched_ind}', scheduler) 544 | 545 | # store warmup schedulers 546 | 547 | self.warmup_schedulers = warmup_schedulers 548 | 549 | def validate_and_return_unet_number(self, unet_number = None): 550 | if self.num_unets == 1: 551 | unet_number = default(unet_number, 1) 552 | 553 | assert exists(unet_number) and 1 <= unet_number <= self.num_unets 554 | return unet_number 555 | 556 | def num_steps_taken(self, unet_number = None): 557 | unet_number = self.validate_and_return_unet_number(unet_number) 558 | return self.steps[unet_number - 1].item() 559 | 560 | def save(self, path, overwrite = True, **kwargs): 561 | path = Path(path) 562 | assert not (path.exists() and not overwrite) 563 | path.parent.mkdir(parents = True, exist_ok = True) 564 | 565 | save_obj = dict( 566 | model = self.accelerator.unwrap_model(self.decoder).state_dict(), 567 | version = __version__, 568 | steps = self.steps.cpu(), 569 | **kwargs 570 | ) 571 | 572 | for ind in range(0, self.num_unets): 573 | optimizer_key = f'optim{ind}' 574 | scheduler_key = f'sched{ind}' 575 | 576 | optimizer = getattr(self, optimizer_key) 577 | scheduler = getattr(self, scheduler_key) 578 | 579 | optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None 580 | scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None 581 | 582 | save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict} 583 | 584 | if self.use_ema: 585 | save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} 586 | 587 | self.accelerator.save(save_obj, str(path)) 588 | 589 | def load_state_dict(self, loaded_obj, only_model = False, strict = True): 590 | if version.parse(__version__) != version.parse(loaded_obj['version']): 591 | self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}') 592 | 593 | self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict) 594 | self.steps.copy_(loaded_obj['steps']) 595 | 596 | if only_model: 597 | return loaded_obj 598 | 599 | for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()): 600 | 601 | optimizer_key = f'optim{ind}' 602 | optimizer = getattr(self, optimizer_key) 603 | 604 | scheduler_key = f'sched{ind}' 605 | scheduler = getattr(self, scheduler_key) 606 | 607 | warmup_scheduler = self.warmup_schedulers[ind] 608 | 609 | if exists(optimizer): 610 | optimizer.load_state_dict(loaded_obj[optimizer_key]) 611 | 612 | if exists(scheduler): 613 | scheduler.load_state_dict(loaded_obj[scheduler_key]) 614 | 615 | if exists(warmup_scheduler): 616 | warmup_scheduler.last_step = last_step 617 | 618 | if self.use_ema: 619 | assert 'ema' in loaded_obj 620 | self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) 621 | 622 | def load(self, path, only_model = False, strict = True): 623 | path = Path(path) 624 | assert path.exists() 625 | 626 | loaded_obj = torch.load(str(path), map_location = 'cpu') 627 | 628 | self.load_state_dict(loaded_obj, only_model = only_model, strict = strict) 629 | 630 | return loaded_obj 631 | 632 | @property 633 | def unets(self): 634 | return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) 635 | 636 | def increment_step(self, unet_number): 637 | assert 1 <= unet_number <= self.num_unets 638 | 639 | unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device) 640 | self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps)) 641 | 642 | def update(self, unet_number = None): 643 | unet_number = self.validate_and_return_unet_number(unet_number) 644 | index = unet_number - 1 645 | 646 | optimizer = getattr(self, f'optim{index}') 647 | scheduler = getattr(self, f'sched{index}') 648 | 649 | if exists(self.max_grad_norm): 650 | self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients 651 | 652 | optimizer.step() 653 | optimizer.zero_grad() 654 | 655 | warmup_scheduler = self.warmup_schedulers[index] 656 | scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext 657 | 658 | with scheduler_context(): 659 | scheduler.step() 660 | 661 | if self.use_ema: 662 | ema_unet = self.ema_unets[index] 663 | ema_unet.update() 664 | 665 | self.increment_step(unet_number) 666 | 667 | @torch.no_grad() 668 | @cast_torch_tensor 669 | @decoder_sample_in_chunks 670 | def sample(self, *args, **kwargs): 671 | distributed = self.accelerator.num_processes > 1 672 | base_decoder = self.accelerator.unwrap_model(self.decoder) 673 | 674 | was_training = base_decoder.training 675 | base_decoder.eval() 676 | 677 | if kwargs.pop('use_non_ema', False) or not self.use_ema: 678 | out = base_decoder.sample(*args, **kwargs, distributed = distributed) 679 | base_decoder.train(was_training) 680 | return out 681 | 682 | trainable_unets = self.accelerator.unwrap_model(self.decoder).unets 683 | base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling 684 | 685 | output = base_decoder.sample(*args, **kwargs, distributed = distributed) 686 | 687 | base_decoder.unets = trainable_unets # restore original training unets 688 | 689 | # cast the ema_model unets back to original device 690 | for ema in self.ema_unets: 691 | ema.restore_ema_model_device() 692 | 693 | base_decoder.train(was_training) 694 | return output 695 | 696 | @torch.no_grad() 697 | @cast_torch_tensor 698 | @prior_sample_in_chunks 699 | def embed_text(self, *args, **kwargs): 700 | return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs) 701 | 702 | @torch.no_grad() 703 | @cast_torch_tensor 704 | @prior_sample_in_chunks 705 | def embed_image(self, *args, **kwargs): 706 | return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs) 707 | 708 | @cast_torch_tensor 709 | def forward( 710 | self, 711 | *args, 712 | unet_number = None, 713 | max_batch_size = None, 714 | return_lowres_cond_image=False, 715 | **kwargs 716 | ): 717 | unet_number = self.validate_and_return_unet_number(unet_number) 718 | 719 | total_loss = 0. 720 | cond_images = [] 721 | for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): 722 | with self.accelerator.autocast(): 723 | loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs) 724 | # loss_obj may be a tuple with loss and cond_image 725 | if return_lowres_cond_image: 726 | loss, cond_image = loss_obj 727 | else: 728 | loss = loss_obj 729 | cond_image = None 730 | loss = loss * chunk_size_frac 731 | if cond_image is not None: 732 | cond_images.append(cond_image) 733 | 734 | total_loss += loss.item() 735 | 736 | if self.training: 737 | self.accelerator.backward(loss) 738 | 739 | if return_lowres_cond_image: 740 | return total_loss, torch.stack(cond_images) 741 | else: 742 | return total_loss 743 | -------------------------------------------------------------------------------- /dalle2_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import importlib 3 | 4 | # helper functions 5 | 6 | def exists(val): 7 | return val is not None 8 | 9 | # time helpers 10 | 11 | class Timer: 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.last_time = time.time() 17 | 18 | def elapsed(self): 19 | return time.time() - self.last_time 20 | 21 | # print helpers 22 | 23 | def print_ribbon(s, symbol = '=', repeat = 40): 24 | flank = symbol * repeat 25 | return f'{flank} {s} {flank}' 26 | 27 | # import helpers 28 | 29 | def import_or_print_error(pkg_name, err_str = None): 30 | try: 31 | return importlib.import_module(pkg_name) 32 | except ModuleNotFoundError as e: 33 | if exists(err_str): 34 | print(err_str) 35 | exit() 36 | -------------------------------------------------------------------------------- /dalle2_pytorch/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.15.6' 2 | -------------------------------------------------------------------------------- /dalle2_pytorch/vqgan_vae.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | from math import sqrt 4 | from functools import partial, wraps 5 | 6 | from vector_quantize_pytorch import VectorQuantize as VQ 7 | 8 | import torch 9 | from torch import nn, einsum 10 | import torch.nn.functional as F 11 | from torch.autograd import grad as torch_grad 12 | import torchvision 13 | 14 | from einops import rearrange, reduce, repeat, pack, unpack 15 | from einops.layers.torch import Rearrange 16 | 17 | # constants 18 | 19 | MList = nn.ModuleList 20 | 21 | # helper functions 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | def default(val, d): 27 | return val if exists(val) else d 28 | 29 | # decorators 30 | 31 | def eval_decorator(fn): 32 | def inner(model, *args, **kwargs): 33 | was_training = model.training 34 | model.eval() 35 | out = fn(model, *args, **kwargs) 36 | model.train(was_training) 37 | return out 38 | return inner 39 | 40 | def remove_vgg(fn): 41 | @wraps(fn) 42 | def inner(self, *args, **kwargs): 43 | has_vgg = hasattr(self, 'vgg') 44 | if has_vgg: 45 | vgg = self.vgg 46 | delattr(self, 'vgg') 47 | 48 | out = fn(self, *args, **kwargs) 49 | 50 | if has_vgg: 51 | self.vgg = vgg 52 | 53 | return out 54 | return inner 55 | 56 | # keyword argument helpers 57 | 58 | def pick_and_pop(keys, d): 59 | values = list(map(lambda key: d.pop(key), keys)) 60 | return dict(zip(keys, values)) 61 | 62 | def group_dict_by_key(cond, d): 63 | return_val = [dict(),dict()] 64 | for key in d.keys(): 65 | match = bool(cond(key)) 66 | ind = int(not match) 67 | return_val[ind][key] = d[key] 68 | return (*return_val,) 69 | 70 | def string_begins_with(prefix, string_input): 71 | return string_input.startswith(prefix) 72 | 73 | def group_by_key_prefix(prefix, d): 74 | return group_dict_by_key(partial(string_begins_with, prefix), d) 75 | 76 | def groupby_prefix_and_trim(prefix, d): 77 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 78 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 79 | return kwargs_without_prefix, kwargs 80 | 81 | # tensor helper functions 82 | 83 | def log(t, eps = 1e-10): 84 | return torch.log(t + eps) 85 | 86 | def gradient_penalty(images, output, weight = 10): 87 | batch_size = images.shape[0] 88 | gradients = torch_grad(outputs = output, inputs = images, 89 | grad_outputs = torch.ones(output.size(), device = images.device), 90 | create_graph = True, retain_graph = True, only_inputs = True)[0] 91 | 92 | gradients = rearrange(gradients, 'b ... -> b (...)') 93 | return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean() 94 | 95 | def l2norm(t): 96 | return F.normalize(t, dim = -1) 97 | 98 | def leaky_relu(p = 0.1): 99 | return nn.LeakyReLU(0.1) 100 | 101 | def stable_softmax(t, dim = -1, alpha = 32 ** 2): 102 | t = t / alpha 103 | t = t - torch.amax(t, dim = dim, keepdim = True).detach() 104 | return (t * alpha).softmax(dim = dim) 105 | 106 | def safe_div(numer, denom, eps = 1e-8): 107 | return numer / (denom + eps) 108 | 109 | # gan losses 110 | 111 | def hinge_discr_loss(fake, real): 112 | return (F.relu(1 + fake) + F.relu(1 - real)).mean() 113 | 114 | def hinge_gen_loss(fake): 115 | return -fake.mean() 116 | 117 | def bce_discr_loss(fake, real): 118 | return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean() 119 | 120 | def bce_gen_loss(fake): 121 | return -log(torch.sigmoid(fake)).mean() 122 | 123 | def grad_layer_wrt_loss(loss, layer): 124 | return torch_grad( 125 | outputs = loss, 126 | inputs = layer, 127 | grad_outputs = torch.ones_like(loss), 128 | retain_graph = True 129 | )[0].detach() 130 | 131 | # vqgan vae 132 | 133 | class LayerNormChan(nn.Module): 134 | def __init__( 135 | self, 136 | dim, 137 | eps = 1e-5 138 | ): 139 | super().__init__() 140 | self.eps = eps 141 | self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1)) 142 | 143 | def forward(self, x): 144 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 145 | mean = torch.mean(x, dim = 1, keepdim = True) 146 | return (x - mean) / (var + self.eps).sqrt() * self.gamma 147 | 148 | # discriminator 149 | 150 | class Discriminator(nn.Module): 151 | def __init__( 152 | self, 153 | dims, 154 | channels = 3, 155 | groups = 16, 156 | init_kernel_size = 5 157 | ): 158 | super().__init__() 159 | dim_pairs = zip(dims[:-1], dims[1:]) 160 | 161 | self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())]) 162 | 163 | for dim_in, dim_out in dim_pairs: 164 | self.layers.append(nn.Sequential( 165 | nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), 166 | nn.GroupNorm(groups, dim_out), 167 | leaky_relu() 168 | )) 169 | 170 | dim = dims[-1] 171 | self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training 172 | nn.Conv2d(dim, dim, 1), 173 | leaky_relu(), 174 | nn.Conv2d(dim, 1, 4) 175 | ) 176 | 177 | def forward(self, x): 178 | for net in self.layers: 179 | x = net(x) 180 | 181 | return self.to_logits(x) 182 | 183 | # positional encoding 184 | 185 | class ContinuousPositionBias(nn.Module): 186 | """ from https://arxiv.org/abs/2111.09883 """ 187 | 188 | def __init__(self, *, dim, heads, layers = 2): 189 | super().__init__() 190 | self.net = MList([]) 191 | self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu())) 192 | 193 | for _ in range(layers - 1): 194 | self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu())) 195 | 196 | self.net.append(nn.Linear(dim, heads)) 197 | self.register_buffer('rel_pos', None, persistent = False) 198 | 199 | def forward(self, x): 200 | n, device = x.shape[-1], x.device 201 | fmap_size = int(sqrt(n)) 202 | 203 | if not exists(self.rel_pos): 204 | pos = torch.arange(fmap_size, device = device) 205 | grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij')) 206 | grid = rearrange(grid, 'c i j -> (i j) c') 207 | rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c') 208 | rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1) 209 | self.register_buffer('rel_pos', rel_pos, persistent = False) 210 | 211 | rel_pos = self.rel_pos.float() 212 | 213 | for layer in self.net: 214 | rel_pos = layer(rel_pos) 215 | 216 | bias = rearrange(rel_pos, 'i j h -> h i j') 217 | return x + bias 218 | 219 | # resnet encoder / decoder 220 | 221 | class ResnetEncDec(nn.Module): 222 | def __init__( 223 | self, 224 | dim, 225 | *, 226 | channels = 3, 227 | layers = 4, 228 | layer_mults = None, 229 | num_resnet_blocks = 1, 230 | resnet_groups = 16, 231 | first_conv_kernel_size = 5, 232 | use_attn = True, 233 | attn_dim_head = 64, 234 | attn_heads = 8, 235 | attn_dropout = 0., 236 | ): 237 | super().__init__() 238 | assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)' 239 | 240 | self.layers = layers 241 | 242 | self.encoders = MList([]) 243 | self.decoders = MList([]) 244 | 245 | layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers)))) 246 | assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers' 247 | 248 | layer_dims = [dim * mult for mult in layer_mults] 249 | dims = (dim, *layer_dims) 250 | 251 | self.encoded_dim = dims[-1] 252 | 253 | dim_pairs = zip(dims[:-1], dims[1:]) 254 | 255 | append = lambda arr, t: arr.append(t) 256 | prepend = lambda arr, t: arr.insert(0, t) 257 | 258 | if not isinstance(num_resnet_blocks, tuple): 259 | num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks) 260 | 261 | if not isinstance(use_attn, tuple): 262 | use_attn = (*((False,) * (layers - 1)), use_attn) 263 | 264 | assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers' 265 | assert len(use_attn) == layers 266 | 267 | for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn): 268 | append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu())) 269 | prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu())) 270 | 271 | if layer_use_attn: 272 | prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) 273 | 274 | for _ in range(layer_num_resnet_blocks): 275 | append(self.encoders, ResBlock(dim_out, groups = resnet_groups)) 276 | prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups)) 277 | 278 | if layer_use_attn: 279 | append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) 280 | 281 | prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2)) 282 | append(self.decoders, nn.Conv2d(dim, channels, 1)) 283 | 284 | def get_encoded_fmap_size(self, image_size): 285 | return image_size // (2 ** self.layers) 286 | 287 | @property 288 | def last_dec_layer(self): 289 | return self.decoders[-1].weight 290 | 291 | def encode(self, x): 292 | for enc in self.encoders: 293 | x = enc(x) 294 | return x 295 | 296 | def decode(self, x): 297 | for dec in self.decoders: 298 | x = dec(x) 299 | return x 300 | 301 | class GLUResBlock(nn.Module): 302 | def __init__(self, chan, groups = 16): 303 | super().__init__() 304 | self.net = nn.Sequential( 305 | nn.Conv2d(chan, chan * 2, 3, padding = 1), 306 | nn.GLU(dim = 1), 307 | nn.GroupNorm(groups, chan), 308 | nn.Conv2d(chan, chan * 2, 3, padding = 1), 309 | nn.GLU(dim = 1), 310 | nn.GroupNorm(groups, chan), 311 | nn.Conv2d(chan, chan, 1) 312 | ) 313 | 314 | def forward(self, x): 315 | return self.net(x) + x 316 | 317 | class ResBlock(nn.Module): 318 | def __init__(self, chan, groups = 16): 319 | super().__init__() 320 | self.net = nn.Sequential( 321 | nn.Conv2d(chan, chan, 3, padding = 1), 322 | nn.GroupNorm(groups, chan), 323 | leaky_relu(), 324 | nn.Conv2d(chan, chan, 3, padding = 1), 325 | nn.GroupNorm(groups, chan), 326 | leaky_relu(), 327 | nn.Conv2d(chan, chan, 1) 328 | ) 329 | 330 | def forward(self, x): 331 | return self.net(x) + x 332 | 333 | # vqgan attention layer 334 | 335 | class VQGanAttention(nn.Module): 336 | def __init__( 337 | self, 338 | *, 339 | dim, 340 | dim_head = 64, 341 | heads = 8, 342 | dropout = 0. 343 | ): 344 | super().__init__() 345 | self.heads = heads 346 | self.scale = dim_head ** -0.5 347 | inner_dim = heads * dim_head 348 | 349 | self.dropout = nn.Dropout(dropout) 350 | self.pre_norm = LayerNormChan(dim) 351 | 352 | self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads) 353 | self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False) 354 | self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False) 355 | 356 | def forward(self, x): 357 | h = self.heads 358 | height, width, residual = *x.shape[-2:], x.clone() 359 | 360 | x = self.pre_norm(x) 361 | 362 | q, k, v = self.to_qkv(x).chunk(3, dim = 1) 363 | 364 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v)) 365 | 366 | sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale 367 | 368 | sim = self.cpb(sim) 369 | 370 | attn = stable_softmax(sim, dim = -1) 371 | attn = self.dropout(attn) 372 | 373 | out = einsum('b h i j, b h c j -> b h c i', attn, v) 374 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width) 375 | out = self.to_out(out) 376 | 377 | return out + residual 378 | 379 | # ViT encoder / decoder 380 | 381 | class RearrangeImage(nn.Module): 382 | def forward(self, x): 383 | n = x.shape[1] 384 | w = h = int(sqrt(n)) 385 | return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w) 386 | 387 | class Attention(nn.Module): 388 | def __init__( 389 | self, 390 | dim, 391 | *, 392 | heads = 8, 393 | dim_head = 32 394 | ): 395 | super().__init__() 396 | self.norm = nn.LayerNorm(dim) 397 | self.heads = heads 398 | self.scale = dim_head ** -0.5 399 | inner_dim = dim_head * heads 400 | 401 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 402 | self.to_out = nn.Linear(inner_dim, dim) 403 | 404 | def forward(self, x): 405 | h = self.heads 406 | 407 | x = self.norm(x) 408 | 409 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 410 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 411 | 412 | q = q * self.scale 413 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 414 | 415 | sim = sim - sim.amax(dim = -1, keepdim = True).detach() 416 | attn = sim.softmax(dim = -1) 417 | 418 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 419 | 420 | out = rearrange(out, 'b h n d -> b n (h d)') 421 | return self.to_out(out) 422 | 423 | def FeedForward(dim, mult = 4): 424 | return nn.Sequential( 425 | nn.LayerNorm(dim), 426 | nn.Linear(dim, dim * mult, bias = False), 427 | nn.GELU(), 428 | nn.Linear(dim * mult, dim, bias = False) 429 | ) 430 | 431 | class Transformer(nn.Module): 432 | def __init__( 433 | self, 434 | dim, 435 | *, 436 | layers, 437 | dim_head = 32, 438 | heads = 8, 439 | ff_mult = 4 440 | ): 441 | super().__init__() 442 | self.layers = nn.ModuleList([]) 443 | for _ in range(layers): 444 | self.layers.append(nn.ModuleList([ 445 | Attention(dim = dim, dim_head = dim_head, heads = heads), 446 | FeedForward(dim = dim, mult = ff_mult) 447 | ])) 448 | 449 | self.norm = nn.LayerNorm(dim) 450 | 451 | def forward(self, x): 452 | for attn, ff in self.layers: 453 | x = attn(x) + x 454 | x = ff(x) + x 455 | 456 | return self.norm(x) 457 | 458 | class ViTEncDec(nn.Module): 459 | def __init__( 460 | self, 461 | dim, 462 | channels = 3, 463 | layers = 4, 464 | patch_size = 8, 465 | dim_head = 32, 466 | heads = 8, 467 | ff_mult = 4 468 | ): 469 | super().__init__() 470 | self.encoded_dim = dim 471 | self.patch_size = patch_size 472 | 473 | input_dim = channels * (patch_size ** 2) 474 | 475 | self.encoder = nn.Sequential( 476 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 477 | nn.Linear(input_dim, dim), 478 | Transformer( 479 | dim = dim, 480 | dim_head = dim_head, 481 | heads = heads, 482 | ff_mult = ff_mult, 483 | layers = layers 484 | ), 485 | RearrangeImage(), 486 | Rearrange('b h w c -> b c h w') 487 | ) 488 | 489 | self.decoder = nn.Sequential( 490 | Rearrange('b c h w -> b (h w) c'), 491 | Transformer( 492 | dim = dim, 493 | dim_head = dim_head, 494 | heads = heads, 495 | ff_mult = ff_mult, 496 | layers = layers 497 | ), 498 | nn.Sequential( 499 | nn.Linear(dim, dim * 4, bias = False), 500 | nn.Tanh(), 501 | nn.Linear(dim * 4, input_dim, bias = False), 502 | ), 503 | RearrangeImage(), 504 | Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size) 505 | ) 506 | 507 | def get_encoded_fmap_size(self, image_size): 508 | return image_size // self.patch_size 509 | 510 | @property 511 | def last_dec_layer(self): 512 | return self.decoder[-3][-1].weight 513 | 514 | def encode(self, x): 515 | return self.encoder(x) 516 | 517 | def decode(self, x): 518 | return self.decoder(x) 519 | 520 | # main vqgan-vae classes 521 | 522 | class NullVQGanVAE(nn.Module): 523 | def __init__( 524 | self, 525 | *, 526 | channels 527 | ): 528 | super().__init__() 529 | self.encoded_dim = channels 530 | self.layers = 0 531 | 532 | def get_encoded_fmap_size(self, size): 533 | return size 534 | 535 | def copy_for_eval(self): 536 | return self 537 | 538 | def encode(self, x): 539 | return x 540 | 541 | def decode(self, x): 542 | return x 543 | 544 | class VQGanVAE(nn.Module): 545 | def __init__( 546 | self, 547 | *, 548 | dim, 549 | image_size, 550 | channels = 3, 551 | layers = 4, 552 | l2_recon_loss = False, 553 | use_hinge_loss = True, 554 | vgg = None, 555 | vq_codebook_dim = 256, 556 | vq_codebook_size = 512, 557 | vq_decay = 0.8, 558 | vq_commitment_weight = 1., 559 | vq_kmeans_init = True, 560 | vq_use_cosine_sim = True, 561 | use_vgg_and_gan = True, 562 | vae_type = 'resnet', 563 | discr_layers = 4, 564 | **kwargs 565 | ): 566 | super().__init__() 567 | vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs) 568 | encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs) 569 | 570 | self.image_size = image_size 571 | self.channels = channels 572 | self.codebook_size = vq_codebook_size 573 | 574 | if vae_type == 'resnet': 575 | enc_dec_klass = ResnetEncDec 576 | elif vae_type == 'vit': 577 | enc_dec_klass = ViTEncDec 578 | else: 579 | raise ValueError(f'{vae_type} not valid') 580 | 581 | self.enc_dec = enc_dec_klass( 582 | dim = dim, 583 | channels = channels, 584 | layers = layers, 585 | **encdec_kwargs 586 | ) 587 | 588 | self.vq = VQ( 589 | dim = self.enc_dec.encoded_dim, 590 | codebook_dim = vq_codebook_dim, 591 | codebook_size = vq_codebook_size, 592 | decay = vq_decay, 593 | commitment_weight = vq_commitment_weight, 594 | accept_image_fmap = True, 595 | kmeans_init = vq_kmeans_init, 596 | use_cosine_sim = vq_use_cosine_sim, 597 | **vq_kwargs 598 | ) 599 | 600 | # reconstruction loss 601 | 602 | self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss 603 | 604 | # turn off GAN and perceptual loss if grayscale 605 | 606 | self.vgg = None 607 | self.discr = None 608 | self.use_vgg_and_gan = use_vgg_and_gan 609 | 610 | if not use_vgg_and_gan: 611 | return 612 | 613 | # preceptual loss 614 | 615 | if exists(vgg): 616 | self.vgg = vgg 617 | else: 618 | self.vgg = torchvision.models.vgg16(pretrained = True) 619 | self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2]) 620 | 621 | # gan related losses 622 | 623 | layer_mults = list(map(lambda t: 2 ** t, range(discr_layers))) 624 | layer_dims = [dim * mult for mult in layer_mults] 625 | dims = (dim, *layer_dims) 626 | 627 | self.discr = Discriminator(dims = dims, channels = channels) 628 | 629 | self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss 630 | self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss 631 | 632 | @property 633 | def encoded_dim(self): 634 | return self.enc_dec.encoded_dim 635 | 636 | def get_encoded_fmap_size(self, image_size): 637 | return self.enc_dec.get_encoded_fmap_size(image_size) 638 | 639 | def copy_for_eval(self): 640 | device = next(self.parameters()).device 641 | vae_copy = copy.deepcopy(self.cpu()) 642 | 643 | if vae_copy.use_vgg_and_gan: 644 | del vae_copy.discr 645 | del vae_copy.vgg 646 | 647 | vae_copy.eval() 648 | return vae_copy.to(device) 649 | 650 | @remove_vgg 651 | def state_dict(self, *args, **kwargs): 652 | return super().state_dict(*args, **kwargs) 653 | 654 | @remove_vgg 655 | def load_state_dict(self, *args, **kwargs): 656 | return super().load_state_dict(*args, **kwargs) 657 | 658 | @property 659 | def codebook(self): 660 | return self.vq.codebook 661 | 662 | def encode(self, fmap): 663 | fmap = self.enc_dec.encode(fmap) 664 | return fmap 665 | 666 | def decode(self, fmap, return_indices_and_loss = False): 667 | fmap, indices, commit_loss = self.vq(fmap) 668 | 669 | fmap = self.enc_dec.decode(fmap) 670 | 671 | if not return_indices_and_loss: 672 | return fmap 673 | 674 | return fmap, indices, commit_loss 675 | 676 | def forward( 677 | self, 678 | img, 679 | return_loss = False, 680 | return_discr_loss = False, 681 | return_recons = False, 682 | add_gradient_penalty = True 683 | ): 684 | batch, channels, height, width, device = *img.shape, img.device 685 | assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}' 686 | assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE' 687 | 688 | fmap = self.encode(img) 689 | 690 | fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True) 691 | 692 | if not return_loss and not return_discr_loss: 693 | return fmap 694 | 695 | assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both' 696 | 697 | # whether to return discriminator loss 698 | 699 | if return_discr_loss: 700 | assert exists(self.discr), 'discriminator must exist to train it' 701 | 702 | fmap.detach_() 703 | img.requires_grad_() 704 | 705 | fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img)) 706 | 707 | discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits) 708 | 709 | if add_gradient_penalty: 710 | gp = gradient_penalty(img, img_discr_logits) 711 | loss = discr_loss + gp 712 | 713 | if return_recons: 714 | return loss, fmap 715 | 716 | return loss 717 | 718 | # reconstruction loss 719 | 720 | recon_loss = self.recon_loss_fn(fmap, img) 721 | 722 | # early return if training on grayscale 723 | 724 | if not self.use_vgg_and_gan: 725 | if return_recons: 726 | return recon_loss, fmap 727 | 728 | return recon_loss 729 | 730 | # perceptual loss 731 | 732 | img_vgg_input = img 733 | fmap_vgg_input = fmap 734 | 735 | if img.shape[1] == 1: 736 | # handle grayscale for vgg 737 | img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input)) 738 | 739 | img_vgg_feats = self.vgg(img_vgg_input) 740 | recon_vgg_feats = self.vgg(fmap_vgg_input) 741 | perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats) 742 | 743 | # generator loss 744 | 745 | gen_loss = self.gen_loss(self.discr(fmap)) 746 | 747 | # calculate adaptive weight 748 | 749 | last_dec_layer = self.enc_dec.last_dec_layer 750 | 751 | norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2) 752 | norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2) 753 | 754 | adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss) 755 | adaptive_weight.clamp_(max = 1e4) 756 | 757 | # combine losses 758 | 759 | loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss 760 | 761 | if return_recons: 762 | return loss, fmap 763 | 764 | return loss 765 | -------------------------------------------------------------------------------- /dalle2_pytorch/vqgan_vae_trainer.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import copy 3 | from random import choice 4 | from pathlib import Path 5 | from shutil import rmtree 6 | from PIL import Image 7 | 8 | import torch 9 | from torch import nn 10 | from torch.cuda.amp import autocast, GradScaler 11 | from torch.utils.data import Dataset, DataLoader, random_split 12 | 13 | import torchvision.transforms as T 14 | from torchvision.datasets import ImageFolder 15 | from torchvision.utils import make_grid, save_image 16 | 17 | from einops import rearrange 18 | 19 | from dalle2_pytorch.vqgan_vae import VQGanVAE 20 | from dalle2_pytorch.optimizer import get_optimizer 21 | 22 | from ema_pytorch import EMA 23 | 24 | # helpers 25 | 26 | def exists(val): 27 | return val is not None 28 | 29 | def noop(*args, **kwargs): 30 | pass 31 | 32 | def cycle(dl): 33 | while True: 34 | for data in dl: 35 | yield data 36 | 37 | def cast_tuple(t): 38 | return t if isinstance(t, (tuple, list)) else (t,) 39 | 40 | def yes_or_no(question): 41 | answer = input(f'{question} (y/n) ') 42 | return answer.lower() in ('yes', 'y') 43 | 44 | def accum_log(log, new_logs): 45 | for key, new_value in new_logs.items(): 46 | old_value = log.get(key, 0.) 47 | log[key] = old_value + new_value 48 | return log 49 | 50 | # classes 51 | 52 | class ImageDataset(Dataset): 53 | def __init__( 54 | self, 55 | folder, 56 | image_size, 57 | exts = ['jpg', 'jpeg', 'png'] 58 | ): 59 | super().__init__() 60 | self.folder = folder 61 | self.image_size = image_size 62 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] 63 | 64 | print(f'{len(self.paths)} training samples found at {folder}') 65 | 66 | self.transform = T.Compose([ 67 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 68 | T.Resize(image_size), 69 | T.RandomHorizontalFlip(), 70 | T.CenterCrop(image_size), 71 | T.ToTensor() 72 | ]) 73 | 74 | def __len__(self): 75 | return len(self.paths) 76 | 77 | def __getitem__(self, index): 78 | path = self.paths[index] 79 | img = Image.open(path) 80 | return self.transform(img) 81 | 82 | # main trainer class 83 | 84 | class VQGanVAETrainer(nn.Module): 85 | def __init__( 86 | self, 87 | vae, 88 | *, 89 | num_train_steps, 90 | lr, 91 | batch_size, 92 | folder, 93 | grad_accum_every, 94 | wd = 0., 95 | save_results_every = 100, 96 | save_model_every = 1000, 97 | results_folder = './results', 98 | valid_frac = 0.05, 99 | random_split_seed = 42, 100 | ema_beta = 0.995, 101 | ema_update_after_step = 500, 102 | ema_update_every = 10, 103 | apply_grad_penalty_every = 4, 104 | amp = False 105 | ): 106 | super().__init__() 107 | assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE' 108 | image_size = vae.image_size 109 | 110 | self.vae = vae 111 | self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every) 112 | 113 | self.register_buffer('steps', torch.Tensor([0])) 114 | 115 | self.num_train_steps = num_train_steps 116 | self.batch_size = batch_size 117 | self.grad_accum_every = grad_accum_every 118 | 119 | all_parameters = set(vae.parameters()) 120 | discr_parameters = set(vae.discr.parameters()) 121 | vae_parameters = all_parameters - discr_parameters 122 | 123 | self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd) 124 | self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd) 125 | 126 | self.amp = amp 127 | self.scaler = GradScaler(enabled = amp) 128 | self.discr_scaler = GradScaler(enabled = amp) 129 | 130 | # create dataset 131 | 132 | self.ds = ImageDataset(folder, image_size = image_size) 133 | 134 | # split for validation 135 | 136 | if valid_frac > 0: 137 | train_size = int((1 - valid_frac) * len(self.ds)) 138 | valid_size = len(self.ds) - train_size 139 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) 140 | print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') 141 | else: 142 | self.valid_ds = self.ds 143 | print(f'training with shared training and valid dataset of {len(self.ds)} samples') 144 | 145 | # dataloader 146 | 147 | self.dl = cycle(DataLoader( 148 | self.ds, 149 | batch_size = batch_size, 150 | shuffle = True 151 | )) 152 | 153 | self.valid_dl = cycle(DataLoader( 154 | self.valid_ds, 155 | batch_size = batch_size, 156 | shuffle = True 157 | )) 158 | 159 | self.save_model_every = save_model_every 160 | self.save_results_every = save_results_every 161 | 162 | self.apply_grad_penalty_every = apply_grad_penalty_every 163 | 164 | self.results_folder = Path(results_folder) 165 | 166 | if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): 167 | rmtree(str(self.results_folder)) 168 | 169 | self.results_folder.mkdir(parents = True, exist_ok = True) 170 | 171 | def train_step(self): 172 | device = next(self.vae.parameters()).device 173 | steps = int(self.steps.item()) 174 | apply_grad_penalty = not (steps % self.apply_grad_penalty_every) 175 | 176 | self.vae.train() 177 | 178 | # logs 179 | 180 | logs = {} 181 | 182 | # update vae (generator) 183 | 184 | for _ in range(self.grad_accum_every): 185 | img = next(self.dl) 186 | img = img.to(device) 187 | 188 | with autocast(enabled = self.amp): 189 | loss = self.vae( 190 | img, 191 | return_loss = True, 192 | apply_grad_penalty = apply_grad_penalty 193 | ) 194 | 195 | 196 | self.scaler.scale(loss / self.grad_accum_every).backward() 197 | 198 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) 199 | 200 | self.scaler.step(self.optim) 201 | self.scaler.update() 202 | self.optim.zero_grad() 203 | 204 | # update discriminator 205 | 206 | if exists(self.vae.discr): 207 | discr_loss = 0 208 | for _ in range(self.grad_accum_every): 209 | img = next(self.dl) 210 | img = img.to(device) 211 | 212 | with autocast(enabled = self.amp): 213 | loss = self.vae(img, return_discr_loss = True) 214 | 215 | self.discr_scaler.scale(loss / self.grad_accum_every).backward() 216 | 217 | accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) 218 | 219 | self.discr_scaler.step(self.discr_optim) 220 | self.discr_scaler.update() 221 | self.discr_optim.zero_grad() 222 | 223 | # log 224 | 225 | print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}") 226 | 227 | # update exponential moving averaged generator 228 | 229 | self.ema_vae.update() 230 | 231 | # sample results every so often 232 | 233 | if not (steps % self.save_results_every): 234 | for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))): 235 | model.eval() 236 | 237 | imgs = next(self.dl) 238 | imgs = imgs.to(device) 239 | 240 | recons = model(imgs) 241 | nrows = int(sqrt(self.batch_size)) 242 | 243 | imgs_and_recons = torch.stack((imgs, recons), dim = 0) 244 | imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...') 245 | 246 | imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.) 247 | grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1)) 248 | 249 | logs['reconstructions'] = grid 250 | 251 | save_image(grid, str(self.results_folder / f'{filename}.png')) 252 | 253 | print(f'{steps}: saving to {str(self.results_folder)}') 254 | 255 | # save model every so often 256 | 257 | if not (steps % self.save_model_every): 258 | state_dict = self.vae.state_dict() 259 | model_path = str(self.results_folder / f'vae.{steps}.pt') 260 | torch.save(state_dict, model_path) 261 | 262 | ema_state_dict = self.ema_vae.state_dict() 263 | model_path = str(self.results_folder / f'vae.{steps}.ema.pt') 264 | torch.save(ema_state_dict, model_path) 265 | 266 | print(f'{steps}: saving model to {str(self.results_folder)}') 267 | 268 | self.steps += 1 269 | return logs 270 | 271 | def train(self, log_fn = noop): 272 | device = next(self.vae.parameters()).device 273 | 274 | while self.steps < self.num_train_steps: 275 | logs = self.train_step() 276 | log_fn(logs) 277 | 278 | print('training complete') 279 | -------------------------------------------------------------------------------- /prior.md: -------------------------------------------------------------------------------- 1 | # Diffusion Prior 2 | This readme serves as an introduction to the diffusion prior. 3 | 4 | ## Intro 5 | 6 | A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful. 7 | 8 | ### Motivation 9 | 10 | Before we dive into the model, let’s look at a quick example of where the model may be helpful. 11 | 12 | For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder. 13 | 14 | > [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets. 15 | 16 | ```python 17 | # Load Models 18 | clip_model = clip.load("ViT-L/14") 19 | decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings 20 | 21 | # Retrieve prompt from user and encode with CLIP 22 | prompt = "A corgi wearing sunglasses" 23 | tokenized_text = tokenize(prompt) 24 | text_embedding = clip_model.encode_text(tokenized_text) 25 | 26 | # Now, pass the text embedding to the decoder 27 | predicted_image = decoder.sample(text_embedding) 28 | ``` 29 | 30 | > **Question**: *Can you spot the issue here?* 31 | > 32 | > **Answer**: *We’re trying to generate an image from a text embedding!* 33 | 34 | Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution 35 | 36 | ```python 37 | # Load Models 38 | prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb 39 | decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings 40 | 41 | # Retrieve prompt from user and encode with a prior 42 | prompt = "A corgi wearing sunglasses" 43 | tokenized_text = tokenize(prompt) 44 | text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images! 45 | 46 | # Now, pass the predicted image embedding to the decoder 47 | predicted_image = decoder.sample(text_embedding) 48 | ``` 49 | 50 | With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data. 51 | 52 | > **You may be asking yourself the following question:** 53 | > 54 | > *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"* 55 | > 56 | > OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile: 57 | 58 | ## Usage 59 | 60 | To utilize a pre-trained prior, it’s quite simple. 61 | 62 | ### Loading Checkpoints 63 | ```python 64 | import torch 65 | from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter 66 | from dalle2_pytorch.trainer import DiffusionPriorTrainer 67 | 68 | def load_diffusion_model(dprior_path): 69 | 70 | prior_network = DiffusionPriorNetwork( 71 | dim=768, 72 | depth=24, 73 | dim_head=64, 74 | heads=32, 75 | normformer=True, 76 | attn_dropout=5e-2, 77 | ff_dropout=5e-2, 78 | num_time_embeds=1, 79 | num_image_embeds=1, 80 | num_text_embeds=1, 81 | num_timesteps=1000, 82 | ff_mult=4 83 | ) 84 | 85 | diffusion_prior = DiffusionPrior( 86 | net=prior_network, 87 | clip=OpenAIClipAdapter("ViT-L/14"), 88 | image_embed_dim=768, 89 | timesteps=1000, 90 | cond_drop_prob=0.1, 91 | loss_type="l2", 92 | condition_on_text_encodings=True, 93 | 94 | ) 95 | 96 | trainer = DiffusionPriorTrainer( 97 | diffusion_prior=diffusion_prior, 98 | lr=1.1e-4, 99 | wd=6.02e-2, 100 | max_grad_norm=0.5, 101 | amp=False, 102 | group_wd_params=True, 103 | use_ema=True, 104 | device=device, 105 | accelerator=None, 106 | ) 107 | 108 | trainer.load(dprior_path) 109 | 110 | return trainer 111 | ``` 112 | 113 | Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*) 114 | 115 | ### Sampling 116 | Once we have a pre-trained model, generating embeddings is quite simple! 117 | ```python 118 | # tokenize the text 119 | tokenized_text = clip.tokenize("") 120 | # predict an embedding 121 | predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0) 122 | ``` 123 | 124 | The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768). 125 | 126 | > For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text(). 127 | 128 | **Some things to note:** 129 | * It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt. 130 | * You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*. 131 | 132 | --- 133 | 134 | ## Training 135 | 136 | ### Overview 137 | 138 | Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration 139 | 140 | ## Dataset 141 | 142 | To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader. 143 | 144 | ## Configuration 145 | 146 | The configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that will specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README. 147 | 148 | ## Distributed Training 149 | 150 | If you would like to train in a distributed manner we have opted to leverage huggingface’ new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPU’s and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator). 151 | 152 | ## Evaluation 153 | 154 | There are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below: 155 | | Metric | Description | Comments | 156 | | ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 157 | | Online Model Validation | The validation loss associated with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. | 158 | | EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your "online" model's validation loss, but should outperform in the long-term. | 159 | | Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. | 160 | | Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) | 161 | | Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting somehow. | 162 | | Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. | 163 | | Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. | 164 | 165 | ## Launching the script 166 | 167 | Now that you’ve done all the prep it’s time for the easy part! 🚀 168 | 169 | To actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path ` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate. 170 | 171 | ## Checkpointing 172 | 173 | Checkpoints will be saved to the directory specified in your configuration file. 174 | 175 | Additionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and titled “latest.pth”. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data. 176 | 177 | ## Things To Keep In Mind 178 | 179 | The prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet. 180 | 181 | As we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks. 182 | 183 | With that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you don’t see documentation for! 184 | -------------------------------------------------------------------------------- /samples/oxford.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/samples/oxford.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | exec(open('dalle2_pytorch/version.py').read()) 3 | 4 | setup( 5 | name = 'dalle2-pytorch', 6 | packages = find_packages(exclude=[]), 7 | include_package_data = True, 8 | entry_points={ 9 | 'console_scripts': [ 10 | 'dalle2_pytorch = dalle2_pytorch.cli:main', 11 | 'dream = dalle2_pytorch.cli:dream' 12 | ], 13 | }, 14 | version = __version__, 15 | license='MIT', 16 | description = 'DALL-E 2', 17 | author = 'Phil Wang', 18 | author_email = 'lucidrains@gmail.com', 19 | long_description_content_type = 'text/markdown', 20 | url = 'https://github.com/lucidrains/dalle2-pytorch', 21 | keywords = [ 22 | 'artificial intelligence', 23 | 'deep learning', 24 | 'text to image' 25 | ], 26 | install_requires=[ 27 | 'accelerate', 28 | 'click', 29 | 'open-clip-torch>=2.0.0,<3.0.0', 30 | 'clip-anytorch>=2.5.2', 31 | 'coca-pytorch>=0.0.5', 32 | 'ema-pytorch>=0.0.7', 33 | 'einops>=0.7.0', 34 | 'embedding-reader', 35 | 'kornia>=0.5.4', 36 | 'numpy', 37 | 'packaging', 38 | 'pillow', 39 | 'pydantic>=2', 40 | 'pytorch-warmup', 41 | 'resize-right>=0.0.2', 42 | 'rotary-embedding-torch', 43 | 'torch>=1.10', 44 | 'torchvision', 45 | 'tqdm', 46 | 'vector-quantize-pytorch', 47 | 'x-clip>=0.4.4', 48 | 'webdataset>=0.2.5', 49 | 'fsspec>=2022.1.0', 50 | 'torchmetrics[image]>=0.8.0' 51 | ], 52 | classifiers=[ 53 | 'Development Status :: 4 - Beta', 54 | 'Intended Audience :: Developers', 55 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 56 | 'License :: OSI Approved :: MIT License', 57 | 'Programming Language :: Python :: 3.6', 58 | ], 59 | ) 60 | -------------------------------------------------------------------------------- /test_data/0.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/0.tar -------------------------------------------------------------------------------- /test_data/1.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/1.tar -------------------------------------------------------------------------------- /test_data/2.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/2.tar -------------------------------------------------------------------------------- /test_data/3.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/3.tar -------------------------------------------------------------------------------- /test_data/4.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/4.tar -------------------------------------------------------------------------------- /test_data/5.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/5.tar -------------------------------------------------------------------------------- /test_data/6.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/6.tar -------------------------------------------------------------------------------- /test_data/7.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/7.tar -------------------------------------------------------------------------------- /test_data/8.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/8.tar -------------------------------------------------------------------------------- /test_data/9.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/9.tar -------------------------------------------------------------------------------- /train_diffusion_prior.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | 4 | from torch import nn 5 | from typing import List 6 | from accelerate import Accelerator 7 | from accelerate.utils import set_seed 8 | from torch.utils.data import DataLoader 9 | from embedding_reader import EmbeddingReader 10 | from accelerate.utils import dataclasses as accelerate_dataclasses 11 | 12 | from dalle2_pytorch.utils import Timer 13 | from dalle2_pytorch.trackers import Tracker 14 | from dalle2_pytorch import DiffusionPriorTrainer 15 | from dalle2_pytorch.dataloaders import get_reader, make_splits 16 | from dalle2_pytorch.train_configs import ( 17 | DiffusionPriorConfig, 18 | DiffusionPriorTrainConfig, 19 | TrainDiffusionPriorConfig, 20 | ) 21 | 22 | 23 | # helpers 24 | 25 | 26 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 27 | 28 | 29 | def exists(val): 30 | return val is not None 31 | 32 | 33 | def all_between(values: list, lower_bound, upper_bound): 34 | for value in values: 35 | if value < lower_bound or value > upper_bound: 36 | return False 37 | 38 | return True 39 | 40 | 41 | def make_model( 42 | prior_config: DiffusionPriorConfig, 43 | train_config: DiffusionPriorTrainConfig, 44 | device: str = None, 45 | accelerator: Accelerator = None, 46 | ): 47 | # create model from config 48 | diffusion_prior = prior_config.create() 49 | 50 | # instantiate the trainer 51 | trainer = DiffusionPriorTrainer( 52 | diffusion_prior=diffusion_prior, 53 | lr=train_config.lr, 54 | wd=train_config.wd, 55 | max_grad_norm=train_config.max_grad_norm, 56 | amp=train_config.amp, 57 | use_ema=train_config.use_ema, 58 | device=device, 59 | accelerator=accelerator, 60 | warmup_steps=train_config.warmup_steps, 61 | ) 62 | 63 | return trainer 64 | 65 | 66 | def create_tracker( 67 | accelerator: Accelerator, 68 | config: TrainDiffusionPriorConfig, 69 | config_path: str, 70 | dummy: bool = False, 71 | ) -> Tracker: 72 | tracker_config = config.tracker 73 | 74 | accelerator_config = { 75 | "Distributed": accelerator.distributed_type 76 | != accelerate_dataclasses.DistributedType.NO, 77 | "DistributedType": accelerator.distributed_type, 78 | "NumProcesses": accelerator.num_processes, 79 | "MixedPrecision": accelerator.mixed_precision, 80 | } 81 | 82 | tracker: Tracker = tracker_config.create( 83 | config, accelerator_config, dummy_mode=dummy 84 | ) 85 | 86 | tracker.save_config(config_path, config_name="prior_config.json") 87 | 88 | return tracker 89 | 90 | 91 | def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"): 92 | """ 93 | pad a value or tensor across all processes and gather 94 | 95 | params: 96 | - trainer: a trainer that carries an accelerator object 97 | - x: a number or torch tensor to reduce 98 | - method: "mean", "sum", "max", "min" 99 | 100 | return: 101 | - the average tensor after maskin out 0's 102 | - None if the gather resulted in an empty tensor 103 | """ 104 | 105 | assert method in [ 106 | "mean", 107 | "sum", 108 | "max", 109 | "min", 110 | ], "This function has limited capabilities [sum, mean, max, min]" 111 | assert type(x) is not None, "Cannot reduce a None type object" 112 | 113 | # wait for everyone to arrive here before gathering 114 | 115 | if type(x) is not torch.Tensor: 116 | x = torch.tensor([x]) 117 | 118 | # verify that the tensor is on the proper device 119 | x = x.to(trainer.device) 120 | 121 | # pad across processes 122 | padded_x = trainer.accelerator.pad_across_processes(x, dim=0) 123 | 124 | # gather across all procesess 125 | gathered_x = trainer.accelerator.gather(padded_x) 126 | 127 | # mask out zeros 128 | masked_x = gathered_x[gathered_x != 0] 129 | 130 | # if the tensor is empty, warn and return None 131 | if len(masked_x) == 0: 132 | click.secho( 133 | f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.", 134 | fg="red", 135 | ) 136 | return None 137 | 138 | if method == "mean": 139 | return torch.mean(masked_x) 140 | elif method == "sum": 141 | return torch.sum(masked_x) 142 | elif method == "max": 143 | return torch.max(masked_x) 144 | elif method == "min": 145 | return torch.min(masked_x) 146 | 147 | 148 | def save_trainer( 149 | tracker: Tracker, 150 | trainer: DiffusionPriorTrainer, 151 | is_latest: bool, 152 | is_best: bool, 153 | epoch: int, 154 | samples_seen: int, 155 | best_validation_loss: float, 156 | ): 157 | """ 158 | Logs the model with an appropriate method depending on the tracker 159 | """ 160 | trainer.accelerator.wait_for_everyone() 161 | 162 | if trainer.accelerator.is_main_process: 163 | click.secho( 164 | f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}", 165 | fg="magenta", 166 | ) 167 | 168 | tracker.save( 169 | trainer=trainer, 170 | is_best=is_best, 171 | is_latest=is_latest, 172 | epoch=int(epoch), 173 | samples_seen=int(samples_seen), 174 | best_validation_loss=best_validation_loss, 175 | ) 176 | 177 | 178 | def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer): 179 | """ 180 | Loads the model with an appropriate method depending on the tracker 181 | """ 182 | 183 | if trainer.accelerator.is_main_process: 184 | click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow") 185 | 186 | state_dict = tracker.recall() 187 | 188 | trainer.load(state_dict, strict=True) 189 | 190 | return ( 191 | int(state_dict.get("epoch", 0)), 192 | state_dict.get("best_validation_loss", 0), 193 | int(state_dict.get("samples_seen", 0)), 194 | ) 195 | 196 | 197 | # eval functions 198 | 199 | 200 | def report_validation_loss( 201 | trainer: DiffusionPriorTrainer, 202 | dataloader: DataLoader, 203 | text_conditioned: bool, 204 | use_ema: bool, 205 | tracker: Tracker, 206 | split: str, 207 | tracker_folder: str, 208 | loss_type: str, 209 | ): 210 | """ 211 | Compute the validation loss on a given subset of data. 212 | """ 213 | 214 | if trainer.accelerator.is_main_process: 215 | click.secho( 216 | f"Measuring performance on {use_ema}-{split} split", 217 | fg="green", 218 | blink=True, 219 | ) 220 | 221 | total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device) 222 | 223 | for image_embeddings, text_data in dataloader: 224 | image_embeddings = image_embeddings.to(trainer.device) 225 | text_data = text_data.to(trainer.device) 226 | 227 | input_args = dict(image_embed=image_embeddings) 228 | 229 | if text_conditioned: 230 | input_args = dict(**input_args, text=text_data) 231 | else: 232 | input_args = dict(**input_args, text_embed=text_data) 233 | 234 | if use_ema: 235 | loss = trainer.ema_diffusion_prior(**input_args) 236 | else: 237 | loss = trainer(**input_args) 238 | 239 | total_loss += loss 240 | 241 | # compute the average loss across all processes 242 | 243 | avg_loss = pad_gather_reduce(trainer, total_loss, method="mean") 244 | stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss} 245 | 246 | # print and log results on main process 247 | tracker.log(stats, step=trainer.step.item() + 1) 248 | 249 | return avg_loss 250 | 251 | 252 | def report_cosine_sims( 253 | trainer: DiffusionPriorTrainer, 254 | dataloader: DataLoader, 255 | text_conditioned: bool, 256 | tracker: Tracker, 257 | split: str, 258 | timesteps: int, 259 | tracker_folder: str, 260 | ): 261 | trainer.eval() 262 | if trainer.accelerator.is_main_process: 263 | click.secho( 264 | f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps", 265 | fg="green", 266 | blink=True, 267 | ) 268 | 269 | for test_image_embeddings, text_data in dataloader: 270 | test_image_embeddings = test_image_embeddings.to(trainer.device) 271 | text_data = text_data.to(trainer.device) 272 | 273 | # we are text conditioned, we produce an embedding from the tokenized text 274 | if text_conditioned: 275 | text_embedding, text_encodings = trainer.embed_text(text_data) 276 | text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings) 277 | else: 278 | text_embedding = text_data 279 | text_cond = dict(text_embed=text_embedding) 280 | 281 | # make a copy of the text embeddings for shuffling 282 | text_embed_shuffled = text_embedding.clone() 283 | 284 | # roll the text to simulate "unrelated" captions 285 | rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1) 286 | text_embed_shuffled = text_embed_shuffled[rolled_idx] 287 | text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm( 288 | dim=1, keepdim=True 289 | ) 290 | 291 | if text_conditioned: 292 | text_encodings_shuffled = text_encodings[rolled_idx] 293 | else: 294 | text_encodings_shuffled = None 295 | 296 | text_cond_shuffled = dict( 297 | text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled 298 | ) 299 | 300 | # prepare the text embedding 301 | text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True) 302 | 303 | # prepare image embeddings 304 | test_image_embeddings = test_image_embeddings / test_image_embeddings.norm( 305 | dim=1, keepdim=True 306 | ) 307 | 308 | # predict on the unshuffled text embeddings 309 | predicted_image_embeddings = trainer.p_sample_loop( 310 | test_image_embeddings.shape, 311 | text_cond, 312 | timesteps=timesteps, 313 | ) 314 | 315 | predicted_image_embeddings = ( 316 | predicted_image_embeddings 317 | / predicted_image_embeddings.norm(dim=1, keepdim=True) 318 | ) 319 | 320 | # predict on the shuffled embeddings 321 | predicted_unrelated_embeddings = trainer.p_sample_loop( 322 | test_image_embeddings.shape, 323 | text_cond_shuffled, 324 | timesteps=timesteps, 325 | ) 326 | 327 | predicted_unrelated_embeddings = ( 328 | predicted_unrelated_embeddings 329 | / predicted_unrelated_embeddings.norm(dim=1, keepdim=True) 330 | ) 331 | 332 | # calculate similarities 333 | orig_sim = pad_gather_reduce( 334 | trainer, cos(text_embed, test_image_embeddings), method="mean" 335 | ) 336 | pred_sim = pad_gather_reduce( 337 | trainer, cos(text_embed, predicted_image_embeddings), method="mean" 338 | ) 339 | unrel_sim = pad_gather_reduce( 340 | trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean" 341 | ) 342 | pred_img_sim = pad_gather_reduce( 343 | trainer, 344 | cos(test_image_embeddings, predicted_image_embeddings), 345 | method="mean", 346 | ) 347 | 348 | stats = { 349 | f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim, 350 | f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim, 351 | f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim, 352 | f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim, 353 | f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim 354 | - orig_sim, 355 | } 356 | 357 | tracker.log(stats, step=trainer.step.item() + 1) 358 | 359 | 360 | def eval_model( 361 | trainer: DiffusionPriorTrainer, 362 | dataloader: DataLoader, 363 | text_conditioned: bool, 364 | split: str, 365 | tracker: Tracker, 366 | use_ema: bool, 367 | report_cosine: bool, 368 | report_loss: bool, 369 | timesteps: List[int], 370 | loss_type: str = None, 371 | ): 372 | """ 373 | Run evaluation on a model and track metrics 374 | 375 | returns: loss if requested 376 | """ 377 | trainer.eval() 378 | 379 | use_ema = "ema" if use_ema else "online" 380 | tracker_folder = f"metrics/{use_ema}-{split}" 381 | 382 | # detemine if valid timesteps are passed 383 | 384 | min_timesteps = trainer.accelerator.unwrap_model( 385 | trainer.diffusion_prior 386 | ).sample_timesteps 387 | max_timesteps = trainer.accelerator.unwrap_model( 388 | trainer.diffusion_prior 389 | ).noise_scheduler.num_timesteps 390 | 391 | assert all_between( 392 | timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps 393 | ), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}" 394 | 395 | # measure cosine metrics across various eta and timesteps 396 | 397 | if report_cosine: 398 | for timestep in timesteps: 399 | report_cosine_sims( 400 | trainer, 401 | dataloader=dataloader, 402 | text_conditioned=text_conditioned, 403 | tracker=tracker, 404 | split=split, 405 | timesteps=timestep, 406 | tracker_folder=tracker_folder, 407 | ) 408 | 409 | # measure loss on a seperate split of data 410 | 411 | if report_loss: 412 | loss = report_validation_loss( 413 | trainer=trainer, 414 | dataloader=dataloader, 415 | text_conditioned=text_conditioned, 416 | use_ema=use_ema, 417 | tracker=tracker, 418 | split=split, 419 | tracker_folder=tracker_folder, 420 | loss_type=loss_type, 421 | ) 422 | 423 | return loss 424 | 425 | 426 | # training script 427 | 428 | 429 | def train( 430 | trainer: DiffusionPriorTrainer, 431 | tracker: Tracker, 432 | train_loader: DataLoader, 433 | eval_loader: DataLoader, 434 | test_loader: DataLoader, 435 | config: DiffusionPriorTrainConfig, 436 | ): 437 | # init timers 438 | save_timer = Timer() # when to save 439 | samples_timer = Timer() # samples/sec 440 | validation_profiler = Timer() # how long is validation taking 441 | validation_countdown = Timer() # when to perform evalutation 442 | 443 | # keep track of best validation loss 444 | 445 | best_validation_loss = config.train.best_validation_loss 446 | samples_seen = config.train.num_samples_seen 447 | 448 | # do training 449 | 450 | start_epoch = config.train.current_epoch 451 | 452 | for epoch in range(start_epoch, config.train.epochs): 453 | # if we finished out an old epoch, reset the distribution to be a full epoch 454 | tracker.log({"tracking/epoch": epoch}, step=trainer.step.item()) 455 | 456 | if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1: 457 | if trainer.accelerator.is_main_process: 458 | click.secho(f"Finished resumed epoch...resetting dataloader.") 459 | train_loader.dataset.set_start(0) 460 | 461 | for img, txt in train_loader: 462 | # setup things every step 463 | 464 | trainer.train() 465 | current_step = trainer.step.item() 466 | samples_timer.reset() 467 | 468 | # place data on device 469 | 470 | img = img.to(trainer.device) 471 | txt = txt.to(trainer.device) 472 | 473 | # pass to model 474 | 475 | loss = trainer(text=txt, image_embed=img) 476 | 477 | # perform backprop & apply EMA updates 478 | 479 | trainer.update() 480 | 481 | # gather info about training step 482 | 483 | all_loss = pad_gather_reduce(trainer, loss, method="mean") 484 | num_samples = pad_gather_reduce(trainer, len(txt), method="sum") 485 | samples_per_sec = num_samples / samples_timer.elapsed() 486 | samples_seen += num_samples 487 | ema_decay = trainer.ema_diffusion_prior.get_current_decay() 488 | 489 | # log 490 | 491 | tracker.log( 492 | { 493 | "tracking/samples-sec": samples_per_sec, 494 | "tracking/samples-seen": samples_seen, 495 | "tracking/ema-decay": ema_decay, 496 | f"tracking/training-{config.prior.loss_type}": all_loss, 497 | }, 498 | step=current_step, 499 | ) 500 | 501 | # Metric Tracking @ Timed Intervals 502 | 503 | eval_delta = pad_gather_reduce( 504 | trainer, validation_countdown.elapsed(), method="min" 505 | ) 506 | 507 | if eval_delta != None and eval_delta > config.data.eval_every_seconds: 508 | # begin timing how long this takes 509 | 510 | validation_profiler.reset() 511 | 512 | # package kwargs for evaluation 513 | 514 | eval_kwargs = { 515 | "trainer": trainer, 516 | "tracker": tracker, 517 | "text_conditioned": config.prior.condition_on_text_encodings, 518 | "timesteps": config.train.eval_timesteps, 519 | } 520 | 521 | # ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT 522 | 523 | eval_model( 524 | dataloader=eval_loader, 525 | loss_type=config.prior.loss_type, 526 | split="validation", 527 | use_ema=False, 528 | report_cosine=False, 529 | report_loss=True, 530 | **eval_kwargs, 531 | ) 532 | 533 | # EMA MODEL : COSINE : LOSS : VALIDATION DATA 534 | 535 | ema_val_loss = eval_model( 536 | dataloader=eval_loader, 537 | loss_type=config.prior.loss_type, 538 | split="validation", 539 | use_ema=True, 540 | report_cosine=True, 541 | report_loss=True, 542 | **eval_kwargs, 543 | ) 544 | 545 | tracker.log( 546 | { 547 | "tracking/validation length (minutes)": validation_profiler.elapsed() 548 | / 60 549 | } 550 | ) 551 | 552 | # check if the ema validation is the lowest seen yet 553 | 554 | if ema_val_loss < best_validation_loss: 555 | best_validation_loss = ema_val_loss 556 | 557 | # go save the model as best 558 | 559 | save_trainer( 560 | trainer=trainer, 561 | tracker=tracker, 562 | is_best=True, 563 | is_latest=False, 564 | samples_seen=samples_seen, 565 | epoch=epoch, 566 | best_validation_loss=best_validation_loss, 567 | ) 568 | 569 | # reset timer for validaiton 570 | 571 | validation_countdown.reset() 572 | 573 | elif eval_delta is None: 574 | click.secho( 575 | f"Error occured reading the eval time on rank: {trainer.device}", 576 | fg="yellow", 577 | ) 578 | 579 | # save as latest model on schedule 580 | 581 | save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method="min") 582 | 583 | if save_delta != None and save_delta >= config.train.save_every_seconds: 584 | save_trainer( 585 | trainer=trainer, 586 | tracker=tracker, 587 | is_best=False, 588 | is_latest=True, 589 | samples_seen=samples_seen, 590 | epoch=epoch, 591 | best_validation_loss=best_validation_loss, 592 | ) 593 | 594 | save_timer.reset() 595 | 596 | elif save_delta is None: 597 | click.secho( 598 | f"Error occured reading the save time on rank: {trainer.device}", 599 | fg="yellow", 600 | ) 601 | 602 | # evaluate on test data 603 | 604 | if trainer.accelerator.is_main_process: 605 | click.secho(f"Starting Test", fg="red") 606 | 607 | # save one last time as latest before beginning validation 608 | 609 | save_trainer( 610 | tracker=tracker, 611 | trainer=trainer, 612 | is_best=False, 613 | is_latest=True, 614 | samples_seen=samples_seen, 615 | epoch=epoch, 616 | best_validation_loss=best_validation_loss, 617 | ) 618 | 619 | test_loss = eval_model( 620 | trainer=trainer, 621 | dataloader=test_loader, 622 | text_conditioned=config.prior.condition_on_text_encodings, 623 | split="test", 624 | tracker=tracker, 625 | use_ema=True, 626 | report_cosine=False, 627 | report_loss=True, 628 | timesteps=config.train.eval_timesteps, 629 | loss_type=config.prior.loss_type, 630 | ) 631 | 632 | if test_loss < best_validation_loss: 633 | best_validation_loss = test_loss 634 | 635 | # go save the model as best 636 | 637 | save_trainer( 638 | trainer=trainer, 639 | tracker=tracker, 640 | is_best=True, 641 | is_latest=False, 642 | samples_seen=samples_seen, 643 | epoch=epoch, 644 | best_validation_loss=test_loss, 645 | ) 646 | 647 | 648 | def initialize_training(config_file, accelerator): 649 | """ 650 | Parse the configuration file, and prepare everything necessary for training 651 | """ 652 | # load the configuration file 653 | if accelerator.is_main_process: 654 | click.secho(f"Loading configuration from {config_file}", fg="green") 655 | 656 | config = TrainDiffusionPriorConfig.from_json_path(config_file) 657 | 658 | # seed 659 | 660 | set_seed(config.train.random_seed) 661 | 662 | # get a device 663 | 664 | device = accelerator.device 665 | 666 | # make the trainer (will automatically distribute if possible & configured) 667 | 668 | trainer: DiffusionPriorTrainer = make_model( 669 | config.prior, config.train, device, accelerator 670 | ).to(device) 671 | 672 | # create a tracker 673 | 674 | tracker = create_tracker( 675 | accelerator, config, config_file, dummy=accelerator.process_index != 0 676 | ) 677 | 678 | # reload from chcekpoint 679 | 680 | if tracker.can_recall: 681 | current_epoch, best_validation_loss, samples_seen = recall_trainer( 682 | tracker=tracker, trainer=trainer 683 | ) 684 | 685 | # display best values 686 | if trainer.accelerator.is_main_process: 687 | click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow") 688 | 689 | # update config to reflect recalled values 690 | config.train.num_samples_seen = samples_seen 691 | config.train.current_epoch = current_epoch 692 | config.train.best_validation_loss = best_validation_loss 693 | 694 | # fetch and prepare data 695 | 696 | if trainer.accelerator.is_main_process: 697 | click.secho("Grabbing data...", fg="blue", blink=True) 698 | 699 | trainer.accelerator.wait_for_everyone() 700 | img_reader = get_reader( 701 | text_conditioned=trainer.text_conditioned, 702 | img_url=config.data.image_url, 703 | meta_url=config.data.meta_url, 704 | ) 705 | 706 | # calculate start point within epoch 707 | 708 | trainer.accelerator.wait_for_everyone() 709 | 710 | train_loader, eval_loader, test_loader = make_splits( 711 | text_conditioned=trainer.text_conditioned, 712 | batch_size=config.data.batch_size, 713 | num_data_points=config.data.num_data_points, 714 | train_split=config.data.splits.train, 715 | eval_split=config.data.splits.val, 716 | image_reader=img_reader, 717 | rank=accelerator.state.process_index, 718 | world_size=accelerator.state.num_processes, 719 | start=0, 720 | ) 721 | 722 | # update the start point to finish out the epoch on a resumed run 723 | 724 | if tracker.can_recall: 725 | samples_seen = config.train.num_samples_seen 726 | length = ( 727 | config.data.num_data_points 728 | if samples_seen <= img_reader.count 729 | else img_reader.count 730 | ) 731 | scaled_samples = length * config.train.current_epoch 732 | start_point = ( 733 | scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen 734 | ) 735 | 736 | if trainer.accelerator.is_main_process: 737 | click.secho(f"Resuming at sample: {start_point}", fg="yellow") 738 | 739 | train_loader.dataset.set_start(start_point) 740 | 741 | # start training 742 | 743 | if trainer.accelerator.is_main_process: 744 | click.secho( 745 | f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}", 746 | fg="yellow", 747 | ) 748 | 749 | train( 750 | trainer=trainer, 751 | tracker=tracker, 752 | train_loader=train_loader, 753 | eval_loader=eval_loader, 754 | test_loader=test_loader, 755 | config=config, 756 | ) 757 | 758 | 759 | @click.command() 760 | @click.option("--config_file", default="configs/train_prior_config.example.json") 761 | def main(config_file): 762 | # start HFA 763 | accelerator = Accelerator() 764 | 765 | # setup training 766 | initialize_training(config_file, accelerator) 767 | 768 | 769 | if __name__ == "__main__": 770 | main() 771 | --------------------------------------------------------------------------------