├── .github
├── FUNDING.yml
└── workflows
│ ├── ci.yml
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── configs
├── README.md
├── train_decoder_config.example.json
├── train_decoder_config.test.json
└── train_prior_config.example.json
├── dalle2.png
├── dalle2_pytorch
├── __init__.py
├── cli.py
├── dalle2_pytorch.py
├── data
│ └── bpe_simple_vocab_16e6.txt
├── dataloaders
│ ├── README.md
│ ├── __init__.py
│ ├── decoder_loader.py
│ ├── prior_loader.py
│ └── simple_image_only_dataloader.py
├── optimizer.py
├── tokenizer.py
├── trackers.py
├── train_configs.py
├── trainer.py
├── utils.py
├── version.py
├── vqgan_vae.py
└── vqgan_vae_trainer.py
├── prior.md
├── samples
└── oxford.png
├── setup.py
├── test_data
├── 0.tar
├── 1.tar
├── 2.tar
├── 3.tar
├── 4.tar
├── 5.tar
├── 6.tar
├── 7.tar
├── 8.tar
└── 9.tar
├── train_decoder.py
└── train_diffusion_prior.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: [nousr, Veldrovive, lucidrains]
2 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: Continuous integration
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | tests:
13 | runs-on: ubuntu-latest
14 | strategy:
15 | matrix:
16 | python-version: [3.8]
17 |
18 | steps:
19 | - uses: actions/checkout@v2
20 | - name: Set up Python ${{ matrix.python-version }}
21 | uses: actions/setup-python@v2
22 | with:
23 | python-version: ${{ matrix.python-version }}
24 | - name: Install
25 | run: |
26 | python3 -m venv .env
27 | source .env/bin/activate
28 | make install
29 | - name: Tests
30 | run: |
31 | source .env/bin/activate
32 | make test
33 |
34 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # default experiment tracker data
2 | .tracker-data/
3 |
4 | # Configuration Files
5 | configs/*
6 | !configs/*.example
7 | !configs/*_defaults.py
8 | !configs/README.md
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | pip-wheel-metadata/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104 | __pypackages__/
105 |
106 | # Celery stuff
107 | celerybeat-schedule
108 | celerybeat.pid
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
139 | .tracker_data
140 | *.pth
141 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include dalle2_pytorch *.txt
2 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | install:
2 | pip install -U pip
3 | pip install -e .
4 |
5 | test:
6 | CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json
7 |
--------------------------------------------------------------------------------
/configs/README.md:
--------------------------------------------------------------------------------
1 | ## DALLE2 Training Configurations
2 |
3 | For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
4 |
5 | ### Decoder Trainer
6 |
7 | The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
8 |
9 | **Unet:**
10 |
11 | This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
12 |
13 | | Option | Required | Default | Description |
14 | | ------ | -------- | ------- | ----------- |
15 | | `dim` | Yes | N/A | The starting channels of the unet. |
16 | | `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |
17 | | `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |
18 |
19 | Any parameter from the `Unet` constructor can also be given here.
20 |
21 | **Decoder:**
22 |
23 | Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
24 | | Option | Required | Default | Description |
25 | | ------ | -------- | ------- | ----------- |
26 | | `unets` | Yes | N/A | A list of unets, using the configuration above |
27 | | `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
28 | | `image_size` | Yes | N/A | Not used. Can be any number. |
29 | | `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
30 | | `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
31 | | `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
32 | | `learned_variance` | No | `True` | Whether to learn the variance. |
33 | | `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
34 |
35 | Any parameter from the `Decoder` constructor can also be given here.
36 |
37 | **Data:**
38 |
39 | Settings for creation of the dataloaders.
40 | | Option | Required | Default | Description |
41 | | ------ | -------- | ------- | ----------- |
42 | | `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
43 | | `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
44 | | `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
45 | | `num_workers` | No | `4` | The number of workers used in the dataloader. |
46 | | `batch_size` | No | `64` | The batch size. |
47 | | `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
48 | | `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |
49 | | `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |
50 | | `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |
51 | | `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
52 | | `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |
53 | | `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |
54 | | `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. |
55 |
56 | [^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.
57 |
58 | [^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.
59 |
60 | [^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.
61 |
62 | **Train:**
63 |
64 | Settings for controlling the training hyperparameters.
65 | | Option | Required | Default | Description |
66 | | ------ | -------- | ------- | ----------- |
67 | | `epochs` | No | `20` | The number of epochs in the training run. |
68 | | `lr` | No | `1e-4` | The learning rate. |
69 | | `wd` | No | `0.01` | The weight decay. |
70 | | `max_grad_norm`| No | `0.5` | The grad norm clipping. |
71 | | `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
72 | | `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. |
73 | | `device` | No | `cuda:0` | The device to train on. |
74 | | `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
75 | | `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
76 | | `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
77 | | `ema_beta` | No | `0.99` | The ema coefficient. |
78 | | `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
79 |
80 | **Evaluate:**
81 |
82 | Defines which evaluation metrics will be used to test the model.
83 | Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
84 | | Option | Required | Default | Description |
85 | | ------ | -------- | ------- | ----------- |
86 | | `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
87 | | `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
88 | | `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
89 | | `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
90 | | `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |
91 |
92 | **Tracker:**
93 |
94 | Selects how the experiment will be tracked.
95 | | Option | Required | Default | Description |
96 | | ------ | -------- | ------- | ----------- |
97 | | `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |
98 | | `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |
99 | | `log` | Yes | N/A | Logging configuration. |
100 | | `load` | No | `None` | Checkpoint loading configuration. |
101 | | `save` | Yes | N/A | Checkpoint/Model saving configuration. |
102 | Tracking is split up into three sections:
103 | * Log: Where to save run metadata and image output. Options are `console` or `wandb`.
104 | * Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.
105 | * Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.
106 |
107 | **Logging:**
108 |
109 | All loggers have the following keys:
110 | | Option | Required | Default | Description |
111 | | ------ | -------- | ------- | ----------- |
112 | | `log_type` | Yes | N/A | The type of logger class to use. |
113 | | `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
114 | | `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
115 |
116 | If using `console` there is no further configuration than setting `log_type` to `console`.
117 | | Option | Required | Default | Description |
118 | | ------ | -------- | ------- | ----------- |
119 | | `log_type` | Yes | N/A | Must be `console`. |
120 |
121 | If using `wandb`
122 | | Option | Required | Default | Description |
123 | | ------ | -------- | ------- | ----------- |
124 | | `log_type` | Yes | N/A | Must be `wandb`. |
125 | | `wandb_entity` | Yes | N/A | The wandb entity to log to. |
126 | | `wandb_project` | Yes | N/A | The wandb project save the run to. |
127 | | `wandb_run_name` | No | `None` | The wandb run name. |
128 | | `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
129 |
130 | **Loading:**
131 |
132 | All loaders have the following keys:
133 | | Option | Required | Default | Description |
134 | | ------ | -------- | ------- | ----------- |
135 | | `load_from` | Yes | N/A | The type of loader class to use. |
136 | | `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
137 |
138 | If using `local`
139 | | Option | Required | Default | Description |
140 | | ------ | -------- | ------- | ----------- |
141 | | `load_from` | Yes | N/A | Must be `local`. |
142 | | `file_path` | Yes | N/A | The path to the checkpoint file. |
143 |
144 | If using `url`
145 | | Option | Required | Default | Description |
146 | | ------ | -------- | ------- | ----------- |
147 | | `load_from` | Yes | N/A | Must be `url`. |
148 | | `url` | Yes | N/A | The url of the checkpoint file. |
149 |
150 | If using `wandb`
151 | | Option | Required | Default | Description |
152 | | ------ | -------- | ------- | ----------- |
153 | | `load_from` | Yes | N/A | Must be `wandb`. |
154 | | `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |
155 | | `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |
156 |
157 | **Saving:**
158 | Unlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.
159 |
160 | All save locations have these configuration options
161 | | Option | Required | Default | Description |
162 | | ------ | -------- | ------- | ----------- |
163 | | `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
164 | | `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. |
165 | | `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
166 | | `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. |
167 | | `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). |
168 |
169 | If using `local`
170 | | Option | Required | Default | Description |
171 | | ------ | -------- | ------- | ----------- |
172 | | `save_to` | Yes | N/A | Must be `local`. |
173 |
174 | If using `huggingface`
175 | | Option | Required | Default | Description |
176 | | ------ | -------- | ------- | ----------- |
177 | | `save_to` | Yes | N/A | Must be `huggingface`. |
178 | | `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
179 | | `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
180 |
181 | If using `wandb`
182 | | Option | Required | Default | Description |
183 | | ------ | -------- | ------- | ----------- |
184 | | `save_to` | Yes | N/A | Must be `wandb`. |
185 | | `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |
186 |
--------------------------------------------------------------------------------
/configs/train_decoder_config.example.json:
--------------------------------------------------------------------------------
1 | {
2 | "decoder": {
3 | "unets": [
4 | {
5 | "dim": 128,
6 | "image_embed_dim": 768,
7 | "cond_dim": 64,
8 | "channels": 3,
9 | "dim_mults": [1, 2, 4, 8],
10 | "attn_dim_head": 32,
11 | "attn_heads": 16
12 | }
13 | ],
14 | "image_sizes": [64],
15 | "channels": 3,
16 | "timesteps": 1000,
17 | "loss_type": "l2",
18 | "beta_schedule": ["cosine"],
19 | "learned_variance": true
20 | },
21 | "data": {
22 | "webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
23 | "img_embeddings_url": "s3://bucket/img_embeddings/path/",
24 | "num_workers": 4,
25 | "batch_size": 64,
26 | "start_shard": 0,
27 | "end_shard": 9999999,
28 | "shard_width": 6,
29 | "index_width": 4,
30 | "splits": {
31 | "train": 0.75,
32 | "val": 0.15,
33 | "test": 0.1
34 | },
35 | "shuffle_train": true,
36 | "resample_train": false,
37 | "preprocessing": {
38 | "RandomResizedCrop": {
39 | "size": [128, 128],
40 | "scale": [0.75, 1.0],
41 | "ratio": [1.0, 1.0]
42 | },
43 | "ToTensor": true
44 | }
45 | },
46 | "train": {
47 | "epochs": 20,
48 | "lr": 1e-4,
49 | "wd": 0.01,
50 | "max_grad_norm": 0.5,
51 | "save_every_n_samples": 100000,
52 | "n_sample_images": 6,
53 | "device": "cuda:0",
54 | "epoch_samples": null,
55 | "validation_samples": null,
56 | "use_ema": true,
57 | "ema_beta": 0.99,
58 | "amp": false,
59 | "unet_training_mask": [true]
60 | },
61 | "evaluate": {
62 | "n_evaluation_samples": 1000,
63 | "FID": {
64 | "feature": 64
65 | },
66 | "IS": {
67 | "feature": 64,
68 | "splits": 10
69 | },
70 | "KID": {
71 | "feature": 64,
72 | "subset_size": 10
73 | },
74 | "LPIPS": {
75 | "net_type": "vgg",
76 | "reduction": "mean"
77 | }
78 | },
79 | "tracker": {
80 | "overwrite_data_path": true,
81 |
82 | "log": {
83 | "log_type": "wandb",
84 |
85 | "wandb_entity": "your_wandb",
86 | "wandb_project": "your_project",
87 |
88 | "verbose": true
89 | },
90 |
91 | "load": {
92 | "load_from": null
93 | },
94 |
95 | "save": [{
96 | "save_to": "wandb",
97 | "save_latest_to": "latest.pth"
98 | }, {
99 | "save_to": "huggingface",
100 | "huggingface_repo": "Veldrovive/test_model",
101 |
102 | "save_latest_to": "path/to/model_dir/latest.pth",
103 | "save_best_to": "path/to/model_dir/best.pth",
104 | "save_meta_to": "path/to/directory/for/assorted/files",
105 |
106 | "save_type": "model"
107 | }]
108 | }
109 | }
110 |
--------------------------------------------------------------------------------
/configs/train_decoder_config.test.json:
--------------------------------------------------------------------------------
1 | {
2 | "decoder": {
3 | "unets": [
4 | {
5 | "dim": 16,
6 | "image_embed_dim": 768,
7 | "cond_dim": 16,
8 | "channels": 3,
9 | "dim_mults": [1, 2, 4, 8],
10 | "attn_dim_head": 16,
11 | "attn_heads": 4,
12 | "self_attn": [false, true, true, true]
13 | }
14 | ],
15 | "clip": {
16 | "make": "openai",
17 | "model": "ViT-L/14"
18 | },
19 |
20 | "timesteps": 10,
21 | "image_sizes": [64],
22 | "channels": 3,
23 | "loss_type": "l2",
24 | "beta_schedule": ["cosine"],
25 | "learned_variance": true
26 | },
27 | "data": {
28 | "webdataset_base_url": "test_data/{}.tar",
29 | "num_workers": 4,
30 | "batch_size": 4,
31 | "start_shard": 0,
32 | "end_shard": 9,
33 | "shard_width": 1,
34 | "index_width": 1,
35 | "splits": {
36 | "train": 0.75,
37 | "val": 0.15,
38 | "test": 0.1
39 | },
40 | "shuffle_train": false,
41 | "resample_train": true,
42 | "preprocessing": {
43 | "RandomResizedCrop": {
44 | "size": [224, 224],
45 | "scale": [0.75, 1.0],
46 | "ratio": [1.0, 1.0]
47 | },
48 | "ToTensor": true
49 | }
50 | },
51 | "train": {
52 | "epochs": 1,
53 | "lr": 1e-16,
54 | "wd": 0.01,
55 | "max_grad_norm": 0.5,
56 | "save_every_n_samples": 100,
57 | "n_sample_images": 1,
58 | "device": "cpu",
59 | "epoch_samples": 50,
60 | "validation_samples": 5,
61 | "use_ema": true,
62 | "ema_beta": 0.99,
63 | "amp": false,
64 | "unet_training_mask": [true]
65 | },
66 | "evaluate": {
67 | "n_evaluation_samples": 2,
68 | "FID": {
69 | "feature": 64
70 | },
71 | "IS": {
72 | "feature": 64,
73 | "splits": 10
74 | },
75 | "KID": {
76 | "feature": 64,
77 | "subset_size": 2
78 | },
79 | "LPIPS": {
80 | "net_type": "vgg",
81 | "reduction": "mean"
82 | }
83 | },
84 | "tracker": {
85 | "overwrite_data_path": true,
86 |
87 | "log": {
88 | "log_type": "console"
89 | },
90 |
91 | "load": {
92 | "load_from": null
93 | },
94 |
95 | "save": [{
96 | "save_to": "local",
97 | "save_latest_to": "latest.pth"
98 | }]
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/configs/train_prior_config.example.json:
--------------------------------------------------------------------------------
1 | {
2 | "prior": {
3 | "clip": {
4 | "make": "openai",
5 | "model": "ViT-L/14"
6 | },
7 | "net": {
8 | "dim": 768,
9 | "depth": 12,
10 | "num_timesteps": 1000,
11 | "max_text_len": 77,
12 | "num_time_embeds": 1,
13 | "num_image_embeds": 1,
14 | "num_text_embeds": 1,
15 | "dim_head": 64,
16 | "heads": 12,
17 | "ff_mult": 4,
18 | "norm_out": true,
19 | "attn_dropout": 0.05,
20 | "ff_dropout": 0.05,
21 | "final_proj": true,
22 | "normformer": true,
23 | "rotary_emb": true
24 | },
25 | "image_embed_dim": 768,
26 | "image_size": 224,
27 | "image_channels": 3,
28 | "timesteps": 1000,
29 | "sample_timesteps": 64,
30 | "cond_drop_prob": 0.1,
31 | "loss_type": "l2",
32 | "predict_x_start": true,
33 | "beta_schedule": "cosine",
34 | "condition_on_text_encodings": true
35 | },
36 | "data": {
37 | "batch_size": 128,
38 | "num_data_points": 100000,
39 | "eval_every_seconds": 1600,
40 | "image_url": "",
41 | "meta_url": "",
42 | "splits": {
43 | "train": 0.8,
44 | "val": 0.1,
45 | "test": 0.1
46 | }
47 | },
48 | "train": {
49 | "epochs": 5,
50 | "lr": 1.1e-4,
51 | "wd": 6.02e-2,
52 | "max_grad_norm": 0.5,
53 | "use_ema": true,
54 | "ema_beta": 0.9999,
55 | "ema_update_after_step": 50,
56 | "warmup_steps": 50,
57 | "amp": false,
58 | "save_every_seconds": 3600,
59 | "eval_timesteps": [64, 1000],
60 | "random_seed": 84513
61 | },
62 | "tracker": {
63 | "data_path": ".prior",
64 | "overwrite_data_path": true,
65 | "log": {
66 | "log_type": "wandb",
67 | "wandb_entity": "",
68 | "wandb_project": "prior_debugging",
69 | "wandb_resume": false,
70 | "verbose": true
71 | },
72 | "save": [
73 | {
74 | "save_to": "local",
75 | "save_type": "checkpoint",
76 | "save_latest_to": ".prior/latest_checkpoint.pth",
77 | "save_best_to": ".prior/best_checkpoint.pth"
78 | }
79 | ]
80 | }
81 | }
82 |
--------------------------------------------------------------------------------
/dalle2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/dalle2.png
--------------------------------------------------------------------------------
/dalle2_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from dalle2_pytorch.version import __version__
2 | from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
3 | from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter
4 | from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
5 |
6 | from dalle2_pytorch.vqgan_vae import VQGanVAE
7 | from x_clip import CLIP
8 |
--------------------------------------------------------------------------------
/dalle2_pytorch/cli.py:
--------------------------------------------------------------------------------
1 | import click
2 | import torch
3 | import torchvision.transforms as T
4 | from functools import reduce
5 | from pathlib import Path
6 |
7 | from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
8 |
9 | def safeget(dictionary, keys, default = None):
10 | return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
11 |
12 | def simple_slugify(text, max_length = 255):
13 | return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
14 |
15 | def get_pkg_version():
16 | from pkg_resources import get_distribution
17 | return get_distribution('dalle2_pytorch').version
18 |
19 | def main():
20 | pass
21 |
22 | @click.command()
23 | @click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
24 | @click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
25 | @click.argument('text')
26 | def dream(
27 | model,
28 | cond_scale,
29 | text
30 | ):
31 | model_path = Path(model)
32 | full_model_path = str(model_path.resolve())
33 | assert model_path.exists(), f'model not found at {full_model_path}'
34 | loaded = torch.load(str(model_path))
35 |
36 | version = safeget(loaded, 'version')
37 | print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
38 |
39 | prior_init_params = safeget(loaded, 'init_params.prior')
40 | decoder_init_params = safeget(loaded, 'init_params.decoder')
41 | model_params = safeget(loaded, 'model_params')
42 |
43 | prior = DiffusionPrior(**prior_init_params)
44 | decoder = Decoder(**decoder_init_params)
45 |
46 | dalle2 = DALLE2(prior, decoder)
47 | dalle2.load_state_dict(model_params)
48 |
49 | image = dalle2(text, cond_scale = cond_scale)
50 |
51 | pil_image = T.ToPILImage()(image)
52 | return pil_image.save(f'./{simple_slugify(text)}.png')
53 |
--------------------------------------------------------------------------------
/dalle2_pytorch/dataloaders/README.md:
--------------------------------------------------------------------------------
1 | ## Dataloaders
2 | In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
3 |
4 | ### Decoder: Image Embedding Dataset
5 | When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
6 |
7 | Generating a dataset of this type:
8 | 1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
9 | 2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
10 | 3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
11 |
12 | Usage:
13 | ```python
14 | from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
15 |
16 | # Create a dataloader directly.
17 | dataloader = create_image_embedding_dataloader(
18 | tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
19 | embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
20 | num_workers=4,
21 | batch_size=32,
22 | shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
23 | shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
24 | shuffle_shards=True, # Shuffle the order the shards are read in
25 | resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
26 | )
27 | for img, emb in dataloader:
28 | print(img.shape) # torch.Size([32, 3, 256, 256])
29 | print(emb.shape) # torch.Size([32, 512])
30 | # Train decoder only as shown above
31 |
32 | # Or create a dataset without a loader so you can configure it manually
33 | dataset = ImageEmbeddingDataset(
34 | urls="/path/or/url/to/webdataset/{0000..9999}.tar",
35 | embedding_folder_url="path/or/url/to/embeddings/folder",
36 | shard_width=4,
37 | shuffle_shards=True,
38 | resample=False
39 | )
40 | ```
41 |
42 | ### Diffusion Prior: Prior Embedding Dataset
43 | When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
44 |
45 | To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
46 |
47 | If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
48 |
49 | Usage:
50 | ```python
51 | from dalle2_pytorch.dataloaders import get_reader, make_splits
52 |
53 | # grab embeddings from some specified location
54 | IMG_URL = "data/img_emb/"
55 | META_URL = "data/meta/"
56 |
57 | reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
58 |
59 | # some config for training
60 | TRAIN_ARGS = {
61 | "world_size": 3,
62 | "text_conditioned": True,
63 | "start": 0,
64 | "num_data_points": 10000,
65 | "batch_size": 2,
66 | "train_split": 0.5,
67 | "eval_split": 0.25,
68 | "image_reader": reader,
69 | }
70 |
71 | # specifying a rank will handle allocation internally
72 | rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
73 | rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
74 | rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
75 | ```
76 |
--------------------------------------------------------------------------------
/dalle2_pytorch/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 | from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
2 | from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset
3 |
--------------------------------------------------------------------------------
/dalle2_pytorch/dataloaders/decoder_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import webdataset as wds
3 | import torch
4 | from torch.utils.data import DataLoader
5 | import numpy as np
6 | import fsspec
7 | import shutil
8 |
9 | def get_shard(filename):
10 | """
11 | Filenames with shards in them have a consistent structure that we can take advantage of
12 | Standard structure: path/to/file/prefix_string_00001.ext
13 | """
14 | try:
15 | return filename.split("_")[-1].split(".")[0]
16 | except ValueError:
17 | raise RuntimeError(f"Could not find shard for filename {filename}")
18 |
19 | def get_example_file(fs, path, file_format):
20 | """
21 | Given a file system and a file extension, return the example file
22 | """
23 | return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
24 |
25 | def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
26 | """Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
27 | previous_tar_url = None
28 | current_embeddings = None
29 | # Get a reference to an abstract file system where the embeddings are stored
30 | embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
31 | example_embedding_file = get_example_file(embeddings_fs, embeddings_path, "npy")
32 | example_embedding_shard = get_shard(example_embedding_file)
33 | emb_shard_width = len(example_embedding_shard)
34 | # Easier to get the basename without the shard once than search through for the correct file every time
35 | embedding_file_basename = '_'.join(example_embedding_file.split("_")[:-1]) + "_"
36 |
37 | def load_corresponding_embeds(tar_url):
38 | """Finds and reads the npy files that contains embeddings for the given webdataset tar"""
39 | shard = int(tar_url.split("/")[-1].split(".")[0])
40 | embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy'
41 | with embeddings_fs.open(embedding_url) as f:
42 | data = np.load(f)
43 | return torch.from_numpy(data)
44 |
45 | for sample in samples:
46 | try:
47 | tar_url = sample["__url__"]
48 | key = sample["__key__"]
49 | if tar_url != previous_tar_url:
50 | # If the tar changed, we need to download new embeddings
51 | # This means if we shuffle before inserting it will load many more files than we expect and be very inefficient.
52 | previous_tar_url = tar_url
53 | current_embeddings = load_corresponding_embeds(tar_url)
54 |
55 | embedding_index = int(key[-index_width:])
56 | embedding = current_embeddings[embedding_index]
57 | # We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
58 | if torch.count_nonzero(embedding) == 0:
59 | raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
60 | sample[sample_key] = embedding
61 | yield sample
62 | except Exception as exn: # From wds implementation
63 | if handler(exn):
64 | continue
65 | else:
66 | break
67 | insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
68 |
69 | def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
70 | """Finds if the is a corresponding embedding for the tarfile at { url: [URL] }"""
71 | embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
72 | embedding_files = embeddings_fs.ls(embeddings_path)
73 | get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
74 | embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
75 |
76 | get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
77 | for tarfile in tarfiles:
78 | try:
79 | webdataset_shard = get_tar_shard(tarfile["url"])
80 | # If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one
81 | if webdataset_shard in embedding_shards:
82 | yield tarfile
83 | except Exception as exn: # From wds implementation
84 | if handler(exn):
85 | continue
86 | else:
87 | break
88 | skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
89 |
90 | def join_embeddings(samples, handler=wds.handlers.reraise_exception):
91 | """
92 | Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
93 | either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
94 | """
95 | for sample in samples:
96 | try:
97 | sample['emb'] = {}
98 | if 'text_emb' in sample:
99 | sample['emb']['text'] = sample['text_emb']
100 | if 'img_emb' in sample:
101 | sample['emb']['img'] = sample['img_emb']
102 | yield sample
103 | except Exception as exn: # From wds implementation
104 | if handler(exn):
105 | continue
106 | else:
107 | break
108 |
109 | def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
110 | """
111 | Requires that both the image and embedding are present in the sample
112 | This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
113 | """
114 | for sample in samples:
115 | try:
116 | for key in required_keys:
117 | assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
118 | yield sample
119 | except Exception as exn: # From wds implementation
120 | if handler(exn):
121 | continue
122 | else:
123 | break
124 | key_verifier = wds.filters.pipelinefilter(verify_keys)
125 |
126 | class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
127 | """
128 | A fluid interface wrapper for DataPipline that returns image embedding pairs
129 | Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source.
130 | """
131 |
132 | def __init__(
133 | self,
134 | urls,
135 | img_embedding_folder_url=None,
136 | text_embedding_folder_url=None,
137 | index_width=None,
138 | img_preproc=None,
139 | extra_keys=[],
140 | handler=wds.handlers.reraise_exception,
141 | resample=False,
142 | shuffle_shards=True
143 | ):
144 | """
145 | Modeled directly off of the WebDataset constructor
146 |
147 | :param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
148 | :param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
149 | Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
150 | :param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
151 | For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
152 | :param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.
153 | :param handler: A webdataset handler.
154 | :param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
155 | :param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
156 |
157 |
158 | """
159 | super().__init__()
160 | keys = ["jpg", "emb"] + extra_keys
161 | # if img_embedding_folder_url is not None:
162 | # keys.append("img_emb")
163 | # if text_embedding_folder_url is not None:
164 | # keys.append("text_emb")
165 | # keys.extend(extra_keys)
166 | self.key_map = {key: i for i, key in enumerate(keys)}
167 | self.resampling = resample
168 | self.img_preproc = img_preproc
169 | # If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up
170 | if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])):
171 | # Then this has an s3 link for the webdataset and we need extra packages
172 | if shutil.which("s3cmd") is None:
173 | raise RuntimeError("s3cmd is required for s3 webdataset")
174 | if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
175 | # Then the embeddings are being loaded from s3 and fsspec requires s3fs
176 | try:
177 | import s3fs
178 | except ImportError:
179 | raise RuntimeError("s3fs is required to load embeddings from s3")
180 | # Add the shardList and randomize or resample if requested
181 | if resample:
182 | assert not shuffle_shards, "Cannot both resample and shuffle"
183 | self.append(wds.ResampledShards(urls))
184 | else:
185 | self.append(wds.SimpleShardList(urls))
186 | if shuffle_shards:
187 | self.append(wds.filters.shuffle(1000))
188 |
189 | if img_embedding_folder_url is not None:
190 | # There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
191 | self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
192 | if text_embedding_folder_url is not None:
193 | self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
194 |
195 | self.append(wds.tarfile_to_samples(handler=handler))
196 | self.append(wds.decode("pilrgb", handler=handler))
197 | if img_embedding_folder_url is not None:
198 | # Then we are loading image embeddings for a remote source
199 | assert index_width is not None, "Reading embeddings separately requires index width length to be given"
200 | self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
201 | if text_embedding_folder_url is not None:
202 | # Then we are loading image embeddings for a remote source
203 | assert index_width is not None, "Reading embeddings separately requires index width length to be given"
204 | self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
205 | self.append(join_embeddings)
206 | self.append(key_verifier(required_keys=keys, handler=handler))
207 | # Apply preprocessing
208 | self.append(wds.map(self.preproc))
209 | self.append(wds.to_tuple(*keys))
210 |
211 | def preproc(self, sample):
212 | """Applies the preprocessing for images"""
213 | if self.img_preproc is not None:
214 | sample["jpg"] = self.img_preproc(sample["jpg"])
215 | return sample
216 |
217 | def create_image_embedding_dataloader(
218 | tar_url,
219 | num_workers,
220 | batch_size,
221 | img_embeddings_url=None,
222 | text_embeddings_url=None,
223 | index_width=None,
224 | shuffle_num = None,
225 | shuffle_shards = True,
226 | resample_shards = False,
227 | img_preproc=None,
228 | extra_keys=[],
229 | handler=wds.handlers.reraise_exception#warn_and_continue
230 | ):
231 | """
232 | Convenience function to create an image embedding dataseta and dataloader in one line
233 |
234 | :param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
235 | :param num_workers: The number of workers to use for the dataloader
236 | :param batch_size: The batch size to use for the dataloader
237 | :param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
238 | Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
239 | :param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
240 | For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
241 | :param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
242 | :param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
243 | :param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
244 | :param handler: A webdataset handler.
245 | """
246 | ds = ImageEmbeddingDataset(
247 | tar_url,
248 | img_embedding_folder_url=img_embeddings_url,
249 | text_embedding_folder_url=text_embeddings_url,
250 | index_width=index_width,
251 | shuffle_shards=shuffle_shards,
252 | resample=resample_shards,
253 | extra_keys=extra_keys,
254 | img_preproc=img_preproc,
255 | handler=handler
256 | )
257 | if shuffle_num is not None and shuffle_num > 0:
258 | ds.shuffle(1000)
259 | return DataLoader(
260 | ds,
261 | num_workers=num_workers,
262 | batch_size=batch_size,
263 | prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
264 | pin_memory=True,
265 | shuffle=False
266 | )
267 |
--------------------------------------------------------------------------------
/dalle2_pytorch/dataloaders/prior_loader.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 | from clip import tokenize
3 | from embedding_reader import EmbeddingReader
4 | from torch import from_numpy
5 | from torch.utils.data import IterableDataset, DataLoader
6 |
7 |
8 | class PriorEmbeddingDataset(IterableDataset):
9 | """
10 | PriorEmbeddingDataset is a wrapper of EmbeddingReader.
11 |
12 | It enables one to simplify the logic necessary to yield samples from
13 | the different EmbeddingReader configurations available.
14 | """
15 |
16 | def __init__(
17 | self,
18 | text_conditioned: bool,
19 | batch_size: int,
20 | start: int,
21 | stop: int,
22 | image_reader,
23 | text_reader: EmbeddingReader = None,
24 | ) -> None:
25 | super(PriorEmbeddingDataset).__init__()
26 |
27 | self.text_conditioned = text_conditioned
28 |
29 | if not self.text_conditioned:
30 | self.text_reader = text_reader
31 |
32 | self.image_reader = image_reader
33 | self.start = start
34 | self.stop = stop
35 | self.batch_size = batch_size
36 |
37 | def __len__(self):
38 | return self.stop - self.start
39 |
40 | def __iter__(self):
41 | # D.R.Y loader args
42 | loader_args = dict(
43 | batch_size=self.batch_size,
44 | start=self.start,
45 | end=self.stop,
46 | show_progress=False,
47 | )
48 |
49 | # if the data requested is text conditioned, only load images
50 | if self.text_conditioned:
51 | self.loader = self.image_reader(**loader_args)
52 | # otherwise, include text embeddings and bypass metadata
53 | else:
54 | self.loader = zip(
55 | self.image_reader(**loader_args), self.text_reader(**loader_args)
56 | )
57 |
58 | # return the data loader in its formatted state
59 | return self
60 |
61 | def __next__(self):
62 | try:
63 | return self.get_sample()
64 | except StopIteration:
65 | raise StopIteration
66 |
67 | def __str__(self):
68 | return f""
69 |
70 | def set_start(self, start):
71 | """
72 | Adjust the starting point within the reader, useful for resuming an epoch
73 | """
74 | self.start = start
75 |
76 | def get_start(self):
77 | return self.start
78 |
79 | def get_sample(self):
80 | """
81 | pre-proocess data from either reader into a common format
82 | """
83 | if self.text_conditioned:
84 | image_embedding, caption = next(self.loader)
85 |
86 | image_embedding = from_numpy(image_embedding)
87 | tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
88 |
89 | return image_embedding, tokenized_caption
90 |
91 | else:
92 | (image_embedding, _), (text_embedding, _) = next(self.loader)
93 |
94 | image_embedding = from_numpy(image_embedding)
95 | text_embedding = from_numpy(text_embedding)
96 |
97 | return image_embedding, text_embedding
98 |
99 |
100 | # helper functions
101 |
102 |
103 | def distribute_to_rank(start, stop, rank, world_size):
104 | """
105 | Distribute data to each rank given the world size.
106 |
107 | Return:
108 | - New start and stop points for this rank.
109 | """
110 | num_samples = int(stop - start)
111 |
112 | per_rank = int(ceil((num_samples) / float(world_size)))
113 |
114 | assert (
115 | per_rank > 0
116 | ), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
117 |
118 | rank_start = start + rank * per_rank
119 |
120 | rank_stop = min(rank_start + per_rank, stop)
121 |
122 | new_length = rank_stop - rank_start
123 |
124 | assert (
125 | new_length > 0
126 | ), "Calculated start and stop points result in a length of zero for this rank."
127 |
128 | return rank_start, rank_stop
129 |
130 |
131 | def get_reader(
132 | text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
133 | ):
134 | """
135 | Create an EmbeddingReader object from the specified URLs
136 |
137 | get_reader() will always expect a url to image embeddings.
138 |
139 | If text-conditioned, it will also expect a meta_url for the captions.
140 | Otherwise, it will need txt_url for the matching text embeddings.
141 |
142 | Returns an image_reader object if text-conditioned.
143 | Otherwise it returns both an image_reader and a text_reader
144 | """
145 |
146 | assert img_url is not None, "Must supply a image url"
147 |
148 | if text_conditioned:
149 | assert meta_url is not None, "Must supply meta url if text-conditioned"
150 |
151 | image_reader = EmbeddingReader(
152 | embeddings_folder=img_url,
153 | file_format="parquet_npy",
154 | # will assume the caption column exists and is the only one requested
155 | meta_columns=["caption"],
156 | metadata_folder=meta_url,
157 | )
158 |
159 | return image_reader
160 |
161 | # otherwise we will require text embeddings as well and return two readers
162 | assert (
163 | txt_url is not None
164 | ), "Must supply text embedding url if not text-conditioning"
165 |
166 | image_reader = EmbeddingReader(img_url, file_format="npy")
167 | text_reader = EmbeddingReader(txt_url, file_format="npy")
168 |
169 | return image_reader, text_reader
170 |
171 |
172 | def make_splits(
173 | text_conditioned: bool,
174 | batch_size: int,
175 | num_data_points: int,
176 | train_split: float,
177 | eval_split: float,
178 | image_reader: EmbeddingReader,
179 | text_reader: EmbeddingReader = None,
180 | start=0,
181 | rank=0,
182 | world_size=1,
183 | ):
184 | """
185 | Split an embedding reader object as needed.
186 |
187 | NOTE: make_splits() will infer the test set size from your train and eval.
188 |
189 | Input:
190 | - text_conditioned: whether to prepare text-conditioned training data
191 | - batch_size: the batch size for a single gpu
192 | - num_data_points: the total number of data points you wish to train on
193 | - train_split: the percentage of data you wish to train on
194 | - eval_split: the percentage of data you wish to validate on
195 | - image_reader: the image_reader you wish to split
196 | - text_reader: the text_reader you want to split (if !text_conditioned)
197 | - start: the starting point within your dataset
198 | - rank: the rank of your worker
199 | - world_size: the total world size of your distributed training run
200 |
201 | Returns:
202 | - PyTorch Dataloaders that yield tuples of (img, txt) data.
203 | """
204 |
205 | assert start < image_reader.count, "start position cannot exceed reader count."
206 |
207 | # verify that the num_data_points does not exceed the max points
208 | if num_data_points > (image_reader.count - start):
209 | print(
210 | "Specified count is larger than what's available...defaulting to reader's count."
211 | )
212 | num_data_points = image_reader.count
213 |
214 | # compute split points
215 | train_set_size = int(train_split * num_data_points)
216 | eval_set_size = int(eval_split * num_data_points)
217 | eval_start = train_set_size
218 | eval_stop = int(eval_start + eval_set_size)
219 |
220 | assert (
221 | train_split + eval_split
222 | ) < 1.0, "Specified train and eval split is too large to infer a test split."
223 |
224 | # distribute to rank
225 | rank_train_start, rank_train_stop = distribute_to_rank(
226 | start, train_set_size, rank, world_size
227 | )
228 | rank_eval_start, rank_eval_stop = distribute_to_rank(
229 | train_set_size, eval_stop, rank, world_size
230 | )
231 | rank_test_start, rank_test_stop = distribute_to_rank(
232 | eval_stop, num_data_points, rank, world_size
233 | )
234 |
235 | # wrap up splits into a dict
236 | train_split_args = dict(
237 | start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
238 | )
239 | eval_split_args = dict(
240 | start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
241 | )
242 | test_split_args = dict(
243 | start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
244 | )
245 |
246 | if text_conditioned:
247 | # add the text-conditioned args to a unified dict
248 | reader_args = dict(
249 | text_conditioned=text_conditioned,
250 | image_reader=image_reader,
251 | )
252 |
253 | train_split_args = dict(**reader_args, **train_split_args)
254 | eval_split_args = dict(**reader_args, **eval_split_args)
255 | test_split_args = dict(**reader_args, **test_split_args)
256 |
257 | train = PriorEmbeddingDataset(**train_split_args)
258 | val = PriorEmbeddingDataset(**eval_split_args)
259 | test = PriorEmbeddingDataset(**test_split_args)
260 |
261 | else:
262 | # add the non-conditioned args to a unified dict
263 | reader_args = dict(
264 | text_conditioned=text_conditioned,
265 | image_reader=image_reader,
266 | text_reader=text_reader,
267 | )
268 |
269 | train_split_args = dict(**reader_args, **train_split_args)
270 | eval_split_args = dict(**reader_args, **eval_split_args)
271 | test_split_args = dict(**reader_args, **test_split_args)
272 |
273 | train = PriorEmbeddingDataset(**train_split_args)
274 | val = PriorEmbeddingDataset(**eval_split_args)
275 | test = PriorEmbeddingDataset(**test_split_args)
276 |
277 | # true batch size is specifed in the PriorEmbeddingDataset
278 | train_loader = DataLoader(train, batch_size=None)
279 | eval_loader = DataLoader(val, batch_size=None)
280 | test_loader = DataLoader(test, batch_size=None)
281 |
282 | return train_loader, eval_loader, test_loader
283 |
--------------------------------------------------------------------------------
/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | from torch.utils import data
5 | from torchvision import transforms, utils
6 |
7 | from PIL import Image
8 |
9 | # helpers functions
10 |
11 | def cycle(dl):
12 | while True:
13 | for data in dl:
14 | yield data
15 |
16 | # dataset and dataloader
17 |
18 | class Dataset(data.Dataset):
19 | def __init__(
20 | self,
21 | folder,
22 | image_size,
23 | exts = ['jpg', 'jpeg', 'png']
24 | ):
25 | super().__init__()
26 | self.folder = folder
27 | self.image_size = image_size
28 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
29 |
30 | self.transform = transforms.Compose([
31 | transforms.Resize(image_size),
32 | transforms.RandomHorizontalFlip(),
33 | transforms.CenterCrop(image_size),
34 | transforms.ToTensor()
35 | ])
36 |
37 | def __len__(self):
38 | return len(self.paths)
39 |
40 | def __getitem__(self, index):
41 | path = self.paths[index]
42 | img = Image.open(path)
43 | return self.transform(img)
44 |
45 | def get_images_dataloader(
46 | folder,
47 | *,
48 | batch_size,
49 | image_size,
50 | shuffle = True,
51 | cycle_dl = True,
52 | pin_memory = True
53 | ):
54 | ds = Dataset(folder, image_size)
55 | dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
56 |
57 | if cycle_dl:
58 | dl = cycle(dl)
59 | return dl
60 |
--------------------------------------------------------------------------------
/dalle2_pytorch/optimizer.py:
--------------------------------------------------------------------------------
1 | from torch.optim import AdamW, Adam
2 |
3 | def separate_weight_decayable_params(params):
4 | wd_params, no_wd_params = [], []
5 | for param in params:
6 | param_list = no_wd_params if param.ndim < 2 else wd_params
7 | param_list.append(param)
8 | return wd_params, no_wd_params
9 |
10 | def get_optimizer(
11 | params,
12 | lr = 1e-4,
13 | wd = 1e-2,
14 | betas = (0.9, 0.99),
15 | eps = 1e-8,
16 | filter_by_requires_grad = False,
17 | group_wd_params = True,
18 | **kwargs
19 | ):
20 | if filter_by_requires_grad:
21 | params = list(filter(lambda t: t.requires_grad, params))
22 |
23 | if wd == 0:
24 | return Adam(params, lr = lr, betas = betas, eps = eps)
25 |
26 | if group_wd_params:
27 | wd_params, no_wd_params = separate_weight_decayable_params(params)
28 |
29 | params = [
30 | {'params': wd_params},
31 | {'params': no_wd_params, 'weight_decay': 0},
32 | ]
33 |
34 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
35 |
--------------------------------------------------------------------------------
/dalle2_pytorch/tokenizer.py:
--------------------------------------------------------------------------------
1 | # take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
2 | # to give users a quick easy start to training DALL-E without doing BPE
3 |
4 | import torch
5 |
6 | import html
7 | import os
8 | import ftfy
9 | import regex as re
10 | from functools import lru_cache
11 | from pathlib import Path
12 |
13 | from dalle2_pytorch.utils import import_or_print_error
14 |
15 | # OpenAI simple tokenizer
16 |
17 | @lru_cache()
18 | def default_bpe():
19 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt")
20 |
21 | @lru_cache()
22 | def bytes_to_unicode():
23 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
24 | cs = bs[:]
25 | n = 0
26 | for b in range(2 ** 8):
27 | if b not in bs:
28 | bs.append(b)
29 | cs.append(2 ** 8 + n)
30 | n += 1
31 | cs = [chr(n) for n in cs]
32 | return dict(zip(bs, cs))
33 |
34 | def get_pairs(word):
35 | pairs = set()
36 | prev_char = word[0]
37 | for char in word[1:]:
38 | pairs.add((prev_char, char))
39 | prev_char = char
40 | return pairs
41 |
42 | def basic_clean(text):
43 | text = ftfy.fix_text(text)
44 | text = html.unescape(html.unescape(text))
45 | return text.strip()
46 |
47 | def whitespace_clean(text):
48 | text = re.sub(r'\s+', ' ', text)
49 | text = text.strip()
50 | return text
51 |
52 | class SimpleTokenizer(object):
53 | def __init__(self, bpe_path = default_bpe()):
54 | self.byte_encoder = bytes_to_unicode()
55 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
56 | merges = Path(bpe_path).read_text(encoding='utf8').split('\n')
57 | merges = merges[1:49152 - 256 - 2 + 1]
58 | merges = [tuple(merge.split()) for merge in merges]
59 | vocab = list(bytes_to_unicode().values())
60 | vocab = vocab + [v + '' for v in vocab]
61 | for merge in merges:
62 | vocab.append(''.join(merge))
63 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
64 |
65 | self.vocab_size = 49408
66 |
67 | self.encoder = dict(zip(vocab, range(len(vocab))))
68 | self.decoder = {v: k for k, v in self.encoder.items()}
69 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
70 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
71 | self.pat = re.compile(
72 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
73 | re.IGNORECASE)
74 |
75 | def bpe(self, token):
76 | if token in self.cache:
77 | return self.cache[token]
78 | word = tuple(token[:-1]) + (token[-1] + '',)
79 | pairs = get_pairs(word)
80 |
81 | if not pairs:
82 | return token + ''
83 |
84 | while True:
85 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
86 | if bigram not in self.bpe_ranks:
87 | break
88 | first, second = bigram
89 | new_word = []
90 | i = 0
91 | while i < len(word):
92 | try:
93 | j = word.index(first, i)
94 | new_word.extend(word[i:j])
95 | i = j
96 | except:
97 | new_word.extend(word[i:])
98 | break
99 |
100 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
101 | new_word.append(first + second)
102 | i += 2
103 | else:
104 | new_word.append(word[i])
105 | i += 1
106 | new_word = tuple(new_word)
107 | word = new_word
108 | if len(word) == 1:
109 | break
110 | else:
111 | pairs = get_pairs(word)
112 | word = ' '.join(word)
113 | self.cache[token] = word
114 | return word
115 |
116 | def encode(self, text):
117 | bpe_tokens = []
118 | text = whitespace_clean(basic_clean(text)).lower()
119 | for token in re.findall(self.pat, text):
120 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
121 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
122 | return bpe_tokens
123 |
124 | def decode(self, tokens, remove_start_end = True, pad_tokens = set()):
125 | if torch.is_tensor(tokens):
126 | tokens = tokens.tolist()
127 |
128 | if remove_start_end:
129 | tokens = [token for token in tokens if token not in (49406, 40407, 0)]
130 | text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
134 | def tokenize(self, texts, context_length = 256, truncate_text = False):
135 | if isinstance(texts, str):
136 | texts = [texts]
137 |
138 | all_tokens = [self.encode(text) for text in texts]
139 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
140 |
141 | for i, tokens in enumerate(all_tokens):
142 | if len(tokens) > context_length:
143 | if truncate_text:
144 | tokens = tokens[:context_length]
145 | else:
146 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
147 | result[i, :len(tokens)] = torch.tensor(tokens)
148 |
149 | return result
150 |
151 | tokenizer = SimpleTokenizer()
152 |
153 | # YTTM tokenizer
154 |
155 | class YttmTokenizer:
156 | def __init__(self, bpe_path = None):
157 | bpe_path = Path(bpe_path)
158 | assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
159 |
160 | self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
161 |
162 | tokenizer = self.yttm.BPE(model = str(bpe_path))
163 | self.tokenizer = tokenizer
164 | self.vocab_size = tokenizer.vocab_size()
165 |
166 | def decode(self, tokens, pad_tokens = set()):
167 | if torch.is_tensor(tokens):
168 | tokens = tokens.tolist()
169 |
170 | return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
171 |
172 | def encode(self, texts):
173 | encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
174 | return list(map(torch.tensor, encoded))
175 |
176 | def tokenize(self, texts, context_length = 256, truncate_text = False):
177 | if isinstance(texts, str):
178 | texts = [texts]
179 |
180 | all_tokens = self.encode(texts)
181 |
182 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
183 | for i, tokens in enumerate(all_tokens):
184 | if len(tokens) > context_length:
185 | if truncate_text:
186 | tokens = tokens[:context_length]
187 | else:
188 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
189 | result[i, :len(tokens)] = torch.tensor(tokens)
190 |
191 | return result
192 |
--------------------------------------------------------------------------------
/dalle2_pytorch/trackers.py:
--------------------------------------------------------------------------------
1 | import urllib.request
2 | import os
3 | import json
4 | from pathlib import Path
5 | import shutil
6 | from itertools import zip_longest
7 | from typing import Any, Optional, List, Union
8 | from pydantic import BaseModel
9 |
10 | import torch
11 | from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
12 | from dalle2_pytorch.utils import import_or_print_error
13 | from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
14 | from dalle2_pytorch.version import __version__
15 | from packaging import version
16 |
17 | # constants
18 |
19 | DEFAULT_DATA_PATH = './.tracker-data'
20 |
21 | # helper functions
22 |
23 | def exists(val):
24 | return val is not None
25 |
26 | class BaseLogger:
27 | """
28 | An abstract class representing an object that can log data.
29 | Parameters:
30 | data_path (str): A file path for storing temporary data.
31 | verbose (bool): Whether of not to always print logs to the console.
32 | """
33 | def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
34 | self.data_path = Path(data_path)
35 | self.resume = resume
36 | self.auto_resume = auto_resume
37 | self.verbose = verbose
38 |
39 | def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
40 | """
41 | Initializes the logger.
42 | Errors if the logger is invalid.
43 | full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
44 | """
45 | raise NotImplementedError
46 |
47 | def log(self, log, **kwargs) -> None:
48 | raise NotImplementedError
49 |
50 | def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
51 | raise NotImplementedError
52 |
53 | def log_file(self, file_path, **kwargs) -> None:
54 | raise NotImplementedError
55 |
56 | def log_error(self, error_string, **kwargs) -> None:
57 | raise NotImplementedError
58 |
59 | def get_resume_data(self, **kwargs) -> dict:
60 | """
61 | Sets tracker attributes that along with { "resume": True } will be used to resume training.
62 | It is assumed that after init is called this data will be complete.
63 | If the logger does not have any resume functionality, it should return an empty dict.
64 | """
65 | raise NotImplementedError
66 |
67 | class ConsoleLogger(BaseLogger):
68 | def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
69 | print("Logging to console")
70 |
71 | def log(self, log, **kwargs) -> None:
72 | print(log)
73 |
74 | def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
75 | pass
76 |
77 | def log_file(self, file_path, **kwargs) -> None:
78 | pass
79 |
80 | def log_error(self, error_string, **kwargs) -> None:
81 | print(error_string)
82 |
83 | def get_resume_data(self, **kwargs) -> dict:
84 | return {}
85 |
86 | class WandbLogger(BaseLogger):
87 | """
88 | Logs to a wandb run.
89 | Parameters:
90 | data_path (str): A file path for storing temporary data.
91 | wandb_entity (str): The wandb entity to log to.
92 | wandb_project (str): The wandb project to log to.
93 | wandb_run_id (str): The wandb run id to resume.
94 | wandb_run_name (str): The wandb run name to use.
95 | """
96 | def __init__(self,
97 | data_path: str,
98 | wandb_entity: str,
99 | wandb_project: str,
100 | wandb_run_id: Optional[str] = None,
101 | wandb_run_name: Optional[str] = None,
102 | **kwargs
103 | ):
104 | super().__init__(data_path, **kwargs)
105 | self.entity = wandb_entity
106 | self.project = wandb_project
107 | self.run_id = wandb_run_id
108 | self.run_name = wandb_run_name
109 |
110 | def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
111 | assert self.entity is not None, "wandb_entity must be specified for wandb logger"
112 | assert self.project is not None, "wandb_project must be specified for wandb logger"
113 | self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
114 | os.environ["WANDB_SILENT"] = "true"
115 | # Initializes the wandb run
116 | init_object = {
117 | "entity": self.entity,
118 | "project": self.project,
119 | "config": {**full_config.dict(), **extra_config}
120 | }
121 | if self.run_name is not None:
122 | init_object['name'] = self.run_name
123 | if self.resume:
124 | assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
125 | if self.run_name is not None:
126 | print("You are renaming a run. I hope that is what you intended.")
127 | init_object['resume'] = 'must'
128 | init_object['id'] = self.run_id
129 |
130 | self.wandb.init(**init_object)
131 | print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
132 |
133 | def log(self, log, **kwargs) -> None:
134 | if self.verbose:
135 | print(log)
136 | self.wandb.log(log, **kwargs)
137 |
138 | def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
139 | """
140 | Takes a tensor of images and a list of captions and logs them to wandb.
141 | """
142 | wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
143 | self.wandb.log({ image_section: wandb_images }, **kwargs)
144 |
145 | def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
146 | if base_path is None:
147 | # Then we take the basepath as the parent of the file_path
148 | base_path = Path(file_path).parent
149 | self.wandb.save(str(file_path), base_path = str(base_path))
150 |
151 | def log_error(self, error_string, step=None, **kwargs) -> None:
152 | if self.verbose:
153 | print(error_string)
154 | self.wandb.log({"error": error_string, **kwargs}, step=step)
155 |
156 | def get_resume_data(self, **kwargs) -> dict:
157 | # In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
158 | return {
159 | "entity": self.entity,
160 | "project": self.project,
161 | "run_id": self.wandb.run.id
162 | }
163 |
164 | logger_type_map = {
165 | 'console': ConsoleLogger,
166 | 'wandb': WandbLogger,
167 | }
168 | def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
169 | if logger_type == 'custom':
170 | raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
171 | try:
172 | logger_class = logger_type_map[logger_type]
173 | except KeyError:
174 | raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
175 | return logger_class(data_path, **kwargs)
176 |
177 | class BaseLoader:
178 | """
179 | An abstract class representing an object that can load a model checkpoint.
180 | Parameters:
181 | data_path (str): A file path for storing temporary data.
182 | """
183 | def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
184 | self.data_path = Path(data_path)
185 | self.only_auto_resume = only_auto_resume
186 |
187 | def init(self, logger: BaseLogger, **kwargs) -> None:
188 | raise NotImplementedError
189 |
190 | def recall() -> dict:
191 | raise NotImplementedError
192 |
193 | class UrlLoader(BaseLoader):
194 | """
195 | A loader that downloads the file from a url and loads it
196 | Parameters:
197 | data_path (str): A file path for storing temporary data.
198 | url (str): The url to download the file from.
199 | """
200 | def __init__(self, data_path: str, url: str, **kwargs):
201 | super().__init__(data_path, **kwargs)
202 | self.url = url
203 |
204 | def init(self, logger: BaseLogger, **kwargs) -> None:
205 | # Makes sure the file exists to be downloaded
206 | pass # TODO: Actually implement that
207 |
208 | def recall(self) -> dict:
209 | # Download the file
210 | save_path = self.data_path / 'loaded_checkpoint.pth'
211 | urllib.request.urlretrieve(self.url, str(save_path))
212 | # Load the file
213 | return torch.load(str(save_path), map_location='cpu')
214 |
215 |
216 | class LocalLoader(BaseLoader):
217 | """
218 | A loader that loads a file from a local path
219 | Parameters:
220 | data_path (str): A file path for storing temporary data.
221 | file_path (str): The path to the file to load.
222 | """
223 | def __init__(self, data_path: str, file_path: str, **kwargs):
224 | super().__init__(data_path, **kwargs)
225 | self.file_path = Path(file_path)
226 |
227 | def init(self, logger: BaseLogger, **kwargs) -> None:
228 | # Makes sure the file exists to be loaded
229 | if not self.file_path.exists() and not self.only_auto_resume:
230 | raise FileNotFoundError(f'Model not found at {self.file_path}')
231 |
232 | def recall(self) -> dict:
233 | # Load the file
234 | return torch.load(str(self.file_path), map_location='cpu')
235 |
236 | class WandbLoader(BaseLoader):
237 | """
238 | A loader that loads a model from an existing wandb run
239 | """
240 | def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
241 | super().__init__(data_path, **kwargs)
242 | self.run_path = wandb_run_path
243 | self.file_path = wandb_file_path
244 |
245 | def init(self, logger: BaseLogger, **kwargs) -> None:
246 | self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
247 | # Make sure the file can be downloaded
248 | if self.wandb.run is not None and self.run_path is None:
249 | self.run_path = self.wandb.run.path
250 | assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
251 | assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
252 | assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
253 |
254 | os.environ["WANDB_SILENT"] = "true"
255 | pass # TODO: Actually implement that
256 |
257 | def recall(self) -> dict:
258 | file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
259 | return torch.load(file_reference.name, map_location='cpu')
260 |
261 | loader_type_map = {
262 | 'url': UrlLoader,
263 | 'local': LocalLoader,
264 | 'wandb': WandbLoader,
265 | }
266 | def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
267 | if loader_type == 'custom':
268 | raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
269 | try:
270 | loader_class = loader_type_map[loader_type]
271 | except KeyError:
272 | raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
273 | return loader_class(data_path, **kwargs)
274 |
275 | class BaseSaver:
276 | def __init__(self,
277 | data_path: str,
278 | save_latest_to: Optional[Union[str, bool]] = None,
279 | save_best_to: Optional[Union[str, bool]] = None,
280 | save_meta_to: Optional[str] = None,
281 | save_type: str = 'checkpoint',
282 | **kwargs
283 | ):
284 | self.data_path = Path(data_path)
285 | self.save_latest_to = save_latest_to
286 | self.saving_latest = save_latest_to is not None and save_latest_to is not False
287 | self.save_best_to = save_best_to
288 | self.saving_best = save_best_to is not None and save_best_to is not False
289 | self.save_meta_to = save_meta_to
290 | self.saving_meta = save_meta_to is not None
291 | self.save_type = save_type
292 | assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
293 | assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'
294 |
295 | def init(self, logger: BaseLogger, **kwargs) -> None:
296 | raise NotImplementedError
297 |
298 | def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
299 | """
300 | Save a general file under save_meta_to
301 | """
302 | raise NotImplementedError
303 |
304 | class LocalSaver(BaseSaver):
305 | def __init__(self,
306 | data_path: str,
307 | **kwargs
308 | ):
309 | super().__init__(data_path, **kwargs)
310 |
311 | def init(self, logger: BaseLogger, **kwargs) -> None:
312 | # Makes sure the directory exists to be saved to
313 | print(f"Saving {self.save_type} locally")
314 | if not self.data_path.exists():
315 | self.data_path.mkdir(parents=True)
316 |
317 | def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
318 | # Copy the file to save_path
319 | save_path_file_name = Path(save_path).name
320 | # Make sure parent directory exists
321 | save_path_parent = Path(save_path).parent
322 | if not save_path_parent.exists():
323 | save_path_parent.mkdir(parents=True)
324 | print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
325 | shutil.copy(local_path, save_path)
326 |
327 | class WandbSaver(BaseSaver):
328 | def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
329 | super().__init__(data_path, **kwargs)
330 | self.run_path = wandb_run_path
331 |
332 | def init(self, logger: BaseLogger, **kwargs) -> None:
333 | self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
334 | os.environ["WANDB_SILENT"] = "true"
335 | # Makes sure that the user can upload tot his run
336 | if self.run_path is not None:
337 | entity, project, run_id = self.run_path.split("/")
338 | self.run = self.wandb.init(entity=entity, project=project, id=run_id)
339 | else:
340 | assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
341 | self.run = self.wandb.run
342 | # TODO: Now actually check if upload is possible
343 | print(f"Saving to wandb run {self.run.path}-{self.run.name}")
344 |
345 | def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
346 | # In order to log something in the correct place in wandb, we need to have the same file structure here
347 | save_path_file_name = Path(save_path).name
348 | print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
349 | save_path = Path(self.data_path) / save_path
350 | save_path.parent.mkdir(parents=True, exist_ok=True)
351 | shutil.copy(local_path, save_path)
352 | self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
353 |
354 | class HuggingfaceSaver(BaseSaver):
355 | def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
356 | super().__init__(data_path, **kwargs)
357 | self.huggingface_repo = huggingface_repo
358 | self.token_path = token_path
359 |
360 | def init(self, logger: BaseLogger, **kwargs):
361 | # Makes sure this user can upload to the repo
362 | self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
363 | try:
364 | identity = self.hub.whoami() # Errors if not logged in
365 | # Then we are logged in
366 | except:
367 | # We are not logged in. Use the token_path to set the token.
368 | if not os.path.exists(self.token_path):
369 | raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
370 | with open(self.token_path, "r") as f:
371 | token = f.read().strip()
372 | self.hub.HfApi.set_access_token(token)
373 | identity = self.hub.whoami()
374 | print(f"Saving to huggingface repo {self.huggingface_repo}")
375 |
376 | def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
377 | # Saving to huggingface is easy, we just need to upload the file with the correct name
378 | save_path_file_name = Path(save_path).name
379 | print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
380 | self.hub.upload_file(
381 | path_or_fileobj=str(local_path),
382 | path_in_repo=str(save_path),
383 | repo_id=self.huggingface_repo
384 | )
385 |
386 | saver_type_map = {
387 | 'local': LocalSaver,
388 | 'wandb': WandbSaver,
389 | 'huggingface': HuggingfaceSaver
390 | }
391 | def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
392 | if saver_type == 'custom':
393 | raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
394 | try:
395 | saver_class = saver_type_map[saver_type]
396 | except KeyError:
397 | raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
398 | return saver_class(data_path, **kwargs)
399 |
400 |
401 | class Tracker:
402 | def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
403 | self.data_path = Path(data_path)
404 | if not dummy_mode:
405 | if not overwrite_data_path:
406 | assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
407 | if not self.data_path.exists():
408 | self.data_path.mkdir(parents=True)
409 | self.logger: BaseLogger = None
410 | self.loader: Optional[BaseLoader] = None
411 | self.savers: List[BaseSaver]= []
412 | self.dummy_mode = dummy_mode
413 |
414 | def _load_auto_resume(self) -> bool:
415 | # If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
416 | if not self.auto_resume_path.exists():
417 | if self.logger.auto_resume:
418 | print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
419 | return False
420 |
421 | # Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
422 | if not self.logger.auto_resume:
423 | print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
424 | self.auto_resume_path.unlink()
425 | return False
426 |
427 | # Otherwise we read the json into a dictionary will will override parts of logger.__dict__
428 | with open(self.auto_resume_path, 'r') as f:
429 | auto_resume_dict = json.load(f)
430 | # Check if the logger is of the same type as the autoresume save
431 | if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
432 | raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
433 | # Then we are ready to override the logger with the autoresume save
434 | self.logger.__dict__["resume"] = True
435 | print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
436 | self.logger.__dict__.update(auto_resume_dict)
437 | return True
438 |
439 | def _save_auto_resume(self):
440 | # Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
441 | auto_resume_dict = self.logger.get_resume_data()
442 | auto_resume_dict['logger_type'] = self.logger.__class__.__name__
443 | with open(self.auto_resume_path, 'w') as f:
444 | json.dump(auto_resume_dict, f)
445 |
446 | def init(self, full_config: BaseModel, extra_config: dict):
447 | self.auto_resume_path = self.data_path / 'auto_resume.json'
448 | # Check for resuming the run
449 | self.did_auto_resume = self._load_auto_resume()
450 | if self.did_auto_resume:
451 | print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
452 | print(f"New logger config: {self.logger.__dict__}")
453 |
454 | self.save_metadata = dict(
455 | version = version.parse(__version__)
456 | ) # Data that will be saved alongside the checkpoint or model
457 | self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata
458 |
459 | assert self.logger is not None, '`logger` must be set before `init` is called'
460 | if self.dummy_mode:
461 | # The only thing we need is a loader
462 | if self.loader is not None:
463 | self.loader.init(self.logger)
464 | return
465 | assert len(self.savers) > 0, '`savers` must be set before `init` is called'
466 |
467 | self.logger.init(full_config, extra_config)
468 | if self.loader is not None:
469 | self.loader.init(self.logger)
470 | for saver in self.savers:
471 | saver.init(self.logger)
472 |
473 | if self.logger.auto_resume:
474 | # Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
475 | self._save_auto_resume()
476 |
477 | def add_logger(self, logger: BaseLogger):
478 | self.logger = logger
479 |
480 | def add_loader(self, loader: BaseLoader):
481 | self.loader = loader
482 |
483 | def add_saver(self, saver: BaseSaver):
484 | self.savers.append(saver)
485 |
486 | def log(self, *args, **kwargs):
487 | if self.dummy_mode:
488 | return
489 | self.logger.log(*args, **kwargs)
490 |
491 | def log_images(self, *args, **kwargs):
492 | if self.dummy_mode:
493 | return
494 | self.logger.log_images(*args, **kwargs)
495 |
496 | def log_file(self, *args, **kwargs):
497 | if self.dummy_mode:
498 | return
499 | self.logger.log_file(*args, **kwargs)
500 |
501 | def save_config(self, current_config_path: str, config_name = 'config.json'):
502 | if self.dummy_mode:
503 | return
504 | # Save the config under config_name in the root folder of data_path
505 | shutil.copy(current_config_path, self.data_path / config_name)
506 | for saver in self.savers:
507 | if saver.saving_meta:
508 | remote_path = Path(saver.save_meta_to) / config_name
509 | saver.save_file(current_config_path, str(remote_path))
510 |
511 | def add_save_metadata(self, state_dict_key: str, metadata: Any):
512 | """
513 | Adds a new piece of metadata that will be saved along with the model or decoder.
514 | """
515 | self.save_metadata[state_dict_key] = metadata
516 |
517 | def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
518 | """
519 | Gets the state dict to be saved and writes it to file_path.
520 | If save_type is 'checkpoint', we save the entire trainer state dict.
521 | If save_type is 'model', we save only the model state dict.
522 | """
523 | assert save_type in ['checkpoint', 'model']
524 | if save_type == 'checkpoint':
525 | # Create a metadata dict without the blacklisted keys so we do not error when we create the state dict
526 | metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
527 | trainer.save(file_path, overwrite=True, **kwargs, **metadata)
528 | elif save_type == 'model':
529 | if isinstance(trainer, DiffusionPriorTrainer):
530 | prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
531 | prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
532 | # Remove CLIP if it is part of the model
533 | original_clip = prior.clip
534 | prior.clip = None
535 | model_state_dict = prior.state_dict()
536 | prior.clip = original_clip
537 | elif isinstance(trainer, DecoderTrainer):
538 | decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
539 | # Remove CLIP if it is part of the model
540 | original_clip = decoder.clip
541 | decoder.clip = None
542 | if trainer.use_ema:
543 | trainable_unets = decoder.unets
544 | decoder.unets = trainer.unets # Swap EMA unets in
545 | model_state_dict = decoder.state_dict()
546 | decoder.unets = trainable_unets # Swap back
547 | else:
548 | model_state_dict = decoder.state_dict()
549 | decoder.clip = original_clip
550 | else:
551 | raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
552 | state_dict = {
553 | **self.save_metadata,
554 | 'model': model_state_dict
555 | }
556 | torch.save(state_dict, file_path)
557 | return Path(file_path)
558 |
559 | def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
560 | if self.dummy_mode:
561 | return
562 | if not is_best and not is_latest:
563 | # Nothing to do
564 | return
565 | # Save the checkpoint and model to data_path
566 | checkpoint_path = self.data_path / 'checkpoint.pth'
567 | self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
568 | model_path = self.data_path / 'model.pth'
569 | self._save_state_dict(trainer, 'model', model_path, **kwargs)
570 | print("Saved cached models")
571 | # Call the save methods on the savers
572 | for saver in self.savers:
573 | local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
574 | if saver.saving_latest and is_latest:
575 | latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
576 | try:
577 | saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
578 | except Exception as e:
579 | self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
580 | print(f'Error saving checkpoint: {e}')
581 | if saver.saving_best and is_best:
582 | best_checkpoint_path = saver.save_best_to.format(**kwargs)
583 | try:
584 | saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
585 | except Exception as e:
586 | self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
587 | print(f'Error saving checkpoint: {e}')
588 |
589 | @property
590 | def can_recall(self):
591 | # Defines whether a recall can be performed.
592 | return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
593 |
594 | def recall(self):
595 | if self.can_recall:
596 | return self.loader.recall()
597 | else:
598 | raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')
599 |
600 |
601 |
--------------------------------------------------------------------------------
/dalle2_pytorch/train_configs.py:
--------------------------------------------------------------------------------
1 | import json
2 | from torchvision import transforms as T
3 | from pydantic import BaseModel, validator, model_validator
4 | from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
5 |
6 | from x_clip import CLIP as XCLIP
7 | from open_clip import list_pretrained
8 | from coca_pytorch import CoCa
9 |
10 | from dalle2_pytorch.dalle2_pytorch import (
11 | CoCaAdapter,
12 | OpenAIClipAdapter,
13 | OpenClipAdapter,
14 | Unet,
15 | Decoder,
16 | DiffusionPrior,
17 | DiffusionPriorNetwork,
18 | XClipAdapter
19 | )
20 | from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
21 |
22 | # helper functions
23 |
24 | def exists(val):
25 | return val is not None
26 |
27 | def default(val, d):
28 | return val if exists(val) else d
29 |
30 | InnerType = TypeVar('InnerType')
31 | ListOrTuple = Union[List[InnerType], Tuple[InnerType]]
32 | SingularOrIterable = Union[InnerType, ListOrTuple[InnerType]]
33 |
34 | # general pydantic classes
35 |
36 | class TrainSplitConfig(BaseModel):
37 | train: float = 0.75
38 | val: float = 0.15
39 | test: float = 0.1
40 |
41 | @model_validator(mode = 'after')
42 | def validate_all(self, m):
43 | actual_sum = sum([*dict(self).values()])
44 | if actual_sum != 1.:
45 | raise ValueError(f'{dict(self).keys()} must sum to 1.0. Found: {actual_sum}')
46 | return self
47 |
48 | class TrackerLogConfig(BaseModel):
49 | log_type: str = 'console'
50 | resume: bool = False # For logs that are saved to unique locations, resume a previous run
51 | auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
52 | verbose: bool = False
53 |
54 | class Config:
55 | # Each individual log type has it's own arguments that will be passed through the config
56 | extra = "allow"
57 |
58 | def create(self, data_path: str):
59 | kwargs = self.dict()
60 | return create_logger(self.log_type, data_path, **kwargs)
61 |
62 |
63 | class TrackerLoadConfig(BaseModel):
64 | load_from: Optional[str] = None
65 | only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
66 |
67 | class Config:
68 | extra = "allow"
69 |
70 | def create(self, data_path: str):
71 | kwargs = self.dict()
72 | if self.load_from is None:
73 | return None
74 | return create_loader(self.load_from, data_path, **kwargs)
75 |
76 | class TrackerSaveConfig(BaseModel):
77 | save_to: str = 'local'
78 | save_all: bool = False
79 | save_latest: bool = True
80 | save_best: bool = True
81 |
82 | class Config:
83 | extra = "allow"
84 |
85 | def create(self, data_path: str):
86 | kwargs = self.dict()
87 | return create_saver(self.save_to, data_path, **kwargs)
88 |
89 | class TrackerConfig(BaseModel):
90 | data_path: str = '.tracker_data'
91 | overwrite_data_path: bool = False
92 | log: TrackerLogConfig
93 | load: Optional[TrackerLoadConfig] = None
94 | save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
95 |
96 | def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
97 | tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
98 | # Add the logger
99 | tracker.add_logger(self.log.create(self.data_path))
100 | # Add the loader
101 | if self.load is not None:
102 | tracker.add_loader(self.load.create(self.data_path))
103 | # Add the saver or savers
104 | if isinstance(self.save, list):
105 | for save_config in self.save:
106 | tracker.add_saver(save_config.create(self.data_path))
107 | else:
108 | tracker.add_saver(self.save.create(self.data_path))
109 | # Initialize all the components and verify that all data is valid
110 | tracker.init(full_config, extra_config)
111 | return tracker
112 |
113 | # diffusion prior pydantic classes
114 |
115 | class AdapterConfig(BaseModel):
116 | make: str = "openai"
117 | model: str = "ViT-L/14"
118 | base_model_kwargs: Optional[Dict[str, Any]] = None
119 |
120 | def create(self):
121 | if self.make == "openai":
122 | return OpenAIClipAdapter(self.model)
123 | elif self.make == "open_clip":
124 | pretrained = dict(list_pretrained())
125 | checkpoint = pretrained[self.model]
126 | return OpenClipAdapter(name=self.model, pretrained=checkpoint)
127 | elif self.make == "x-clip":
128 | return XClipAdapter(XCLIP(**self.base_model_kwargs))
129 | elif self.make == "coca":
130 | return CoCaAdapter(CoCa(**self.base_model_kwargs))
131 | else:
132 | raise AttributeError("No adapter with that name is available.")
133 |
134 | class DiffusionPriorNetworkConfig(BaseModel):
135 | dim: int
136 | depth: int
137 | max_text_len: Optional[int] = None
138 | num_timesteps: Optional[int] = None
139 | num_time_embeds: int = 1
140 | num_image_embeds: int = 1
141 | num_text_embeds: int = 1
142 | dim_head: int = 64
143 | heads: int = 8
144 | ff_mult: int = 4
145 | norm_in: bool = False
146 | norm_out: bool = True
147 | attn_dropout: float = 0.
148 | ff_dropout: float = 0.
149 | final_proj: bool = True
150 | normformer: bool = False
151 | rotary_emb: bool = True
152 |
153 | class Config:
154 | extra = "allow"
155 |
156 | def create(self):
157 | kwargs = self.dict()
158 | return DiffusionPriorNetwork(**kwargs)
159 |
160 | class DiffusionPriorConfig(BaseModel):
161 | clip: Optional[AdapterConfig] = None
162 | net: DiffusionPriorNetworkConfig
163 | image_embed_dim: int
164 | image_size: int
165 | image_channels: int = 3
166 | timesteps: int = 1000
167 | sample_timesteps: Optional[int] = None
168 | cond_drop_prob: float = 0.
169 | loss_type: str = 'l2'
170 | predict_x_start: bool = True
171 | beta_schedule: str = 'cosine'
172 | condition_on_text_encodings: bool = True
173 |
174 | class Config:
175 | extra = "allow"
176 |
177 | def create(self):
178 | kwargs = self.dict()
179 |
180 | has_clip = exists(kwargs.pop('clip'))
181 | kwargs.pop('net')
182 |
183 | clip = None
184 | if has_clip:
185 | clip = self.clip.create()
186 |
187 | diffusion_prior_network = self.net.create()
188 | return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
189 |
190 | class DiffusionPriorTrainConfig(BaseModel):
191 | epochs: int = 1
192 | lr: float = 1.1e-4
193 | wd: float = 6.02e-2
194 | max_grad_norm: float = 0.5
195 | use_ema: bool = True
196 | ema_beta: float = 0.99
197 | amp: bool = False
198 | warmup_steps: Optional[int] = None # number of warmup steps
199 | save_every_seconds: int = 3600 # how often to save
200 | eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
201 | best_validation_loss: float = 1e9 # the current best valudation loss observed
202 | current_epoch: int = 0 # the current epoch
203 | num_samples_seen: int = 0 # the current number of samples seen
204 | random_seed: int = 0 # manual seed for torch
205 |
206 | class DiffusionPriorDataConfig(BaseModel):
207 | image_url: str # path to embeddings folder
208 | meta_url: str # path to metadata (captions) for images
209 | splits: TrainSplitConfig # define train, validation, test splits for your dataset
210 | batch_size: int # per-gpu batch size used to train the model
211 | num_data_points: int = 25e7 # total number of datapoints to train on
212 | eval_every_seconds: int = 3600 # validation statistics will be performed this often
213 |
214 | class TrainDiffusionPriorConfig(BaseModel):
215 | prior: DiffusionPriorConfig
216 | data: DiffusionPriorDataConfig
217 | train: DiffusionPriorTrainConfig
218 | tracker: TrackerConfig
219 |
220 | @classmethod
221 | def from_json_path(cls, json_path):
222 | with open(json_path) as f:
223 | config = json.load(f)
224 | return cls(**config)
225 |
226 | # decoder pydantic classes
227 |
228 | class UnetConfig(BaseModel):
229 | dim: int
230 | dim_mults: ListOrTuple[int]
231 | image_embed_dim: Optional[int] = None
232 | text_embed_dim: Optional[int] = None
233 | cond_on_text_encodings: Optional[bool] = None
234 | cond_dim: Optional[int] = None
235 | channels: int = 3
236 | self_attn: SingularOrIterable[bool] = False
237 | attn_dim_head: int = 32
238 | attn_heads: int = 16
239 | init_cross_embed: bool = True
240 |
241 | class Config:
242 | extra = "allow"
243 |
244 | class DecoderConfig(BaseModel):
245 | unets: ListOrTuple[UnetConfig]
246 | image_size: Optional[int] = None
247 | image_sizes: ListOrTuple[int] = None
248 | clip: Optional[AdapterConfig] = None # The clip model to use if embeddings are not provided
249 | channels: int = 3
250 | timesteps: int = 1000
251 | sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
252 | loss_type: str = 'l2'
253 | beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine
254 | learned_variance: SingularOrIterable[bool] = True
255 | image_cond_drop_prob: float = 0.1
256 | text_cond_drop_prob: float = 0.5
257 |
258 | def create(self):
259 | decoder_kwargs = self.dict()
260 |
261 | unet_configs = decoder_kwargs.pop('unets')
262 | unets = [Unet(**config) for config in unet_configs]
263 |
264 | has_clip = exists(decoder_kwargs.pop('clip'))
265 | clip = None
266 | if has_clip:
267 | clip = self.clip.create()
268 |
269 | return Decoder(unets, clip=clip, **decoder_kwargs)
270 |
271 | @validator('image_sizes')
272 | def check_image_sizes(cls, image_sizes, values):
273 | if exists(values.get('image_size')) ^ exists(image_sizes):
274 | return image_sizes
275 | raise ValueError('either image_size or image_sizes is required, but not both')
276 |
277 | class Config:
278 | extra = "allow"
279 |
280 | class DecoderDataConfig(BaseModel):
281 | webdataset_base_url: str # path to a webdataset with jpg images
282 | img_embeddings_url: Optional[str] = None # path to .npy files with embeddings
283 | text_embeddings_url: Optional[str] = None # path to .npy files with embeddings
284 | num_workers: int = 4
285 | batch_size: int = 64
286 | start_shard: int = 0
287 | end_shard: int = 9999999
288 | shard_width: int = 6
289 | index_width: int = 4
290 | splits: TrainSplitConfig
291 | shuffle_train: bool = True
292 | resample_train: bool = False
293 | preprocessing: Dict[str, Any] = {'ToTensor': True}
294 |
295 | @property
296 | def img_preproc(self):
297 | def _get_transformation(transformation_name, **kwargs):
298 | if transformation_name == "RandomResizedCrop":
299 | return T.RandomResizedCrop(**kwargs)
300 | elif transformation_name == "RandomHorizontalFlip":
301 | return T.RandomHorizontalFlip()
302 | elif transformation_name == "ToTensor":
303 | return T.ToTensor()
304 |
305 | transforms = []
306 | for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
307 | transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
308 | transforms.append(_get_transformation(transform_name, **transform_kwargs))
309 | return T.Compose(transforms)
310 |
311 | class DecoderTrainConfig(BaseModel):
312 | epochs: int = 20
313 | lr: SingularOrIterable[float] = 1e-4
314 | wd: SingularOrIterable[float] = 0.01
315 | warmup_steps: Optional[SingularOrIterable[int]] = None
316 | find_unused_parameters: bool = True
317 | static_graph: bool = True
318 | max_grad_norm: SingularOrIterable[float] = 0.5
319 | save_every_n_samples: int = 100000
320 | n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
321 | cond_scale: Union[float, List[float]] = 1.0
322 | device: str = 'cuda:0'
323 | epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
324 | validation_samples: Optional[int] = None # Same as above but for validation.
325 | save_immediately: bool = False
326 | use_ema: bool = True
327 | ema_beta: float = 0.999
328 | amp: bool = False
329 | unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets
330 |
331 | class DecoderEvaluateConfig(BaseModel):
332 | n_evaluation_samples: int = 1000
333 | FID: Optional[Dict[str, Any]] = None
334 | IS: Optional[Dict[str, Any]] = None
335 | KID: Optional[Dict[str, Any]] = None
336 | LPIPS: Optional[Dict[str, Any]] = None
337 |
338 | class TrainDecoderConfig(BaseModel):
339 | decoder: DecoderConfig
340 | data: DecoderDataConfig
341 | train: DecoderTrainConfig
342 | evaluate: DecoderEvaluateConfig
343 | tracker: TrackerConfig
344 | seed: int = 0
345 |
346 | @classmethod
347 | def from_json_path(cls, json_path):
348 | with open(json_path) as f:
349 | config = json.load(f)
350 | print(config)
351 | return cls(**config)
352 |
353 | @model_validator(mode = 'after')
354 | def check_has_embeddings(self, m):
355 | # Makes sure that enough information is provided to get the embeddings specified for training
356 | values = dict(self)
357 |
358 | data_config, decoder_config = values.get('data'), values.get('decoder')
359 |
360 | if not exists(data_config) or not exists(decoder_config):
361 | # Then something else errored and we should just pass through
362 | return values
363 |
364 | using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
365 | using_clip = exists(decoder_config.clip)
366 | img_emb_url = data_config.img_embeddings_url
367 | text_emb_url = data_config.text_embeddings_url
368 |
369 | if using_text_embeddings:
370 | # Then we need some way to get the embeddings
371 | assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
372 |
373 | if using_clip:
374 | if using_text_embeddings:
375 | assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
376 | else:
377 | assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
378 |
379 | if text_emb_url:
380 | assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
381 |
382 | return m
383 |
--------------------------------------------------------------------------------
/dalle2_pytorch/trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import copy
3 | from pathlib import Path
4 | from math import ceil
5 | from functools import partial, wraps
6 | from contextlib import nullcontext
7 | from collections.abc import Iterable
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn
12 | from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
13 | from torch.cuda.amp import autocast, GradScaler
14 |
15 | from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
16 | from dalle2_pytorch.optimizer import get_optimizer
17 | from dalle2_pytorch.version import __version__
18 | from packaging import version
19 |
20 | import pytorch_warmup as warmup
21 |
22 | from ema_pytorch import EMA
23 |
24 | from accelerate import Accelerator, DistributedType
25 |
26 | import numpy as np
27 |
28 | # helper functions
29 |
30 | def exists(val):
31 | return val is not None
32 |
33 | def default(val, d):
34 | if exists(val):
35 | return val
36 | return d() if callable(d) else d
37 |
38 | def cast_tuple(val, length = 1):
39 | return val if isinstance(val, tuple) else ((val,) * length)
40 |
41 | def pick_and_pop(keys, d):
42 | values = list(map(lambda key: d.pop(key), keys))
43 | return dict(zip(keys, values))
44 |
45 | def group_dict_by_key(cond, d):
46 | return_val = [dict(),dict()]
47 | for key in d.keys():
48 | match = bool(cond(key))
49 | ind = int(not match)
50 | return_val[ind][key] = d[key]
51 | return (*return_val,)
52 |
53 | def string_begins_with(prefix, str):
54 | return str.startswith(prefix)
55 |
56 | def group_by_key_prefix(prefix, d):
57 | return group_dict_by_key(partial(string_begins_with, prefix), d)
58 |
59 | def groupby_prefix_and_trim(prefix, d):
60 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
61 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
62 | return kwargs_without_prefix, kwargs
63 |
64 | def num_to_groups(num, divisor):
65 | groups = num // divisor
66 | remainder = num % divisor
67 | arr = [divisor] * groups
68 | if remainder > 0:
69 | arr.append(remainder)
70 | return arr
71 |
72 | # decorators
73 |
74 | def cast_torch_tensor(fn):
75 | @wraps(fn)
76 | def inner(model, *args, **kwargs):
77 | device = kwargs.pop('_device', next(model.parameters()).device)
78 | cast_device = kwargs.pop('_cast_device', True)
79 | cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
80 |
81 | kwargs_keys = kwargs.keys()
82 | all_args = (*args, *kwargs.values())
83 | split_kwargs_index = len(all_args) - len(kwargs_keys)
84 | all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
85 |
86 | if cast_device:
87 | all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
88 |
89 | if cast_deepspeed_precision:
90 | try:
91 | accelerator = model.accelerator
92 | if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
93 | cast_type_map = {
94 | "fp16": torch.half,
95 | "bf16": torch.bfloat16,
96 | "no": torch.float
97 | }
98 | precision_type = cast_type_map[accelerator.mixed_precision]
99 | all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
100 | except AttributeError:
101 | # Then this model doesn't have an accelerator
102 | pass
103 |
104 | args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
105 | kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
106 |
107 | out = fn(model, *args, **kwargs)
108 | return out
109 | return inner
110 |
111 | # gradient accumulation functions
112 |
113 | def split_iterable(it, split_size):
114 | accum = []
115 | for ind in range(ceil(len(it) / split_size)):
116 | start_index = ind * split_size
117 | accum.append(it[start_index: (start_index + split_size)])
118 | return accum
119 |
120 | def split(t, split_size = None):
121 | if not exists(split_size):
122 | return t
123 |
124 | if isinstance(t, torch.Tensor):
125 | return t.split(split_size, dim = 0)
126 |
127 | if isinstance(t, Iterable):
128 | return split_iterable(t, split_size)
129 |
130 | return TypeError
131 |
132 | def find_first(cond, arr):
133 | for el in arr:
134 | if cond(el):
135 | return el
136 | return None
137 |
138 | def split_args_and_kwargs(*args, split_size = None, **kwargs):
139 | all_args = (*args, *kwargs.values())
140 | len_all_args = len(all_args)
141 | first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
142 | assert exists(first_tensor)
143 |
144 | batch_size = len(first_tensor)
145 | split_size = default(split_size, batch_size)
146 | num_chunks = ceil(batch_size / split_size)
147 |
148 | dict_len = len(kwargs)
149 | dict_keys = kwargs.keys()
150 | split_kwargs_index = len_all_args - dict_len
151 |
152 | split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
153 | chunk_sizes = tuple(map(len, split_all_args[0]))
154 |
155 | for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
156 | chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
157 | chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
158 | chunk_size_frac = chunk_size / batch_size
159 | yield chunk_size_frac, (chunked_args, chunked_kwargs)
160 |
161 | # diffusion prior trainer
162 |
163 | def prior_sample_in_chunks(fn):
164 | @wraps(fn)
165 | def inner(self, *args, max_batch_size = None, **kwargs):
166 | if not exists(max_batch_size):
167 | return fn(self, *args, **kwargs)
168 |
169 | outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
170 | return torch.cat(outputs, dim = 0)
171 | return inner
172 |
173 | class DiffusionPriorTrainer(nn.Module):
174 | def __init__(
175 | self,
176 | diffusion_prior,
177 | accelerator = None,
178 | use_ema = True,
179 | lr = 3e-4,
180 | wd = 1e-2,
181 | eps = 1e-6,
182 | max_grad_norm = None,
183 | group_wd_params = True,
184 | warmup_steps = None,
185 | cosine_decay_max_steps = None,
186 | **kwargs
187 | ):
188 | super().__init__()
189 | assert isinstance(diffusion_prior, DiffusionPrior)
190 |
191 | ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
192 | accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
193 |
194 | if not exists(accelerator):
195 | accelerator = Accelerator(**accelerator_kwargs)
196 |
197 | # assign some helpful member vars
198 |
199 | self.accelerator = accelerator
200 | self.text_conditioned = diffusion_prior.condition_on_text_encodings
201 |
202 | # setting the device
203 |
204 | self.device = accelerator.device
205 | diffusion_prior.to(self.device)
206 |
207 | # save model
208 |
209 | self.diffusion_prior = diffusion_prior
210 |
211 | # mixed precision checks
212 |
213 | if (
214 | exists(self.accelerator)
215 | and self.accelerator.distributed_type == DistributedType.DEEPSPEED
216 | and self.diffusion_prior.clip is not None
217 | ):
218 | # Then we need to make sure clip is using the correct precision or else deepspeed will error
219 | cast_type_map = {
220 | "fp16": torch.half,
221 | "bf16": torch.bfloat16,
222 | "no": torch.float
223 | }
224 | precision_type = cast_type_map[accelerator.mixed_precision]
225 | assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
226 | self.diffusion_prior.clip.to(precision_type)
227 |
228 | # optimizer stuff
229 |
230 | self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
231 |
232 | self.optimizer = get_optimizer(
233 | self.diffusion_prior.parameters(),
234 | **self.optim_kwargs,
235 | **kwargs
236 | )
237 |
238 | if exists(cosine_decay_max_steps):
239 | self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
240 | else:
241 | self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
242 |
243 | self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
244 |
245 | # distribute the model if using HFA
246 |
247 | self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)
248 |
249 | # exponential moving average stuff
250 |
251 | self.use_ema = use_ema
252 |
253 | if self.use_ema:
254 | self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)
255 |
256 | # gradient clipping if needed
257 |
258 | self.max_grad_norm = max_grad_norm
259 |
260 | # track steps internally
261 |
262 | self.register_buffer('step', torch.tensor([0], device = self.device))
263 |
264 | # utility
265 |
266 | def save(self, path, overwrite = True, **kwargs):
267 |
268 | # only save on the main process
269 | if self.accelerator.is_main_process:
270 | print(f"Saving checkpoint at step: {self.step.item()}")
271 | path = Path(path)
272 | assert not (path.exists() and not overwrite)
273 | path.parent.mkdir(parents = True, exist_ok = True)
274 |
275 | # FIXME: LambdaLR can't be saved due to pickling issues
276 | save_obj = dict(
277 | optimizer = self.optimizer.state_dict(),
278 | scheduler = self.scheduler.state_dict(),
279 | warmup_scheduler = self.warmup_scheduler,
280 | model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
281 | version = version.parse(__version__),
282 | step = self.step,
283 | **kwargs
284 | )
285 |
286 | if self.use_ema:
287 | save_obj = {
288 | **save_obj,
289 | 'ema': self.ema_diffusion_prior.state_dict(),
290 | 'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload
291 | }
292 |
293 | torch.save(save_obj, str(path))
294 |
295 | def load(self, path_or_state, overwrite_lr = True, strict = True):
296 | """
297 | Load a checkpoint of a diffusion prior trainer.
298 |
299 | Will load the entire trainer, including the optimizer and EMA.
300 |
301 | Params:
302 | - path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file
303 | - overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
304 | - strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
305 |
306 | Returns:
307 | loaded_obj (dict): The loaded checkpoint dictionary
308 | """
309 |
310 | # all processes need to load checkpoint. no restriction here
311 | if isinstance(path_or_state, str):
312 | path = Path(path_or_state)
313 | assert path.exists()
314 | loaded_obj = torch.load(str(path), map_location=self.device)
315 |
316 | elif isinstance(path_or_state, dict):
317 | loaded_obj = path_or_state
318 |
319 | if version.parse(__version__) != loaded_obj['version']:
320 | print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
321 |
322 | # unwrap the model when loading from checkpoint
323 | self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
324 | self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
325 |
326 | self.optimizer.load_state_dict(loaded_obj['optimizer'])
327 | self.scheduler.load_state_dict(loaded_obj['scheduler'])
328 |
329 | # set warmupstep
330 | if exists(self.warmup_scheduler):
331 | self.warmup_scheduler.last_step = self.step.item()
332 |
333 | # ensure new lr is used if different from old one
334 | if overwrite_lr:
335 | new_lr = self.optim_kwargs["lr"]
336 |
337 | for group in self.optimizer.param_groups:
338 | group["lr"] = new_lr if group["lr"] > 0.0 else 0.0
339 |
340 | if self.use_ema:
341 | assert 'ema' in loaded_obj
342 | self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
343 | # below might not be necessary, but I had a suspicion that this wasn't being loaded correctly
344 | self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
345 |
346 | return loaded_obj
347 |
348 | # model functionality
349 |
350 | def update(self):
351 |
352 | if exists(self.max_grad_norm):
353 | self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
354 |
355 | self.optimizer.step()
356 | self.optimizer.zero_grad()
357 |
358 | # accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
359 | if not self.accelerator.optimizer_step_was_skipped:
360 | sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
361 | with sched_context():
362 | self.scheduler.step()
363 |
364 | if self.use_ema:
365 | self.ema_diffusion_prior.update()
366 |
367 | self.step += 1
368 |
369 | @torch.no_grad()
370 | @cast_torch_tensor
371 | @prior_sample_in_chunks
372 | def p_sample_loop(self, *args, **kwargs):
373 | model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
374 | return model.p_sample_loop(*args, **kwargs)
375 |
376 | @torch.no_grad()
377 | @cast_torch_tensor
378 | @prior_sample_in_chunks
379 | def sample(self, *args, **kwargs):
380 | model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
381 | return model.sample(*args, **kwargs)
382 |
383 | @torch.no_grad()
384 | def sample_batch_size(self, *args, **kwargs):
385 | model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
386 | return model.sample_batch_size(*args, **kwargs)
387 |
388 | @torch.no_grad()
389 | @cast_torch_tensor
390 | @prior_sample_in_chunks
391 | def embed_text(self, *args, **kwargs):
392 | return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
393 |
394 | @cast_torch_tensor
395 | def forward(
396 | self,
397 | *args,
398 | max_batch_size = None,
399 | **kwargs
400 | ):
401 | total_loss = 0.
402 |
403 | for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
404 | with self.accelerator.autocast():
405 | loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
406 | loss = loss * chunk_size_frac
407 |
408 | total_loss += loss.item()
409 |
410 | if self.training:
411 | self.accelerator.backward(loss)
412 |
413 | return total_loss
414 |
415 | # decoder trainer
416 |
417 | def decoder_sample_in_chunks(fn):
418 | @wraps(fn)
419 | def inner(self, *args, max_batch_size = None, **kwargs):
420 | if not exists(max_batch_size):
421 | return fn(self, *args, **kwargs)
422 |
423 | if self.decoder.unconditional:
424 | batch_size = kwargs.get('batch_size')
425 | batch_sizes = num_to_groups(batch_size, max_batch_size)
426 | outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
427 | else:
428 | outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
429 |
430 | return torch.cat(outputs, dim = 0)
431 | return inner
432 |
433 | class DecoderTrainer(nn.Module):
434 | def __init__(
435 | self,
436 | decoder,
437 | accelerator = None,
438 | dataloaders = None,
439 | use_ema = True,
440 | lr = 1e-4,
441 | wd = 1e-2,
442 | eps = 1e-8,
443 | warmup_steps = None,
444 | cosine_decay_max_steps = None,
445 | max_grad_norm = 0.5,
446 | amp = False,
447 | group_wd_params = True,
448 | **kwargs
449 | ):
450 | super().__init__()
451 | assert isinstance(decoder, Decoder)
452 | ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
453 |
454 | self.accelerator = default(accelerator, Accelerator)
455 |
456 | self.num_unets = len(decoder.unets)
457 |
458 | self.use_ema = use_ema
459 | self.ema_unets = nn.ModuleList([])
460 |
461 | self.amp = amp
462 |
463 | # be able to finely customize learning rate, weight decay
464 | # per unet
465 |
466 | lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
467 |
468 | assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
469 |
470 | optimizers = []
471 | schedulers = []
472 | warmup_schedulers = []
473 |
474 | for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
475 | if isinstance(unet, nn.Identity):
476 | optimizers.append(None)
477 | schedulers.append(None)
478 | warmup_schedulers.append(None)
479 | else:
480 | optimizer = get_optimizer(
481 | unet.parameters(),
482 | lr = unet_lr,
483 | wd = unet_wd,
484 | eps = unet_eps,
485 | group_wd_params = group_wd_params,
486 | **kwargs
487 | )
488 |
489 | optimizers.append(optimizer)
490 |
491 | if exists(unet_cosine_decay_max_steps):
492 | scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
493 | else:
494 | scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
495 |
496 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
497 | warmup_schedulers.append(warmup_scheduler)
498 |
499 | schedulers.append(scheduler)
500 |
501 | if self.use_ema:
502 | self.ema_unets.append(EMA(unet, **ema_kwargs))
503 |
504 | # gradient clipping if needed
505 |
506 | self.max_grad_norm = max_grad_norm
507 |
508 | self.register_buffer('steps', torch.tensor([0] * self.num_unets))
509 |
510 | if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
511 | # Then we need to make sure clip is using the correct precision or else deepspeed will error
512 | cast_type_map = {
513 | "fp16": torch.half,
514 | "bf16": torch.bfloat16,
515 | "no": torch.float
516 | }
517 | precision_type = cast_type_map[accelerator.mixed_precision]
518 | assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
519 | clip = decoder.clip
520 | clip.to(precision_type)
521 |
522 | decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
523 |
524 | self.decoder = decoder
525 |
526 | # prepare dataloaders
527 |
528 | train_loader = val_loader = None
529 | if exists(dataloaders):
530 | train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
531 |
532 | self.train_loader = train_loader
533 | self.val_loader = val_loader
534 |
535 | # store optimizers
536 |
537 | for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
538 | setattr(self, f'optim{opt_ind}', optimizer)
539 |
540 | # store schedulers
541 |
542 | for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
543 | setattr(self, f'sched{sched_ind}', scheduler)
544 |
545 | # store warmup schedulers
546 |
547 | self.warmup_schedulers = warmup_schedulers
548 |
549 | def validate_and_return_unet_number(self, unet_number = None):
550 | if self.num_unets == 1:
551 | unet_number = default(unet_number, 1)
552 |
553 | assert exists(unet_number) and 1 <= unet_number <= self.num_unets
554 | return unet_number
555 |
556 | def num_steps_taken(self, unet_number = None):
557 | unet_number = self.validate_and_return_unet_number(unet_number)
558 | return self.steps[unet_number - 1].item()
559 |
560 | def save(self, path, overwrite = True, **kwargs):
561 | path = Path(path)
562 | assert not (path.exists() and not overwrite)
563 | path.parent.mkdir(parents = True, exist_ok = True)
564 |
565 | save_obj = dict(
566 | model = self.accelerator.unwrap_model(self.decoder).state_dict(),
567 | version = __version__,
568 | steps = self.steps.cpu(),
569 | **kwargs
570 | )
571 |
572 | for ind in range(0, self.num_unets):
573 | optimizer_key = f'optim{ind}'
574 | scheduler_key = f'sched{ind}'
575 |
576 | optimizer = getattr(self, optimizer_key)
577 | scheduler = getattr(self, scheduler_key)
578 |
579 | optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
580 | scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
581 |
582 | save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
583 |
584 | if self.use_ema:
585 | save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
586 |
587 | self.accelerator.save(save_obj, str(path))
588 |
589 | def load_state_dict(self, loaded_obj, only_model = False, strict = True):
590 | if version.parse(__version__) != version.parse(loaded_obj['version']):
591 | self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
592 |
593 | self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
594 | self.steps.copy_(loaded_obj['steps'])
595 |
596 | if only_model:
597 | return loaded_obj
598 |
599 | for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
600 |
601 | optimizer_key = f'optim{ind}'
602 | optimizer = getattr(self, optimizer_key)
603 |
604 | scheduler_key = f'sched{ind}'
605 | scheduler = getattr(self, scheduler_key)
606 |
607 | warmup_scheduler = self.warmup_schedulers[ind]
608 |
609 | if exists(optimizer):
610 | optimizer.load_state_dict(loaded_obj[optimizer_key])
611 |
612 | if exists(scheduler):
613 | scheduler.load_state_dict(loaded_obj[scheduler_key])
614 |
615 | if exists(warmup_scheduler):
616 | warmup_scheduler.last_step = last_step
617 |
618 | if self.use_ema:
619 | assert 'ema' in loaded_obj
620 | self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
621 |
622 | def load(self, path, only_model = False, strict = True):
623 | path = Path(path)
624 | assert path.exists()
625 |
626 | loaded_obj = torch.load(str(path), map_location = 'cpu')
627 |
628 | self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
629 |
630 | return loaded_obj
631 |
632 | @property
633 | def unets(self):
634 | return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
635 |
636 | def increment_step(self, unet_number):
637 | assert 1 <= unet_number <= self.num_unets
638 |
639 | unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
640 | self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
641 |
642 | def update(self, unet_number = None):
643 | unet_number = self.validate_and_return_unet_number(unet_number)
644 | index = unet_number - 1
645 |
646 | optimizer = getattr(self, f'optim{index}')
647 | scheduler = getattr(self, f'sched{index}')
648 |
649 | if exists(self.max_grad_norm):
650 | self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
651 |
652 | optimizer.step()
653 | optimizer.zero_grad()
654 |
655 | warmup_scheduler = self.warmup_schedulers[index]
656 | scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
657 |
658 | with scheduler_context():
659 | scheduler.step()
660 |
661 | if self.use_ema:
662 | ema_unet = self.ema_unets[index]
663 | ema_unet.update()
664 |
665 | self.increment_step(unet_number)
666 |
667 | @torch.no_grad()
668 | @cast_torch_tensor
669 | @decoder_sample_in_chunks
670 | def sample(self, *args, **kwargs):
671 | distributed = self.accelerator.num_processes > 1
672 | base_decoder = self.accelerator.unwrap_model(self.decoder)
673 |
674 | was_training = base_decoder.training
675 | base_decoder.eval()
676 |
677 | if kwargs.pop('use_non_ema', False) or not self.use_ema:
678 | out = base_decoder.sample(*args, **kwargs, distributed = distributed)
679 | base_decoder.train(was_training)
680 | return out
681 |
682 | trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
683 | base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
684 |
685 | output = base_decoder.sample(*args, **kwargs, distributed = distributed)
686 |
687 | base_decoder.unets = trainable_unets # restore original training unets
688 |
689 | # cast the ema_model unets back to original device
690 | for ema in self.ema_unets:
691 | ema.restore_ema_model_device()
692 |
693 | base_decoder.train(was_training)
694 | return output
695 |
696 | @torch.no_grad()
697 | @cast_torch_tensor
698 | @prior_sample_in_chunks
699 | def embed_text(self, *args, **kwargs):
700 | return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
701 |
702 | @torch.no_grad()
703 | @cast_torch_tensor
704 | @prior_sample_in_chunks
705 | def embed_image(self, *args, **kwargs):
706 | return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
707 |
708 | @cast_torch_tensor
709 | def forward(
710 | self,
711 | *args,
712 | unet_number = None,
713 | max_batch_size = None,
714 | return_lowres_cond_image=False,
715 | **kwargs
716 | ):
717 | unet_number = self.validate_and_return_unet_number(unet_number)
718 |
719 | total_loss = 0.
720 | cond_images = []
721 | for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
722 | with self.accelerator.autocast():
723 | loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
724 | # loss_obj may be a tuple with loss and cond_image
725 | if return_lowres_cond_image:
726 | loss, cond_image = loss_obj
727 | else:
728 | loss = loss_obj
729 | cond_image = None
730 | loss = loss * chunk_size_frac
731 | if cond_image is not None:
732 | cond_images.append(cond_image)
733 |
734 | total_loss += loss.item()
735 |
736 | if self.training:
737 | self.accelerator.backward(loss)
738 |
739 | if return_lowres_cond_image:
740 | return total_loss, torch.stack(cond_images)
741 | else:
742 | return total_loss
743 |
--------------------------------------------------------------------------------
/dalle2_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import importlib
3 |
4 | # helper functions
5 |
6 | def exists(val):
7 | return val is not None
8 |
9 | # time helpers
10 |
11 | class Timer:
12 | def __init__(self):
13 | self.reset()
14 |
15 | def reset(self):
16 | self.last_time = time.time()
17 |
18 | def elapsed(self):
19 | return time.time() - self.last_time
20 |
21 | # print helpers
22 |
23 | def print_ribbon(s, symbol = '=', repeat = 40):
24 | flank = symbol * repeat
25 | return f'{flank} {s} {flank}'
26 |
27 | # import helpers
28 |
29 | def import_or_print_error(pkg_name, err_str = None):
30 | try:
31 | return importlib.import_module(pkg_name)
32 | except ModuleNotFoundError as e:
33 | if exists(err_str):
34 | print(err_str)
35 | exit()
36 |
--------------------------------------------------------------------------------
/dalle2_pytorch/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.15.6'
2 |
--------------------------------------------------------------------------------
/dalle2_pytorch/vqgan_vae.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | from math import sqrt
4 | from functools import partial, wraps
5 |
6 | from vector_quantize_pytorch import VectorQuantize as VQ
7 |
8 | import torch
9 | from torch import nn, einsum
10 | import torch.nn.functional as F
11 | from torch.autograd import grad as torch_grad
12 | import torchvision
13 |
14 | from einops import rearrange, reduce, repeat, pack, unpack
15 | from einops.layers.torch import Rearrange
16 |
17 | # constants
18 |
19 | MList = nn.ModuleList
20 |
21 | # helper functions
22 |
23 | def exists(val):
24 | return val is not None
25 |
26 | def default(val, d):
27 | return val if exists(val) else d
28 |
29 | # decorators
30 |
31 | def eval_decorator(fn):
32 | def inner(model, *args, **kwargs):
33 | was_training = model.training
34 | model.eval()
35 | out = fn(model, *args, **kwargs)
36 | model.train(was_training)
37 | return out
38 | return inner
39 |
40 | def remove_vgg(fn):
41 | @wraps(fn)
42 | def inner(self, *args, **kwargs):
43 | has_vgg = hasattr(self, 'vgg')
44 | if has_vgg:
45 | vgg = self.vgg
46 | delattr(self, 'vgg')
47 |
48 | out = fn(self, *args, **kwargs)
49 |
50 | if has_vgg:
51 | self.vgg = vgg
52 |
53 | return out
54 | return inner
55 |
56 | # keyword argument helpers
57 |
58 | def pick_and_pop(keys, d):
59 | values = list(map(lambda key: d.pop(key), keys))
60 | return dict(zip(keys, values))
61 |
62 | def group_dict_by_key(cond, d):
63 | return_val = [dict(),dict()]
64 | for key in d.keys():
65 | match = bool(cond(key))
66 | ind = int(not match)
67 | return_val[ind][key] = d[key]
68 | return (*return_val,)
69 |
70 | def string_begins_with(prefix, string_input):
71 | return string_input.startswith(prefix)
72 |
73 | def group_by_key_prefix(prefix, d):
74 | return group_dict_by_key(partial(string_begins_with, prefix), d)
75 |
76 | def groupby_prefix_and_trim(prefix, d):
77 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
78 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
79 | return kwargs_without_prefix, kwargs
80 |
81 | # tensor helper functions
82 |
83 | def log(t, eps = 1e-10):
84 | return torch.log(t + eps)
85 |
86 | def gradient_penalty(images, output, weight = 10):
87 | batch_size = images.shape[0]
88 | gradients = torch_grad(outputs = output, inputs = images,
89 | grad_outputs = torch.ones(output.size(), device = images.device),
90 | create_graph = True, retain_graph = True, only_inputs = True)[0]
91 |
92 | gradients = rearrange(gradients, 'b ... -> b (...)')
93 | return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
94 |
95 | def l2norm(t):
96 | return F.normalize(t, dim = -1)
97 |
98 | def leaky_relu(p = 0.1):
99 | return nn.LeakyReLU(0.1)
100 |
101 | def stable_softmax(t, dim = -1, alpha = 32 ** 2):
102 | t = t / alpha
103 | t = t - torch.amax(t, dim = dim, keepdim = True).detach()
104 | return (t * alpha).softmax(dim = dim)
105 |
106 | def safe_div(numer, denom, eps = 1e-8):
107 | return numer / (denom + eps)
108 |
109 | # gan losses
110 |
111 | def hinge_discr_loss(fake, real):
112 | return (F.relu(1 + fake) + F.relu(1 - real)).mean()
113 |
114 | def hinge_gen_loss(fake):
115 | return -fake.mean()
116 |
117 | def bce_discr_loss(fake, real):
118 | return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
119 |
120 | def bce_gen_loss(fake):
121 | return -log(torch.sigmoid(fake)).mean()
122 |
123 | def grad_layer_wrt_loss(loss, layer):
124 | return torch_grad(
125 | outputs = loss,
126 | inputs = layer,
127 | grad_outputs = torch.ones_like(loss),
128 | retain_graph = True
129 | )[0].detach()
130 |
131 | # vqgan vae
132 |
133 | class LayerNormChan(nn.Module):
134 | def __init__(
135 | self,
136 | dim,
137 | eps = 1e-5
138 | ):
139 | super().__init__()
140 | self.eps = eps
141 | self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
142 |
143 | def forward(self, x):
144 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
145 | mean = torch.mean(x, dim = 1, keepdim = True)
146 | return (x - mean) / (var + self.eps).sqrt() * self.gamma
147 |
148 | # discriminator
149 |
150 | class Discriminator(nn.Module):
151 | def __init__(
152 | self,
153 | dims,
154 | channels = 3,
155 | groups = 16,
156 | init_kernel_size = 5
157 | ):
158 | super().__init__()
159 | dim_pairs = zip(dims[:-1], dims[1:])
160 |
161 | self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
162 |
163 | for dim_in, dim_out in dim_pairs:
164 | self.layers.append(nn.Sequential(
165 | nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
166 | nn.GroupNorm(groups, dim_out),
167 | leaky_relu()
168 | ))
169 |
170 | dim = dims[-1]
171 | self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
172 | nn.Conv2d(dim, dim, 1),
173 | leaky_relu(),
174 | nn.Conv2d(dim, 1, 4)
175 | )
176 |
177 | def forward(self, x):
178 | for net in self.layers:
179 | x = net(x)
180 |
181 | return self.to_logits(x)
182 |
183 | # positional encoding
184 |
185 | class ContinuousPositionBias(nn.Module):
186 | """ from https://arxiv.org/abs/2111.09883 """
187 |
188 | def __init__(self, *, dim, heads, layers = 2):
189 | super().__init__()
190 | self.net = MList([])
191 | self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
192 |
193 | for _ in range(layers - 1):
194 | self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
195 |
196 | self.net.append(nn.Linear(dim, heads))
197 | self.register_buffer('rel_pos', None, persistent = False)
198 |
199 | def forward(self, x):
200 | n, device = x.shape[-1], x.device
201 | fmap_size = int(sqrt(n))
202 |
203 | if not exists(self.rel_pos):
204 | pos = torch.arange(fmap_size, device = device)
205 | grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
206 | grid = rearrange(grid, 'c i j -> (i j) c')
207 | rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
208 | rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
209 | self.register_buffer('rel_pos', rel_pos, persistent = False)
210 |
211 | rel_pos = self.rel_pos.float()
212 |
213 | for layer in self.net:
214 | rel_pos = layer(rel_pos)
215 |
216 | bias = rearrange(rel_pos, 'i j h -> h i j')
217 | return x + bias
218 |
219 | # resnet encoder / decoder
220 |
221 | class ResnetEncDec(nn.Module):
222 | def __init__(
223 | self,
224 | dim,
225 | *,
226 | channels = 3,
227 | layers = 4,
228 | layer_mults = None,
229 | num_resnet_blocks = 1,
230 | resnet_groups = 16,
231 | first_conv_kernel_size = 5,
232 | use_attn = True,
233 | attn_dim_head = 64,
234 | attn_heads = 8,
235 | attn_dropout = 0.,
236 | ):
237 | super().__init__()
238 | assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
239 |
240 | self.layers = layers
241 |
242 | self.encoders = MList([])
243 | self.decoders = MList([])
244 |
245 | layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
246 | assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
247 |
248 | layer_dims = [dim * mult for mult in layer_mults]
249 | dims = (dim, *layer_dims)
250 |
251 | self.encoded_dim = dims[-1]
252 |
253 | dim_pairs = zip(dims[:-1], dims[1:])
254 |
255 | append = lambda arr, t: arr.append(t)
256 | prepend = lambda arr, t: arr.insert(0, t)
257 |
258 | if not isinstance(num_resnet_blocks, tuple):
259 | num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
260 |
261 | if not isinstance(use_attn, tuple):
262 | use_attn = (*((False,) * (layers - 1)), use_attn)
263 |
264 | assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
265 | assert len(use_attn) == layers
266 |
267 | for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
268 | append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
269 | prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
270 |
271 | if layer_use_attn:
272 | prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
273 |
274 | for _ in range(layer_num_resnet_blocks):
275 | append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
276 | prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
277 |
278 | if layer_use_attn:
279 | append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
280 |
281 | prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
282 | append(self.decoders, nn.Conv2d(dim, channels, 1))
283 |
284 | def get_encoded_fmap_size(self, image_size):
285 | return image_size // (2 ** self.layers)
286 |
287 | @property
288 | def last_dec_layer(self):
289 | return self.decoders[-1].weight
290 |
291 | def encode(self, x):
292 | for enc in self.encoders:
293 | x = enc(x)
294 | return x
295 |
296 | def decode(self, x):
297 | for dec in self.decoders:
298 | x = dec(x)
299 | return x
300 |
301 | class GLUResBlock(nn.Module):
302 | def __init__(self, chan, groups = 16):
303 | super().__init__()
304 | self.net = nn.Sequential(
305 | nn.Conv2d(chan, chan * 2, 3, padding = 1),
306 | nn.GLU(dim = 1),
307 | nn.GroupNorm(groups, chan),
308 | nn.Conv2d(chan, chan * 2, 3, padding = 1),
309 | nn.GLU(dim = 1),
310 | nn.GroupNorm(groups, chan),
311 | nn.Conv2d(chan, chan, 1)
312 | )
313 |
314 | def forward(self, x):
315 | return self.net(x) + x
316 |
317 | class ResBlock(nn.Module):
318 | def __init__(self, chan, groups = 16):
319 | super().__init__()
320 | self.net = nn.Sequential(
321 | nn.Conv2d(chan, chan, 3, padding = 1),
322 | nn.GroupNorm(groups, chan),
323 | leaky_relu(),
324 | nn.Conv2d(chan, chan, 3, padding = 1),
325 | nn.GroupNorm(groups, chan),
326 | leaky_relu(),
327 | nn.Conv2d(chan, chan, 1)
328 | )
329 |
330 | def forward(self, x):
331 | return self.net(x) + x
332 |
333 | # vqgan attention layer
334 |
335 | class VQGanAttention(nn.Module):
336 | def __init__(
337 | self,
338 | *,
339 | dim,
340 | dim_head = 64,
341 | heads = 8,
342 | dropout = 0.
343 | ):
344 | super().__init__()
345 | self.heads = heads
346 | self.scale = dim_head ** -0.5
347 | inner_dim = heads * dim_head
348 |
349 | self.dropout = nn.Dropout(dropout)
350 | self.pre_norm = LayerNormChan(dim)
351 |
352 | self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
353 | self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
354 | self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
355 |
356 | def forward(self, x):
357 | h = self.heads
358 | height, width, residual = *x.shape[-2:], x.clone()
359 |
360 | x = self.pre_norm(x)
361 |
362 | q, k, v = self.to_qkv(x).chunk(3, dim = 1)
363 |
364 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
365 |
366 | sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
367 |
368 | sim = self.cpb(sim)
369 |
370 | attn = stable_softmax(sim, dim = -1)
371 | attn = self.dropout(attn)
372 |
373 | out = einsum('b h i j, b h c j -> b h c i', attn, v)
374 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
375 | out = self.to_out(out)
376 |
377 | return out + residual
378 |
379 | # ViT encoder / decoder
380 |
381 | class RearrangeImage(nn.Module):
382 | def forward(self, x):
383 | n = x.shape[1]
384 | w = h = int(sqrt(n))
385 | return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w)
386 |
387 | class Attention(nn.Module):
388 | def __init__(
389 | self,
390 | dim,
391 | *,
392 | heads = 8,
393 | dim_head = 32
394 | ):
395 | super().__init__()
396 | self.norm = nn.LayerNorm(dim)
397 | self.heads = heads
398 | self.scale = dim_head ** -0.5
399 | inner_dim = dim_head * heads
400 |
401 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
402 | self.to_out = nn.Linear(inner_dim, dim)
403 |
404 | def forward(self, x):
405 | h = self.heads
406 |
407 | x = self.norm(x)
408 |
409 | q, k, v = self.to_qkv(x).chunk(3, dim = -1)
410 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
411 |
412 | q = q * self.scale
413 | sim = einsum('b h i d, b h j d -> b h i j', q, k)
414 |
415 | sim = sim - sim.amax(dim = -1, keepdim = True).detach()
416 | attn = sim.softmax(dim = -1)
417 |
418 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
419 |
420 | out = rearrange(out, 'b h n d -> b n (h d)')
421 | return self.to_out(out)
422 |
423 | def FeedForward(dim, mult = 4):
424 | return nn.Sequential(
425 | nn.LayerNorm(dim),
426 | nn.Linear(dim, dim * mult, bias = False),
427 | nn.GELU(),
428 | nn.Linear(dim * mult, dim, bias = False)
429 | )
430 |
431 | class Transformer(nn.Module):
432 | def __init__(
433 | self,
434 | dim,
435 | *,
436 | layers,
437 | dim_head = 32,
438 | heads = 8,
439 | ff_mult = 4
440 | ):
441 | super().__init__()
442 | self.layers = nn.ModuleList([])
443 | for _ in range(layers):
444 | self.layers.append(nn.ModuleList([
445 | Attention(dim = dim, dim_head = dim_head, heads = heads),
446 | FeedForward(dim = dim, mult = ff_mult)
447 | ]))
448 |
449 | self.norm = nn.LayerNorm(dim)
450 |
451 | def forward(self, x):
452 | for attn, ff in self.layers:
453 | x = attn(x) + x
454 | x = ff(x) + x
455 |
456 | return self.norm(x)
457 |
458 | class ViTEncDec(nn.Module):
459 | def __init__(
460 | self,
461 | dim,
462 | channels = 3,
463 | layers = 4,
464 | patch_size = 8,
465 | dim_head = 32,
466 | heads = 8,
467 | ff_mult = 4
468 | ):
469 | super().__init__()
470 | self.encoded_dim = dim
471 | self.patch_size = patch_size
472 |
473 | input_dim = channels * (patch_size ** 2)
474 |
475 | self.encoder = nn.Sequential(
476 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
477 | nn.Linear(input_dim, dim),
478 | Transformer(
479 | dim = dim,
480 | dim_head = dim_head,
481 | heads = heads,
482 | ff_mult = ff_mult,
483 | layers = layers
484 | ),
485 | RearrangeImage(),
486 | Rearrange('b h w c -> b c h w')
487 | )
488 |
489 | self.decoder = nn.Sequential(
490 | Rearrange('b c h w -> b (h w) c'),
491 | Transformer(
492 | dim = dim,
493 | dim_head = dim_head,
494 | heads = heads,
495 | ff_mult = ff_mult,
496 | layers = layers
497 | ),
498 | nn.Sequential(
499 | nn.Linear(dim, dim * 4, bias = False),
500 | nn.Tanh(),
501 | nn.Linear(dim * 4, input_dim, bias = False),
502 | ),
503 | RearrangeImage(),
504 | Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
505 | )
506 |
507 | def get_encoded_fmap_size(self, image_size):
508 | return image_size // self.patch_size
509 |
510 | @property
511 | def last_dec_layer(self):
512 | return self.decoder[-3][-1].weight
513 |
514 | def encode(self, x):
515 | return self.encoder(x)
516 |
517 | def decode(self, x):
518 | return self.decoder(x)
519 |
520 | # main vqgan-vae classes
521 |
522 | class NullVQGanVAE(nn.Module):
523 | def __init__(
524 | self,
525 | *,
526 | channels
527 | ):
528 | super().__init__()
529 | self.encoded_dim = channels
530 | self.layers = 0
531 |
532 | def get_encoded_fmap_size(self, size):
533 | return size
534 |
535 | def copy_for_eval(self):
536 | return self
537 |
538 | def encode(self, x):
539 | return x
540 |
541 | def decode(self, x):
542 | return x
543 |
544 | class VQGanVAE(nn.Module):
545 | def __init__(
546 | self,
547 | *,
548 | dim,
549 | image_size,
550 | channels = 3,
551 | layers = 4,
552 | l2_recon_loss = False,
553 | use_hinge_loss = True,
554 | vgg = None,
555 | vq_codebook_dim = 256,
556 | vq_codebook_size = 512,
557 | vq_decay = 0.8,
558 | vq_commitment_weight = 1.,
559 | vq_kmeans_init = True,
560 | vq_use_cosine_sim = True,
561 | use_vgg_and_gan = True,
562 | vae_type = 'resnet',
563 | discr_layers = 4,
564 | **kwargs
565 | ):
566 | super().__init__()
567 | vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
568 | encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)
569 |
570 | self.image_size = image_size
571 | self.channels = channels
572 | self.codebook_size = vq_codebook_size
573 |
574 | if vae_type == 'resnet':
575 | enc_dec_klass = ResnetEncDec
576 | elif vae_type == 'vit':
577 | enc_dec_klass = ViTEncDec
578 | else:
579 | raise ValueError(f'{vae_type} not valid')
580 |
581 | self.enc_dec = enc_dec_klass(
582 | dim = dim,
583 | channels = channels,
584 | layers = layers,
585 | **encdec_kwargs
586 | )
587 |
588 | self.vq = VQ(
589 | dim = self.enc_dec.encoded_dim,
590 | codebook_dim = vq_codebook_dim,
591 | codebook_size = vq_codebook_size,
592 | decay = vq_decay,
593 | commitment_weight = vq_commitment_weight,
594 | accept_image_fmap = True,
595 | kmeans_init = vq_kmeans_init,
596 | use_cosine_sim = vq_use_cosine_sim,
597 | **vq_kwargs
598 | )
599 |
600 | # reconstruction loss
601 |
602 | self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
603 |
604 | # turn off GAN and perceptual loss if grayscale
605 |
606 | self.vgg = None
607 | self.discr = None
608 | self.use_vgg_and_gan = use_vgg_and_gan
609 |
610 | if not use_vgg_and_gan:
611 | return
612 |
613 | # preceptual loss
614 |
615 | if exists(vgg):
616 | self.vgg = vgg
617 | else:
618 | self.vgg = torchvision.models.vgg16(pretrained = True)
619 | self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
620 |
621 | # gan related losses
622 |
623 | layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
624 | layer_dims = [dim * mult for mult in layer_mults]
625 | dims = (dim, *layer_dims)
626 |
627 | self.discr = Discriminator(dims = dims, channels = channels)
628 |
629 | self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
630 | self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
631 |
632 | @property
633 | def encoded_dim(self):
634 | return self.enc_dec.encoded_dim
635 |
636 | def get_encoded_fmap_size(self, image_size):
637 | return self.enc_dec.get_encoded_fmap_size(image_size)
638 |
639 | def copy_for_eval(self):
640 | device = next(self.parameters()).device
641 | vae_copy = copy.deepcopy(self.cpu())
642 |
643 | if vae_copy.use_vgg_and_gan:
644 | del vae_copy.discr
645 | del vae_copy.vgg
646 |
647 | vae_copy.eval()
648 | return vae_copy.to(device)
649 |
650 | @remove_vgg
651 | def state_dict(self, *args, **kwargs):
652 | return super().state_dict(*args, **kwargs)
653 |
654 | @remove_vgg
655 | def load_state_dict(self, *args, **kwargs):
656 | return super().load_state_dict(*args, **kwargs)
657 |
658 | @property
659 | def codebook(self):
660 | return self.vq.codebook
661 |
662 | def encode(self, fmap):
663 | fmap = self.enc_dec.encode(fmap)
664 | return fmap
665 |
666 | def decode(self, fmap, return_indices_and_loss = False):
667 | fmap, indices, commit_loss = self.vq(fmap)
668 |
669 | fmap = self.enc_dec.decode(fmap)
670 |
671 | if not return_indices_and_loss:
672 | return fmap
673 |
674 | return fmap, indices, commit_loss
675 |
676 | def forward(
677 | self,
678 | img,
679 | return_loss = False,
680 | return_discr_loss = False,
681 | return_recons = False,
682 | add_gradient_penalty = True
683 | ):
684 | batch, channels, height, width, device = *img.shape, img.device
685 | assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
686 | assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
687 |
688 | fmap = self.encode(img)
689 |
690 | fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
691 |
692 | if not return_loss and not return_discr_loss:
693 | return fmap
694 |
695 | assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
696 |
697 | # whether to return discriminator loss
698 |
699 | if return_discr_loss:
700 | assert exists(self.discr), 'discriminator must exist to train it'
701 |
702 | fmap.detach_()
703 | img.requires_grad_()
704 |
705 | fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
706 |
707 | discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
708 |
709 | if add_gradient_penalty:
710 | gp = gradient_penalty(img, img_discr_logits)
711 | loss = discr_loss + gp
712 |
713 | if return_recons:
714 | return loss, fmap
715 |
716 | return loss
717 |
718 | # reconstruction loss
719 |
720 | recon_loss = self.recon_loss_fn(fmap, img)
721 |
722 | # early return if training on grayscale
723 |
724 | if not self.use_vgg_and_gan:
725 | if return_recons:
726 | return recon_loss, fmap
727 |
728 | return recon_loss
729 |
730 | # perceptual loss
731 |
732 | img_vgg_input = img
733 | fmap_vgg_input = fmap
734 |
735 | if img.shape[1] == 1:
736 | # handle grayscale for vgg
737 | img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
738 |
739 | img_vgg_feats = self.vgg(img_vgg_input)
740 | recon_vgg_feats = self.vgg(fmap_vgg_input)
741 | perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
742 |
743 | # generator loss
744 |
745 | gen_loss = self.gen_loss(self.discr(fmap))
746 |
747 | # calculate adaptive weight
748 |
749 | last_dec_layer = self.enc_dec.last_dec_layer
750 |
751 | norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
752 | norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
753 |
754 | adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
755 | adaptive_weight.clamp_(max = 1e4)
756 |
757 | # combine losses
758 |
759 | loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
760 |
761 | if return_recons:
762 | return loss, fmap
763 |
764 | return loss
765 |
--------------------------------------------------------------------------------
/dalle2_pytorch/vqgan_vae_trainer.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | import copy
3 | from random import choice
4 | from pathlib import Path
5 | from shutil import rmtree
6 | from PIL import Image
7 |
8 | import torch
9 | from torch import nn
10 | from torch.cuda.amp import autocast, GradScaler
11 | from torch.utils.data import Dataset, DataLoader, random_split
12 |
13 | import torchvision.transforms as T
14 | from torchvision.datasets import ImageFolder
15 | from torchvision.utils import make_grid, save_image
16 |
17 | from einops import rearrange
18 |
19 | from dalle2_pytorch.vqgan_vae import VQGanVAE
20 | from dalle2_pytorch.optimizer import get_optimizer
21 |
22 | from ema_pytorch import EMA
23 |
24 | # helpers
25 |
26 | def exists(val):
27 | return val is not None
28 |
29 | def noop(*args, **kwargs):
30 | pass
31 |
32 | def cycle(dl):
33 | while True:
34 | for data in dl:
35 | yield data
36 |
37 | def cast_tuple(t):
38 | return t if isinstance(t, (tuple, list)) else (t,)
39 |
40 | def yes_or_no(question):
41 | answer = input(f'{question} (y/n) ')
42 | return answer.lower() in ('yes', 'y')
43 |
44 | def accum_log(log, new_logs):
45 | for key, new_value in new_logs.items():
46 | old_value = log.get(key, 0.)
47 | log[key] = old_value + new_value
48 | return log
49 |
50 | # classes
51 |
52 | class ImageDataset(Dataset):
53 | def __init__(
54 | self,
55 | folder,
56 | image_size,
57 | exts = ['jpg', 'jpeg', 'png']
58 | ):
59 | super().__init__()
60 | self.folder = folder
61 | self.image_size = image_size
62 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
63 |
64 | print(f'{len(self.paths)} training samples found at {folder}')
65 |
66 | self.transform = T.Compose([
67 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
68 | T.Resize(image_size),
69 | T.RandomHorizontalFlip(),
70 | T.CenterCrop(image_size),
71 | T.ToTensor()
72 | ])
73 |
74 | def __len__(self):
75 | return len(self.paths)
76 |
77 | def __getitem__(self, index):
78 | path = self.paths[index]
79 | img = Image.open(path)
80 | return self.transform(img)
81 |
82 | # main trainer class
83 |
84 | class VQGanVAETrainer(nn.Module):
85 | def __init__(
86 | self,
87 | vae,
88 | *,
89 | num_train_steps,
90 | lr,
91 | batch_size,
92 | folder,
93 | grad_accum_every,
94 | wd = 0.,
95 | save_results_every = 100,
96 | save_model_every = 1000,
97 | results_folder = './results',
98 | valid_frac = 0.05,
99 | random_split_seed = 42,
100 | ema_beta = 0.995,
101 | ema_update_after_step = 500,
102 | ema_update_every = 10,
103 | apply_grad_penalty_every = 4,
104 | amp = False
105 | ):
106 | super().__init__()
107 | assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
108 | image_size = vae.image_size
109 |
110 | self.vae = vae
111 | self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
112 |
113 | self.register_buffer('steps', torch.Tensor([0]))
114 |
115 | self.num_train_steps = num_train_steps
116 | self.batch_size = batch_size
117 | self.grad_accum_every = grad_accum_every
118 |
119 | all_parameters = set(vae.parameters())
120 | discr_parameters = set(vae.discr.parameters())
121 | vae_parameters = all_parameters - discr_parameters
122 |
123 | self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
124 | self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
125 |
126 | self.amp = amp
127 | self.scaler = GradScaler(enabled = amp)
128 | self.discr_scaler = GradScaler(enabled = amp)
129 |
130 | # create dataset
131 |
132 | self.ds = ImageDataset(folder, image_size = image_size)
133 |
134 | # split for validation
135 |
136 | if valid_frac > 0:
137 | train_size = int((1 - valid_frac) * len(self.ds))
138 | valid_size = len(self.ds) - train_size
139 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
140 | print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
141 | else:
142 | self.valid_ds = self.ds
143 | print(f'training with shared training and valid dataset of {len(self.ds)} samples')
144 |
145 | # dataloader
146 |
147 | self.dl = cycle(DataLoader(
148 | self.ds,
149 | batch_size = batch_size,
150 | shuffle = True
151 | ))
152 |
153 | self.valid_dl = cycle(DataLoader(
154 | self.valid_ds,
155 | batch_size = batch_size,
156 | shuffle = True
157 | ))
158 |
159 | self.save_model_every = save_model_every
160 | self.save_results_every = save_results_every
161 |
162 | self.apply_grad_penalty_every = apply_grad_penalty_every
163 |
164 | self.results_folder = Path(results_folder)
165 |
166 | if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
167 | rmtree(str(self.results_folder))
168 |
169 | self.results_folder.mkdir(parents = True, exist_ok = True)
170 |
171 | def train_step(self):
172 | device = next(self.vae.parameters()).device
173 | steps = int(self.steps.item())
174 | apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
175 |
176 | self.vae.train()
177 |
178 | # logs
179 |
180 | logs = {}
181 |
182 | # update vae (generator)
183 |
184 | for _ in range(self.grad_accum_every):
185 | img = next(self.dl)
186 | img = img.to(device)
187 |
188 | with autocast(enabled = self.amp):
189 | loss = self.vae(
190 | img,
191 | return_loss = True,
192 | apply_grad_penalty = apply_grad_penalty
193 | )
194 |
195 |
196 | self.scaler.scale(loss / self.grad_accum_every).backward()
197 |
198 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
199 |
200 | self.scaler.step(self.optim)
201 | self.scaler.update()
202 | self.optim.zero_grad()
203 |
204 | # update discriminator
205 |
206 | if exists(self.vae.discr):
207 | discr_loss = 0
208 | for _ in range(self.grad_accum_every):
209 | img = next(self.dl)
210 | img = img.to(device)
211 |
212 | with autocast(enabled = self.amp):
213 | loss = self.vae(img, return_discr_loss = True)
214 |
215 | self.discr_scaler.scale(loss / self.grad_accum_every).backward()
216 |
217 | accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
218 |
219 | self.discr_scaler.step(self.discr_optim)
220 | self.discr_scaler.update()
221 | self.discr_optim.zero_grad()
222 |
223 | # log
224 |
225 | print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
226 |
227 | # update exponential moving averaged generator
228 |
229 | self.ema_vae.update()
230 |
231 | # sample results every so often
232 |
233 | if not (steps % self.save_results_every):
234 | for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
235 | model.eval()
236 |
237 | imgs = next(self.dl)
238 | imgs = imgs.to(device)
239 |
240 | recons = model(imgs)
241 | nrows = int(sqrt(self.batch_size))
242 |
243 | imgs_and_recons = torch.stack((imgs, recons), dim = 0)
244 | imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
245 |
246 | imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
247 | grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
248 |
249 | logs['reconstructions'] = grid
250 |
251 | save_image(grid, str(self.results_folder / f'{filename}.png'))
252 |
253 | print(f'{steps}: saving to {str(self.results_folder)}')
254 |
255 | # save model every so often
256 |
257 | if not (steps % self.save_model_every):
258 | state_dict = self.vae.state_dict()
259 | model_path = str(self.results_folder / f'vae.{steps}.pt')
260 | torch.save(state_dict, model_path)
261 |
262 | ema_state_dict = self.ema_vae.state_dict()
263 | model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
264 | torch.save(ema_state_dict, model_path)
265 |
266 | print(f'{steps}: saving model to {str(self.results_folder)}')
267 |
268 | self.steps += 1
269 | return logs
270 |
271 | def train(self, log_fn = noop):
272 | device = next(self.vae.parameters()).device
273 |
274 | while self.steps < self.num_train_steps:
275 | logs = self.train_step()
276 | log_fn(logs)
277 |
278 | print('training complete')
279 |
--------------------------------------------------------------------------------
/prior.md:
--------------------------------------------------------------------------------
1 | # Diffusion Prior
2 | This readme serves as an introduction to the diffusion prior.
3 |
4 | ## Intro
5 |
6 | A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful.
7 |
8 | ### Motivation
9 |
10 | Before we dive into the model, let’s look at a quick example of where the model may be helpful.
11 |
12 | For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.
13 |
14 | > [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets.
15 |
16 | ```python
17 | # Load Models
18 | clip_model = clip.load("ViT-L/14")
19 | decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings
20 |
21 | # Retrieve prompt from user and encode with CLIP
22 | prompt = "A corgi wearing sunglasses"
23 | tokenized_text = tokenize(prompt)
24 | text_embedding = clip_model.encode_text(tokenized_text)
25 |
26 | # Now, pass the text embedding to the decoder
27 | predicted_image = decoder.sample(text_embedding)
28 | ```
29 |
30 | > **Question**: *Can you spot the issue here?*
31 | >
32 | > **Answer**: *We’re trying to generate an image from a text embedding!*
33 |
34 | Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution
35 |
36 | ```python
37 | # Load Models
38 | prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb
39 | decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings
40 |
41 | # Retrieve prompt from user and encode with a prior
42 | prompt = "A corgi wearing sunglasses"
43 | tokenized_text = tokenize(prompt)
44 | text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!
45 |
46 | # Now, pass the predicted image embedding to the decoder
47 | predicted_image = decoder.sample(text_embedding)
48 | ```
49 |
50 | With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.
51 |
52 | > **You may be asking yourself the following question:**
53 | >
54 | > *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"*
55 | >
56 | > OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile:
57 |
58 | ## Usage
59 |
60 | To utilize a pre-trained prior, it’s quite simple.
61 |
62 | ### Loading Checkpoints
63 | ```python
64 | import torch
65 | from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
66 | from dalle2_pytorch.trainer import DiffusionPriorTrainer
67 |
68 | def load_diffusion_model(dprior_path):
69 |
70 | prior_network = DiffusionPriorNetwork(
71 | dim=768,
72 | depth=24,
73 | dim_head=64,
74 | heads=32,
75 | normformer=True,
76 | attn_dropout=5e-2,
77 | ff_dropout=5e-2,
78 | num_time_embeds=1,
79 | num_image_embeds=1,
80 | num_text_embeds=1,
81 | num_timesteps=1000,
82 | ff_mult=4
83 | )
84 |
85 | diffusion_prior = DiffusionPrior(
86 | net=prior_network,
87 | clip=OpenAIClipAdapter("ViT-L/14"),
88 | image_embed_dim=768,
89 | timesteps=1000,
90 | cond_drop_prob=0.1,
91 | loss_type="l2",
92 | condition_on_text_encodings=True,
93 |
94 | )
95 |
96 | trainer = DiffusionPriorTrainer(
97 | diffusion_prior=diffusion_prior,
98 | lr=1.1e-4,
99 | wd=6.02e-2,
100 | max_grad_norm=0.5,
101 | amp=False,
102 | group_wd_params=True,
103 | use_ema=True,
104 | device=device,
105 | accelerator=None,
106 | )
107 |
108 | trainer.load(dprior_path)
109 |
110 | return trainer
111 | ```
112 |
113 | Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*)
114 |
115 | ### Sampling
116 | Once we have a pre-trained model, generating embeddings is quite simple!
117 | ```python
118 | # tokenize the text
119 | tokenized_text = clip.tokenize("")
120 | # predict an embedding
121 | predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)
122 | ```
123 |
124 | The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768).
125 |
126 | > For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().
127 |
128 | **Some things to note:**
129 | * It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.
130 | * You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*.
131 |
132 | ---
133 |
134 | ## Training
135 |
136 | ### Overview
137 |
138 | Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration
139 |
140 | ## Dataset
141 |
142 | To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader.
143 |
144 | ## Configuration
145 |
146 | The configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that will specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.
147 |
148 | ## Distributed Training
149 |
150 | If you would like to train in a distributed manner we have opted to leverage huggingface’ new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPU’s and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator).
151 |
152 | ## Evaluation
153 |
154 | There are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:
155 | | Metric | Description | Comments |
156 | | ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
157 | | Online Model Validation | The validation loss associated with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. |
158 | | EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your "online" model's validation loss, but should outperform in the long-term. |
159 | | Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. |
160 | | Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) |
161 | | Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting somehow. |
162 | | Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. |
163 | | Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. |
164 |
165 | ## Launching the script
166 |
167 | Now that you’ve done all the prep it’s time for the easy part! 🚀
168 |
169 | To actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path ` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate.
170 |
171 | ## Checkpointing
172 |
173 | Checkpoints will be saved to the directory specified in your configuration file.
174 |
175 | Additionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and titled “latest.pth”. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data.
176 |
177 | ## Things To Keep In Mind
178 |
179 | The prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet.
180 |
181 | As we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.
182 |
183 | With that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you don’t see documentation for!
184 |
--------------------------------------------------------------------------------
/samples/oxford.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/samples/oxford.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | exec(open('dalle2_pytorch/version.py').read())
3 |
4 | setup(
5 | name = 'dalle2-pytorch',
6 | packages = find_packages(exclude=[]),
7 | include_package_data = True,
8 | entry_points={
9 | 'console_scripts': [
10 | 'dalle2_pytorch = dalle2_pytorch.cli:main',
11 | 'dream = dalle2_pytorch.cli:dream'
12 | ],
13 | },
14 | version = __version__,
15 | license='MIT',
16 | description = 'DALL-E 2',
17 | author = 'Phil Wang',
18 | author_email = 'lucidrains@gmail.com',
19 | long_description_content_type = 'text/markdown',
20 | url = 'https://github.com/lucidrains/dalle2-pytorch',
21 | keywords = [
22 | 'artificial intelligence',
23 | 'deep learning',
24 | 'text to image'
25 | ],
26 | install_requires=[
27 | 'accelerate',
28 | 'click',
29 | 'open-clip-torch>=2.0.0,<3.0.0',
30 | 'clip-anytorch>=2.5.2',
31 | 'coca-pytorch>=0.0.5',
32 | 'ema-pytorch>=0.0.7',
33 | 'einops>=0.7.0',
34 | 'embedding-reader',
35 | 'kornia>=0.5.4',
36 | 'numpy',
37 | 'packaging',
38 | 'pillow',
39 | 'pydantic>=2',
40 | 'pytorch-warmup',
41 | 'resize-right>=0.0.2',
42 | 'rotary-embedding-torch',
43 | 'torch>=1.10',
44 | 'torchvision',
45 | 'tqdm',
46 | 'vector-quantize-pytorch',
47 | 'x-clip>=0.4.4',
48 | 'webdataset>=0.2.5',
49 | 'fsspec>=2022.1.0',
50 | 'torchmetrics[image]>=0.8.0'
51 | ],
52 | classifiers=[
53 | 'Development Status :: 4 - Beta',
54 | 'Intended Audience :: Developers',
55 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
56 | 'License :: OSI Approved :: MIT License',
57 | 'Programming Language :: Python :: 3.6',
58 | ],
59 | )
60 |
--------------------------------------------------------------------------------
/test_data/0.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/0.tar
--------------------------------------------------------------------------------
/test_data/1.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/1.tar
--------------------------------------------------------------------------------
/test_data/2.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/2.tar
--------------------------------------------------------------------------------
/test_data/3.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/3.tar
--------------------------------------------------------------------------------
/test_data/4.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/4.tar
--------------------------------------------------------------------------------
/test_data/5.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/5.tar
--------------------------------------------------------------------------------
/test_data/6.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/6.tar
--------------------------------------------------------------------------------
/test_data/7.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/7.tar
--------------------------------------------------------------------------------
/test_data/8.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/8.tar
--------------------------------------------------------------------------------
/test_data/9.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/DALLE2-pytorch/680dfc4d93b70f9ab23c814a22ca18017a738ef6/test_data/9.tar
--------------------------------------------------------------------------------
/train_diffusion_prior.py:
--------------------------------------------------------------------------------
1 | import click
2 | import torch
3 |
4 | from torch import nn
5 | from typing import List
6 | from accelerate import Accelerator
7 | from accelerate.utils import set_seed
8 | from torch.utils.data import DataLoader
9 | from embedding_reader import EmbeddingReader
10 | from accelerate.utils import dataclasses as accelerate_dataclasses
11 |
12 | from dalle2_pytorch.utils import Timer
13 | from dalle2_pytorch.trackers import Tracker
14 | from dalle2_pytorch import DiffusionPriorTrainer
15 | from dalle2_pytorch.dataloaders import get_reader, make_splits
16 | from dalle2_pytorch.train_configs import (
17 | DiffusionPriorConfig,
18 | DiffusionPriorTrainConfig,
19 | TrainDiffusionPriorConfig,
20 | )
21 |
22 |
23 | # helpers
24 |
25 |
26 | cos = nn.CosineSimilarity(dim=1, eps=1e-6)
27 |
28 |
29 | def exists(val):
30 | return val is not None
31 |
32 |
33 | def all_between(values: list, lower_bound, upper_bound):
34 | for value in values:
35 | if value < lower_bound or value > upper_bound:
36 | return False
37 |
38 | return True
39 |
40 |
41 | def make_model(
42 | prior_config: DiffusionPriorConfig,
43 | train_config: DiffusionPriorTrainConfig,
44 | device: str = None,
45 | accelerator: Accelerator = None,
46 | ):
47 | # create model from config
48 | diffusion_prior = prior_config.create()
49 |
50 | # instantiate the trainer
51 | trainer = DiffusionPriorTrainer(
52 | diffusion_prior=diffusion_prior,
53 | lr=train_config.lr,
54 | wd=train_config.wd,
55 | max_grad_norm=train_config.max_grad_norm,
56 | amp=train_config.amp,
57 | use_ema=train_config.use_ema,
58 | device=device,
59 | accelerator=accelerator,
60 | warmup_steps=train_config.warmup_steps,
61 | )
62 |
63 | return trainer
64 |
65 |
66 | def create_tracker(
67 | accelerator: Accelerator,
68 | config: TrainDiffusionPriorConfig,
69 | config_path: str,
70 | dummy: bool = False,
71 | ) -> Tracker:
72 | tracker_config = config.tracker
73 |
74 | accelerator_config = {
75 | "Distributed": accelerator.distributed_type
76 | != accelerate_dataclasses.DistributedType.NO,
77 | "DistributedType": accelerator.distributed_type,
78 | "NumProcesses": accelerator.num_processes,
79 | "MixedPrecision": accelerator.mixed_precision,
80 | }
81 |
82 | tracker: Tracker = tracker_config.create(
83 | config, accelerator_config, dummy_mode=dummy
84 | )
85 |
86 | tracker.save_config(config_path, config_name="prior_config.json")
87 |
88 | return tracker
89 |
90 |
91 | def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"):
92 | """
93 | pad a value or tensor across all processes and gather
94 |
95 | params:
96 | - trainer: a trainer that carries an accelerator object
97 | - x: a number or torch tensor to reduce
98 | - method: "mean", "sum", "max", "min"
99 |
100 | return:
101 | - the average tensor after maskin out 0's
102 | - None if the gather resulted in an empty tensor
103 | """
104 |
105 | assert method in [
106 | "mean",
107 | "sum",
108 | "max",
109 | "min",
110 | ], "This function has limited capabilities [sum, mean, max, min]"
111 | assert type(x) is not None, "Cannot reduce a None type object"
112 |
113 | # wait for everyone to arrive here before gathering
114 |
115 | if type(x) is not torch.Tensor:
116 | x = torch.tensor([x])
117 |
118 | # verify that the tensor is on the proper device
119 | x = x.to(trainer.device)
120 |
121 | # pad across processes
122 | padded_x = trainer.accelerator.pad_across_processes(x, dim=0)
123 |
124 | # gather across all procesess
125 | gathered_x = trainer.accelerator.gather(padded_x)
126 |
127 | # mask out zeros
128 | masked_x = gathered_x[gathered_x != 0]
129 |
130 | # if the tensor is empty, warn and return None
131 | if len(masked_x) == 0:
132 | click.secho(
133 | f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.",
134 | fg="red",
135 | )
136 | return None
137 |
138 | if method == "mean":
139 | return torch.mean(masked_x)
140 | elif method == "sum":
141 | return torch.sum(masked_x)
142 | elif method == "max":
143 | return torch.max(masked_x)
144 | elif method == "min":
145 | return torch.min(masked_x)
146 |
147 |
148 | def save_trainer(
149 | tracker: Tracker,
150 | trainer: DiffusionPriorTrainer,
151 | is_latest: bool,
152 | is_best: bool,
153 | epoch: int,
154 | samples_seen: int,
155 | best_validation_loss: float,
156 | ):
157 | """
158 | Logs the model with an appropriate method depending on the tracker
159 | """
160 | trainer.accelerator.wait_for_everyone()
161 |
162 | if trainer.accelerator.is_main_process:
163 | click.secho(
164 | f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}",
165 | fg="magenta",
166 | )
167 |
168 | tracker.save(
169 | trainer=trainer,
170 | is_best=is_best,
171 | is_latest=is_latest,
172 | epoch=int(epoch),
173 | samples_seen=int(samples_seen),
174 | best_validation_loss=best_validation_loss,
175 | )
176 |
177 |
178 | def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):
179 | """
180 | Loads the model with an appropriate method depending on the tracker
181 | """
182 |
183 | if trainer.accelerator.is_main_process:
184 | click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow")
185 |
186 | state_dict = tracker.recall()
187 |
188 | trainer.load(state_dict, strict=True)
189 |
190 | return (
191 | int(state_dict.get("epoch", 0)),
192 | state_dict.get("best_validation_loss", 0),
193 | int(state_dict.get("samples_seen", 0)),
194 | )
195 |
196 |
197 | # eval functions
198 |
199 |
200 | def report_validation_loss(
201 | trainer: DiffusionPriorTrainer,
202 | dataloader: DataLoader,
203 | text_conditioned: bool,
204 | use_ema: bool,
205 | tracker: Tracker,
206 | split: str,
207 | tracker_folder: str,
208 | loss_type: str,
209 | ):
210 | """
211 | Compute the validation loss on a given subset of data.
212 | """
213 |
214 | if trainer.accelerator.is_main_process:
215 | click.secho(
216 | f"Measuring performance on {use_ema}-{split} split",
217 | fg="green",
218 | blink=True,
219 | )
220 |
221 | total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device)
222 |
223 | for image_embeddings, text_data in dataloader:
224 | image_embeddings = image_embeddings.to(trainer.device)
225 | text_data = text_data.to(trainer.device)
226 |
227 | input_args = dict(image_embed=image_embeddings)
228 |
229 | if text_conditioned:
230 | input_args = dict(**input_args, text=text_data)
231 | else:
232 | input_args = dict(**input_args, text_embed=text_data)
233 |
234 | if use_ema:
235 | loss = trainer.ema_diffusion_prior(**input_args)
236 | else:
237 | loss = trainer(**input_args)
238 |
239 | total_loss += loss
240 |
241 | # compute the average loss across all processes
242 |
243 | avg_loss = pad_gather_reduce(trainer, total_loss, method="mean")
244 | stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss}
245 |
246 | # print and log results on main process
247 | tracker.log(stats, step=trainer.step.item() + 1)
248 |
249 | return avg_loss
250 |
251 |
252 | def report_cosine_sims(
253 | trainer: DiffusionPriorTrainer,
254 | dataloader: DataLoader,
255 | text_conditioned: bool,
256 | tracker: Tracker,
257 | split: str,
258 | timesteps: int,
259 | tracker_folder: str,
260 | ):
261 | trainer.eval()
262 | if trainer.accelerator.is_main_process:
263 | click.secho(
264 | f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps",
265 | fg="green",
266 | blink=True,
267 | )
268 |
269 | for test_image_embeddings, text_data in dataloader:
270 | test_image_embeddings = test_image_embeddings.to(trainer.device)
271 | text_data = text_data.to(trainer.device)
272 |
273 | # we are text conditioned, we produce an embedding from the tokenized text
274 | if text_conditioned:
275 | text_embedding, text_encodings = trainer.embed_text(text_data)
276 | text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings)
277 | else:
278 | text_embedding = text_data
279 | text_cond = dict(text_embed=text_embedding)
280 |
281 | # make a copy of the text embeddings for shuffling
282 | text_embed_shuffled = text_embedding.clone()
283 |
284 | # roll the text to simulate "unrelated" captions
285 | rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
286 | text_embed_shuffled = text_embed_shuffled[rolled_idx]
287 | text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(
288 | dim=1, keepdim=True
289 | )
290 |
291 | if text_conditioned:
292 | text_encodings_shuffled = text_encodings[rolled_idx]
293 | else:
294 | text_encodings_shuffled = None
295 |
296 | text_cond_shuffled = dict(
297 | text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled
298 | )
299 |
300 | # prepare the text embedding
301 | text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
302 |
303 | # prepare image embeddings
304 | test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(
305 | dim=1, keepdim=True
306 | )
307 |
308 | # predict on the unshuffled text embeddings
309 | predicted_image_embeddings = trainer.p_sample_loop(
310 | test_image_embeddings.shape,
311 | text_cond,
312 | timesteps=timesteps,
313 | )
314 |
315 | predicted_image_embeddings = (
316 | predicted_image_embeddings
317 | / predicted_image_embeddings.norm(dim=1, keepdim=True)
318 | )
319 |
320 | # predict on the shuffled embeddings
321 | predicted_unrelated_embeddings = trainer.p_sample_loop(
322 | test_image_embeddings.shape,
323 | text_cond_shuffled,
324 | timesteps=timesteps,
325 | )
326 |
327 | predicted_unrelated_embeddings = (
328 | predicted_unrelated_embeddings
329 | / predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
330 | )
331 |
332 | # calculate similarities
333 | orig_sim = pad_gather_reduce(
334 | trainer, cos(text_embed, test_image_embeddings), method="mean"
335 | )
336 | pred_sim = pad_gather_reduce(
337 | trainer, cos(text_embed, predicted_image_embeddings), method="mean"
338 | )
339 | unrel_sim = pad_gather_reduce(
340 | trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean"
341 | )
342 | pred_img_sim = pad_gather_reduce(
343 | trainer,
344 | cos(test_image_embeddings, predicted_image_embeddings),
345 | method="mean",
346 | )
347 |
348 | stats = {
349 | f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim,
350 | f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim,
351 | f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim,
352 | f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim,
353 | f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim
354 | - orig_sim,
355 | }
356 |
357 | tracker.log(stats, step=trainer.step.item() + 1)
358 |
359 |
360 | def eval_model(
361 | trainer: DiffusionPriorTrainer,
362 | dataloader: DataLoader,
363 | text_conditioned: bool,
364 | split: str,
365 | tracker: Tracker,
366 | use_ema: bool,
367 | report_cosine: bool,
368 | report_loss: bool,
369 | timesteps: List[int],
370 | loss_type: str = None,
371 | ):
372 | """
373 | Run evaluation on a model and track metrics
374 |
375 | returns: loss if requested
376 | """
377 | trainer.eval()
378 |
379 | use_ema = "ema" if use_ema else "online"
380 | tracker_folder = f"metrics/{use_ema}-{split}"
381 |
382 | # detemine if valid timesteps are passed
383 |
384 | min_timesteps = trainer.accelerator.unwrap_model(
385 | trainer.diffusion_prior
386 | ).sample_timesteps
387 | max_timesteps = trainer.accelerator.unwrap_model(
388 | trainer.diffusion_prior
389 | ).noise_scheduler.num_timesteps
390 |
391 | assert all_between(
392 | timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps
393 | ), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}"
394 |
395 | # measure cosine metrics across various eta and timesteps
396 |
397 | if report_cosine:
398 | for timestep in timesteps:
399 | report_cosine_sims(
400 | trainer,
401 | dataloader=dataloader,
402 | text_conditioned=text_conditioned,
403 | tracker=tracker,
404 | split=split,
405 | timesteps=timestep,
406 | tracker_folder=tracker_folder,
407 | )
408 |
409 | # measure loss on a seperate split of data
410 |
411 | if report_loss:
412 | loss = report_validation_loss(
413 | trainer=trainer,
414 | dataloader=dataloader,
415 | text_conditioned=text_conditioned,
416 | use_ema=use_ema,
417 | tracker=tracker,
418 | split=split,
419 | tracker_folder=tracker_folder,
420 | loss_type=loss_type,
421 | )
422 |
423 | return loss
424 |
425 |
426 | # training script
427 |
428 |
429 | def train(
430 | trainer: DiffusionPriorTrainer,
431 | tracker: Tracker,
432 | train_loader: DataLoader,
433 | eval_loader: DataLoader,
434 | test_loader: DataLoader,
435 | config: DiffusionPriorTrainConfig,
436 | ):
437 | # init timers
438 | save_timer = Timer() # when to save
439 | samples_timer = Timer() # samples/sec
440 | validation_profiler = Timer() # how long is validation taking
441 | validation_countdown = Timer() # when to perform evalutation
442 |
443 | # keep track of best validation loss
444 |
445 | best_validation_loss = config.train.best_validation_loss
446 | samples_seen = config.train.num_samples_seen
447 |
448 | # do training
449 |
450 | start_epoch = config.train.current_epoch
451 |
452 | for epoch in range(start_epoch, config.train.epochs):
453 | # if we finished out an old epoch, reset the distribution to be a full epoch
454 | tracker.log({"tracking/epoch": epoch}, step=trainer.step.item())
455 |
456 | if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1:
457 | if trainer.accelerator.is_main_process:
458 | click.secho(f"Finished resumed epoch...resetting dataloader.")
459 | train_loader.dataset.set_start(0)
460 |
461 | for img, txt in train_loader:
462 | # setup things every step
463 |
464 | trainer.train()
465 | current_step = trainer.step.item()
466 | samples_timer.reset()
467 |
468 | # place data on device
469 |
470 | img = img.to(trainer.device)
471 | txt = txt.to(trainer.device)
472 |
473 | # pass to model
474 |
475 | loss = trainer(text=txt, image_embed=img)
476 |
477 | # perform backprop & apply EMA updates
478 |
479 | trainer.update()
480 |
481 | # gather info about training step
482 |
483 | all_loss = pad_gather_reduce(trainer, loss, method="mean")
484 | num_samples = pad_gather_reduce(trainer, len(txt), method="sum")
485 | samples_per_sec = num_samples / samples_timer.elapsed()
486 | samples_seen += num_samples
487 | ema_decay = trainer.ema_diffusion_prior.get_current_decay()
488 |
489 | # log
490 |
491 | tracker.log(
492 | {
493 | "tracking/samples-sec": samples_per_sec,
494 | "tracking/samples-seen": samples_seen,
495 | "tracking/ema-decay": ema_decay,
496 | f"tracking/training-{config.prior.loss_type}": all_loss,
497 | },
498 | step=current_step,
499 | )
500 |
501 | # Metric Tracking @ Timed Intervals
502 |
503 | eval_delta = pad_gather_reduce(
504 | trainer, validation_countdown.elapsed(), method="min"
505 | )
506 |
507 | if eval_delta != None and eval_delta > config.data.eval_every_seconds:
508 | # begin timing how long this takes
509 |
510 | validation_profiler.reset()
511 |
512 | # package kwargs for evaluation
513 |
514 | eval_kwargs = {
515 | "trainer": trainer,
516 | "tracker": tracker,
517 | "text_conditioned": config.prior.condition_on_text_encodings,
518 | "timesteps": config.train.eval_timesteps,
519 | }
520 |
521 | # ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT
522 |
523 | eval_model(
524 | dataloader=eval_loader,
525 | loss_type=config.prior.loss_type,
526 | split="validation",
527 | use_ema=False,
528 | report_cosine=False,
529 | report_loss=True,
530 | **eval_kwargs,
531 | )
532 |
533 | # EMA MODEL : COSINE : LOSS : VALIDATION DATA
534 |
535 | ema_val_loss = eval_model(
536 | dataloader=eval_loader,
537 | loss_type=config.prior.loss_type,
538 | split="validation",
539 | use_ema=True,
540 | report_cosine=True,
541 | report_loss=True,
542 | **eval_kwargs,
543 | )
544 |
545 | tracker.log(
546 | {
547 | "tracking/validation length (minutes)": validation_profiler.elapsed()
548 | / 60
549 | }
550 | )
551 |
552 | # check if the ema validation is the lowest seen yet
553 |
554 | if ema_val_loss < best_validation_loss:
555 | best_validation_loss = ema_val_loss
556 |
557 | # go save the model as best
558 |
559 | save_trainer(
560 | trainer=trainer,
561 | tracker=tracker,
562 | is_best=True,
563 | is_latest=False,
564 | samples_seen=samples_seen,
565 | epoch=epoch,
566 | best_validation_loss=best_validation_loss,
567 | )
568 |
569 | # reset timer for validaiton
570 |
571 | validation_countdown.reset()
572 |
573 | elif eval_delta is None:
574 | click.secho(
575 | f"Error occured reading the eval time on rank: {trainer.device}",
576 | fg="yellow",
577 | )
578 |
579 | # save as latest model on schedule
580 |
581 | save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method="min")
582 |
583 | if save_delta != None and save_delta >= config.train.save_every_seconds:
584 | save_trainer(
585 | trainer=trainer,
586 | tracker=tracker,
587 | is_best=False,
588 | is_latest=True,
589 | samples_seen=samples_seen,
590 | epoch=epoch,
591 | best_validation_loss=best_validation_loss,
592 | )
593 |
594 | save_timer.reset()
595 |
596 | elif save_delta is None:
597 | click.secho(
598 | f"Error occured reading the save time on rank: {trainer.device}",
599 | fg="yellow",
600 | )
601 |
602 | # evaluate on test data
603 |
604 | if trainer.accelerator.is_main_process:
605 | click.secho(f"Starting Test", fg="red")
606 |
607 | # save one last time as latest before beginning validation
608 |
609 | save_trainer(
610 | tracker=tracker,
611 | trainer=trainer,
612 | is_best=False,
613 | is_latest=True,
614 | samples_seen=samples_seen,
615 | epoch=epoch,
616 | best_validation_loss=best_validation_loss,
617 | )
618 |
619 | test_loss = eval_model(
620 | trainer=trainer,
621 | dataloader=test_loader,
622 | text_conditioned=config.prior.condition_on_text_encodings,
623 | split="test",
624 | tracker=tracker,
625 | use_ema=True,
626 | report_cosine=False,
627 | report_loss=True,
628 | timesteps=config.train.eval_timesteps,
629 | loss_type=config.prior.loss_type,
630 | )
631 |
632 | if test_loss < best_validation_loss:
633 | best_validation_loss = test_loss
634 |
635 | # go save the model as best
636 |
637 | save_trainer(
638 | trainer=trainer,
639 | tracker=tracker,
640 | is_best=True,
641 | is_latest=False,
642 | samples_seen=samples_seen,
643 | epoch=epoch,
644 | best_validation_loss=test_loss,
645 | )
646 |
647 |
648 | def initialize_training(config_file, accelerator):
649 | """
650 | Parse the configuration file, and prepare everything necessary for training
651 | """
652 | # load the configuration file
653 | if accelerator.is_main_process:
654 | click.secho(f"Loading configuration from {config_file}", fg="green")
655 |
656 | config = TrainDiffusionPriorConfig.from_json_path(config_file)
657 |
658 | # seed
659 |
660 | set_seed(config.train.random_seed)
661 |
662 | # get a device
663 |
664 | device = accelerator.device
665 |
666 | # make the trainer (will automatically distribute if possible & configured)
667 |
668 | trainer: DiffusionPriorTrainer = make_model(
669 | config.prior, config.train, device, accelerator
670 | ).to(device)
671 |
672 | # create a tracker
673 |
674 | tracker = create_tracker(
675 | accelerator, config, config_file, dummy=accelerator.process_index != 0
676 | )
677 |
678 | # reload from chcekpoint
679 |
680 | if tracker.can_recall:
681 | current_epoch, best_validation_loss, samples_seen = recall_trainer(
682 | tracker=tracker, trainer=trainer
683 | )
684 |
685 | # display best values
686 | if trainer.accelerator.is_main_process:
687 | click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow")
688 |
689 | # update config to reflect recalled values
690 | config.train.num_samples_seen = samples_seen
691 | config.train.current_epoch = current_epoch
692 | config.train.best_validation_loss = best_validation_loss
693 |
694 | # fetch and prepare data
695 |
696 | if trainer.accelerator.is_main_process:
697 | click.secho("Grabbing data...", fg="blue", blink=True)
698 |
699 | trainer.accelerator.wait_for_everyone()
700 | img_reader = get_reader(
701 | text_conditioned=trainer.text_conditioned,
702 | img_url=config.data.image_url,
703 | meta_url=config.data.meta_url,
704 | )
705 |
706 | # calculate start point within epoch
707 |
708 | trainer.accelerator.wait_for_everyone()
709 |
710 | train_loader, eval_loader, test_loader = make_splits(
711 | text_conditioned=trainer.text_conditioned,
712 | batch_size=config.data.batch_size,
713 | num_data_points=config.data.num_data_points,
714 | train_split=config.data.splits.train,
715 | eval_split=config.data.splits.val,
716 | image_reader=img_reader,
717 | rank=accelerator.state.process_index,
718 | world_size=accelerator.state.num_processes,
719 | start=0,
720 | )
721 |
722 | # update the start point to finish out the epoch on a resumed run
723 |
724 | if tracker.can_recall:
725 | samples_seen = config.train.num_samples_seen
726 | length = (
727 | config.data.num_data_points
728 | if samples_seen <= img_reader.count
729 | else img_reader.count
730 | )
731 | scaled_samples = length * config.train.current_epoch
732 | start_point = (
733 | scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen
734 | )
735 |
736 | if trainer.accelerator.is_main_process:
737 | click.secho(f"Resuming at sample: {start_point}", fg="yellow")
738 |
739 | train_loader.dataset.set_start(start_point)
740 |
741 | # start training
742 |
743 | if trainer.accelerator.is_main_process:
744 | click.secho(
745 | f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}",
746 | fg="yellow",
747 | )
748 |
749 | train(
750 | trainer=trainer,
751 | tracker=tracker,
752 | train_loader=train_loader,
753 | eval_loader=eval_loader,
754 | test_loader=test_loader,
755 | config=config,
756 | )
757 |
758 |
759 | @click.command()
760 | @click.option("--config_file", default="configs/train_prior_config.example.json")
761 | def main(config_file):
762 | # start HFA
763 | accelerator = Accelerator()
764 |
765 | # setup training
766 | initialize_training(config_file, accelerator)
767 |
768 |
769 | if __name__ == "__main__":
770 | main()
771 |
--------------------------------------------------------------------------------