├── README.md ├── assets ├── foxchain.png ├── image_editing.png ├── imagebart_poster.pdf ├── modelfigure.png └── sample-001.jpg ├── configs ├── ffhq │ ├── 2_scales │ │ ├── ffhq-custom-scale1.yaml │ │ └── ffhq-custom-scale2.yaml │ ├── 4_scales │ │ ├── ffhq-geometric-scale1.yaml │ │ ├── ffhq-geometric-scale2.yaml │ │ ├── ffhq-geometric-scale3.yaml │ │ └── ffhq-geometric-scale4.yaml │ └── ffhq_4-scales_joint-training.yaml ├── imagenet │ ├── 4_scales │ │ ├── imagenet_geometric_scale1.yaml │ │ ├── imagenet_geometric_scale2.yaml │ │ ├── imagenet_geometric_scale3.yaml │ │ └── imagenet_geometric_scale4.yaml │ └── 5_scales │ │ ├── imagenet-custom-scale1.yaml │ │ ├── imagenet-custom-scale2.yaml │ │ ├── imagenet-custom-scale3.yaml │ │ ├── imagenet-custom-scale4.yaml │ │ └── imagenet-custom-scale5.yaml ├── lsun-beds │ ├── lsun-beds-scale1.yaml │ ├── lsun-beds-scale2.yaml │ └── lsun-beds-scale3.yaml ├── lsun-cats │ ├── lsun-cats-scale1.yaml │ ├── lsun-cats-scale2.yaml │ └── lsun-cats-scale3.yaml ├── lsun-churches │ ├── lsun-churches-scale1.yaml │ ├── lsun-churches-scale2.yaml │ └── lsun-churches-scale3.yaml └── sampling │ ├── ffhq │ ├── ffhq_2_scales_custom.yaml │ └── ffhq_4_scales_geometric.yaml │ ├── imagenet │ ├── imagenet_4_scales_geometric.yaml │ └── imagenet_5_scales_custom.yaml │ └── lsun │ ├── beds_3_scales.yaml │ ├── cats_3_scales.yaml │ └── churches_3_scales.yaml ├── data ├── DejaVuSans.ttf ├── ffhq_schedule_vs_metric.p ├── ffhqtrain.txt ├── ffhqvalidation.txt ├── imagenet_ids2labels.yaml ├── in_schedule_vs_metric.p └── vqgan_indices │ ├── bedroom_indices.npy │ ├── cat_indices.npy │ ├── church_indices.npy │ ├── ffhq_indices.npy │ └── imagenet_indices.npy ├── environment.yaml ├── imagebart ├── __init__.py ├── data │ ├── __init__.py │ ├── base.py │ └── lsun.py ├── lr_scheduler.py ├── models │ ├── __init__.py │ ├── diffusion.py │ └── vqgan.py ├── modules │ ├── __init__.py │ ├── betas.py │ ├── ema.py │ ├── transformer │ │ ├── __init__.py │ │ ├── mingpt.py │ │ ├── vit.py │ │ └── warper.py │ └── xtransformers │ │ ├── __init__.py │ │ ├── autoregressive_wrapper.py │ │ ├── positional_embeddings.py │ │ └── x_transformer.py └── util.py ├── main.py ├── scripts ├── inpaint_imagebart.py └── sample_imagebart.py └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # ImageBART 2 | #### [NeurIPS 2021](https://nips.cc/) 3 | 4 | ![teaser](assets/modelfigure.png) 5 |
6 | [Patrick Esser](https://github.com/pesser)\*, 7 | [Robin Rombach](https://github.com/rromb)\*, 8 | [Andreas Blattmann](https://github.com/ablattmann)\*, 9 | [Björn Ommer](https://ommer-lab.com/)
10 | \* equal contribution 11 | 12 | [arXiv](https://arxiv.org/abs/2108.08827) | [BibTeX](#bibtex) | [Poster](assets/imagebart_poster.pdf) 13 | 14 | ## Requirements 15 | A suitable [conda](https://conda.io/) environment named `imagebart` can be created 16 | and activated with: 17 | 18 | ``` 19 | conda env create -f environment.yaml 20 | conda activate imagebart 21 | ``` 22 | 23 | ## Get the Models 24 | 25 | We provide pretrained weights and hyperparameters for models trained on the following datasets: 26 | 27 | * FFHQ: 28 | * [4 scales, geometric noise schedule](https://ommer-lab.com/files/ffhq_4_scales_geometric.zip): `wget -c https://ommer-lab.com/files/ffhq_4_scales_geometric.zip` 29 | * [2 scales, custom noise schedule](https://ommer-lab.com/files/ffhq_2_scales_custom.zip): `wget -c https://ommer-lab.com/files/ffhq_2_scales_custom.zip` 30 | * LSUN, 3 scales, custom noise schedules: 31 | * [Churches](https://ommer-lab.com/files/churches_3_scales.zip): `wget -c https://ommer-lab.com/files/churches_3_scales.zip` 32 | * [Bedrooms](https://ommer-lab.com/files/bedrooms_3_scales.zip): `wget -c https://ommer-lab.com/files/bedrooms_3_scales.zip` 33 | * [Cats](https://ommer-lab.com/files/cats_3_scales.zip): `wget -c https://ommer-lab.com/files/cats_3_scales.zip` 34 | * Class-conditional ImageNet: 35 | * [5 scales, custom noise schedule](https://ommer-lab.com/files/cin_5_scales_custom.zip): `wget -c https://ommer-lab.com/files/cin_5_scales_custom.zip` 36 | * [4 scales, geometric noise schedule](https://ommer-lab.com/files/cin_4_scales_geometric.zip): `wget -c https://ommer-lab.com/files/cin_4_scales_geometric.zip` 37 | 38 | Download the respective files and extract their contents to a directory `./models/`. 39 | 40 | Moreover, we provide all the required VQGANs as a .zip at [https://ommer-lab.com/files/vqgan.zip](https://ommer-lab.com/files/vqgan.zip), 41 | which contents have to be extracted to `./vqgan/`. 42 | 43 | ## Get the Data 44 | Running the training configs or the [inpainting script](scripts/inpaint_imagebart.py) requires 45 | a dataset available locally. For ImageNet and FFHQ, see this repo's parent directory [taming-transformers](https://github.com/CompVis/taming-transformers). 46 | The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun). 47 | We performed a custom split into training and validation images, and provide the corresponding filenames 48 | at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip). 49 | After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should 50 | also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively. 51 | 52 | ## Inference 53 | 54 | ### Unconditional Sampling 55 | We provide a script for sampling from unconditional models trained on the LSUN-{bedrooms,bedrooms,cats}- and FFHQ-datasets. 56 | 57 | #### FFHQ 58 | 59 | On the FFHQ dataset, we provide two distinct pretrained models, one with a chain of length 4 and a geometric noise schedule as proposed by Sohl-Dickstein et al. [[1]](##References) , and another one with a chain of length 2 and a custom schedule. 60 | These models can be started with 61 | ```shell script 62 | CUDA_VISIBLE_DEVICES= streamlit run scripts/sample_imagebart.py configs/sampling/ffhq/ 63 | ``` 64 | 65 | #### LSUN 66 | For the models trained on the LSUN-datasets, use 67 | ```shell script 68 | CUDA_VISIBLE_DEVICES= streamlit run scripts/sample_imagebart.py configs/sampling/lsun/ 69 | ``` 70 | 71 | ### Class Conditional Sampling on ImageNet 72 | 73 | 74 | To sample from class-conditional ImageNet models, use 75 | ```shell script 76 | CUDA_VISIBLE_DEVICES= streamlit run scripts/sample_imagebart.py configs/sampling/imagenet/ 77 | ``` 78 | 79 | ### Image Editing with Unconditional Models 80 | 81 | We also provide a script for image editing with our unconditional models. For our FFHQ-model with geometric schedule this can be started with 82 | ```shell script 83 | CUDA_VISIBLE_DEVICES= streamlit run scripts/inpaint_imagebart.py configs/sampling/ffhq/ffhq_4scales_geometric.yaml 84 | ``` 85 | resulting in samples similar to the following. 86 | ![teaser](assets/image_editing.png) 87 | 88 | 89 | ## Training 90 | In general, there are two options for training the autoregressive transition probabilities of the 91 | reverse Markov chain: (i) train them jointly, taking into account a weighting of the 92 | individual scale contributions, or (ii) train them independently, which means that each 93 | training process optimizes a single transition and the scales must be stacked after training. 94 | We conduct most of our experiments using the latter option, but provide configurations for both cases. 95 | 96 | ### Training Scales Independently 97 | For training scales independently, each transition requires a seperate optimization process, which can 98 | started via 99 | 100 | ``` 101 | CUDA_VISIBLE_DEVICES= python main.py --base configs//.yaml -t --gpus 0, 102 | ``` 103 | 104 | We provide training configs for a four scale training of FFHQ using a geometric schedule, 105 | a four scale geometric training on ImageNet and various three-scale experiments on LSUN. 106 | See also the overview of our [pretrained models](#get-the-models). 107 | 108 | 109 | ### Training Scales Jointly 110 | 111 | For completeness, we also provide a config to run a joint training with 4 scales on FFHQ. 112 | Training can be started by running 113 | 114 | ``` 115 | CUDA_VISIBLE_DEVICES= python main.py --base configs/ffhq/ffhq_4_scales_joint-training.yaml -t --gpus 0, 116 | ``` 117 | 118 | 119 | ## Shout-Outs 120 | Many thanks to all who make their work and implementations publicly available. 121 | For this work, these were in particular: 122 | 123 | - The extremely clear and extensible encoder-decoder transformer implementations by [lucidrains](https://github.com/lucidrains): 124 | https://github.com/lucidrains/x-transformers 125 | - Emiel Hoogeboom et al's paper on multinomial diffusion and argmax flows: https://arxiv.org/abs/2102.05379 126 | 127 | 128 | ![teaser](assets/foxchain.png) 129 | 130 | ## References 131 | 132 | [1] Sohl-Dickstein, J., Weiss, E., Maheswaranathan, N. & Ganguli, S.. (2015). Deep Unsupervised Learning using Nonequilibrium Thermodynamics. Proceedings of the 32nd International Conference on Machine Learning 133 | 134 | ## Bibtex 135 | 136 | ``` 137 | @article{DBLP:journals/corr/abs-2108-08827, 138 | author = {Patrick Esser and 139 | Robin Rombach and 140 | Andreas Blattmann and 141 | Bj{\"{o}}rn Ommer}, 142 | title = {ImageBART: Bidirectional Context with Multinomial Diffusion for Autoregressive 143 | Image Synthesis}, 144 | journal = {CoRR}, 145 | volume = {abs/2108.08827}, 146 | year = {2021} 147 | } 148 | ``` -------------------------------------------------------------------------------- /assets/foxchain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/assets/foxchain.png -------------------------------------------------------------------------------- /assets/image_editing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/assets/image_editing.png -------------------------------------------------------------------------------- /assets/imagebart_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/assets/imagebart_poster.pdf -------------------------------------------------------------------------------- /assets/modelfigure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/assets/modelfigure.png -------------------------------------------------------------------------------- /assets/sample-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/assets/sample-001.jpg -------------------------------------------------------------------------------- /configs/ffhq/2_scales/ffhq-custom-scale1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 2 8 | single_scale: 1 9 | top_k: 548 10 | alpha: 0.0 11 | redraw_prob: ffhq_bernoulli_PSIM 12 | use_ema: true 13 | 14 | scheduler_config: 15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 16 | params: 17 | verbosity_interval: 0 18 | warm_up_steps: 10000 19 | max_decay_steps: 1500001 20 | lr_start: 2.5e-06 21 | lr_max: 0.0001 22 | lr_min: 1.0e-08 23 | transformer_config: 24 | target: imagebart.modules.xtransformers.x_transformer.XTransformer 25 | params: 26 | wrap_decoder: false 27 | dim: 1152 28 | enc_num_tokens: 548 29 | enc_depth: 32 30 | enc_heads: 16 31 | enc_max_seq_len: 257 32 | dec_num_tokens: 548 33 | dec_depth: 6 34 | dec_heads: 16 35 | tie_token_emb: false 36 | dec_max_seq_len: 256 37 | first_stage_config: 38 | target: imagebart.models.vqgan.VQGANWrapper 39 | params: 40 | ckpt_path: vqgan/vqgan-ffhq.ckpt 41 | remap: data/vqgan_indices/ffhq_indices.npy 42 | sane_index_shape: true 43 | embed_dim: 256 44 | n_embed: 1024 45 | ddconfig: 46 | double_z: false 47 | z_channels: 256 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 1 55 | - 2 56 | - 2 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: 60 | - 16 61 | dropout: 0.0 62 | lossconfig: 63 | target: taming.modules.losses.vqperceptual.DummyLoss 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 18 69 | num_workers: 32 70 | wrap: false 71 | train: 72 | target: taming.data.faceshq.FFHQTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: taming.data.faceshq.FFHQValidation 77 | params: 78 | size: 256 79 | -------------------------------------------------------------------------------- /configs/ffhq/2_scales/ffhq-custom-scale2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 2 8 | single_scale: 2 9 | top_k: 548 10 | alpha: 1.0 11 | redraw_prob: ffhq_bernoulli_PSIM 12 | use_ema: true 13 | scheduler_config: 14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 15 | params: 16 | verbosity_interval: 0 17 | warm_up_steps: 10000 18 | max_decay_steps: 1500001 19 | lr_start: 2.5e-06 20 | lr_max: 0.0001 21 | lr_min: 1.0e-08 22 | transformer_config: 23 | target: imagebart.modules.transformer.mingpt.GPT 24 | params: 25 | vocab_size: 548 26 | block_size: 256 27 | n_layer: 36 28 | n_head: 16 29 | n_embd: 1216 30 | first_stage_config: 31 | target: imagebart.models.vqgan.VQGANWrapper 32 | params: 33 | ckpt_path: vqgan/vqgan-ffhq.ckpt 34 | remap: data/vqgan_indices/ffhq_indices.npy 35 | sane_index_shape: true 36 | embed_dim: 256 37 | n_embed: 1024 38 | ddconfig: 39 | double_z: false 40 | z_channels: 256 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 1 48 | - 2 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: 53 | - 16 54 | dropout: 0.0 55 | lossconfig: 56 | target: taming.modules.losses.vqperceptual.DummyLoss 57 | 58 | data: 59 | target: main.DataModuleFromConfig 60 | params: 61 | batch_size: 18 62 | wrap: false 63 | train: 64 | target: taming.data.faceshq.FFHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: taming.data.faceshq.FFHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /configs/ffhq/4_scales/ffhq-geometric-scale1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: "image" 6 | monitor: "val/loss" 7 | n_scales: 4 8 | single_scale: 1 9 | top_k: 548 10 | alpha: 0.0 11 | redraw_prob: geometric 12 | use_ema: True 13 | 14 | scheduler_config: 15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 16 | params: 17 | verbosity_interval: 0 # 0 or negative to disable 18 | warm_up_steps: 10000 19 | max_decay_steps: 1500001 20 | lr_start: 2.5e-6 21 | lr_max: 1.0e-4 22 | lr_min: 1.0e-8 23 | 24 | transformer_config: 25 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer 26 | params: 27 | scale_pos: 0 28 | n_scales: 4 29 | xt_start: 1 30 | xt_size: 256 # predict x_{t-1} 31 | wrap_decoder: False 32 | dim: 752 33 | enc_num_tokens: 548 34 | enc_depth: 18 35 | enc_heads: 16 36 | enc_max_seq_len: 257 37 | dec_num_tokens: 548 38 | dec_depth: 6 39 | dec_heads: 16 40 | tie_token_emb: False 41 | dec_max_seq_len: 256 42 | 43 | first_stage_config: 44 | target: imagebart.models.vqgan.VQGANWrapper 45 | params: 46 | ckpt_path: vqgan/vqgan-ffhq.ckpt 47 | remap: "data/vqgan_indices/ffhq_indices.npy" 48 | sane_index_shape: True 49 | embed_dim: 256 50 | n_embed: 1024 51 | ddconfig: 52 | double_z: false 53 | z_channels: 256 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 59 | num_res_blocks: 2 60 | attn_resolutions: [ 16 ] 61 | dropout: 0.0 62 | lossconfig: 63 | target: taming.modules.losses.vqperceptual.DummyLoss 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 16 69 | num_workers: 32 70 | wrap: False 71 | train: 72 | target: taming.data.faceshq.FFHQTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: taming.data.faceshq.FFHQValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 1000 86 | max_images: 4 87 | increase_log_steps: False 88 | trainer: 89 | benchmark: True 90 | -------------------------------------------------------------------------------- /configs/ffhq/4_scales/ffhq-geometric-scale2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: "image" 6 | monitor: "val/loss" 7 | n_scales: 4 8 | single_scale: 2 9 | top_k: 548 10 | alpha: 0.0 11 | redraw_prob: geometric 12 | use_ema: True 13 | 14 | scheduler_config: 15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 16 | params: 17 | verbosity_interval: 0 # 0 or negative to disable 18 | warm_up_steps: 10000 19 | max_decay_steps: 1500001 20 | lr_start: 2.5e-6 21 | lr_max: 1.0e-4 22 | lr_min: 1.0e-8 23 | 24 | transformer_config: 25 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer 26 | params: 27 | scale_pos: 0 28 | n_scales: 4 29 | xt_start: 1 30 | xt_size: 256 # predict x_{t-1} 31 | wrap_decoder: False 32 | dim: 752 33 | enc_num_tokens: 548 34 | enc_depth: 18 35 | enc_heads: 16 36 | enc_max_seq_len: 257 37 | dec_num_tokens: 548 38 | dec_depth: 6 39 | dec_heads: 16 40 | tie_token_emb: False 41 | dec_max_seq_len: 256 42 | 43 | first_stage_config: 44 | target: imagebart.models.vqgan.VQGANWrapper 45 | params: 46 | ckpt_path: vqgan/vqgan-ffhq.ckpt 47 | remap: "data/vqgan_indices/ffhq_indices.npy" 48 | sane_index_shape: True 49 | embed_dim: 256 50 | n_embed: 1024 51 | ddconfig: 52 | double_z: false 53 | z_channels: 256 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 59 | num_res_blocks: 2 60 | attn_resolutions: [ 16 ] 61 | dropout: 0.0 62 | lossconfig: 63 | target: taming.modules.losses.vqperceptual.DummyLoss 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 16 69 | num_workers: 32 70 | wrap: False 71 | train: 72 | target: taming.data.faceshq.FFHQTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: taming.data.faceshq.FFHQValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 1000 86 | max_images: 4 87 | increase_log_steps: False 88 | trainer: 89 | benchmark: True 90 | -------------------------------------------------------------------------------- /configs/ffhq/4_scales/ffhq-geometric-scale3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: "image" 6 | monitor: "val/loss" 7 | n_scales: 4 8 | single_scale: 3 9 | top_k: 548 10 | alpha: 0.0 11 | redraw_prob: geometric 12 | use_ema: True 13 | 14 | scheduler_config: 15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 16 | params: 17 | verbosity_interval: 0 # 0 or negative to disable 18 | warm_up_steps: 10000 19 | max_decay_steps: 1500001 20 | lr_start: 2.5e-6 21 | lr_max: 1.0e-4 22 | lr_min: 1.0e-8 23 | 24 | transformer_config: 25 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer 26 | params: 27 | scale_pos: 0 28 | n_scales: 4 29 | xt_start: 1 30 | xt_size: 256 # predict x_{t-1} 31 | wrap_decoder: False 32 | dim: 752 33 | enc_num_tokens: 548 34 | enc_depth: 18 35 | enc_heads: 16 36 | enc_max_seq_len: 257 37 | dec_num_tokens: 548 38 | dec_depth: 6 39 | dec_heads: 16 40 | tie_token_emb: False 41 | dec_max_seq_len: 256 42 | 43 | first_stage_config: 44 | target: imagebart.models.vqgan.VQGANWrapper 45 | params: 46 | ckpt_path: vqgan/vqgan-ffhq.ckpt 47 | remap: "data/vqgan_indices/ffhq_indices.npy" 48 | sane_index_shape: True 49 | embed_dim: 256 50 | n_embed: 1024 51 | ddconfig: 52 | double_z: false 53 | z_channels: 256 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 59 | num_res_blocks: 2 60 | attn_resolutions: [ 16 ] 61 | dropout: 0.0 62 | lossconfig: 63 | target: taming.modules.losses.vqperceptual.DummyLoss 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 16 69 | num_workers: 32 70 | wrap: False 71 | train: 72 | target: taming.data.faceshq.FFHQTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: taming.data.faceshq.FFHQValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 1000 86 | max_images: 4 87 | increase_log_steps: False 88 | trainer: 89 | benchmark: True 90 | -------------------------------------------------------------------------------- /configs/ffhq/4_scales/ffhq-geometric-scale4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser 4 | params: 5 | first_stage_key: "image" 6 | monitor: "val/loss" 7 | n_scales: 4 8 | single_scale: 4 9 | top_k: 548 10 | alpha: 1.0 # soft only 11 | redraw_prob: geometric 12 | use_ema: True 13 | 14 | scheduler_config: 15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 16 | params: 17 | verbosity_interval: 0 # 0 or negative to disable 18 | warm_up_steps: 10000 19 | max_decay_steps: 1500001 20 | lr_start: 2.5e-6 21 | lr_max: 1.0e-4 22 | lr_min: 1.0e-8 23 | 24 | transformer_config: 25 | target: imagebart.modules.transformer.mingpt.GPT 26 | params: 27 | vocab_size: 548 28 | block_size: 256 # 256 data tokens - 1 + 1 scale token 29 | n_layer: 26 30 | n_head: 16 31 | n_embd: 800 32 | 33 | first_stage_config: 34 | target: imagebart.models.vqgan.VQGANWrapper 35 | params: 36 | ckpt_path: vqgan/vqgan-ffhq.ckpt 37 | remap: "data/vqgan_indices/ffhq_indices.npy" 38 | sane_index_shape: True 39 | embed_dim: 256 40 | n_embed: 1024 41 | ddconfig: 42 | double_z: False 43 | z_channels: 256 44 | resolution: 256 45 | in_channels: 3 46 | out_ch: 3 47 | ch: 128 48 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 49 | num_res_blocks: 2 50 | attn_resolutions: [ 16 ] 51 | dropout: 0.0 52 | lossconfig: 53 | target: taming.modules.losses.vqperceptual.DummyLoss 54 | 55 | data: 56 | target: main.DataModuleFromConfig 57 | params: 58 | batch_size: 16 59 | wrap: False 60 | train: 61 | target: taming.data.faceshq.FFHQTrain 62 | params: 63 | size: 256 64 | validation: 65 | target: taming.data.faceshq.FFHQValidation 66 | params: 67 | size: 256 68 | 69 | lightning: 70 | callbacks: 71 | image_logger: 72 | target: main.ImageLogger 73 | params: 74 | batch_frequency: 1000 75 | max_images: 4 76 | increase_log_steps: False 77 | trainer: 78 | benchmark: True 79 | -------------------------------------------------------------------------------- /configs/ffhq/ffhq_4-scales_joint-training.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: "image" 6 | monitor: "val/loss" 7 | n_scales: 4 8 | top_k: 548 9 | alpha: 1.0 10 | redraw_prob: geometric 11 | use_ema: True 12 | 13 | scheduler_config: 14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 15 | params: 16 | verbosity_interval: 0 # 0 or negative to disable 17 | warm_up_steps: 10000 18 | max_decay_steps: 1500001 19 | lr_start: 2.5e-6 20 | lr_max: 1.0e-4 21 | lr_min: 1.0e-8 22 | 23 | transformer_config: 24 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer 25 | params: 26 | scale_pos: 0 27 | n_scales: 4 28 | xt_start: 1 29 | xt_size: 256 # predict x_{t-1} 30 | wrap_decoder: False 31 | dim: 752 32 | enc_num_tokens: 548 33 | enc_depth: 18 34 | enc_heads: 16 35 | enc_max_seq_len: 257 36 | dec_num_tokens: 548 37 | dec_depth: 6 38 | dec_heads: 16 39 | tie_token_emb: False 40 | dec_max_seq_len: 256 41 | 42 | first_stage_config: 43 | target: imagebart.models.vqgan.VQGANWrapper 44 | params: 45 | ckpt_path: vqgan/vqgan-ffhq.ckpt 46 | remap: "data/vqgan_indices/ffhq_indices.npy" 47 | sane_index_shape: True 48 | embed_dim: 256 49 | n_embed: 1024 50 | ddconfig: 51 | double_z: false 52 | z_channels: 256 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 58 | num_res_blocks: 2 59 | attn_resolutions: [ 16 ] 60 | dropout: 0.0 61 | lossconfig: 62 | target: taming.modules.losses.vqperceptual.DummyLoss 63 | 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 16 68 | num_workers: 32 69 | wrap: False 70 | train: 71 | target: taming.data.faceshq.FFHQTrain 72 | params: 73 | size: 256 74 | validation: 75 | target: taming.data.faceshq.FFHQValidation 76 | params: 77 | size: 256 78 | 79 | lightning: 80 | callbacks: 81 | image_logger: 82 | target: main.ImageLogger 83 | params: 84 | batch_frequency: 1000 85 | max_images: 4 86 | increase_log_steps: False 87 | log_images_kwargs: 88 | sample_full: True 89 | sample_half: True 90 | 91 | trainer: 92 | benchmark: True 93 | -------------------------------------------------------------------------------- /configs/imagenet/4_scales/imagenet_geometric_scale1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 4 8 | single_scale: 1 9 | top_k: 973 10 | alpha: 0.0 11 | redraw_prob: geometric 12 | use_ema: true 13 | conditioner_config: 14 | target: imagebart.util.ClassProvider 15 | params: 16 | key: class_label 17 | scheduler_config: 18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 19 | params: 20 | verbosity_interval: 0 21 | warm_up_steps: 10000 22 | max_decay_steps: 10000000 23 | lr_start: 2.5e-06 24 | lr_max: 0.0001 25 | lr_min: 1.0e-08 26 | transformer_config: 27 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer 28 | params: 29 | scale_pos: 0 30 | n_scales: 4 31 | xt_start: 2 32 | xt_size: 256 33 | wrap_decoder: false 34 | dim: 1152 35 | enc_num_tokens: 1000 36 | enc_depth: 32 37 | enc_heads: 16 38 | enc_max_seq_len: 258 39 | dec_num_tokens: 973 40 | dec_depth: 6 41 | dec_heads: 16 42 | tie_token_emb: false 43 | dec_max_seq_len: 256 44 | first_stage_config: 45 | target: imagebart.models.vqgan.VQGANWrapper 46 | params: 47 | ckpt_path: vqgan/vqgan-imagenet.ckpt 48 | remap: data/vqgan_indices/imagenet_indices.npy 49 | sane_index_shape: true 50 | embed_dim: 256 51 | n_embed: 16384 52 | ddconfig: 53 | double_z: false 54 | z_channels: 256 55 | resolution: 256 56 | in_channels: 3 57 | out_ch: 3 58 | ch: 128 59 | ch_mult: 60 | - 1 61 | - 1 62 | - 2 63 | - 2 64 | - 4 65 | num_res_blocks: 2 66 | attn_resolutions: 67 | - 16 68 | dropout: 0.0 69 | tanh_out: true 70 | lossconfig: 71 | target: taming.modules.losses.vqperceptual.DummyLoss 72 | 73 | data: 74 | target: main.DataModuleFromConfig 75 | params: 76 | batch_size: 16 77 | wrap: false 78 | train: 79 | target: taming.data.imagenet.ImageNetTrain 80 | params: 81 | config: 82 | size: 256 83 | validation: 84 | target: taming.data.imagenet.ImageNetValidation 85 | params: 86 | config: 87 | size: 256 88 | 89 | lightning: 90 | callbacks: 91 | image_logger: 92 | target: main.ImageLogger 93 | params: 94 | batch_frequency: 1000 95 | max_images: 4 96 | increase_log_steps: false 97 | trainer: 98 | benchmark: true 99 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/imagenet/4_scales/imagenet_geometric_scale2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 4 8 | single_scale: 2 9 | top_k: 973 10 | alpha: 1.0 11 | redraw_prob: geometric 12 | use_ema: true 13 | conditioner_config: 14 | target: imagebart.util.ClassProvider 15 | params: 16 | key: class_label 17 | scheduler_config: 18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 19 | params: 20 | verbosity_interval: 0 21 | warm_up_steps: 10000 22 | max_decay_steps: 10000000 23 | lr_start: 2.5e-06 24 | lr_max: 0.0001 25 | lr_min: 1.0e-08 26 | transformer_config: 27 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer 28 | params: 29 | scale_pos: 0 30 | n_scales: 4 31 | xt_start: 2 32 | xt_size: 256 33 | wrap_decoder: false 34 | dim: 1152 35 | enc_num_tokens: 1000 36 | enc_depth: 32 37 | enc_heads: 16 38 | enc_max_seq_len: 258 39 | dec_num_tokens: 973 40 | dec_depth: 6 41 | dec_heads: 16 42 | tie_token_emb: false 43 | dec_max_seq_len: 256 44 | first_stage_config: 45 | target: imagebart.models.vqgan.VQGANWrapper 46 | params: 47 | ckpt_path: vqgan/vqgan-imagenet.ckpt 48 | remap: data/vqgan_indices/imagenet_indices.npy 49 | sane_index_shape: true 50 | embed_dim: 256 51 | n_embed: 16384 52 | ddconfig: 53 | double_z: false 54 | z_channels: 256 55 | resolution: 256 56 | in_channels: 3 57 | out_ch: 3 58 | ch: 128 59 | ch_mult: 60 | - 1 61 | - 1 62 | - 2 63 | - 2 64 | - 4 65 | num_res_blocks: 2 66 | attn_resolutions: 67 | - 16 68 | dropout: 0.0 69 | tanh_out: true 70 | lossconfig: 71 | target: taming.modules.losses.vqperceptual.DummyLoss 72 | 73 | data: 74 | target: main.DataModuleFromConfig 75 | params: 76 | batch_size: 16 77 | wrap: false 78 | train: 79 | target: taming.data.imagenet.ImageNetTrain 80 | params: 81 | config: 82 | size: 256 83 | validation: 84 | target: taming.data.imagenet.ImageNetValidation 85 | params: 86 | config: 87 | size: 256 88 | 89 | lightning: 90 | callbacks: 91 | image_logger: 92 | target: main.ImageLogger 93 | params: 94 | batch_frequency: 1000 95 | max_images: 4 96 | increase_log_steps: false 97 | trainer: 98 | benchmark: true 99 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/imagenet/4_scales/imagenet_geometric_scale3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 4 8 | single_scale: 3 9 | top_k: 973 10 | alpha: 1.0 11 | redraw_prob: geometric 12 | use_ema: true 13 | conditioner_config: 14 | target: imagebart.util.ClassProvider 15 | params: 16 | key: class_label 17 | scheduler_config: 18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 19 | params: 20 | verbosity_interval: 0 21 | warm_up_steps: 10000 22 | max_decay_steps: 10000000 23 | lr_start: 2.5e-06 24 | lr_max: 0.0001 25 | lr_min: 1.0e-08 26 | transformer_config: 27 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer 28 | params: 29 | scale_pos: 0 30 | n_scales: 4 31 | xt_start: 2 32 | xt_size: 256 33 | wrap_decoder: false 34 | dim: 1152 35 | enc_num_tokens: 1000 36 | enc_depth: 32 37 | enc_heads: 16 38 | enc_max_seq_len: 258 39 | dec_num_tokens: 973 40 | dec_depth: 6 41 | dec_heads: 16 42 | tie_token_emb: false 43 | dec_max_seq_len: 256 44 | first_stage_config: 45 | target: imagebart.models.vqgan.VQGANWrapper 46 | params: 47 | ckpt_path: vqgan/vqgan-imagenet.ckpt 48 | remap: data/vqgan_indices/imagenet_indices.npy 49 | sane_index_shape: true 50 | embed_dim: 256 51 | n_embed: 16384 52 | ddconfig: 53 | double_z: false 54 | z_channels: 256 55 | resolution: 256 56 | in_channels: 3 57 | out_ch: 3 58 | ch: 128 59 | ch_mult: 60 | - 1 61 | - 1 62 | - 2 63 | - 2 64 | - 4 65 | num_res_blocks: 2 66 | attn_resolutions: 67 | - 16 68 | dropout: 0.0 69 | tanh_out: true 70 | lossconfig: 71 | target: taming.modules.losses.vqperceptual.DummyLoss 72 | 73 | data: 74 | target: main.DataModuleFromConfig 75 | params: 76 | batch_size: 16 77 | wrap: false 78 | train: 79 | target: taming.data.imagenet.ImageNetTrain 80 | params: 81 | config: 82 | size: 256 83 | validation: 84 | target: taming.data.imagenet.ImageNetValidation 85 | params: 86 | config: 87 | size: 256 88 | 89 | lightning: 90 | callbacks: 91 | image_logger: 92 | target: main.ImageLogger 93 | params: 94 | batch_frequency: 1000 95 | max_images: 4 96 | increase_log_steps: false 97 | trainer: 98 | benchmark: true 99 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/imagenet/4_scales/imagenet_geometric_scale4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 4 8 | single_scale: 4 9 | top_k: 973 10 | alpha: 1.0 11 | redraw_prob: geometric 12 | use_ema: true 13 | conditioner_config: 14 | target: imagebart.util.ClassProvider 15 | params: 16 | key: class_label 17 | scheduler_config: 18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 19 | params: 20 | verbosity_interval: 0 21 | warm_up_steps: 10000 22 | max_decay_steps: 10000000 23 | lr_start: 2.5e-06 24 | lr_max: 0.0001 25 | lr_min: 1.0e-08 26 | transformer_config: 27 | target: imagebart.modules.transformer.mingpt.GPT 28 | params: 29 | input_vocab_size: 1000 30 | vocab_size: 973 31 | block_size: 258 32 | n_layer: 36 33 | n_head: 16 34 | n_embd: 1216 35 | n_unmasked: 2 36 | first_stage_config: 37 | target: imagebart.models.vqgan.VQGANWrapper 38 | params: 39 | ckpt_path: vqgan/vqgan-imagenet.ckpt 40 | remap: data/vqgan_indices/imagenet_indices.npy 41 | sane_index_shape: true 42 | embed_dim: 256 43 | n_embed: 16384 44 | ddconfig: 45 | double_z: false 46 | z_channels: 256 47 | resolution: 256 48 | in_channels: 3 49 | out_ch: 3 50 | ch: 128 51 | ch_mult: 52 | - 1 53 | - 1 54 | - 2 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: 59 | - 16 60 | dropout: 0.0 61 | lossconfig: 62 | target: taming.modules.losses.vqperceptual.DummyLoss 63 | 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 16 68 | wrap: false 69 | train: 70 | target: taming.data.imagenet.ImageNetTrain 71 | params: 72 | config: 73 | size: 256 74 | validation: 75 | target: taming.data.imagenet.ImageNetValidation 76 | params: 77 | config: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 1000 86 | max_images: 4 87 | increase_log_steps: false 88 | trainer: 89 | benchmark: true 90 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/imagenet/5_scales/imagenet-custom-scale1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 5 8 | single_scale: 1 9 | top_k: 973 10 | alpha: 0.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | conditioner_config: 14 | target: imagebart.util.ClassProvider 15 | params: 16 | key: class_label 17 | scheduler_config: 18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 19 | params: 20 | verbosity_interval: 0 # 0 or negative to disable 21 | warm_up_steps: 10000 22 | max_decay_steps: 10000000 23 | lr_start: 2.5e-6 24 | lr_max: 1.0e-4 25 | lr_min: 1.0e-8 26 | transformer_config: 27 | target: imagebart.modules.xtransformers.x_transformer.XTransformer 28 | params: 29 | wrap_decoder: false 30 | dim: 1152 31 | enc_num_tokens: 1000 32 | enc_depth: 32 33 | enc_heads: 16 34 | enc_max_seq_len: 258 35 | dec_num_tokens: 973 36 | dec_depth: 6 37 | dec_heads: 16 38 | tie_token_emb: false 39 | dec_max_seq_len: 256 40 | first_stage_config: 41 | target: imagebart.models.vqgan.VQGANWrapper 42 | params: 43 | ckpt_path: vqgan/vqgan-imagenet.ckpt 44 | remap: data/vqgan_indices/imagenet_indices.npy 45 | sane_index_shape: true 46 | embed_dim: 256 47 | n_embed: 16384 48 | ddconfig: 49 | double_z: false 50 | z_channels: 256 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: taming.modules.losses.vqperceptual.DummyLoss 67 | data: 68 | target: main.DataModuleFromConfig 69 | params: 70 | batch_size: 16 71 | wrap: false 72 | train: 73 | target: taming.data.imagenet.ImageNetTrain 74 | params: 75 | config: 76 | size: 256 77 | validation: 78 | target: taming.data.imagenet.ImageNetValidation 79 | params: 80 | config: 81 | size: 256 82 | 83 | lightning: 84 | callbacks: 85 | image_logger: 86 | target: main.ImageLogger 87 | params: 88 | batch_frequency: 1000 89 | max_images: 4 90 | increase_log_steps: false 91 | trainer: 92 | benchmark: true 93 | accumulate_grad_batches: 2 94 | -------------------------------------------------------------------------------- /configs/imagenet/5_scales/imagenet-custom-scale2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 5 8 | single_scale: 2 9 | top_k: 973 10 | alpha: 1.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | conditioner_config: 14 | target: imagebart.util.ClassProvider 15 | params: 16 | key: class_label 17 | scheduler_config: 18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 19 | params: 20 | verbosity_interval: 0 # 0 or negative to disable 21 | warm_up_steps: 10000 22 | max_decay_steps: 10000000 23 | lr_start: 2.5e-6 24 | lr_max: 1.0e-4 25 | lr_min: 1.0e-8 26 | transformer_config: 27 | target: imagebart.modules.xtransformers.x_transformer.XTransformer 28 | params: 29 | wrap_decoder: false 30 | dim: 1152 31 | enc_num_tokens: 1000 32 | enc_depth: 32 33 | enc_heads: 16 34 | enc_max_seq_len: 258 35 | dec_num_tokens: 973 36 | dec_depth: 6 37 | dec_heads: 16 38 | tie_token_emb: false 39 | dec_max_seq_len: 256 40 | first_stage_config: 41 | target: imagebart.models.vqgan.VQGANWrapper 42 | params: 43 | ckpt_path: vqgan/vqgan-imagenet.ckpt 44 | remap: data/vqgan_indices/imagenet_indices.npy 45 | sane_index_shape: true 46 | embed_dim: 256 47 | n_embed: 16384 48 | ddconfig: 49 | double_z: false 50 | z_channels: 256 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: taming.modules.losses.vqperceptual.DummyLoss 67 | data: 68 | target: main.DataModuleFromConfig 69 | params: 70 | batch_size: 16 71 | wrap: false 72 | train: 73 | target: taming.data.imagenet.ImageNetTrain 74 | params: 75 | config: 76 | size: 256 77 | validation: 78 | target: taming.data.imagenet.ImageNetValidation 79 | params: 80 | config: 81 | size: 256 82 | 83 | lightning: 84 | callbacks: 85 | image_logger: 86 | target: main.ImageLogger 87 | params: 88 | batch_frequency: 1000 89 | max_images: 4 90 | increase_log_steps: false 91 | trainer: 92 | benchmark: true 93 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/imagenet/5_scales/imagenet-custom-scale3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 5 8 | single_scale: 3 9 | top_k: 973 10 | alpha: 1.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | conditioner_config: 14 | target: imagebart.util.ClassProvider 15 | params: 16 | key: class_label 17 | scheduler_config: 18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 19 | params: 20 | verbosity_interval: 0 # 0 or negative to disable 21 | warm_up_steps: 10000 22 | max_decay_steps: 10000000 23 | lr_start: 2.5e-6 24 | lr_max: 1.0e-4 25 | lr_min: 1.0e-8 26 | transformer_config: 27 | target: imagebart.modules.xtransformers.x_transformer.XTransformer 28 | params: 29 | wrap_decoder: false 30 | dim: 1152 31 | enc_num_tokens: 1000 32 | enc_depth: 32 33 | enc_heads: 16 34 | enc_max_seq_len: 258 35 | dec_num_tokens: 973 36 | dec_depth: 6 37 | dec_heads: 16 38 | tie_token_emb: false 39 | dec_max_seq_len: 256 40 | first_stage_config: 41 | target: imagebart.models.vqgan.VQGANWrapper 42 | params: 43 | ckpt_path: vqgan/vqgan-imagenet.ckpt 44 | remap: data/vqgan_indices/imagenet_indices.npy 45 | sane_index_shape: true 46 | embed_dim: 256 47 | n_embed: 16384 48 | ddconfig: 49 | double_z: false 50 | z_channels: 256 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: taming.modules.losses.vqperceptual.DummyLoss 67 | data: 68 | target: main.DataModuleFromConfig 69 | params: 70 | batch_size: 16 71 | wrap: false 72 | train: 73 | target: taming.data.imagenet.ImageNetTrain 74 | params: 75 | config: 76 | size: 256 77 | validation: 78 | target: taming.data.imagenet.ImageNetValidation 79 | params: 80 | config: 81 | size: 256 82 | 83 | lightning: 84 | callbacks: 85 | image_logger: 86 | target: main.ImageLogger 87 | params: 88 | batch_frequency: 1000 89 | max_images: 4 90 | increase_log_steps: false 91 | trainer: 92 | benchmark: true 93 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/imagenet/5_scales/imagenet-custom-scale4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 5 8 | single_scale: 4 9 | top_k: 973 10 | alpha: 1.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | conditioner_config: 14 | target: imagebart.util.ClassProvider 15 | params: 16 | key: class_label 17 | scheduler_config: 18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 19 | params: 20 | verbosity_interval: 0 # 0 or negative to disable 21 | warm_up_steps: 10000 22 | max_decay_steps: 10000000 23 | lr_start: 2.5e-6 24 | lr_max: 1.0e-4 25 | lr_min: 1.0e-8 26 | transformer_config: 27 | target: imagebart.modules.xtransformers.x_transformer.XTransformer 28 | params: 29 | wrap_decoder: false 30 | dim: 1152 31 | enc_num_tokens: 1000 32 | enc_depth: 32 33 | enc_heads: 16 34 | enc_max_seq_len: 258 35 | dec_num_tokens: 973 36 | dec_depth: 6 37 | dec_heads: 16 38 | tie_token_emb: false 39 | dec_max_seq_len: 256 40 | first_stage_config: 41 | target: imagebart.models.vqgan.VQGANWrapper 42 | params: 43 | ckpt_path: vqgan/vqgan-imagenet.ckpt 44 | remap: data/vqgan_indices/imagenet_indices.npy 45 | sane_index_shape: true 46 | embed_dim: 256 47 | n_embed: 16384 48 | ddconfig: 49 | double_z: false 50 | z_channels: 256 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: taming.modules.losses.vqperceptual.DummyLoss 67 | data: 68 | target: main.DataModuleFromConfig 69 | params: 70 | batch_size: 16 71 | wrap: false 72 | train: 73 | target: taming.data.imagenet.ImageNetTrain 74 | params: 75 | config: 76 | size: 256 77 | validation: 78 | target: taming.data.imagenet.ImageNetValidation 79 | params: 80 | config: 81 | size: 256 82 | 83 | lightning: 84 | callbacks: 85 | image_logger: 86 | target: main.ImageLogger 87 | params: 88 | batch_frequency: 1000 89 | max_images: 4 90 | increase_log_steps: false 91 | trainer: 92 | benchmark: true 93 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/imagenet/5_scales/imagenet-custom-scale5.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 5 8 | single_scale: 5 9 | top_k: 973 10 | alpha: 1.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | conditioner_config: 14 | target: imagebart.util.ClassProvider 15 | params: 16 | key: class_label 17 | scheduler_config: 18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 19 | params: 20 | verbosity_interval: 0 # 0 or negative to disable 21 | warm_up_steps: 10000 22 | max_decay_steps: 10000000 23 | lr_start: 2.5e-6 24 | lr_max: 1.0e-4 25 | lr_min: 1.0e-8 26 | transformer_config: 27 | target: imagebart.modules.transformer.mingpt.GPT 28 | params: 29 | input_vocab_size: 1000 30 | vocab_size: 973 31 | block_size: 258 32 | n_layer: 36 33 | n_head: 16 34 | n_embd: 1216 35 | n_unmasked: 2 36 | first_stage_config: 37 | target: imagebart.models.vqgan.VQGANWrapper 38 | params: 39 | ckpt_path: vqgan/vqgan-imagenet.ckpt 40 | remap: data/vqgan_indices/imagenet_indices.npy 41 | sane_index_shape: true 42 | embed_dim: 256 43 | n_embed: 16384 44 | ddconfig: 45 | double_z: false 46 | z_channels: 256 47 | resolution: 256 48 | in_channels: 3 49 | out_ch: 3 50 | ch: 128 51 | ch_mult: 52 | - 1 53 | - 1 54 | - 2 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: 59 | - 16 60 | dropout: 0.0 61 | lossconfig: 62 | target: taming.modules.losses.vqperceptual.DummyLoss 63 | data: 64 | target: main.DataModuleFromConfig 65 | params: 66 | batch_size: 16 67 | wrap: false 68 | train: 69 | target: taming.data.imagenet.ImageNetTrain 70 | params: 71 | config: 72 | size: 256 73 | validation: 74 | target: taming.data.imagenet.ImageNetValidation 75 | params: 76 | config: 77 | size: 256 78 | 79 | lightning: 80 | callbacks: 81 | image_logger: 82 | target: main.ImageLogger 83 | params: 84 | batch_frequency: 1000 85 | max_images: 4 86 | increase_log_steps: false 87 | trainer: 88 | benchmark: true 89 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/lsun-beds/lsun-beds-scale1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 3 8 | single_scale: 1 9 | top_k: 1017 10 | alpha: 0.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | scheduler_config: 14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 15 | params: 16 | verbosity_interval: 0 17 | warm_up_steps: 10000 18 | max_decay_steps: 5000000 19 | lr_start: 2.5e-06 20 | lr_max: 0.0001 21 | lr_min: 1.0e-08 22 | 23 | transformer_config: 24 | target: imagebart.modules.xtransformers.x_transformer.XTransformer 25 | params: 26 | wrap_decoder: false 27 | dim: 1152 28 | enc_num_tokens: 1017 29 | enc_depth: 32 30 | enc_heads: 16 31 | enc_max_seq_len: 257 32 | dec_num_tokens: 1017 33 | dec_depth: 6 34 | dec_heads: 16 35 | tie_token_emb: false 36 | dec_max_seq_len: 256 37 | first_stage_config: 38 | target: imagebart.models.vqgan.VQGANWrapper 39 | params: 40 | ckpt_path: vqgan/vqgan-bedrooms.ckpt 41 | remap: data/vqgan_indices/bedroom_indices.npy 42 | sane_index_shape: true 43 | embed_dim: 256 44 | n_embed: 16384 45 | ddconfig: 46 | double_z: false 47 | z_channels: 256 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 1 55 | - 2 56 | - 2 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: 60 | - 16 61 | dropout: 0.0 62 | tanh_out: true 63 | lossconfig: 64 | target: taming.modules.losses.vqperceptual.DummyLoss 65 | 66 | data: 67 | target: main.DataModuleFromConfig 68 | params: 69 | batch_size: 16 70 | wrap: false 71 | train: 72 | target: imagebart.data.lsun.LSUNBedroomsTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: imagebart.data.lsun.LSUNBedroomsValidation 77 | params: 78 | size: 256 79 | 80 | 81 | lightning: 82 | callbacks: 83 | image_logger: 84 | target: main.ImageLogger 85 | params: 86 | batch_frequency: 1000 87 | max_images: 4 88 | increase_log_steps: false 89 | trainer: 90 | benchmark: true 91 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/lsun-beds/lsun-beds-scale2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 3 8 | single_scale: 2 9 | top_k: 1017 10 | alpha: 1.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | scheduler_config: 14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 15 | params: 16 | verbosity_interval: 0 17 | warm_up_steps: 10000 18 | max_decay_steps: 5000000 19 | lr_start: 2.5e-06 20 | lr_max: 0.0001 21 | lr_min: 1.0e-08 22 | 23 | transformer_config: 24 | target: imagebart.modules.xtransformers.x_transformer.XTransformer 25 | params: 26 | wrap_decoder: false 27 | dim: 1152 28 | enc_num_tokens: 1017 29 | enc_depth: 32 30 | enc_heads: 16 31 | enc_max_seq_len: 257 32 | dec_num_tokens: 1017 33 | dec_depth: 6 34 | dec_heads: 16 35 | tie_token_emb: false 36 | dec_max_seq_len: 256 37 | first_stage_config: 38 | target: imagebart.models.vqgan.VQGANWrapper 39 | params: 40 | ckpt_path: vqgan/vqgan-bedrooms.ckpt 41 | remap: data/vqgan_indices/bedroom_indices.npy 42 | sane_index_shape: true 43 | embed_dim: 256 44 | n_embed: 16384 45 | ddconfig: 46 | double_z: false 47 | z_channels: 256 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 1 55 | - 2 56 | - 2 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: 60 | - 16 61 | dropout: 0.0 62 | tanh_out: true 63 | lossconfig: 64 | target: taming.modules.losses.vqperceptual.DummyLoss 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 16 69 | wrap: false 70 | train: 71 | target: imagebart.data.lsun.LSUNBedroomsTrain 72 | params: 73 | size: 256 74 | validation: 75 | target: imagebart.data.lsun.LSUNBedroomsValidation 76 | params: 77 | size: 256 78 | 79 | lightning: 80 | callbacks: 81 | image_logger: 82 | target: main.ImageLogger 83 | params: 84 | batch_frequency: 1000 85 | max_images: 4 86 | increase_log_steps: false 87 | trainer: 88 | benchmark: true 89 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/lsun-beds/lsun-beds-scale3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 3 8 | single_scale: 3 9 | top_k: 1017 10 | alpha: 1.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | scheduler_config: 14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 15 | params: 16 | verbosity_interval: 0 17 | warm_up_steps: 10000 18 | max_decay_steps: 5000000 19 | lr_start: 2.5e-06 20 | lr_max: 0.0001 21 | lr_min: 1.0e-08 22 | transformer_config: 23 | target: imagebart.modules.transformer.mingpt.GPT 24 | params: 25 | input_vocab_size: 1017 26 | vocab_size: 1017 27 | block_size: 256 28 | n_layer: 36 29 | n_head: 16 30 | n_embd: 1216 31 | first_stage_config: 32 | target: imagebart.models.vqgan.VQGANWrapper 33 | params: 34 | ckpt_path: vqgan/vqgan-bedrooms.ckpt 35 | remap: data/vqgan_indices/bedroom_indices.npy 36 | sane_index_shape: true 37 | embed_dim: 256 38 | n_embed: 16384 39 | ddconfig: 40 | double_z: false 41 | z_channels: 256 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 1 49 | - 2 50 | - 2 51 | - 4 52 | num_res_blocks: 2 53 | attn_resolutions: 54 | - 16 55 | dropout: 0.0 56 | tanh_out: true 57 | lossconfig: 58 | target: taming.modules.losses.vqperceptual.DummyLoss 59 | 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 16 64 | wrap: false 65 | train: 66 | target: imagebart.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: imagebart.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | lightning: 75 | callbacks: 76 | image_logger: 77 | target: main.ImageLogger 78 | params: 79 | batch_frequency: 1000 80 | max_images: 4 81 | increase_log_steps: false 82 | trainer: 83 | benchmark: true 84 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/lsun-cats/lsun-cats-scale1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 3 8 | single_scale: 1 9 | top_k: 1014 10 | alpha: 0.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | scheduler_config: 14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 15 | params: 16 | verbosity_interval: 0 17 | warm_up_steps: 10000 18 | max_decay_steps: 5000000 19 | lr_start: 2.5e-06 20 | lr_max: 0.0001 21 | lr_min: 1.0e-08 22 | 23 | transformer_config: 24 | target: imagebart.modules.xtransformers.x_transformer.XTransformer 25 | params: 26 | wrap_decoder: false 27 | dim: 1152 28 | enc_num_tokens: 1014 29 | enc_depth: 32 30 | enc_heads: 16 31 | enc_max_seq_len: 257 32 | dec_num_tokens: 1014 33 | dec_depth: 6 34 | dec_heads: 16 35 | tie_token_emb: false 36 | dec_max_seq_len: 256 37 | first_stage_config: 38 | target: imagebart.models.vqgan.VQGANWrapper 39 | params: 40 | ckpt_path: vqgan/vqgan-cats.ckpt 41 | remap: data/vqgan_indices/cat_indices.npy 42 | sane_index_shape: true 43 | embed_dim: 256 44 | n_embed: 16384 45 | ddconfig: 46 | double_z: false 47 | z_channels: 256 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 1 55 | - 2 56 | - 2 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: 60 | - 16 61 | dropout: 0.0 62 | tanh_out: true 63 | lossconfig: 64 | target: taming.modules.losses.vqperceptual.DummyLoss 65 | 66 | data: 67 | target: main.DataModuleFromConfig 68 | params: 69 | batch_size: 16 70 | wrap: false 71 | train: 72 | target: imagebart.data.lsun.LSUNCatsTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: imagebart.data.lsun.LSUNCatsValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 1000 86 | max_images: 4 87 | increase_log_steps: false 88 | trainer: 89 | benchmark: true 90 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/lsun-cats/lsun-cats-scale2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 3 8 | single_scale: 2 9 | top_k: 1014 10 | alpha: 1.0 11 | redraw_prob: bernoulli_PSIM 12 | 13 | use_ema: true 14 | scheduler_config: 15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 16 | params: 17 | verbosity_interval: 0 18 | warm_up_steps: 10000 19 | max_decay_steps: 5000000 20 | lr_start: 2.5e-06 21 | lr_max: 0.0001 22 | lr_min: 1.0e-08 23 | transformer_config: 24 | target: imagebart.modules.xtransformers.x_transformer.XTransformer 25 | params: 26 | wrap_decoder: false 27 | dim: 1152 28 | enc_num_tokens: 1014 29 | enc_depth: 32 30 | enc_heads: 16 31 | enc_max_seq_len: 257 32 | dec_num_tokens: 1014 33 | dec_depth: 6 34 | dec_heads: 16 35 | tie_token_emb: false 36 | dec_max_seq_len: 256 37 | first_stage_config: 38 | target: imagebart.models.vqgan.VQGANWrapper 39 | params: 40 | ckpt_path: vqgan/vqgan-cats.ckpt 41 | remap: data/vqgan_indices/cat_indices.npy 42 | sane_index_shape: true 43 | embed_dim: 256 44 | n_embed: 16384 45 | ddconfig: 46 | double_z: false 47 | z_channels: 256 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 1 55 | - 2 56 | - 2 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: 60 | - 16 61 | dropout: 0.0 62 | tanh_out: true 63 | lossconfig: 64 | target: taming.modules.losses.vqperceptual.DummyLoss 65 | 66 | data: 67 | target: main.DataModuleFromConfig 68 | params: 69 | batch_size: 16 70 | wrap: false 71 | train: 72 | target: imagebart.data.lsun.LSUNCatsTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: imagebart.data.lsun.LSUNCatsValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 1000 86 | max_images: 4 87 | increase_log_steps: false 88 | trainer: 89 | benchmark: true 90 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/lsun-cats/lsun-cats-scale3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 3 8 | single_scale: 3 9 | top_k: 1014 10 | alpha: 1.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | scheduler_config: 14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 15 | params: 16 | verbosity_interval: 0 17 | warm_up_steps: 10000 18 | max_decay_steps: 5000000 19 | lr_start: 2.5e-06 20 | lr_max: 0.0001 21 | lr_min: 1.0e-08 22 | 23 | transformer_config: 24 | target: imagebart.modules.transformer.mingpt.GPT 25 | params: 26 | input_vocab_size: 1014 27 | vocab_size: 1014 28 | block_size: 256 29 | n_layer: 36 30 | n_head: 16 31 | n_embd: 1216 32 | 33 | first_stage_config: 34 | target: imagebart.models.vqgan.VQGANWrapper 35 | params: 36 | ckpt_path: vqgan/vqgan-cats.ckpt 37 | remap: data/vqgan_indices/cat_indices.npy 38 | sane_index_shape: true 39 | embed_dim: 256 40 | n_embed: 16384 41 | ddconfig: 42 | double_z: false 43 | z_channels: 256 44 | resolution: 256 45 | in_channels: 3 46 | out_ch: 3 47 | ch: 128 48 | ch_mult: 49 | - 1 50 | - 1 51 | - 2 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: 56 | - 16 57 | dropout: 0.0 58 | tanh_out: true 59 | lossconfig: 60 | target: taming.modules.losses.vqperceptual.DummyLoss 61 | 62 | data: 63 | target: main.DataModuleFromConfig 64 | params: 65 | batch_size: 16 66 | wrap: false 67 | train: 68 | target: imagebart.data.lsun.LSUNCatsTrain 69 | params: 70 | size: 256 71 | validation: 72 | target: imagebart.data.lsun.LSUNCatsValidation 73 | params: 74 | size: 256 75 | 76 | 77 | lightning: 78 | callbacks: 79 | image_logger: 80 | target: main.ImageLogger 81 | params: 82 | batch_frequency: 1000 83 | max_images: 4 84 | increase_log_steps: false 85 | trainer: 86 | benchmark: true 87 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/lsun-churches/lsun-churches-scale1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 3 8 | single_scale: 1 9 | top_k: 1022 10 | alpha: 0.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | scheduler_config: 14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 15 | params: 16 | verbosity_interval: 0 17 | warm_up_steps: 10000 18 | max_decay_steps: 1300001 19 | lr_start: 2.5e-06 20 | lr_max: 0.0001 21 | lr_min: 1.0e-08 22 | transformer_config: 23 | target: imagebart.modules.xtransformers.x_transformer.XTransformer 24 | params: 25 | wrap_decoder: false 26 | dim: 1152 27 | enc_num_tokens: 1022 28 | enc_depth: 32 29 | enc_heads: 16 30 | enc_max_seq_len: 258 31 | dec_num_tokens: 1022 32 | dec_depth: 6 33 | dec_heads: 16 34 | tie_token_emb: false 35 | dec_max_seq_len: 256 36 | first_stage_config: 37 | target: imagebart.models.vqgan.VQGANWrapper 38 | params: 39 | ckpt_path: vqgan/vqgan-churches.ckpt 40 | remap: data/vqgan_indices/church_indices.npy 41 | sane_index_shape: true 42 | embed_dim: 256 43 | n_embed: 16384 44 | ddconfig: 45 | double_z: false 46 | z_channels: 256 47 | resolution: 256 48 | in_channels: 3 49 | out_ch: 3 50 | ch: 128 51 | ch_mult: 52 | - 1 53 | - 1 54 | - 2 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: 59 | - 16 60 | dropout: 0.0 61 | tanh_out: true 62 | lossconfig: 63 | target: taming.modules.losses.vqperceptual.DummyLoss 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 16 68 | wrap: false 69 | train: 70 | target: imagebart.data.lsun.LSUNChurchesTrain 71 | params: 72 | size: 256 73 | validation: 74 | target: imagebart.data.lsun.LSUNChurchesValidation 75 | params: 76 | size: 256 77 | 78 | lightning: 79 | callbacks: 80 | image_logger: 81 | target: main.ImageLogger 82 | params: 83 | batch_frequency: 1000 84 | max_images: 4 85 | increase_log_steps: false 86 | trainer: 87 | benchmark: true 88 | accumulate_grad_batches: 2 89 | -------------------------------------------------------------------------------- /configs/lsun-churches/lsun-churches-scale2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DenoisingXTransformer 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 3 8 | single_scale: 2 9 | top_k: 1022 10 | alpha: 1.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | scheduler_config: 14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 15 | params: 16 | verbosity_interval: 0 17 | warm_up_steps: 10000 18 | max_decay_steps: 1300001 19 | lr_start: 2.5e-06 20 | lr_max: 0.0001 21 | lr_min: 1.0e-08 22 | transformer_config: 23 | target: imagebart.modules.xtransformers.x_transformer.XTransformer 24 | params: 25 | wrap_decoder: false 26 | dim: 1152 27 | enc_num_tokens: 1022 28 | enc_depth: 32 29 | enc_heads: 16 30 | enc_max_seq_len: 258 31 | dec_num_tokens: 1022 32 | dec_depth: 6 33 | dec_heads: 16 34 | tie_token_emb: false 35 | dec_max_seq_len: 256 36 | first_stage_config: 37 | target: imagebart.models.vqgan.VQGANWrapper 38 | params: 39 | ckpt_path: vqgan/vqgan-churches.ckpt 40 | remap: data/vqgan_indices/church_indices.npy 41 | sane_index_shape: true 42 | embed_dim: 256 43 | n_embed: 16384 44 | ddconfig: 45 | double_z: false 46 | z_channels: 256 47 | resolution: 256 48 | in_channels: 3 49 | out_ch: 3 50 | ch: 128 51 | ch_mult: 52 | - 1 53 | - 1 54 | - 2 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: 59 | - 16 60 | dropout: 0.0 61 | tanh_out: true 62 | lossconfig: 63 | target: taming.modules.losses.vqperceptual.DummyLoss 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 16 68 | wrap: false 69 | train: 70 | target: imagebart.data.lsun.LSUNChurchesTrain 71 | params: 72 | size: 256 73 | validation: 74 | target: imagebart.data.lsun.LSUNChurchesValidation 75 | params: 76 | size: 256 77 | 78 | lightning: 79 | callbacks: 80 | image_logger: 81 | target: main.ImageLogger 82 | params: 83 | batch_frequency: 1000 84 | max_images: 4 85 | increase_log_steps: false 86 | trainer: 87 | benchmark: true 88 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/lsun-churches/lsun-churches-scale3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser 4 | params: 5 | first_stage_key: image 6 | monitor: val/loss 7 | n_scales: 3 8 | single_scale: 3 9 | top_k: 1022 10 | alpha: 1.0 11 | redraw_prob: bernoulli_PSIM 12 | use_ema: true 13 | scheduler_config: 14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler 15 | params: 16 | verbosity_interval: 0 17 | warm_up_steps: 10000 18 | max_decay_steps: 1300001 19 | lr_start: 2.5e-06 20 | lr_max: 0.0001 21 | lr_min: 1.0e-08 22 | transformer_config: 23 | target: imagebart.modules.transformer.mingpt.GPT 24 | params: 25 | input_vocab_size: 1022 26 | vocab_size: 1022 27 | block_size: 256 28 | n_layer: 36 29 | n_head: 16 30 | n_embd: 1216 31 | first_stage_config: 32 | target: imagebart.models.vqgan.VQGANWrapper 33 | params: 34 | ckpt_path: vqgan/vqgan-churches.ckpt 35 | remap: data/vqgan_indices/church_indices.npy 36 | sane_index_shape: true 37 | embed_dim: 256 38 | n_embed: 16384 39 | ddconfig: 40 | double_z: false 41 | z_channels: 256 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 1 49 | - 2 50 | - 2 51 | - 4 52 | num_res_blocks: 2 53 | attn_resolutions: 54 | - 16 55 | dropout: 0.0 56 | tanh_out: true 57 | lossconfig: 58 | target: taming.modules.losses.vqperceptual.DummyLoss 59 | 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 16 64 | wrap: false 65 | train: 66 | target: imagebart.data.lsun.LSUNChurchesTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: imagebart.data.lsun.LSUNChurchesValidation 71 | params: 72 | size: 256 73 | 74 | lightning: 75 | callbacks: 76 | image_logger: 77 | target: main.ImageLogger 78 | params: 79 | batch_frequency: 1000 80 | max_images: 4 81 | increase_log_steps: false 82 | trainer: 83 | benchmark: true 84 | accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /configs/sampling/ffhq/ffhq_2_scales_custom.yaml: -------------------------------------------------------------------------------- 1 | checkpoints: 2 | - "models/ffhq_2_scales_custom/scale1.ckpt" 3 | - "models/ffhq_2_scales_custom/scale2.ckpt" 4 | 5 | configs: 6 | - "configs/ffhq/2_scales/ffhq-custom-scale1.yaml" 7 | - "configs/ffhq/2_scales/ffhq-custom-scale2.yaml" -------------------------------------------------------------------------------- /configs/sampling/ffhq/ffhq_4_scales_geometric.yaml: -------------------------------------------------------------------------------- 1 | checkpoints: 2 | - "models/ffhq_4_scales_geometric/scale1.ckpt" 3 | - "models/ffhq_4_scales_geometric/scale2.ckpt" 4 | - "models/ffhq_4_scales_geometric/scale3.ckpt" 5 | - "models/ffhq_4_scales_geometric/scale4.ckpt" 6 | 7 | configs: 8 | - "configs/ffhq/4_scales/ffhq-geometric-scale1.yaml" 9 | - "configs/ffhq/4_scales/ffhq-geometric-scale2.yaml" 10 | - "configs/ffhq/4_scales/ffhq-geometric-scale3.yaml" 11 | - "configs/ffhq/4_scales/ffhq-geometric-scale4.yaml" 12 | -------------------------------------------------------------------------------- /configs/sampling/imagenet/imagenet_4_scales_geometric.yaml: -------------------------------------------------------------------------------- 1 | checkpoints: 2 | - "models/cin_4_scales_geometric/scale1.ckpt" 3 | - "models/cin_4_scales_geometric/scale2.ckpt" 4 | - "models/cin_4_scales_geometric/scale3.ckpt" 5 | - "models/cin_4_scales_geometric/scale4.ckpt" 6 | 7 | configs: 8 | - "configs/imagenet/4_scales/imagenet_geometric_scale1.yaml" 9 | - "configs/imagenet/4_scales/imagenet_geometric_scale2.yaml" 10 | - "configs/imagenet/4_scales/imagenet_geometric_scale3.yaml" 11 | - "configs/imagenet/4_scales/imagenet_geometric_scale4.yaml" 12 | -------------------------------------------------------------------------------- /configs/sampling/imagenet/imagenet_5_scales_custom.yaml: -------------------------------------------------------------------------------- 1 | checkpoints: 2 | - "models/cin_5_scales_custom/scale1.ckpt" 3 | - "models/cin_5_scales_custom/scale2.ckpt" 4 | - "models/cin_5_scales_custom/scale3.ckpt" 5 | - "models/cin_5_scales_custom/scale4.ckpt" 6 | - "models/cin_5_scales_custom/scale5.ckpt" 7 | 8 | configs: 9 | - "configs/imagenet/5_scales/imagenet-custom-scale1.yaml" 10 | - "configs/imagenet/5_scales/imagenet-custom-scale2.yaml" 11 | - "configs/imagenet/5_scales/imagenet-custom-scale3.yaml" 12 | - "configs/imagenet/5_scales/imagenet-custom-scale4.yaml" 13 | - "configs/imagenet/5_scales/imagenet-custom-scale5.yaml" 14 | > 15 | -------------------------------------------------------------------------------- /configs/sampling/lsun/beds_3_scales.yaml: -------------------------------------------------------------------------------- 1 | checkpoints: 2 | - "models/bedrooms_3_scales/scale1.ckpt" 3 | - "models/bedrooms_3_scales/scale2.ckpt" 4 | - "models/bedrooms_3_scales/scale3.ckpt" 5 | 6 | configs: 7 | - "configs/lsun-beds/lsun-beds-scale1.yaml" 8 | - "configs/lsun-beds/lsun-beds-scale2.yaml" 9 | - "configs/lsun-beds/lsun-beds-scale3.yaml" 10 | -------------------------------------------------------------------------------- /configs/sampling/lsun/cats_3_scales.yaml: -------------------------------------------------------------------------------- 1 | checkpoints: 2 | - "models/cats_3_scales/scale1.ckpt" 3 | - "models/cats_3_scales/scale2.ckpt" 4 | - "models/cats_3_scales/scale3.ckpt" 5 | 6 | configs: 7 | - "configs/lsun-cats/lsun-cats-scale1.yaml" 8 | - "configs/lsun-cats/lsun-cats-scale2.yaml" 9 | - "configs/lsun-cats/lsun-cats-scale3.yaml" 10 | -------------------------------------------------------------------------------- /configs/sampling/lsun/churches_3_scales.yaml: -------------------------------------------------------------------------------- 1 | checkpoints: 2 | - "models/churches_3_scales/scale1.ckpt" 3 | - "models/churches_3_scales/scale2.ckpt" 4 | - "models/churches_3_scales/scale3.ckpt" 5 | 6 | configs: 7 | - "configs/lsun-churches/lsun-churches-scale1.yaml" 8 | - "configs/lsun-churches/lsun-churches-scale2.yaml" 9 | - "configs/lsun-churches/lsun-churches-scale3.yaml" 10 | -------------------------------------------------------------------------------- /data/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/DejaVuSans.ttf -------------------------------------------------------------------------------- /data/ffhq_schedule_vs_metric.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/ffhq_schedule_vs_metric.p -------------------------------------------------------------------------------- /data/in_schedule_vs_metric.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/in_schedule_vs_metric.p -------------------------------------------------------------------------------- /data/vqgan_indices/bedroom_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/vqgan_indices/bedroom_indices.npy -------------------------------------------------------------------------------- /data/vqgan_indices/cat_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/vqgan_indices/cat_indices.npy -------------------------------------------------------------------------------- /data/vqgan_indices/church_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/vqgan_indices/church_indices.npy -------------------------------------------------------------------------------- /data/vqgan_indices/ffhq_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/vqgan_indices/ffhq_indices.npy -------------------------------------------------------------------------------- /data/vqgan_indices/imagenet_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/vqgan_indices/imagenet_indices.npy -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: imagebart 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.0 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.4.2 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - einops==0.3.0 23 | - transformers==4.3.1 24 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 25 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 26 | - -e . 27 | -------------------------------------------------------------------------------- /imagebart/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/__init__.py -------------------------------------------------------------------------------- /imagebart/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/data/__init__.py -------------------------------------------------------------------------------- /imagebart/data/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 5 | import albumentations 6 | import bisect 7 | from abc import abstractmethod 8 | 9 | class Txt2ImgIterableBaseDataset(IterableDataset): 10 | ''' 11 | Define an interface to make the IterableDatasets for text2img data chainable 12 | ''' 13 | def __init__(self, num_records=0, valid_ids=None, size=256): 14 | super().__init__() 15 | self.num_records = num_records 16 | self.valid_ids = valid_ids 17 | self.sample_ids = valid_ids 18 | self.size = size 19 | 20 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 21 | 22 | def __len__(self): 23 | return self.num_records 24 | 25 | @abstractmethod 26 | def __iter__(self): 27 | pass -------------------------------------------------------------------------------- /imagebart/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /imagebart/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning-bolts.readthedocs.io/en/0.2.1/api/pl_bolts.optimizers.lr_scheduler.html 2 | 3 | import math 4 | import warnings 5 | from typing import List 6 | 7 | import torch.nn as nn 8 | from torch.optim import Optimizer, Adam 9 | from torch.optim.lr_scheduler import _LRScheduler 10 | from torch._six import inf 11 | import numpy as np 12 | 13 | 14 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 15 | """ 16 | Sets the learning rate of each parameter group to follow a linear warmup schedule 17 | between warmup_start_lr and base_lr followed by a cosine annealing schedule between 18 | base_lr and eta_min. 19 | .. warning:: 20 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` 21 | after each iteration as calling it after each epoch will keep the starting lr at 22 | warmup_start_lr for the first epoch which is 0 in most cases. 23 | .. warning:: 24 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. 25 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of 26 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing 27 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling 28 | train and validation methods. 29 | Args: 30 | optimizer (Optimizer): Wrapped optimizer. 31 | warmup_epochs (int): Maximum number of iterations for linear warmup 32 | max_epochs (int): Maximum number of iterations 33 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 34 | eta_min (float): Minimum learning rate. Default: 0. 35 | last_epoch (int): The index of last epoch. Default: -1. 36 | Example: 37 | >>> layer = nn.Linear(10, 1) 38 | >>> optimizer = Adam(layer.parameters(), lr=0.02) 39 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) 40 | >>> # 41 | >>> # the default case 42 | >>> for epoch in range(40): 43 | ... # train(...) 44 | ... # validate(...) 45 | ... scheduler.step() 46 | >>> # 47 | >>> # passing epoch param case 48 | >>> for epoch in range(40): 49 | ... scheduler.step(epoch) 50 | ... # train(...) 51 | ... # validate(...) 52 | """ 53 | 54 | def __init__( 55 | self, 56 | optimizer: Optimizer, 57 | warmup_epochs: int, 58 | max_epochs: int, 59 | warmup_start_lr: float = 0.0, 60 | eta_min: float = 0.0, 61 | last_epoch: int = -1, 62 | ) -> None: 63 | 64 | self.warmup_epochs = warmup_epochs 65 | self.max_epochs = max_epochs 66 | self.warmup_start_lr = warmup_start_lr 67 | self.eta_min = eta_min 68 | 69 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 70 | 71 | def get_lr(self) -> List[float]: 72 | """ 73 | Compute learning rate using chainable form of the scheduler 74 | """ 75 | if not self._get_lr_called_within_step: 76 | warnings.warn( 77 | "To get the last learning rate computed by the scheduler, " 78 | "please use `get_last_lr()`.", 79 | UserWarning, 80 | ) 81 | 82 | if self.last_epoch == 0: 83 | return [self.warmup_start_lr] * len(self.base_lrs) 84 | elif self.last_epoch < self.warmup_epochs: 85 | return [ 86 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 87 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 88 | ] 89 | elif self.last_epoch == self.warmup_epochs: 90 | return self.base_lrs 91 | elif (self.last_epoch - 1 - self.max_epochs) % ( 92 | 2 * (self.max_epochs - self.warmup_epochs) 93 | ) == 0: 94 | return [ 95 | group["lr"] + (base_lr - self.eta_min) * ( 96 | 1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs)) 97 | ) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 98 | ] 99 | 100 | return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / 101 | (self.max_epochs - self.warmup_epochs))) / 102 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / 103 | (self.max_epochs - self.warmup_epochs))) * 104 | (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups 105 | ] 106 | 107 | def _get_closed_form_lr(self) -> List[float]: 108 | """ 109 | Called when epoch is passed as a param to the `step` function of the scheduler. 110 | """ 111 | if self.last_epoch < self.warmup_epochs: 112 | return [ 113 | self.warmup_start_lr + self.last_epoch * ( 114 | base_lr - self.warmup_start_lr 115 | ) / (self.warmup_epochs - 1) for base_lr in self.base_lrs 116 | ] 117 | 118 | return [ 119 | self.eta_min + 0.5 * (base_lr - self.eta_min) * ( 120 | 1 + math.cos( 121 | math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs) 122 | ) 123 | ) for base_lr in self.base_lrs 124 | ] 125 | 126 | 127 | class ReduceLROnLossPlateau: 128 | 129 | def __init__(self, lr_init, mode='min', factor=0.1, patience=10, 130 | threshold=1e-4, threshold_mode='rel', cooldown=0, 131 | min_lr=0, eps=1e-8, verbose=False): 132 | if factor >= 1.0: 133 | raise ValueError('Factor should be < 1.0.') 134 | self.factor = factor 135 | # Attach optimize 136 | 137 | self.current_factor = 1. 138 | self.min_lr = min_lr 139 | self.current_lr = lr_init 140 | 141 | self.patience = patience 142 | self.verbose = verbose 143 | self.cooldown = cooldown 144 | self.cooldown_counter = 0 145 | self.mode = mode 146 | self.threshold = threshold 147 | self.threshold_mode = threshold_mode 148 | self.best = None 149 | self.num_bad_epochs = None 150 | self.mode_worse = None # the worse value for the chosen mode 151 | self.eps = eps 152 | self.last_it = 0 153 | self._init_is_better(mode=mode, threshold=threshold, 154 | threshold_mode=threshold_mode) 155 | 156 | self._reset() 157 | 158 | def set_factor(self, f): 159 | self.current_factor = f 160 | 161 | def _reset(self): 162 | """Resets num_bad_epochs counter and cooldown counter.""" 163 | self.best = self.mode_worse 164 | self.cooldown_counter = 0 165 | self.num_bad_it = 0 166 | 167 | def schedule(self, metrics): 168 | # convert `metrics` to float, in case it's a zero-dim Tensor 169 | current = float(metrics) 170 | 171 | # it = self.last_it + 1 172 | 173 | # self.it = it 174 | 175 | if self.is_better(current, self.best): 176 | self.best = current 177 | self.num_bad_it = 0 178 | else: 179 | self.num_bad_it += 1 180 | 181 | if self.in_cooldown: 182 | self.cooldown_counter -= 1 183 | self.num_bad_it = 0 # ignore any bad epochs in cooldown 184 | 185 | if self.num_bad_it > self.patience and self.current_lr >= self.min_lr: 186 | # reduce lr by factor 187 | new_f = self.current_factor * self.factor 188 | self.set_factor(new_f) 189 | self.current_lr = self.current_lr * self.current_factor 190 | 191 | self.set_factor(new_f) 192 | self.cooldown_counter = self.cooldown 193 | self.num_bad_it = 0 194 | 195 | # self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 196 | return self.current_factor 197 | 198 | @property 199 | def in_cooldown(self): 200 | return self.cooldown_counter > 0 201 | 202 | def is_better(self, a, best): 203 | if self.mode == 'min' and self.threshold_mode == 'rel': 204 | rel_epsilon = 1. - self.threshold 205 | return a < best * rel_epsilon 206 | 207 | elif self.mode == 'min' and self.threshold_mode == 'abs': 208 | return a < best - self.threshold 209 | 210 | elif self.mode == 'max' and self.threshold_mode == 'rel': 211 | rel_epsilon = self.threshold + 1. 212 | return a > best * rel_epsilon 213 | 214 | else: # mode == 'max' and epsilon_mode == 'abs': 215 | return a > best + self.threshold 216 | 217 | def _init_is_better(self, mode, threshold, threshold_mode): 218 | if mode not in {'min', 'max'}: 219 | raise ValueError('mode ' + mode + ' is unknown!') 220 | if threshold_mode not in {'rel', 'abs'}: 221 | raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') 222 | 223 | if mode == 'min': 224 | self.mode_worse = inf 225 | else: # mode == 'max': 226 | self.mode_worse = -inf 227 | 228 | self.mode = mode 229 | self.threshold = threshold 230 | self.threshold_mode = threshold_mode 231 | 232 | def __call__(self, metrics): 233 | return self.schedule(metrics) 234 | 235 | 236 | class LambdaWarmUpCosineScheduler: 237 | """ 238 | note: use with a base_lr of 1.0 239 | """ 240 | 241 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 242 | self.lr_warm_up_steps = warm_up_steps 243 | self.lr_start = lr_start 244 | self.lr_min = lr_min 245 | self.lr_max = lr_max 246 | self.lr_max_decay_steps = max_decay_steps 247 | self.last_lr = 0. 248 | self.verbosity_interval = verbosity_interval 249 | 250 | def schedule(self, n, **kwargs): 251 | if self.verbosity_interval > 0: 252 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 253 | if n < self.lr_warm_up_steps: 254 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 255 | self.last_lr = lr 256 | return lr 257 | else: 258 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 259 | t = min(t, 1.0) 260 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 261 | 1 + np.cos(t * np.pi)) 262 | self.last_lr = lr 263 | return lr 264 | 265 | def __call__(self, n, **kwargs): 266 | return self.schedule(n, **kwargs) 267 | 268 | 269 | class LambdaWarmUpCosineScheduler2: 270 | """ 271 | supports repeated iterations, configurable via lists 272 | note: use with a base_lr of 1.0. 273 | """ 274 | 275 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 276 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 277 | self.lr_warm_up_steps = warm_up_steps 278 | self.f_start = f_start 279 | self.f_min = f_min 280 | self.f_max = f_max 281 | self.cycle_lengths = cycle_lengths 282 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 283 | self.last_f = 0. 284 | self.verbosity_interval = verbosity_interval 285 | 286 | def find_in_interval(self, n): 287 | interval = 0 288 | for cl in self.cum_cycles[1:]: 289 | if n <= cl: 290 | return interval 291 | interval += 1 292 | 293 | def schedule(self, n, **kwargs): 294 | cycle = self.find_in_interval(n) 295 | n = n - self.cum_cycles[cycle] 296 | if self.verbosity_interval > 0: 297 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 298 | f"current cycle {cycle}") 299 | if n < self.lr_warm_up_steps[cycle]: 300 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 301 | self.last_f = f 302 | return f 303 | else: 304 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 305 | t = min(t, 1.0) 306 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 307 | 1 + np.cos(t * np.pi)) 308 | self.last_f = f 309 | return f 310 | 311 | def __call__(self, n, **kwargs): 312 | return self.schedule(n, **kwargs) 313 | 314 | 315 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 316 | 317 | def schedule(self, n, **kwargs): 318 | cycle = self.find_in_interval(n) 319 | n = n - self.cum_cycles[cycle] 320 | if self.verbosity_interval > 0: 321 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 322 | f"current cycle {cycle}") 323 | 324 | if n < self.lr_warm_up_steps[cycle]: 325 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 326 | self.last_f = f 327 | return f 328 | else: 329 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( 330 | self.cycle_lengths[cycle]) 331 | self.last_f = f 332 | return f 333 | 334 | 335 | if __name__ == "__main__": 336 | from tqdm import trange 337 | import matplotlib.pyplot as plt 338 | 339 | warm_up_steps = [1000, 500, 500] 340 | f_min = [1., 0., 0.] 341 | f_max = [10, 7.5, 5.] 342 | f_start = [0., 1., 0.] 343 | cycle_lengths = [10000, 5000, 4000] 344 | scheduler = LambdaWarmUpCosineScheduler2(warm_up_steps=warm_up_steps, f_min=f_min, f_max=f_max, f_start=f_start, 345 | cycle_lengths=cycle_lengths, verbosity_interval=100) 346 | 347 | schedule = [] 348 | for n in trange(int(sum(cycle_lengths)), desc="Iter"): 349 | schedule.append(scheduler(n)) 350 | 351 | plt.figure() 352 | plt.plot(schedule) 353 | plt.xlabel("global step") 354 | plt.savefig("scheduler_test.png") 355 | print("done.") 356 | 357 | """ 358 | layer = nn.Linear(10, 1) 359 | optimizer = Adam(layer.parameters(), lr=0.02) 360 | scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) 361 | # 362 | # the default case 363 | for epoch in range(40): 364 | # train(...) 365 | # validate(...) 366 | scheduler.step() 367 | # 368 | # passing epoch param case 369 | for epoch in range(40): 370 | scheduler.step(epoch) 371 | # train(...) 372 | # validate(...) 373 | """ 374 | -------------------------------------------------------------------------------- /imagebart/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/models/__init__.py -------------------------------------------------------------------------------- /imagebart/models/vqgan.py: -------------------------------------------------------------------------------- 1 | from taming.models.vqgan import VQModel 2 | 3 | 4 | class VQGANWrapper(VQModel): 5 | 6 | def __init__(self,embed_dim,*args,**kwargs): 7 | super().__init__(embed_dim=embed_dim,*args,**kwargs) 8 | self.embed_dim = embed_dim 9 | 10 | def encode_to_prequant(self, x): 11 | h = self.encoder(x) 12 | h = self.quant_conv(h) 13 | return h -------------------------------------------------------------------------------- /imagebart/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/modules/__init__.py -------------------------------------------------------------------------------- /imagebart/modules/betas.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | 5 | def find_roots(x,y): 6 | # https://stackoverflow.com/questions/46909373/how-to-find-the-exact-intersection-of-a-curve-as-np-array-with-y-0/46911822#46911822 7 | s = np.abs(np.diff(np.sign(y))).astype(bool) 8 | return x[:-1][s] + np.diff(x)[s]/(np.abs(y[1:][s]/y[:-1][s])+1) 9 | 10 | 11 | def load_betas(path, dist, metric, n_steps): 12 | with open(path, "rb") as f: 13 | data = pickle.load(f) 14 | alphacum = np.array(data[dist]["alphacum"]) 15 | values = np.array(data[dist][metric])[:,0] 16 | 17 | equi_alphacum = list() 18 | for val in np.linspace(values.min(), values.max(), n_steps): 19 | equi_alphacum.append(find_roots(alphacum, values-val)[0]) 20 | equi_alphacum = np.array(equi_alphacum) 21 | beta_1toT = 1-equi_alphacum[1:]/equi_alphacum[:-1] 22 | beta_0toT = np.concatenate((np.array([0.0]), beta_1toT)) 23 | return beta_0toT.astype(np.float32) 24 | 25 | 26 | if __name__ == "__main__": 27 | import sys 28 | betas = load_betas(sys.argv[1], "bernoulli", "FID", int(sys.argv[2])) 29 | print(betas) 30 | print(np.cumprod(1-betas, 0)) 31 | -------------------------------------------------------------------------------- /imagebart/modules/ema.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class LitEma(nn.Module): 7 | def __init__(self, model, decay=0.9999, use_num_upates=True): 8 | super().__init__() 9 | if decay < 0.0 or decay > 1.0: 10 | raise ValueError('Decay must be between 0 and 1') 11 | 12 | self.m_name2s_name = {} 13 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 14 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 15 | else torch.tensor(-1,dtype=torch.int)) 16 | 17 | for name, p in model.named_parameters(): 18 | if p.requires_grad: 19 | #remove as '.'-character is not allowed in buffers 20 | s_name = name.replace('.','') 21 | self.m_name2s_name.update({name:s_name}) 22 | self.register_buffer(s_name,p.clone().detach().data) 23 | 24 | self.collected_params = [] 25 | 26 | def forward(self,model): 27 | decay = self.decay 28 | 29 | if self.num_updates >= 0: 30 | self.num_updates += 1 31 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 32 | 33 | one_minus_decay = 1.0 - decay 34 | 35 | with torch.no_grad(): 36 | m_param = dict(model.named_parameters()) 37 | shadow_params = dict(self.named_buffers()) 38 | 39 | for key in m_param: 40 | if m_param[key].requires_grad: 41 | sname = self.m_name2s_name[key] 42 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 43 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 44 | else: 45 | assert not key in self.m_name2s_name 46 | 47 | def copy_to(self, model): 48 | m_param = dict(model.named_parameters()) 49 | shadow_params = dict(self.named_buffers()) 50 | for key in m_param: 51 | if m_param[key].requires_grad: 52 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def store(self, parameters): 57 | """ 58 | Save the current parameters for restoring later. 59 | Args: 60 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 61 | temporarily stored. 62 | """ 63 | self.collected_params = [param.clone() for param in parameters] 64 | 65 | def restore(self, parameters): 66 | """ 67 | Restore the parameters stored with the `store` method. 68 | Useful to validate the model with EMA parameters without affecting the 69 | original optimization process. Store the parameters before the 70 | `copy_to` method. After validation (or model saving), use this to 71 | restore the former parameters. 72 | Args: 73 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 74 | updated with the stored parameters. 75 | """ 76 | for c_param, param in zip(self.collected_params, parameters): 77 | param.data.copy_(c_param.data) 78 | -------------------------------------------------------------------------------- /imagebart/modules/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/modules/transformer/__init__.py -------------------------------------------------------------------------------- /imagebart/modules/transformer/mingpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | source: https://github.com/karpathy/minGPT/ 3 | GPT model: 4 | - the initial stem consists of a combination of token encoding and a positional encoding 5 | - the meat of it is a uniform sequence of Transformer blocks 6 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block 7 | - all blocks feed into a central residual pathway similar to resnets 8 | - the final decoder is a linear projection into a vanilla Softmax classifier 9 | """ 10 | 11 | import math 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import functional as F 16 | from transformers import top_k_top_p_filtering 17 | 18 | from imagebart.modules.xtransformers.positional_embeddings import apply_rotary_pos_emb 19 | 20 | 21 | def layers_from_width(d, a=5.039, b=5.55e-2): 22 | """as in https://arxiv.org/pdf/2006.12467.pdf""" 23 | return (math.log(d)-a)/b 24 | 25 | 26 | class GPTConfig: 27 | """ base GPT config, params common to all GPT versions """ 28 | embd_pdrop = 0.1 29 | resid_pdrop = 0.1 30 | attn_pdrop = 0.1 31 | 32 | def __init__(self, vocab_size, block_size, **kwargs): 33 | self.vocab_size = vocab_size 34 | self.block_size = block_size 35 | for k,v in kwargs.items(): 36 | setattr(self, k, v) 37 | 38 | 39 | class CausalSelfAttention(nn.Module): 40 | """ 41 | A vanilla multi-head masked self-attention layer with a projection at the end. 42 | It is possible to use torch.nn.MultiheadAttention here but I am including an 43 | explicit implementation here to show that there is nothing too scary here. 44 | """ 45 | 46 | def __init__(self, config): 47 | super().__init__() 48 | assert config.n_embd % config.n_head == 0, f"n_embd is {config.n_embd} but n_head is {config.n_head}." 49 | # key, query, value projections for all heads 50 | self.key = nn.Linear(config.n_embd, config.n_embd) 51 | self.query = nn.Linear(config.n_embd, config.n_embd) 52 | self.value = nn.Linear(config.n_embd, config.n_embd) 53 | # regularization 54 | self.attn_drop = nn.Dropout(config.attn_pdrop) 55 | self.resid_drop = nn.Dropout(config.resid_pdrop) 56 | # output projection 57 | self.proj = nn.Linear(config.n_embd, config.n_embd) 58 | # causal mask to ensure that attention is only applied to the left in the input sequence 59 | mask = torch.tril(torch.ones(config.block_size, 60 | config.block_size)) 61 | if hasattr(config, "n_unmasked"): 62 | mask[:config.n_unmasked, :config.n_unmasked] = 1 63 | self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) 64 | self.n_head = config.n_head 65 | 66 | def forward(self, x, layer_past=None, rotary_pos_emb=None): 67 | B, T, C = x.size() 68 | 69 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 70 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 71 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 72 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 73 | 74 | if rotary_pos_emb is not None: 75 | l = rotary_pos_emb.shape[-1] 76 | (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) 77 | ql, kl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl)) 78 | q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) 79 | 80 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 81 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 82 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 83 | att = F.softmax(att, dim=-1) 84 | att = self.attn_drop(att) 85 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 86 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 87 | 88 | # output projection 89 | y = self.resid_drop(self.proj(y)) 90 | return y 91 | 92 | def forward_with_past(self, x, layer_past=None, rotary_pos_emb=None): 93 | assert rotary_pos_emb is None, 'just for debugging' 94 | B, T, C = x.size() 95 | 96 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 97 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 98 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 99 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 100 | 101 | if rotary_pos_emb is not None: 102 | l = rotary_pos_emb.shape[-1] 103 | (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) 104 | ql, kl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl)) 105 | q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) 106 | 107 | present = torch.stack((k, v)) 108 | #present = torch.stack((k.clone(), v.clone())) # (2, B, nh, 1, hs) 109 | 110 | if layer_past is not None: 111 | past_key, past_value = layer_past 112 | k = torch.cat((past_key, k), dim=-2) 113 | v = torch.cat((past_value, v), dim=-2) 114 | 115 | # causal self-attention; Self-attend: (B, nh, Tq, hs) x (B, nh, hs, Tk) -> (B, nh, Tq, Tk) 116 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 117 | if layer_past is None: 118 | pass 119 | 120 | att = F.softmax(att, dim=-1) 121 | att = self.attn_drop(att) 122 | y = att @ v # (B, nh, Tq, Tk) x (B, nh, Tk, hs) -> (B, nh, Tq, hs) 123 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 124 | 125 | # output projection 126 | y = self.resid_drop(self.proj(y)) 127 | return y, present 128 | 129 | 130 | class CausalCrossAttention(nn.Module): 131 | """ 132 | A vanilla multi-head masked self-attention layer with a projection at the end. 133 | It is possible to use torch.nn.MultiheadAttention here but I am including an 134 | explicit implementation here to show that there is nothing too scary here. 135 | """ 136 | 137 | def __init__(self, config): 138 | super().__init__() 139 | assert config.n_embd % config.n_head == 0, f"n_embd is {config.n_embd} but n_head is {config.n_head}." 140 | # key, query, value projections for all heads 141 | self.key = nn.Linear(config.n_embd, config.n_embd) 142 | self.query = nn.Linear(config.n_embd, config.n_embd) 143 | self.value = nn.Linear(config.n_embd, config.n_embd) 144 | # regularization 145 | self.attn_drop = nn.Dropout(config.attn_pdrop) 146 | self.resid_drop = nn.Dropout(config.resid_pdrop) 147 | # output projection 148 | self.proj = nn.Linear(config.n_embd, config.n_embd) 149 | # causal mask to ensure that attention is only applied to the left in the input sequence 150 | block_size = config.block_size 151 | cond_length = config.n_unmasked 152 | data_length = block_size-cond_length+1 153 | mask = np.zeros((data_length, block_size), dtype=np.float32) 154 | mask[:,:cond_length]=1 # make conditioning visible 155 | submask=np.tril(np.ones((data_length-1,data_length-1), dtype=np.float32)) # causal submask 156 | mask[1:,cond_length:] = submask 157 | mask = torch.tensor(mask) 158 | 159 | self.register_buffer("mask", mask.view(1, 1, data_length, block_size)) 160 | self.n_head = config.n_head 161 | 162 | def forward(self, x_q, x_kv, layer_past=None, rotary_pos_emb=None): 163 | B, T_q, C = x_q.size() 164 | _B, T_kv, _C = x_kv.size() 165 | assert B==_B and C==_C 166 | 167 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 168 | q = self.query(x_q).view(B, T_q, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 169 | k = self.key(x_kv).view(B, T_kv, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 170 | v = self.value(x_kv).view(B, T_kv, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 171 | 172 | if rotary_pos_emb is not None: 173 | l = rotary_pos_emb.shape[-1] 174 | (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) 175 | ql, kl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl)) 176 | q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) 177 | 178 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 179 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 180 | att = att.masked_fill(self.mask[:,:,:T_q,:T_kv] == 0, float('-inf')) 181 | att = F.softmax(att, dim=-1) 182 | att = self.attn_drop(att) 183 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 184 | y = y.transpose(1, 2).contiguous().view(B, T_q, C) # re-assemble all head outputs side by side 185 | 186 | # output projection 187 | y = self.resid_drop(self.proj(y)) 188 | return y 189 | 190 | def forward_with_past(self, x, layer_past=None, rotary_pos_emb=None): 191 | raise NotImplementedError(":(") 192 | 193 | 194 | class Block(nn.Module): 195 | """ an unassuming Transformer block """ 196 | def __init__(self, config): 197 | super().__init__() 198 | self.ln1 = nn.LayerNorm(config.n_embd) 199 | self.ln2 = nn.LayerNorm(config.n_embd) 200 | self.attn = CausalSelfAttention(config) 201 | self.mlp = nn.Sequential( 202 | nn.Linear(config.n_embd, 4 * config.n_embd), 203 | nn.GELU(), # nice 204 | nn.Linear(4 * config.n_embd, config.n_embd), 205 | nn.Dropout(config.resid_pdrop), 206 | ) 207 | 208 | def forward(self, x, rotary_pos_emb=None): 209 | x = x + self.attn(self.ln1(x), rotary_pos_emb=rotary_pos_emb) 210 | x = x + self.mlp(self.ln2(x)) 211 | return x 212 | 213 | def forward_with_past(self, x, rotary_pos_emb=None, layer_past=None): 214 | assert rotary_pos_emb is None, 'just for debugging' 215 | attn, present = self.attn.forward_with_past(self.ln1(x), rotary_pos_emb=rotary_pos_emb, layer_past=layer_past) 216 | # layer past: tuple of length two with B, nh, T, hs 217 | x = x + attn 218 | x = x + self.mlp(self.ln2(x)) 219 | return x, present 220 | 221 | 222 | class CrossBlock(nn.Module): 223 | """ an unassuming Transformer block """ 224 | def __init__(self, config): 225 | super().__init__() 226 | self.lnq = nn.LayerNorm(config.n_embd) 227 | self.ln1 = nn.LayerNorm(config.n_embd) 228 | self.ln2 = nn.LayerNorm(config.n_embd) 229 | self.attn = CausalCrossAttention(config) 230 | self.mlp = nn.Sequential( 231 | nn.Linear(config.n_embd, 4 * config.n_embd), 232 | nn.GELU(), # nice 233 | nn.Linear(4 * config.n_embd, config.n_embd), 234 | nn.Dropout(config.resid_pdrop), 235 | ) 236 | 237 | def forward(self, x_q, x_kv, rotary_pos_emb=None): 238 | x = x_q + self.attn(x_q=self.lnq(x_q), 239 | x_kv=self.ln1(x_kv), 240 | rotary_pos_emb=rotary_pos_emb) 241 | x = x + self.mlp(self.ln2(x)) 242 | return x 243 | 244 | def forward_with_past(self, x, rotary_pos_emb=None, layer_past=None): 245 | raise NotImplementedError() 246 | 247 | 248 | class GPT(nn.Module): 249 | """ the full GPT language model, with a context size of block_size """ 250 | def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, 251 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0, 252 | input_vocab_size=None, autoset_layers=False): 253 | super().__init__() 254 | recommended_layers = int(np.around(layers_from_width(n_embd))) 255 | if autoset_layers: 256 | n_layer = recommended_layers 257 | print(f"Training with a width of n_embed = {n_embd} and L = {n_layer} layers. " 258 | f"https://arxiv.org/pdf/2006.12467.pdf suggest that one should use {recommended_layers} layers.") 259 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size, 260 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, 261 | n_layer=n_layer, n_head=n_head, n_embd=n_embd, 262 | n_unmasked=n_unmasked) 263 | 264 | # input embedding stem 265 | in_vocab_size = vocab_size if not input_vocab_size else input_vocab_size 266 | self.tok_emb = nn.Embedding(in_vocab_size, config.n_embd) 267 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 268 | self.drop = nn.Dropout(config.embd_pdrop) 269 | # transformer 270 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 271 | # decoder head 272 | self.ln_f = nn.LayerNorm(config.n_embd) 273 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 274 | self.block_size = config.block_size 275 | self.apply(self._init_weights) 276 | self.config = config 277 | 278 | def get_block_size(self): 279 | return self.block_size 280 | 281 | def _init_weights(self, module): 282 | if isinstance(module, (nn.Linear, nn.Embedding)): 283 | module.weight.data.normal_(mean=0.0, std=0.02) 284 | if isinstance(module, nn.Linear) and module.bias is not None: 285 | module.bias.data.zero_() 286 | elif isinstance(module, nn.LayerNorm): 287 | module.bias.data.zero_() 288 | module.weight.data.fill_(1.0) 289 | 290 | def forward(self, idx, embeddings=None, targets=None, return_layers=False, token_embeddings=None): 291 | # forward the GPT model 292 | if token_embeddings is None: 293 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 294 | if embeddings is not None: # prepend explicit embeddings 295 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) 296 | 297 | t = token_embeddings.shape[1] 298 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 299 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 300 | x = self.drop(token_embeddings + position_embeddings) 301 | 302 | if return_layers: 303 | layers = [x] 304 | for block in self.blocks: 305 | x = block(x) 306 | layers.append(x) 307 | return layers 308 | 309 | x = self.blocks(x) 310 | x = self.ln_f(x) 311 | logits = self.head(x) 312 | 313 | # if we are given some desired targets also calculate the loss 314 | loss = None 315 | if targets is not None: 316 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 317 | return logits, loss 318 | 319 | def forward_with_past(self, idx, embeddings=None, targets=None, token_embeddings=None, 320 | past=None, past_length=None): 321 | 322 | if token_embeddings is None: 323 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 324 | if embeddings is not None: # prepend explicit embeddings 325 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) 326 | assert not self.training 327 | if past is not None: 328 | assert past_length is not None 329 | past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head 330 | past_shape = list(past.shape) 331 | expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head] 332 | assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}" 333 | position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector 334 | else: 335 | position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :] 336 | 337 | x = self.drop(token_embeddings + position_embeddings) 338 | 339 | presents = [] # accumulate over layers 340 | for i, block in enumerate(self.blocks): 341 | x, present = block.forward_with_past(x, layer_past=past[i, ...] if past is not None else None) # take from layer 342 | presents.append(present) 343 | 344 | x = self.ln_f(x) 345 | logits = self.head(x) 346 | # if we are given some desired targets also calculate the loss 347 | loss = None 348 | if targets is not None: 349 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 350 | 351 | return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head 352 | 353 | 354 | def top_k_logits(logits, k): 355 | v, ix = torch.topk(logits, k) 356 | out = logits.clone() 357 | out[out < v[:, [-1]]] = -float('Inf') 358 | return out 359 | 360 | @torch.no_grad() 361 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): 362 | """ 363 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 364 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 365 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 366 | of block_size, unlike an RNN that has an infinite context window. 367 | """ 368 | block_size = model.get_block_size() 369 | model.eval() 370 | for k in range(steps): 371 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed 372 | logits, _ = model(x_cond) 373 | # pluck the logits at the final step and scale by temperature 374 | logits = logits[:, -1, :] / temperature 375 | # optionally crop probabilities to only the top k options 376 | if top_k is not None: 377 | logits = top_k_logits(logits, top_k) 378 | # apply softmax to convert to probabilities 379 | probs = F.softmax(logits, dim=-1) 380 | # sample from the distribution or take the most likely 381 | if sample: 382 | ix = torch.multinomial(probs, num_samples=1) 383 | else: 384 | _, ix = torch.topk(probs, k=1, dim=-1) 385 | # append to the sequence and continue 386 | x = torch.cat((x, ix), dim=1) 387 | 388 | return x 389 | 390 | 391 | @torch.no_grad() 392 | def sample_with_past(x, model, steps, temperature=1., sample_logits=True, 393 | top_k=None, callback=None, guide=None, top_p=None, 394 | embeddings=None): 395 | # x is conditioning 396 | sample = x 397 | cond_len = x.shape[1] 398 | if embeddings is not None: 399 | cond_len += embeddings.shape[1] 400 | past = None 401 | for n in range(steps): 402 | if callback is not None: 403 | callback(n) 404 | logits, _, present = model.forward_with_past(x, embeddings=embeddings, past=past, past_length=(n+cond_len-1)) 405 | embeddings = None # only pass in first time 406 | if past is None: 407 | past = [present] 408 | else: 409 | past.append(present) 410 | if guide is not None: 411 | logits = logits + guide[:, [n]] 412 | logits = logits[:, -1, :] / temperature 413 | if top_k is not None: 414 | if top_p is not None: 415 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 416 | else: 417 | logits = top_k_logits(logits, top_k) 418 | 419 | probs = F.softmax(logits, dim=-1) 420 | if not sample_logits: 421 | _, x = torch.topk(probs, k=1, dim=-1) 422 | else: 423 | x = torch.multinomial(probs, num_samples=1) 424 | # append to the sequence and continue 425 | sample = torch.cat((sample, x), dim=1) 426 | del past 427 | sample = sample[:, -steps:] # cut conditioning off 428 | return sample 429 | 430 | 431 | @torch.no_grad() 432 | def sample_vanilla(x, model, steps, temperature=1.0, sample_logits=False, 433 | top_k=None, embeddings=None): 434 | block_size = model.get_block_size() 435 | model.eval() 436 | for k in range(steps): 437 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed 438 | logits, _ = model(x_cond, embeddings=embeddings) 439 | # pluck the logits at the final step and scale by temperature 440 | logits = logits[:, -1, :] / temperature 441 | # optionally crop probabilities to only the top k options 442 | if top_k is not None: 443 | logits = top_k_logits(logits, top_k) 444 | # apply softmax to convert to probabilities 445 | probs = F.softmax(logits, dim=-1) 446 | # sample from the distribution or take the most likely 447 | if sample_logits: 448 | ix = torch.multinomial(probs, num_samples=1) 449 | else: 450 | _, ix = torch.topk(probs, k=1, dim=-1) 451 | # append to the sequence and continue 452 | x = torch.cat((x, ix), dim=1) 453 | x = x[:, -steps:] 454 | return x 455 | 456 | 457 | def seed(seed=42): 458 | np.random.seed(seed) 459 | torch.manual_seed(seed) 460 | torch.cuda.manual_seed(seed) 461 | torch.cuda.manual_seed_all(seed) 462 | torch.backends.cudnn.deterministic = True 463 | torch.backends.cudnn.benchmark = False 464 | 465 | 466 | def test1(): 467 | import time 468 | #torch.use_deterministic_algorithms(True) 469 | SEED = 142 470 | h = w = 4 471 | b = 1 472 | clen = 2 473 | cb = lambda n: print(n) 474 | cb = None 475 | elen = 3 476 | 477 | SAMPLE = True 478 | DEVICE = "cpu" 479 | 480 | # test past 481 | device = torch.device(DEVICE) 482 | model = GPT(vocab_size=1024, block_size=h*w+(elen+clen-1), n_embd=32).to(device).eval() 483 | x = torch.randint(0, 1024, size=(b, clen)).to(device) # start 484 | emb = torch.randn(b, elen, 32) 485 | 486 | print(f"in goes: {x}") 487 | 488 | # with past 489 | seed(SEED) 490 | t0 = time.time() 491 | s0 = sample_with_past(x, model, embeddings=emb, steps=h * w, sample_logits=SAMPLE, callback=cb) 492 | t1 = time.time() 493 | 494 | # without past 495 | seed(SEED) 496 | s1 = sample_vanilla(x, model, embeddings=emb, steps=h*w, sample_logits=SAMPLE) 497 | t2 = time.time() 498 | 499 | print(f"s0 (with past): time = {t1-t0:.2f}s") 500 | print(s0) 501 | print(f"s1 (no past): time = {t2-t1:.2f}s") 502 | print(s1) 503 | print("are equal:", torch.equal(s0, s1)) 504 | print("done.") 505 | 506 | 507 | if __name__ == "__main__": 508 | test1() 509 | -------------------------------------------------------------------------------- /imagebart/modules/transformer/vit.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_pytorch.py 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat 6 | from torch import nn 7 | 8 | MIN_NUM_PATCHES = 16 9 | 10 | 11 | class Residual(nn.Module): 12 | def __init__(self, fn): 13 | super().__init__() 14 | self.fn = fn 15 | def forward(self, x, **kwargs): 16 | return self.fn(x, **kwargs) + x 17 | 18 | 19 | class PreNorm(nn.Module): 20 | def __init__(self, dim, fn): 21 | super().__init__() 22 | self.norm = nn.LayerNorm(dim) 23 | self.fn = fn 24 | def forward(self, x, **kwargs): 25 | return self.fn(self.norm(x), **kwargs) 26 | 27 | 28 | class FeedForward(nn.Module): 29 | def __init__(self, dim, hidden_dim, dropout = 0.): 30 | super().__init__() 31 | self.net = nn.Sequential( 32 | nn.Linear(dim, hidden_dim), 33 | nn.GELU(), 34 | nn.Dropout(dropout), 35 | nn.Linear(hidden_dim, dim), 36 | nn.Dropout(dropout) 37 | ) 38 | def forward(self, x): 39 | return self.net(x) 40 | 41 | 42 | class Attention(nn.Module): 43 | def __init__(self, dim, heads = 8, dropout = 0.): 44 | super().__init__() 45 | self.heads = heads 46 | self.scale = dim ** -0.5 47 | 48 | self.to_qkv = nn.Linear(dim, dim * 3, bias = False) 49 | self.to_out = nn.Sequential( 50 | nn.Linear(dim, dim), 51 | nn.Dropout(dropout) 52 | ) 53 | 54 | def forward(self, x, mask = None): 55 | b, n, _, h = *x.shape, self.heads 56 | qkv = self.to_qkv(x).chunk(3, dim = -1) 57 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 58 | 59 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 60 | mask_value = -torch.finfo(dots.dtype).max 61 | 62 | if mask is not None: 63 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 64 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 65 | mask = mask[:, None, :] * mask[:, :, None] 66 | dots.masked_fill_(~mask, mask_value) 67 | del mask 68 | 69 | attn = dots.softmax(dim=-1) 70 | 71 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 72 | out = rearrange(out, 'b h n d -> b n (h d)') 73 | out = self.to_out(out) 74 | return out 75 | 76 | 77 | class Transformer(nn.Module): 78 | def __init__(self, dim, depth, heads, mlp_dim, dropout): 79 | super().__init__() 80 | self.layers = nn.ModuleList([]) 81 | for _ in range(depth): 82 | self.layers.append(nn.ModuleList([ 83 | Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))), 84 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 85 | ])) 86 | def forward(self, x, mask = None): 87 | for attn, ff in self.layers: 88 | x = attn(x, mask = mask) 89 | x = ff(x) 90 | return x 91 | 92 | 93 | class ViT(nn.Module): 94 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.): 95 | super().__init__() 96 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 97 | num_patches = (image_size // patch_size) ** 2 98 | patch_dim = channels * patch_size ** 2 99 | assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' 100 | 101 | self.patch_size = patch_size 102 | 103 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 104 | self.patch_to_embedding = nn.Linear(patch_dim, dim) 105 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 106 | self.dropout = nn.Dropout(emb_dropout) 107 | 108 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) 109 | self.to_cls_token = nn.Identity() 110 | 111 | self.mlp_head = nn.Sequential( 112 | nn.LayerNorm(dim), 113 | nn.Linear(dim, mlp_dim), 114 | nn.GELU(), 115 | nn.Dropout(dropout), 116 | nn.Linear(mlp_dim, num_classes) 117 | ) 118 | 119 | def forward(self, img, mask = None): 120 | p = self.patch_size 121 | 122 | x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) 123 | x = self.patch_to_embedding(x) 124 | b, n, _ = x.shape 125 | 126 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 127 | x = torch.cat((cls_tokens, x), dim=1) 128 | x += self.pos_embedding[:, :(n + 1)] 129 | x = self.dropout(x) 130 | 131 | x = self.transformer(x, mask) 132 | 133 | x = self.to_cls_token(x[:, 0]) 134 | return self.mlp_head(x) 135 | 136 | 137 | if __name__ == "__main__": 138 | from torchsummary import summary 139 | v = ViT( 140 | image_size=256, 141 | patch_size=32, 142 | num_classes=1000, 143 | dim=1024, 144 | depth=6, 145 | heads=8, 146 | mlp_dim=2048, 147 | dropout=0.1, 148 | emb_dropout=0.1 149 | ).to("cuda") 150 | 151 | img = torch.randn(1, 3, 256, 256).to("cuda") 152 | 153 | summary(v, img.shape[1:]) 154 | mask = torch.ones(1, 8, 8).bool().to("cuda") # optional mask, designating which patch to attend to 155 | preds = v(img, mask=mask) # (1, 1000) 156 | -------------------------------------------------------------------------------- /imagebart/modules/transformer/warper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from einops import rearrange 5 | 6 | from braket.modules.warp.midas import Midas 7 | from braket.modules.transformer.util import Block as TransformerBlock 8 | from braket.modules.transformer.util import TinyTransformer 9 | 10 | 11 | def disabled_train(self, mode=True): 12 | """Overwrite model.train with this function to make sure train/eval mode 13 | does not change anymore.""" 14 | return self 15 | 16 | 17 | class AbstractWarper(nn.Module): 18 | def __init__(self, *args, **kwargs): 19 | super().__init__() 20 | self._midas = Midas() 21 | self._midas.eval() 22 | self._midas.train = disabled_train 23 | for param in self._midas.parameters(): 24 | param.requires_grad = False 25 | 26 | self.n_unmasked = kwargs["n_unmasked"] # length of conditioning 27 | self.n_embd = kwargs["n_embd"] 28 | self.block_size = kwargs["block_size"] 29 | self.size = kwargs["size"] # h, w tuple 30 | self.start_idx = kwargs.get("start_idx", 0) # hint to not modify parts 31 | 32 | self._use_cache = False 33 | self.new_emb = None # cache 34 | self.new_pos = None # cache 35 | 36 | def set_cache(self, value): 37 | self._use_cache = value 38 | 39 | def get_embeddings(self, token_embeddings, position_embeddings, warpkwargs): 40 | if self._use_cache: 41 | assert not self.training, "Do you really want to use caching during training?" 42 | assert self.new_emb is not None 43 | assert self.new_pos is not None 44 | return self.new_emb, self.new_pos 45 | self.new_emb, self.new_pos = self._get_embeddings(token_embeddings, 46 | position_embeddings, 47 | warpkwargs) 48 | return self.new_emb, self.new_pos 49 | 50 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs): 51 | raise NotImplementedError() 52 | 53 | def forward(self, token_embeddings, position_embeddings, warpkwargs): 54 | new_emb, new_pos = self.get_embeddings(token_embeddings, 55 | position_embeddings, 56 | warpkwargs) 57 | 58 | new_emb = torch.cat([new_emb, token_embeddings[:,self.n_unmasked:,:]], 59 | dim=1) 60 | b = new_pos.shape[0] 61 | new_pos = torch.cat([new_pos, position_embeddings[:,self.n_unmasked:,:][b*[0],...]], 62 | dim=1) 63 | 64 | return new_emb, new_pos 65 | 66 | def _to_sequence(self, x): 67 | x = rearrange(x, 'b c h w -> b (h w) c') 68 | return x 69 | 70 | def _to_imglike(self, x): 71 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.size[0]) 72 | return x 73 | 74 | 75 | class AbstractWarperWithCustomEmbedding(AbstractWarper): 76 | def __init__(self, *args, **kwargs): 77 | super().__init__() 78 | self.pos_emb = nn.Parameter(torch.zeros(1, self.block_size, self.n_embd)) 79 | 80 | 81 | class NoSourceWarper(AbstractWarper): 82 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs): 83 | cond_emb = token_embeddings[:,:self.n_unmasked,:] 84 | cond_pos = position_embeddings[:,:self.n_unmasked,:] 85 | 86 | b, seq_length, chn = cond_emb.shape 87 | cond_emb = self._to_imglike(cond_emb) 88 | 89 | cond_pos = self._to_imglike(cond_pos) 90 | cond_pos = cond_pos[b*[0],...] 91 | 92 | new_emb, _ = self._midas.warp_features(f=cond_emb, no_depth_grad=True, 93 | boltzmann_factor=0.0, 94 | **warpkwargs) 95 | new_pos, _ = self._midas.warp_features(f=cond_pos, no_depth_grad=True, 96 | boltzmann_factor=0.0, 97 | **warpkwargs) 98 | new_emb = self._filter_nans(new_emb) 99 | new_pos = self._filter_nans(new_pos) 100 | 101 | new_emb = self._to_sequence(new_emb) 102 | new_pos = self._to_sequence(new_pos) 103 | return new_emb, new_pos 104 | 105 | def _filter_nans(self, x): 106 | x[torch.isnan(x)] = 0. 107 | return x 108 | 109 | 110 | class ConvWarper(AbstractWarper): 111 | def __init__(self, *args, **kwargs): 112 | super().__init__(*args, **kwargs) 113 | self.emb_conv = nn.Conv2d(2*self.n_embd, self.n_embd, 114 | kernel_size=1, 115 | padding=0, 116 | bias=False) 117 | self.pos_conv = nn.Conv2d(2*self.n_embd, self.n_embd, 118 | kernel_size=1, 119 | padding=0, 120 | bias=False) 121 | 122 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs): 123 | cond_emb = token_embeddings[:,self.start_idx:self.n_unmasked,:] 124 | cond_pos = position_embeddings[:,self.start_idx:self.n_unmasked,:] 125 | 126 | b, seq_length, chn = cond_emb.shape 127 | cond_emb = cond_emb.reshape(b, self.size[0], self.size[1], chn) 128 | cond_emb = cond_emb.permute(0,3,1,2) 129 | 130 | cond_pos = cond_pos.reshape(1, self.size[0], self.size[1], chn) 131 | cond_pos = cond_pos.permute(0,3,1,2) 132 | cond_pos = cond_pos[b*[0],...] 133 | 134 | with torch.no_grad(): 135 | warp_emb, _ = self._midas.warp_features(f=cond_emb, no_depth_grad=True, **warpkwargs) 136 | warp_pos, _ = self._midas.warp_features(f=cond_pos, no_depth_grad=True, **warpkwargs) 137 | 138 | new_emb = self.emb_conv(torch.cat([cond_emb, warp_emb], dim=1)) 139 | new_pos = self.pos_conv(torch.cat([cond_pos, warp_pos], dim=1)) 140 | 141 | new_emb = new_emb.permute(0,2,3,1) 142 | new_emb = new_emb.reshape(b,seq_length,chn) 143 | 144 | new_pos = new_pos.permute(0,2,3,1) 145 | new_pos = new_pos.reshape(b,seq_length,chn) 146 | 147 | # prepend unmodified ones again 148 | new_emb = torch.cat((token_embeddings[:,:self.start_idx,:], new_emb), 149 | dim=1) 150 | new_pos = torch.cat((position_embeddings[:,:self.start_idx,:][b*[0],...], new_pos), 151 | dim=1) 152 | 153 | return new_emb, new_pos 154 | 155 | 156 | class ParallelAttentionWarper(AbstractWarper): 157 | def __init__(self, *args, **kwargs): 158 | super().__init__(*args, **kwargs) 159 | self.pos_block = nn.ModuleList([TinyTransformer(block_size=self.block_size, 160 | n_layer=2, 161 | n_head=16, 162 | n_embd=self.n_embd, 163 | use_head=False) for _ in range(3)]) 164 | self.emb_block = nn.ModuleList([TinyTransformer(block_size=self.block_size, 165 | n_layer=2, 166 | n_head=16, 167 | n_embd=self.n_embd, 168 | use_head=False) for _ in range(3)]) 169 | 170 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs): 171 | cond_emb = token_embeddings[:,:self.n_unmasked,:] 172 | cond_pos = position_embeddings[:,:self.n_unmasked,:] 173 | 174 | b, seq_length, chn = cond_emb.shape 175 | cond_emb = self._to_imglike(cond_emb) 176 | 177 | cond_pos = self._to_imglike(cond_pos) 178 | cond_pos = cond_pos[b*[0],...] 179 | 180 | with torch.no_grad(): 181 | warp_emb, _ = self._midas.warp_features(f=cond_emb, no_depth_grad=True, **warpkwargs) 182 | warp_pos, _ = self._midas.warp_features(f=cond_pos, no_depth_grad=True, **warpkwargs) 183 | 184 | warp_emb = self._to_sequence(warp_emb) 185 | warp_pos = self._to_sequence(warp_pos) 186 | 187 | cond_emb = self._to_sequence(cond_emb) 188 | cond_pos = self._to_sequence(cond_pos) 189 | 190 | cond_emb = self.emb_block[0](idx=None, token_embeddings=cond_emb) + cond_emb 191 | warp_emb = self.emb_block[1](idx=None, token_embeddings=warp_emb) + warp_emb 192 | new_emb = self.emb_block[2](idx=None, token_embeddings=warp_emb+cond_emb) 193 | 194 | cond_pos = self.pos_block[0](idx=None, token_embeddings=cond_pos) + cond_pos 195 | warp_pos = self.pos_block[1](idx=None, token_embeddings=warp_pos) + warp_pos 196 | new_pos = self.pos_block[2](idx=None, token_embeddings=warp_pos+cond_pos) 197 | 198 | return new_emb, new_pos 199 | -------------------------------------------------------------------------------- /imagebart/modules/xtransformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/modules/xtransformers/__init__.py -------------------------------------------------------------------------------- /imagebart/modules/xtransformers/autoregressive_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | try: 6 | from entmax import entmax_bisect 7 | except ModuleNotFoundError: 8 | print("'entmax' not installed. skipping.") 9 | entmax_bisect = None 10 | 11 | 12 | # nucleus 13 | 14 | def top_p(logits, thres=0.9): 15 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 16 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 17 | 18 | sorted_indices_to_remove = cum_probs > (1 - thres) 19 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() 20 | sorted_indices_to_remove[:, 0] = 0 21 | 22 | sorted_logits[sorted_indices_to_remove] = float('-inf') 23 | return sorted_logits.scatter(1, sorted_indices, sorted_logits) 24 | 25 | 26 | # topk 27 | 28 | def top_k(logits, thres=0.9): 29 | k = int((1 - thres) * logits.shape[-1]) 30 | val, ind = torch.topk(logits, k) 31 | probs = torch.full_like(logits, float('-inf')) 32 | probs.scatter_(1, ind, val) 33 | return probs 34 | 35 | 36 | # entmax 37 | 38 | ENTMAX_ALPHA = 1.3 39 | entmax = entmax_bisect 40 | 41 | 42 | class AutoregressiveWrapper(nn.Module): 43 | def __init__(self, net, ignore_index=-100, pad_value=0): 44 | super().__init__() 45 | self.pad_value = pad_value 46 | self.ignore_index = ignore_index 47 | 48 | self.net = net 49 | self.max_seq_len = net.max_seq_len 50 | 51 | @torch.no_grad() 52 | def generate(self, start_tokens, seq_len, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, 53 | **kwargs): 54 | device = start_tokens.device 55 | was_training = self.net.training 56 | num_dims = len(start_tokens.shape) 57 | 58 | if num_dims == 1: 59 | start_tokens = start_tokens[None, :] 60 | 61 | b, t = start_tokens.shape 62 | 63 | self.net.eval() 64 | out = start_tokens 65 | mask = kwargs.pop('mask', None) 66 | 67 | if mask is None: 68 | mask = torch.full_like(out, True, dtype=torch.bool, device=out.device) 69 | 70 | for _ in range(seq_len): 71 | x = out[:, -self.max_seq_len:] 72 | mask = mask[:, -self.max_seq_len:] 73 | 74 | logits = self.net(x, mask=mask, **kwargs)[:, -1, :] 75 | 76 | if filter_logits_fn in {top_k, top_p}: 77 | filtered_logits = filter_logits_fn(logits, thres=filter_thres) 78 | probs = F.softmax(filtered_logits / temperature, dim=-1) 79 | 80 | elif filter_logits_fn is entmax: 81 | probs = entmax(logits / temperature, alpha=ENTMAX_ALPHA, dim=-1) 82 | 83 | sample = torch.multinomial(probs, 1) 84 | 85 | out = torch.cat((out, sample), dim=-1) 86 | mask = F.pad(mask, (0, 1), value=True) 87 | 88 | if eos_token is not None and (sample == eos_token).all(): 89 | break 90 | 91 | out = out[:, t:] 92 | 93 | if num_dims == 1: 94 | out = out.squeeze(0) 95 | 96 | self.net.train(was_training) 97 | return out 98 | 99 | def forward(self, x, return_logits=True, **kwargs): 100 | xi = x[:, :-1] 101 | xo = x[:, 1:] 102 | 103 | # help auto-solve a frequent area of confusion around input masks in auto-regressive 104 | # if user supplies a mask that is only off by one from the source sequence, resolve it for them 105 | mask = kwargs.get('mask', None) 106 | if mask is not None and mask.shape[1] == x.shape[1]: 107 | mask = mask[:, :-1] 108 | kwargs['mask'] = mask 109 | 110 | out = self.net(xi, **kwargs) 111 | loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index=self.ignore_index) 112 | if return_logits: 113 | return out, loss 114 | return loss 115 | -------------------------------------------------------------------------------- /imagebart/modules/xtransformers/positional_embeddings.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange 7 | from inspect import isfunction 8 | 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | 14 | def default(val, d): 15 | if exists(val): 16 | return val 17 | return d() if isfunction(d) else d 18 | 19 | 20 | # positional embeddings 21 | class DepthWiseConv1d(nn.Module): 22 | def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True, groups=False): 23 | super().__init__() 24 | groups = default(groups, dim_in) 25 | self.net = nn.Sequential( 26 | nn.Conv1d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride, 27 | bias=bias), 28 | nn.Conv1d(dim_in, dim_out, 1) 29 | ) 30 | 31 | def forward(self, x): 32 | return self.net(x) 33 | 34 | 35 | class AbsolutePositionalEmbedding(nn.Module): 36 | def __init__(self, dim, max_seq_len): 37 | super().__init__() 38 | self.emb = nn.Embedding(max_seq_len, dim) 39 | self.init_() 40 | 41 | def init_(self): 42 | nn.init.normal_(self.emb.weight, std=0.02) 43 | 44 | def forward(self, x): 45 | n = torch.arange(x.shape[1], device=x.device) 46 | return self.emb(n)[None, :, :] 47 | 48 | 49 | class FixedPositionalEmbedding(nn.Module): 50 | def __init__(self, dim): 51 | super().__init__() 52 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 53 | self.register_buffer('inv_freq', inv_freq) 54 | 55 | def forward(self, x, seq_dim=1, offset=0): 56 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset 57 | sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) 58 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) 59 | return emb[None, :, :] 60 | 61 | 62 | class RelativePositionBias(nn.Module): 63 | def __init__(self, causal=False, num_buckets=32, max_distance=128, heads=8): 64 | super().__init__() 65 | self.causal = causal 66 | self.num_buckets = num_buckets 67 | self.max_distance = max_distance 68 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 69 | 70 | @staticmethod 71 | def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): 72 | ret = 0 73 | n = -relative_position 74 | if not causal: 75 | num_buckets //= 2 76 | ret += (n < 0).long() * num_buckets 77 | n = torch.abs(n) 78 | else: 79 | n = torch.max(n, torch.zeros_like(n)) 80 | 81 | max_exact = num_buckets // 2 82 | is_small = n < max_exact 83 | 84 | val_if_large = max_exact + ( 85 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 86 | ).long() 87 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 88 | 89 | ret += torch.where(is_small, n, val_if_large) 90 | return ret 91 | 92 | def forward(self, qk_dots): 93 | i, j, device = *qk_dots.shape[-2:], qk_dots.device 94 | q_pos = torch.arange(i, dtype=torch.long, device=device) 95 | k_pos = torch.arange(j, dtype=torch.long, device=device) 96 | rel_pos = k_pos[None, :] - q_pos[:, None] 97 | rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, 98 | max_distance=self.max_distance) 99 | values = self.relative_attention_bias(rp_bucket) 100 | bias = rearrange(values, 'i j h -> () h i j') 101 | return qk_dots + bias 102 | 103 | 104 | class RotaryEmbedding(nn.Module): 105 | def __init__(self, dim): 106 | super().__init__() 107 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 108 | self.register_buffer('inv_freq', inv_freq) 109 | 110 | def forward(self, x, seq_dim=1): 111 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) 112 | freqs = torch.einsum('i , j -> i j', t, self.inv_freq) 113 | emb = torch.cat((freqs, freqs), dim=-1) 114 | return emb[None, :, :] 115 | 116 | 117 | def rotate_half(x): 118 | x = rearrange(x, '... (j d) -> ... j d', j=2) 119 | x1, x2 = x.unbind(dim=-2) 120 | return torch.cat((-x2, x1), dim=-1) 121 | 122 | 123 | def apply_rotary_pos_emb(t, freqs): 124 | seq_len = t.shape[-2] 125 | freqs = freqs[:, :, -seq_len:] 126 | return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) -------------------------------------------------------------------------------- /imagebart/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image, ImageDraw, ImageFont 4 | 5 | 6 | class ClassProvider(torch.nn.Module): 7 | def __init__(self, key="class"): 8 | super().__init__() 9 | self.key = key 10 | 11 | def forward(self, batch): 12 | c = batch[self.key][:, None] 13 | return c 14 | 15 | 16 | class BasicTokenizer(torch.nn.Module): 17 | """ 18 | Uses the 'simple tokenizer' of the CLIP model 19 | https://github.com/openai/CLIP 20 | """ 21 | def __init__(self, device="cuda", key="caption"): 22 | super().__init__() 23 | from clip import tokenize 24 | self.tknz_fn = tokenize 25 | self.device = device 26 | self.key = key 27 | 28 | def forward(self, batch): 29 | text = batch[self.key] 30 | tokens = self.tknz_fn(text).to(self.device) 31 | return tokens 32 | 33 | 34 | class KeyNotFoundError(Exception): 35 | def __init__(self, cause, keys=None, visited=None): 36 | self.cause = cause 37 | self.keys = keys 38 | self.visited = visited 39 | messages = list() 40 | if keys is not None: 41 | messages.append("Key not found: {}".format(keys)) 42 | if visited is not None: 43 | messages.append("Visited: {}".format(visited)) 44 | messages.append("Cause:\n{}".format(cause)) 45 | message = "\n".join(messages) 46 | super().__init__(message) 47 | 48 | 49 | def log_txt_as_img(wh, xc, size=10): 50 | # wh a tuple of (width, height) 51 | # xc a list of captions to plot 52 | b = len(xc) 53 | txts = list() 54 | for bi in range(b): 55 | txt = Image.new("RGB", wh, color="white") 56 | draw = ImageDraw.Draw(txt) 57 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 58 | nc = int(40 * (wh[0]/256)) 59 | lines = "\n".join(xc[bi][start:start+nc] for start in range(0, len(xc[bi]), nc)) 60 | 61 | try: 62 | draw.text((0,0), lines, fill="black", font=font) 63 | except UnicodeEncodeError: 64 | print("Cant encode string for logging. Skipping.") 65 | 66 | txt = np.array(txt).transpose(2,0,1)/127.5-1.0 67 | txts.append(txt) 68 | txts = np.stack(txts) 69 | txts = torch.tensor(txts) 70 | return txts 71 | 72 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, datetime, glob, importlib, csv 2 | import numpy as np 3 | import time 4 | import torch 5 | import torchvision 6 | import pytorch_lightning as pl 7 | 8 | from packaging import version 9 | from omegaconf import OmegaConf 10 | from torch.utils.data import random_split, DataLoader, Dataset, Subset 11 | from functools import partial 12 | from PIL import Image 13 | 14 | from pytorch_lightning import seed_everything 15 | from pytorch_lightning.trainer import Trainer 16 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor 17 | from pytorch_lightning.utilities.distributed import rank_zero_only 18 | from pytorch_lightning.utilities import rank_zero_info 19 | 20 | from imagebart.data.base import Txt2ImgIterableBaseDataset 21 | 22 | 23 | def get_obj_from_str(string, reload=False): 24 | module, cls = string.rsplit(".", 1) 25 | if reload: 26 | module_imp = importlib.import_module(module) 27 | importlib.reload(module_imp) 28 | return getattr(importlib.import_module(module, package=None), cls) 29 | 30 | 31 | def get_parser(**parser_kwargs): 32 | def str2bool(v): 33 | if isinstance(v, bool): 34 | return v 35 | if v.lower() in ("yes", "true", "t", "y", "1"): 36 | return True 37 | elif v.lower() in ("no", "false", "f", "n", "0"): 38 | return False 39 | else: 40 | raise argparse.ArgumentTypeError("Boolean value expected.") 41 | 42 | parser = argparse.ArgumentParser(**parser_kwargs) 43 | parser.add_argument( 44 | "-n", 45 | "--name", 46 | type=str, 47 | const=True, 48 | default="", 49 | nargs="?", 50 | help="postfix for logdir", 51 | ) 52 | parser.add_argument( 53 | "-r", 54 | "--resume", 55 | type=str, 56 | const=True, 57 | default="", 58 | nargs="?", 59 | help="resume from logdir or checkpoint in logdir", 60 | ) 61 | parser.add_argument( 62 | "-b", 63 | "--base", 64 | nargs="*", 65 | metavar="base_config.yaml", 66 | help="paths to base configs. Loaded from left-to-right. " 67 | "Parameters can be overwritten or added with command-line options of the form `--key value`.", 68 | default=list(), 69 | ) 70 | parser.add_argument( 71 | "-t", 72 | "--train", 73 | type=str2bool, 74 | const=True, 75 | default=False, 76 | nargs="?", 77 | help="train", 78 | ) 79 | parser.add_argument( 80 | "--no-test", 81 | type=str2bool, 82 | const=True, 83 | default=False, 84 | nargs="?", 85 | help="disable test", 86 | ) 87 | parser.add_argument("-p", "--project", help="name of new or path to existing project") 88 | parser.add_argument( 89 | "-d", 90 | "--debug", 91 | type=str2bool, 92 | nargs="?", 93 | const=True, 94 | default=False, 95 | help="enable post-mortem debugging", 96 | ) 97 | parser.add_argument( 98 | "-s", 99 | "--seed", 100 | type=int, 101 | default=23, 102 | help="seed for seed_everything", 103 | ) 104 | parser.add_argument( 105 | "-f", 106 | "--postfix", 107 | type=str, 108 | default="", 109 | help="post-postfix for default name", 110 | ) 111 | parser.add_argument( 112 | "-l", 113 | "--logdir", 114 | type=str, 115 | default="logs", 116 | help="directory for logging dat shit", 117 | ) 118 | parser.add_argument( 119 | "--scale_lr", 120 | type=str2bool, 121 | nargs="?", 122 | const=True, 123 | default=True, 124 | help="scale base-lr by ngpu * batch_size * n_accumulate", 125 | ) 126 | return parser 127 | 128 | 129 | def nondefault_trainer_args(opt): 130 | parser = argparse.ArgumentParser() 131 | parser = Trainer.add_argparse_args(parser) 132 | args = parser.parse_args([]) 133 | return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) 134 | 135 | 136 | def instantiate_from_config(config): 137 | if not "target" in config: 138 | if config == '__is_first_stage__': 139 | return None 140 | elif config == "__is_unconditional__": 141 | return None 142 | raise KeyError("Expected key `target` to instantiate.") 143 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 144 | 145 | 146 | class WrappedDataset(Dataset): 147 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 148 | 149 | def __init__(self, dataset): 150 | self.data = dataset 151 | 152 | def __len__(self): 153 | return len(self.data) 154 | 155 | def __getitem__(self, idx): 156 | return self.data[idx] 157 | 158 | 159 | def worker_init_fn(_): 160 | worker_info = torch.utils.data.get_worker_info() 161 | 162 | dataset = worker_info.dataset 163 | worker_id = worker_info.id 164 | 165 | if isinstance(dataset, Txt2ImgIterableBaseDataset): 166 | split_size = dataset.num_records // worker_info.num_workers 167 | # reset num_records to the true number to retain reliable length information 168 | dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] 169 | current_id = np.random.choice(len(np.random.get_state()[1]), 1) 170 | return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 171 | else: 172 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 173 | 174 | 175 | class DataModuleFromConfig(pl.LightningDataModule): 176 | def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, 177 | wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, 178 | shuffle_val_dataloader=False): 179 | super().__init__() 180 | self.batch_size = batch_size 181 | self.dataset_configs = dict() 182 | self.num_workers = num_workers if num_workers is not None else batch_size * 2 183 | self.use_worker_init_fn = use_worker_init_fn 184 | if train is not None: 185 | self.dataset_configs["train"] = train 186 | self.train_dataloader = self._train_dataloader 187 | if validation is not None: 188 | self.dataset_configs["validation"] = validation 189 | self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) 190 | if test is not None: 191 | self.dataset_configs["test"] = test 192 | self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) 193 | if predict is not None: 194 | self.dataset_configs["predict"] = predict 195 | self.predict_dataloader = self._predict_dataloader 196 | self.wrap = wrap 197 | 198 | def prepare_data(self): 199 | for data_cfg in self.dataset_configs.values(): 200 | instantiate_from_config(data_cfg) 201 | 202 | def setup(self, stage=None): 203 | self.datasets = dict( 204 | (k, instantiate_from_config(self.dataset_configs[k])) 205 | for k in self.dataset_configs) 206 | if self.wrap: 207 | for k in self.datasets: 208 | self.datasets[k] = WrappedDataset(self.datasets[k]) 209 | 210 | def _train_dataloader(self): 211 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 212 | if is_iterable_dataset or self.use_worker_init_fn: 213 | init_fn = worker_init_fn 214 | else: 215 | init_fn = None 216 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, 217 | num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, 218 | worker_init_fn=init_fn) 219 | 220 | def _val_dataloader(self, shuffle=False): 221 | if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 222 | init_fn = worker_init_fn 223 | else: 224 | init_fn = None 225 | return DataLoader(self.datasets["validation"], 226 | batch_size=self.batch_size, 227 | num_workers=self.num_workers, 228 | worker_init_fn=init_fn, 229 | shuffle=shuffle) 230 | 231 | def _test_dataloader(self, shuffle=False): 232 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 233 | if is_iterable_dataset or self.use_worker_init_fn: 234 | init_fn = worker_init_fn 235 | else: 236 | init_fn = None 237 | 238 | # do not shuffle dataloader for iterable dataset 239 | shuffle = shuffle and (not is_iterable_dataset) 240 | 241 | return DataLoader(self.datasets["test"], batch_size=self.batch_size, 242 | num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) 243 | 244 | def _predict_dataloader(self, shuffle=False): 245 | if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 246 | init_fn = worker_init_fn 247 | else: 248 | init_fn = None 249 | return DataLoader(self.datasets["predict"], batch_size=self.batch_size, 250 | num_workers=self.num_workers, worker_init_fn=init_fn) 251 | 252 | 253 | class SetupCallback(Callback): 254 | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): 255 | super().__init__() 256 | self.resume = resume 257 | self.now = now 258 | self.logdir = logdir 259 | self.ckptdir = ckptdir 260 | self.cfgdir = cfgdir 261 | self.config = config 262 | self.lightning_config = lightning_config 263 | 264 | def on_keyboard_interrupt(self, trainer, pl_module): 265 | if trainer.global_rank == 0: 266 | print("Summoning checkpoint.") 267 | ckpt_path = os.path.join(self.ckptdir, "last.ckpt") 268 | trainer.save_checkpoint(ckpt_path) 269 | 270 | def on_pretrain_routine_start(self, trainer, pl_module): 271 | if trainer.global_rank == 0: 272 | # Create logdirs and save configs 273 | os.makedirs(self.logdir, exist_ok=True) 274 | os.makedirs(self.ckptdir, exist_ok=True) 275 | os.makedirs(self.cfgdir, exist_ok=True) 276 | 277 | if "callbacks" in self.lightning_config: 278 | if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: 279 | os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) 280 | print("Project config") 281 | print(OmegaConf.to_yaml(self.config)) 282 | OmegaConf.save(self.config, 283 | os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) 284 | 285 | print("Lightning config") 286 | print(OmegaConf.to_yaml(self.lightning_config)) 287 | OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), 288 | os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) 289 | 290 | else: 291 | # ModelCheckpoint callback created log directory --- remove it 292 | if not self.resume and os.path.exists(self.logdir): 293 | dst, name = os.path.split(self.logdir) 294 | dst = os.path.join(dst, "child_runs", name) 295 | os.makedirs(os.path.split(dst)[0], exist_ok=True) 296 | try: 297 | os.rename(self.logdir, dst) 298 | except FileNotFoundError: 299 | pass 300 | 301 | 302 | class ImageLogger(Callback): 303 | def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True, 304 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 305 | log_images_kwargs=None): 306 | super().__init__() 307 | self.rescale = rescale 308 | self.batch_freq = batch_frequency 309 | self.max_images = max_images 310 | self.logger_log_images = { 311 | pl.loggers.TestTubeLogger: self._testtube, 312 | } 313 | self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] 314 | if not increase_log_steps: 315 | self.log_steps = [self.batch_freq] 316 | self.clamp = clamp 317 | self.disabled = disabled 318 | self.log_on_batch_idx = log_on_batch_idx 319 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 320 | self.log_first_step = log_first_step 321 | 322 | @rank_zero_only 323 | def _testtube(self, pl_module, images, batch_idx, split): 324 | for k in images: 325 | grid = torchvision.utils.make_grid(images[k]) 326 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 327 | 328 | tag = f"{split}/{k}" 329 | pl_module.logger.experiment.add_image( 330 | tag, grid, 331 | global_step=pl_module.global_step) 332 | 333 | @rank_zero_only 334 | def log_local(self, save_dir, split, images, 335 | global_step, current_epoch, batch_idx): 336 | root = os.path.join(save_dir, "images", split) 337 | for k in images: 338 | grid = torchvision.utils.make_grid(images[k], nrow=4) 339 | if self.rescale: 340 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 341 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 342 | grid = grid.numpy() 343 | grid = (grid * 255).astype(np.uint8) 344 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( 345 | k, 346 | global_step, 347 | current_epoch, 348 | batch_idx) 349 | path = os.path.join(root, filename) 350 | os.makedirs(os.path.split(path)[0], exist_ok=True) 351 | Image.fromarray(grid).save(path) 352 | 353 | def log_img(self, pl_module, batch, batch_idx, split="train"): 354 | check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step 355 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 356 | hasattr(pl_module, "log_images") and 357 | callable(pl_module.log_images) and 358 | self.max_images > 0): 359 | logger = type(pl_module.logger) 360 | 361 | is_train = pl_module.training 362 | if is_train: 363 | pl_module.eval() 364 | 365 | with torch.no_grad(): 366 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 367 | 368 | for k in images: 369 | N = min(images[k].shape[0], self.max_images) 370 | images[k] = images[k][:N] 371 | if isinstance(images[k], torch.Tensor): 372 | images[k] = images[k].detach().cpu() 373 | if self.clamp: 374 | images[k] = torch.clamp(images[k], -1., 1.) 375 | 376 | self.log_local(pl_module.logger.save_dir, split, images, 377 | pl_module.global_step, pl_module.current_epoch, batch_idx) 378 | 379 | logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) 380 | logger_log_images(pl_module, images, pl_module.global_step, split) 381 | 382 | if is_train: 383 | pl_module.train() 384 | 385 | def check_frequency(self, check_idx): 386 | if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( 387 | check_idx > 0 or self.log_first_step): 388 | try: 389 | self.log_steps.pop(0) 390 | except IndexError as e: 391 | print(e) 392 | pass 393 | return True 394 | return False 395 | 396 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 397 | if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): 398 | self.log_img(pl_module, batch, batch_idx, split="train") 399 | 400 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 401 | if not self.disabled and pl_module.global_step > 0: 402 | self.log_img(pl_module, batch, batch_idx, split="val") 403 | if hasattr(pl_module, 'calibrate_grad_norm'): 404 | if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: 405 | self.log_gradients(trainer, pl_module, batch_idx=batch_idx) 406 | 407 | 408 | class CUDACallback(Callback): 409 | # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py 410 | def on_train_epoch_start(self, trainer, pl_module): 411 | # Reset the memory use counter 412 | torch.cuda.reset_peak_memory_stats(trainer.root_gpu) 413 | torch.cuda.synchronize(trainer.root_gpu) 414 | self.start_time = time.time() 415 | 416 | def on_train_epoch_end(self, trainer, pl_module, outputs): 417 | torch.cuda.synchronize(trainer.root_gpu) 418 | max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 419 | epoch_time = time.time() - self.start_time 420 | 421 | try: 422 | max_memory = trainer.training_type_plugin.reduce(max_memory) 423 | epoch_time = trainer.training_type_plugin.reduce(epoch_time) 424 | 425 | rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") 426 | rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") 427 | except AttributeError: 428 | pass 429 | 430 | 431 | if __name__ == "__main__": 432 | # custom parser to specify config files, train, test and debug mode, 433 | # postfix, resume. 434 | # `--key value` arguments are interpreted as arguments to the trainer. 435 | # `nested.key=value` arguments are interpreted as config parameters. 436 | # configs are merged from left-to-right followed by command line parameters. 437 | 438 | # model: 439 | # base_learning_rate: float 440 | # target: path to lightning module 441 | # params: 442 | # key: value 443 | # data: 444 | # target: main.DataModuleFromConfig 445 | # params: 446 | # batch_size: int 447 | # wrap: bool 448 | # train: 449 | # target: path to train dataset 450 | # params: 451 | # key: value 452 | # validation: 453 | # target: path to validation dataset 454 | # params: 455 | # key: value 456 | # test: 457 | # target: path to test dataset 458 | # params: 459 | # key: value 460 | # lightning: (optional, has sane defaults and can be specified on cmdline) 461 | # trainer: 462 | # additional arguments to trainer 463 | # logger: 464 | # logger to instantiate 465 | # modelcheckpoint: 466 | # modelcheckpoint to instantiate 467 | # callbacks: 468 | # callback1: 469 | # target: importpath 470 | # params: 471 | # key: value 472 | 473 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 474 | 475 | # add cwd for convenience and to make classes in this file available when 476 | # running as `python main.py` 477 | # (in particular `main.DataModuleFromConfig`) 478 | sys.path.append(os.getcwd()) 479 | 480 | parser = get_parser() 481 | parser = Trainer.add_argparse_args(parser) 482 | 483 | opt, unknown = parser.parse_known_args() 484 | if opt.name and opt.resume: 485 | raise ValueError( 486 | "-n/--name and -r/--resume cannot be specified both." 487 | "If you want to resume training in a new log folder, " 488 | "use -n/--name in combination with --resume_from_checkpoint" 489 | ) 490 | if opt.resume: 491 | if not os.path.exists(opt.resume): 492 | raise ValueError("Cannot find {}".format(opt.resume)) 493 | if os.path.isfile(opt.resume): 494 | paths = opt.resume.split("/") 495 | # idx = len(paths)-paths[::-1].index("logs")+1 496 | # logdir = "/".join(paths[:idx]) 497 | logdir = "/".join(paths[:-2]) 498 | ckpt = opt.resume 499 | else: 500 | assert os.path.isdir(opt.resume), opt.resume 501 | logdir = opt.resume.rstrip("/") 502 | ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") 503 | 504 | opt.resume_from_checkpoint = ckpt 505 | base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) 506 | opt.base = base_configs + opt.base 507 | _tmp = logdir.split("/") 508 | nowname = _tmp[-1] 509 | else: 510 | if opt.name: 511 | name = "_" + opt.name 512 | elif opt.base: 513 | cfg_fname = os.path.split(opt.base[0])[-1] 514 | cfg_name = os.path.splitext(cfg_fname)[0] 515 | name = "_" + cfg_name 516 | else: 517 | name = "" 518 | nowname = now + name + opt.postfix 519 | logdir = os.path.join(opt.logdir, nowname) 520 | 521 | ckptdir = os.path.join(logdir, "checkpoints") 522 | cfgdir = os.path.join(logdir, "configs") 523 | seed_everything(opt.seed) 524 | 525 | try: 526 | # init and save configs 527 | configs = [OmegaConf.load(cfg) for cfg in opt.base] 528 | cli = OmegaConf.from_dotlist(unknown) 529 | config = OmegaConf.merge(*configs, cli) 530 | lightning_config = config.pop("lightning", OmegaConf.create()) 531 | # merge trainer cli with config 532 | trainer_config = lightning_config.get("trainer", OmegaConf.create()) 533 | # default to ddp 534 | trainer_config["accelerator"] = "ddp" 535 | for k in nondefault_trainer_args(opt): 536 | trainer_config[k] = getattr(opt, k) 537 | if not "gpus" in trainer_config: 538 | del trainer_config["accelerator"] 539 | cpu = True 540 | else: 541 | gpuinfo = trainer_config["gpus"] 542 | print(f"Running on GPUs {gpuinfo}") 543 | cpu = False 544 | trainer_opt = argparse.Namespace(**trainer_config) 545 | lightning_config.trainer = trainer_config 546 | 547 | # model 548 | model = instantiate_from_config(config.model) 549 | 550 | # trainer and callbacks 551 | trainer_kwargs = dict() 552 | 553 | # default logger configs 554 | # NOTE wandb < 0.10.0 interferes with shutdown 555 | # wandb >= 0.10.0 seems to fix it but still interferes with pudb 556 | # debugging (wrongly sized pudb ui) 557 | # thus prefer testtube for now 558 | default_logger_cfgs = { 559 | "wandb": { 560 | "target": "pytorch_lightning.loggers.WandbLogger", 561 | "params": { 562 | "name": nowname, 563 | "save_dir": logdir, 564 | "offline": opt.debug, 565 | "id": nowname, 566 | } 567 | }, 568 | "testtube": { 569 | "target": "pytorch_lightning.loggers.TestTubeLogger", 570 | "params": { 571 | "name": "testtube", 572 | "save_dir": logdir, 573 | } 574 | }, 575 | } 576 | default_logger_cfg = default_logger_cfgs["testtube"] 577 | if 'logger' in lightning_config: 578 | logger_cfg = lightning_config.logger 579 | else: 580 | logger_cfg = OmegaConf.create() 581 | logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) 582 | trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) 583 | 584 | # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to 585 | # specify which metric is used to determine best models 586 | default_modelckpt_cfg = { 587 | "target": "pytorch_lightning.callbacks.ModelCheckpoint", 588 | "params": { 589 | "dirpath": ckptdir, 590 | "filename": "{epoch:06}", 591 | "verbose": True, 592 | "save_last": True, 593 | } 594 | } 595 | if hasattr(model, "monitor"): 596 | print(f"Monitoring {model.monitor} as checkpoint metric.") 597 | default_modelckpt_cfg["params"]["monitor"] = model.monitor 598 | default_modelckpt_cfg["params"]["save_top_k"] = 3 599 | 600 | if 'modelcheckpoint' in lightning_config: 601 | modelckpt_cfg = lightning_config.modelcheckpoint 602 | else: 603 | modelckpt_cfg = OmegaConf.create() 604 | modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) 605 | print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") 606 | if version.parse(pl.__version__) < version.parse('1.4.0'): 607 | trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) 608 | 609 | # add callback which sets up log directory 610 | default_callbacks_cfg = { 611 | "setup_callback": { 612 | "target": "main.SetupCallback", 613 | "params": { 614 | "resume": opt.resume, 615 | "now": now, 616 | "logdir": logdir, 617 | "ckptdir": ckptdir, 618 | "cfgdir": cfgdir, 619 | "config": config, 620 | "lightning_config": lightning_config, 621 | } 622 | }, 623 | "image_logger": { 624 | "target": "main.ImageLogger", 625 | "params": { 626 | "batch_frequency": 750, 627 | "max_images": 4, 628 | "clamp": True 629 | } 630 | }, 631 | "learning_rate_logger": { 632 | "target": "main.LearningRateMonitor", 633 | "params": { 634 | "logging_interval": "step", 635 | # "log_momentum": True 636 | } 637 | }, 638 | "cuda_callback": { 639 | "target": "main.CUDACallback" 640 | }, 641 | } 642 | if version.parse(pl.__version__) >= version.parse('1.4.0'): 643 | default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) 644 | 645 | if 'callbacks' in lightning_config: 646 | callbacks_cfg = lightning_config.callbacks 647 | else: 648 | callbacks_cfg = OmegaConf.create() 649 | 650 | if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: 651 | print( 652 | 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') 653 | default_metrics_over_trainsteps_ckpt_dict = { 654 | 'metrics_over_trainsteps_checkpoint': 655 | {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', 656 | 'params': { 657 | "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), 658 | "filename": "{epoch:06}-{step:09}", 659 | "verbose": True, 660 | 'save_top_k': -1, 661 | 'every_n_train_steps': 10000, 662 | 'save_weights_only': True 663 | } 664 | } 665 | } 666 | default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) 667 | 668 | callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) 669 | if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): 670 | callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint 671 | elif 'ignore_keys_callback' in callbacks_cfg: 672 | del callbacks_cfg['ignore_keys_callback'] 673 | 674 | trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] 675 | 676 | trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) 677 | trainer.logdir = logdir ### 678 | 679 | # data 680 | data = instantiate_from_config(config.data) 681 | # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html 682 | # calling these ourselves should not be necessary but it is. 683 | # lightning still takes care of proper multiprocessing though 684 | data.prepare_data() 685 | data.setup() 686 | print("#### Data #####") 687 | for k in data.datasets: 688 | print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") 689 | 690 | # configure learning rate 691 | bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate 692 | if not cpu: 693 | ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) 694 | else: 695 | ngpu = 1 696 | accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1 697 | print(f"accumulate_grad_batches = {accumulate_grad_batches}") 698 | lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches 699 | if opt.scale_lr: 700 | model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr 701 | print( 702 | "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( 703 | model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) 704 | else: 705 | model.learning_rate = base_lr 706 | print("++++ NOT USING LR SCALING ++++") 707 | print(f"Setting learning rate to {model.learning_rate:.2e}") 708 | 709 | 710 | # allow checkpointing via USR1 711 | def melk(*args, **kwargs): 712 | # run all checkpoint hooks 713 | if trainer.global_rank == 0: 714 | print("Summoning checkpoint.") 715 | ckpt_path = os.path.join(ckptdir, "last.ckpt") 716 | trainer.save_checkpoint(ckpt_path) 717 | 718 | 719 | def divein(*args, **kwargs): 720 | if trainer.global_rank == 0: 721 | import pudb; 722 | pudb.set_trace() 723 | 724 | 725 | import signal 726 | 727 | signal.signal(signal.SIGUSR1, melk) 728 | signal.signal(signal.SIGUSR2, divein) 729 | 730 | # run 731 | if opt.train: 732 | try: 733 | trainer.fit(model, data) 734 | except Exception: 735 | melk() 736 | raise 737 | if not opt.no_test and not trainer.interrupted: 738 | trainer.test(model, data) 739 | except Exception: 740 | if opt.debug and trainer.global_rank == 0: 741 | try: 742 | import pudb as debugger 743 | except ImportError: 744 | import pdb as debugger 745 | debugger.post_mortem() 746 | raise 747 | finally: 748 | # move newly created debug project to debug_runs 749 | if opt.debug and not opt.resume and trainer.global_rank == 0: 750 | dst, name = os.path.split(logdir) 751 | dst = os.path.join(dst, "debug_runs", name) 752 | os.makedirs(os.path.split(dst)[0], exist_ok=True) 753 | os.rename(logdir, dst) 754 | if trainer.global_rank == 0: 755 | print(trainer.profiler.summary()) 756 | -------------------------------------------------------------------------------- /scripts/inpaint_imagebart.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import numpy as np 4 | from omegaconf import OmegaConf 5 | import streamlit as st 6 | from streamlit import caching 7 | from PIL import Image 8 | from main import instantiate_from_config 9 | from torch.utils.data.dataloader import default_collate 10 | from torchvision.utils import make_grid 11 | from tqdm import trange 12 | from einops import repeat 13 | from contextlib import contextmanager 14 | 15 | 16 | from scripts.sample_imagebart import get_top_k_schedule, get_temperature_schedule 17 | 18 | 19 | @torch.no_grad() 20 | def sample_unconditional(models, batch_size, chain_schedule, temperature_schedule=None, top_k_schedule=None, 21 | dim_z=256, h=16, w=16, 22 | example=None, mask=None): 23 | ########### opts 24 | use_ema = True 25 | device = torch.device("cuda") 26 | model_index = len(models) - 1 27 | steps = len(models) * [h * w] 28 | ################# 29 | 30 | # mask and input 31 | mask = mask.to(device=device) 32 | unmasked_input = torch.tensor(example["image"]).permute(0, 3, 1, 2).to(device=device) 33 | masked_inputs = mask * unmasked_input 34 | 35 | ##### clean for guidance 36 | with on_gpu(models[0]): 37 | pre_quant = models[0].get_h(unmasked_input) 38 | _, _, x_clean, _, _, _, _ = models[0].get_scales(pre_quant=pre_quant) 39 | 40 | model = models[model_index] 41 | c = torch.randint(0, 42 | model.first_stage_model.quantize.re_embed, 43 | x_clean.shape, 44 | device=x_clean.device) 45 | 46 | mask = torch.nn.functional.interpolate(mask, size=(16, 16), 47 | mode="nearest") 48 | mask = mask.reshape(c.shape).to(device=c.device) 49 | orig_dtype = c.dtype 50 | c = (1 - mask) * c + mask * x_clean 51 | c = c.to(dtype=orig_dtype) 52 | 53 | guide = torch.nn.functional.one_hot(c, num_classes=model.first_stage_model.quantize.re_embed).to( 54 | dtype=torch.float32) 55 | guide = torch.log(guide) 56 | guide[mask < 0.5] = 0 57 | 58 | # start sampling 59 | c_scale_indices = c 60 | scale = model_index 61 | current_scale = (scale * torch.ones(batch_size, 1)).to(device).long() 62 | steppys = st.empty() 63 | cb = lambda x: steppys.write(f"{x}/{h * w}") 64 | scaleinfo = st.empty() 65 | n_scales = len(models) 66 | for scale_n, model in enumerate(models[:model_index + 1][::-1]): 67 | temperature = temperature_schedule[scale] 68 | top_k = top_k_schedule[scale] 69 | n_chains = chain_schedule[scale] 70 | with on_gpu(model): 71 | with ema_scope(model, active=use_ema): 72 | for chain_idx in range(n_chains): 73 | scaleinfo.write( 74 | f"sampling chain {chain_idx + 1}/{n_chains} for scale {n_scales - scale_n}/{n_scales}, " 75 | f"temp = {temperature:.2f}, top-k = {top_k}") 76 | 77 | chain_weight = 1 - chain_idx / n_chains 78 | if chain_idx > 0: 79 | # already reversed, run forward again 80 | origdtype = c_scale_indices.dtype 81 | randindices = torch.randint( 82 | 0, 83 | model.first_stage_model.quantize.re_embed, 84 | c_scale_indices.shape, 85 | device=c_scale_indices.device) 86 | redraw_prob = chain_weight * model.temperature_range[model.single_scale] 87 | redraw = torch.bernoulli( 88 | redraw_prob * torch.ones(c_scale_indices.shape)).to( 89 | device=c_scale_indices.device) 90 | c_scale_indices = (1 - redraw) * c_scale_indices + redraw * randindices 91 | c_scale_indices = c_scale_indices.to(dtype=origdtype) 92 | c_scale_indices = model.sample_single_scale(c_scale_indices, 93 | current_scale + 1, 94 | temp_x=None, 95 | steps=steps[scale], 96 | temperature=temperature, 97 | top_k=top_k, 98 | guide=guide, 99 | callback=cb 100 | ) 101 | scale -= 1 102 | current_scale = (scale * torch.ones(batch_size, 1)).to(device).long() 103 | 104 | qzshape = [batch_size, dim_z, h, w] 105 | with on_gpu(model): 106 | sample = model.decode_to_img(c_scale_indices, qzshape) 107 | 108 | log = dict() 109 | log["samples"] = sample 110 | log["inputs"] = unmasked_input 111 | log["masked_inputs"] = masked_inputs 112 | return log 113 | 114 | 115 | def generate_mask(masking_option, shape): 116 | bs, h, w, c = shape 117 | mask = np.array(Image.new('L', (h, w))).astype(np.bool) 118 | 119 | if masking_option == 'upper-half completion': 120 | mask[h // 2:] = np.logical_not(mask[h // 2:]) 121 | 122 | elif masking_option == 'window-inpainting': 123 | window_size = st.number_input(f'Select size of quadratic window for {masking_option} ' 124 | f'(note: divided by 16 in latent space)', min_value=10, 125 | max_value=h // 2, value=h // 4) 126 | mask = np.logical_not(mask) 127 | mask[ 128 | (h - window_size) // 2:(h + window_size) // 2, 129 | (w - window_size) // 2:(w + window_size) // 2] = np.logical_not( 130 | mask[(h - window_size) // 2:(h + window_size) // 2, 131 | (w - window_size) // 2:(w + window_size) // 2]) 132 | 133 | else: 134 | window_size = st.number_input(f'Select size of quadratic window for {masking_option} ' 135 | f'(note: divided by 16 in latent space)', min_value=h // 2, 136 | max_value=h - 20, value=h // 2) 137 | mask[ 138 | (h - window_size) // 2:(h + window_size) // 2, 139 | (w - window_size) // 2:(w + window_size) // 2] = np.logical_not( 140 | mask[(h - window_size) // 2:(h + window_size) // 2, 141 | (w - window_size) // 2:(w + window_size) // 2]) 142 | st.warning('With outpainting enabled, you might have to increase the length of the chains') 143 | 144 | display_mask = mask 145 | # only for displaying reasons incase of inpainting and upper half completion 146 | for p in [0, h - 1]: 147 | display_mask[p] = False 148 | display_mask[:, p] = False 149 | st.image((255 * mask.astype(np.uint8)), f'Selected mask for {masking_option}') 150 | 151 | mask = torch.from_numpy(mask.astype(np.float32)).float() 152 | mask = repeat(mask, 'h w -> b 1 h w', b=batch_size) 153 | return mask 154 | 155 | 156 | @torch.no_grad() 157 | def run(models, dset, batch_size, temperature, top_k, chain_schedule, num_runs): 158 | img_spatial = models[0].first_stage_model.encoder.resolution 159 | img_shape = [batch_size, img_spatial, img_spatial, 3] 160 | 161 | masking_option = st.selectbox('Select masking option', 162 | ['upper-half completion', 'window-inpainting', 'window-outpainting'], 163 | index=0) 164 | 165 | mask = generate_mask(masking_option, img_shape) 166 | 167 | if st.button('Sample with chain'): 168 | 169 | for n in trange(num_runs, desc="Data"): 170 | indices = np.random.choice(len(dset), batch_size, replace=False) 171 | example = default_collate([dset[i] for i in indices]) 172 | logs = sample_unconditional(models, batch_size=batch_size, 173 | temperature_schedule=temperature, top_k_schedule=top_k, 174 | example=example, mask=mask, chain_schedule=chain_schedule) 175 | 176 | log_to_st(logs, n) 177 | 178 | 179 | def log_to_st(log, n): 180 | keys = ['inputs', 'masked_inputs', 'samples'] 181 | bs = log[keys[0]].shape[0] 182 | 183 | flatgrid = torch.cat([torch.clamp(log[k].detach().cpu(), -1., 1.) for k in keys], dim=0) 184 | grid = make_grid(flatgrid, nrow=bs, normalize=True).permute(1, 2, 0).numpy() 185 | 186 | st.image(grid, f'Masked samples #{n + 1} (top: original, mid: masked input, bottom: sample)') 187 | 188 | 189 | @contextmanager 190 | def ema_scope(model, active=False, context=None): 191 | if active: 192 | model.transformer_ema.store(model.transformer.parameters()) 193 | model.transformer_ema.copy_to(model.transformer) 194 | if context is not None: 195 | print(f"{context}: Switched to EMA weights") 196 | try: 197 | yield None 198 | finally: 199 | if active: 200 | model.transformer_ema.restore(model.transformer.parameters()) 201 | if context is not None: 202 | print(f"{context}: Restored training weights") 203 | 204 | 205 | @contextmanager 206 | def on_gpu(model, context=None): 207 | model = model.cuda() 208 | if context is not None: 209 | print(f"{context}: Moved model to GPU") 210 | try: 211 | yield None 212 | finally: 213 | model = model.cpu() 214 | torch.cuda.empty_cache() 215 | if context is not None: 216 | print(f"{context}: Moved model to CPU") 217 | 218 | 219 | def load_model_from_config(config, sd, gpu=True, eval_mode=True): 220 | print("config:") 221 | print(OmegaConf.to_yaml(config)) 222 | model = instantiate_from_config(config["model"]) 223 | if sd is not None: 224 | m, u = model.load_state_dict(sd, strict=False) 225 | if gpu: 226 | model.cuda() 227 | if eval_mode: 228 | model.eval() 229 | return {"model": model} 230 | 231 | 232 | @st.cache(allow_output_mutation=True) 233 | def get_data(config): 234 | # get data 235 | try: 236 | if config.data.params.train.target == "braket.data.faceshq.FFHQTrain": 237 | config.data.params.train.params.random_flip = False 238 | print("Disabled random flip for FFHQ train") 239 | except Exception: 240 | pass 241 | data = instantiate_from_config(config.data) 242 | data.prepare_data() 243 | data.setup() 244 | return data 245 | 246 | 247 | def get_config(path): 248 | config = OmegaConf.load(path) 249 | return config 250 | 251 | 252 | @st.cache(allow_output_mutation=True) 253 | def load_models(paths, gpu=False, eval_mode=True): 254 | assert not gpu, 'moving them later' 255 | models = list() 256 | configs = list() 257 | global_steps = list() 258 | for ckpt_path, config_path in zip(paths["checkpoints"], paths["configs"]): 259 | print(f"loading config from {config_path} and model from {ckpt_path}") 260 | config = get_config(config_path) 261 | pl_sd = torch.load(ckpt_path, map_location="cpu") 262 | global_step = pl_sd["global_step"] 263 | model = load_model_from_config(config, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] 264 | models.append(model) 265 | configs.append(config) 266 | global_steps.append(global_step) 267 | print(f"loaded model from global step {global_step}") 268 | if "repeat" in paths: 269 | n_models = list() 270 | n_configs = list() 271 | n_global_steps = list() 272 | for i, n in enumerate(paths["repeat"]): 273 | print(f"Repeating model {i} {n}x.") 274 | n_models += n * [models[i]] 275 | n_configs += n * [configs[i]] 276 | n_global_steps += n * [global_steps[i]] 277 | models = n_models 278 | configs = n_configs 279 | global_steps = n_global_steps 280 | return models, configs, global_steps 281 | 282 | 283 | if __name__ == "__main__": 284 | sys.path.append(os.getcwd()) 285 | if not st._is_running_with_streamlit: 286 | print("Not running with streamlit. Redefining st functions...") 287 | st.info = print 288 | st.write = print 289 | 290 | yaml_path = sys.argv[1] 291 | paths = OmegaConf.load(yaml_path) 292 | paths = OmegaConf.to_container(paths) 293 | 294 | gpu = True 295 | eval_mode = True 296 | 297 | models, configs, global_steps = load_models(paths, gpu=False, eval_mode=eval_mode) 298 | if models[0].conditioner is not None: 299 | raise NotImplementedError('Currently only available for unconditional models.') 300 | device = torch.device("cuda") if gpu else torch.device("cpu") 301 | dsets = get_data(configs[0]) 302 | 303 | split = "validation" 304 | dset = dsets.datasets[split] 305 | print(f"Dataset size: {len(dset)}") 306 | 307 | codebook_size = models[0].first_stage_model.quantize.re_embed 308 | 309 | st.sidebar.write('Sampling options') 310 | n_runs = st.sidebar.number_input('Number of runs', 1, 100, 1) 311 | batch_size = st.sidebar.number_input('Batch size', 1, 20, 4) 312 | 313 | top_k = get_top_k_schedule(len(models), codebook_size=codebook_size) 314 | temperature = get_temperature_schedule(len(models)) 315 | 316 | chain_schedule = [] 317 | st.write('Define chain schedule') 318 | st.info( 319 | f'The n-th entry in the chain schedule defines the number of sucessive runs, ' 320 | f'the n-th AR submodel should perfom before passing the output to the next submodel.') 321 | for n in range(len(models)): 322 | if models[n].redraw_prob != 'geometric': 323 | if n == len(models) - 1: 324 | def_chain_len = 1 325 | else: 326 | def_chain_len = 5 327 | else: 328 | def_chain_len = 3 329 | chain_n = st.number_input(f"Chain length for scale #{n + 1}", min_value=1, max_value=100, value=def_chain_len) 330 | chain_schedule.append(chain_n) 331 | 332 | chain_schedule = chain_schedule 333 | 334 | run(models, dset, batch_size, temperature=temperature, top_k=top_k, chain_schedule=chain_schedule, num_runs=n_runs) 335 | -------------------------------------------------------------------------------- /scripts/sample_imagebart.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import hashlib 3 | import torch 4 | import numpy as np 5 | from omegaconf import OmegaConf 6 | import streamlit as st 7 | from PIL import Image 8 | from main import instantiate_from_config 9 | from torchvision.utils import make_grid 10 | from contextlib import contextmanager 11 | 12 | rescale = lambda x: (x + 1.) / 2. 13 | 14 | 15 | def bchw_to_st(x): 16 | return rescale(x.detach().cpu().numpy().transpose(0,2,3,1)) 17 | 18 | def chw_to_st(x): 19 | return rescale(x.detach().cpu().numpy().transpose(1,2,0)) 20 | 21 | def chw_to_pillow(x): 22 | return Image.fromarray(chw_to_np(x)) 23 | 24 | def chw_to_np(x): 25 | return (255 * rescale(x.detach().cpu().numpy().transpose(1, 2, 0))).clip(0, 255).astype(np.uint8) 26 | 27 | 28 | def computeMD5hash(string): 29 | m = hashlib.md5() 30 | m.update(string.encode('utf-8')) 31 | return m.hexdigest() 32 | 33 | 34 | class L1(torch.nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | 38 | def forward(self, x, y): 39 | return torch.abs(x-y).sum(dim=[1,2,3]).mean() 40 | 41 | 42 | 43 | @torch.no_grad() 44 | def custom_log_images(model, batch, temperature): 45 | log = dict() 46 | x = model.get_input(batch, model.image_key) 47 | x = x.to(model.device) 48 | # encode 49 | h = model.encoder(x) 50 | h = model.quant_conv(h) 51 | quant, _, _ = model.quantize(h, temp=temperature, rescale_logits=True) 52 | # decode 53 | x_rec = model.decode(quant) 54 | log["inputs"] = x 55 | log["reconstructions"] = x_rec 56 | return log 57 | 58 | 59 | def grid2img(grid): 60 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 61 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 62 | grid = grid.numpy() 63 | grid = (grid * 255).astype(np.uint8) 64 | print(f"grid.max/grid.min {grid.max()}/{grid.min()}") 65 | grid = Image.fromarray(grid) 66 | return grid 67 | 68 | 69 | def get_top_k_schedule(n_steps, codebook_size): 70 | tk_schedule = st.radio("top-k scheduling", ["constant", "linear", "user"], 2) 71 | if tk_schedule == "constant": 72 | tk = st.number_input("Constant Top-K Value", value=codebook_size) 73 | top_k_schedule = (np.ones(n_steps) * tk).astype(int) 74 | elif tk_schedule == "linear": 75 | tk_start = st.number_input("Start Top-K Value", value=codebook_size) 76 | tk_end = st.number_input("End Top-K Value", value=codebook_size//4) 77 | top_k_schedule = np.linspace(tk_start, tk_end, n_steps).astype(int) 78 | elif tk_schedule == "user": 79 | default = f"{codebook_size}," * n_steps 80 | tk_list = st.text_input(f"Top-K Values (comma separated, must be {n_steps}, counted in sampling order, i.e. from last to first scale)", f"{default[:-1]}") 81 | tks = tk_list.split(",") 82 | top_k_schedule = list() 83 | for tk in tks: 84 | top_k_schedule.append(int(tk)) 85 | assert len(top_k_schedule) == n_steps 86 | else: 87 | return None 88 | return top_k_schedule 89 | 90 | def get_temperature_schedule(n_steps): 91 | type = st.radio("temperature scheduling", ["constant", "linear", "user"], 2) 92 | if type == "constant": 93 | tk = st.number_input("Constant Temperature Value", value=1.0) 94 | schedule = (np.ones(n_steps) * tk) 95 | elif type == "linear": 96 | tk_start = st.number_input("Start Temperature Value", value=1.) 97 | tk_end = st.number_input("End Top-Temperature Value", value=0.1) 98 | schedule = np.linspace(tk_start, tk_end, n_steps) 99 | elif type == "user": 100 | default = f"{1.0}," * n_steps 101 | tk_list = st.text_input(f"Temperature Values (comma separated, must be {n_steps}, counted in sampling order, i.e. from last to first scale)", f"{default[:-1]}") 102 | tks = tk_list.split(",") 103 | schedule = list() 104 | for tk in tks: 105 | schedule.append(float(tk)) 106 | assert len(schedule) == n_steps 107 | else: 108 | return None 109 | return schedule 110 | 111 | 112 | @contextmanager 113 | def ema_scope(model, active=False, context=None): 114 | if active: 115 | model.transformer_ema.store(model.transformer.parameters()) 116 | model.transformer_ema.copy_to(model.transformer) 117 | if context is not None: 118 | print(f"{context}: Switched to EMA weights") 119 | try: 120 | yield None 121 | finally: 122 | if active: 123 | model.transformer_ema.restore(model.transformer.parameters()) 124 | if context is not None: 125 | print(f"{context}: Restored training weights") 126 | 127 | 128 | @contextmanager 129 | def on_gpu(model, context=None): 130 | model = model.cuda() 131 | if context is not None: 132 | print(f"{context}: Moved model to GPU") 133 | try: 134 | yield None 135 | finally: 136 | model = model.cpu() 137 | torch.cuda.empty_cache() 138 | if context is not None: 139 | print(f"{context}: Moved model to CPU") 140 | 141 | 142 | @torch.no_grad() 143 | def run(models, user_conditioning, batch_size, device=torch.device("cuda"),conditional=False): 144 | assert type(models) == list 145 | 146 | n_scales = len(models) 147 | codebook_size = models[0].first_stage_model.quantize.re_embed 148 | 149 | cond = None 150 | start_index = len(models) - 1 151 | model = models[start_index] 152 | 153 | 154 | n_downs= model.first_stage_model.encoder.num_resolutions - 1 155 | h = model.first_stage_model.encoder.resolution // (2**n_downs) 156 | w = model.first_stage_model.encoder.resolution // (2**n_downs) 157 | dim_z = model.first_stage_model.embed_dim 158 | 159 | 160 | index_shape = [batch_size, h * w] 161 | qzshape = [batch_size, dim_z, h, w] 162 | 163 | st.info(f'Latent shape is {qzshape}') 164 | if user_conditioning is not None: 165 | exmpl = {model.conditioner.key: user_conditioning} 166 | cond = model.get_conditioning(exmpl).to(device) 167 | st.sidebar.write(f"cond.shape: {cond.shape}") 168 | 169 | 170 | top_k_schedule = get_top_k_schedule(n_scales, codebook_size=codebook_size)[::-1] 171 | temperature_schedule = get_temperature_schedule(n_scales)[::-1] 172 | st.text(f"top-k schedule: {top_k_schedule}") 173 | st.text(f"temperature schedule: {temperature_schedule}") 174 | 175 | n_batches = st.number_input("number runs", value=1, min_value=1, max_value=1000) 176 | steps = n_scales * [h*w] 177 | 178 | 179 | if st.button("Sample", False): 180 | grid_ph = st.empty() 181 | final_samples = list() 182 | for n in range(n_batches): 183 | 184 | scaleinfo = st.empty() 185 | scale = start_index 186 | c_scale_indices = torch.randint(0, 187 | model.first_stage_model.quantize.re_embed, 188 | index_shape, 189 | device=model.device) 190 | current_scale = (scale * torch.ones(batch_size, 1)).to(c_scale_indices).to(device) 191 | 192 | steppys = st.empty() 193 | cb = lambda x: steppys.write(f"{x}/{h*w}") 194 | for model_count, model in enumerate(models[::-1]): 195 | with on_gpu(model, "Sampling"): 196 | temperature = temperature_schedule[scale] 197 | top_k = top_k_schedule[scale] 198 | scaleinfo.write(f"sampling scale {scale+1}/{n_scales}, temp = {temperature:.2f}, top-k = {top_k}") 199 | 200 | with ema_scope(model, active=True, context="Sampling"): 201 | 202 | assert (current_scale + 1)[0].item() == model.single_scale, \ 203 | f"{(current_scale + 1)[0].item()} =/= {model.single_scale} :(" 204 | c_scale_indices = model.sample_single_scale(c_scale_indices.to(device), 205 | (current_scale+1).to(device), 206 | temp_x=None, 207 | steps=steps[scale], 208 | temperature=temperature, 209 | top_k=top_k, 210 | cond=cond.to(model.device) if cond is not None else None, 211 | callback=cb 212 | ) 213 | 214 | if model_count == len(models) -1: 215 | final_samples.append(model.decode_to_img(c_scale_indices, 216 | [batch_size, qzshape[1], qzshape[2], qzshape[3]])) 217 | scale -= 1 218 | current_scale = (scale * torch.ones(batch_size, 1)).to(device).long() 219 | 220 | intermediate_grid = make_grid(final_samples[-1],nrow=batch_size,padding=0) 221 | st.image(chw_to_st(intermediate_grid),clamp=True,output_format='PNG') 222 | 223 | final_samples = torch.cat(final_samples, 0) 224 | grid = make_grid(final_samples, nrow=batch_size, padding=0) 225 | grid_ph.image(chw_to_st(grid), clamp=True, output_format="PNG") 226 | 227 | @torch.no_grad() 228 | def render_as_grid(scale_samples, batch_size, stack=True): 229 | # make a grid 230 | if stack: 231 | scale_samples = torch.stack(scale_samples, dim=0) 232 | assert batch_size == scale_samples.shape[1] 233 | grids = [] 234 | for i in range(scale_samples.shape[1]): 235 | grid = make_grid(scale_samples[:, i, ...], nrow=scale_samples.shape[0]) 236 | grids.append(grid) 237 | 238 | for k in range(len(grids)): 239 | st.image(chw_to_st(grids[k]), clamp=True, output_format="PNG", use_column_width=grids[k].shape[2] < 500) 240 | 241 | 242 | def load_model_from_config(config, sd, gpu=True, eval_mode=True): 243 | print("config:") 244 | print(OmegaConf.to_yaml(config)) 245 | model = instantiate_from_config(config["model"]) 246 | if sd is not None: 247 | m, u = model.load_state_dict(sd, strict=False) 248 | if gpu: 249 | model.cuda() 250 | if eval_mode: 251 | model.eval() 252 | return {"model": model} 253 | 254 | @st.cache(allow_output_mutation=True) 255 | def get_data(config): 256 | # get data 257 | data = instantiate_from_config(config.data) 258 | data.prepare_data() 259 | data.setup() 260 | return data 261 | 262 | 263 | def get_config(path): 264 | config = OmegaConf.load(path) 265 | return config 266 | 267 | 268 | @st.cache(allow_output_mutation=True) 269 | def load_models(paths, gpu=False, eval_mode=True): 270 | assert not gpu, 'moving them later' 271 | models = list() 272 | configs = list() 273 | global_steps = list() 274 | 275 | for ckpt_path, config_path in zip(paths["checkpoints"], paths["configs"]): 276 | print(f"loading config from {config_path} and model from {ckpt_path}") 277 | config = get_config(config_path) 278 | pl_sd = torch.load(ckpt_path, map_location="cpu") 279 | global_step = pl_sd["global_step"] 280 | model = load_model_from_config(config, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] 281 | models.append(model) 282 | configs.append(config) 283 | global_steps.append(global_step) 284 | 285 | print(f"loaded model from global step {global_step}") 286 | 287 | return models, configs, global_steps 288 | 289 | 290 | if __name__ == "__main__": 291 | 292 | yaml_path = sys.argv[1] 293 | log_path = yaml_path.split(os.sep)[-1][:-5] 294 | paths = OmegaConf.load(yaml_path) 295 | print(OmegaConf.to_yaml(paths)) 296 | paths = OmegaConf.to_container(paths) 297 | 298 | log_path =paths["metrics"]["savepath"] if 'metrics' in paths and 'savepath' in paths['metrics'] else os.path.join("logs", log_path) 299 | print(f"loading from .yaml at {yaml_path}") 300 | 301 | 302 | print(f'logging to {log_path}') 303 | 304 | st.sidebar.text("Options") 305 | gpu = st.sidebar.checkbox("GPU", value=True) 306 | 307 | models, configs, global_steps = load_models(paths, gpu=False) 308 | device = torch.device("cuda") if gpu else torch.device("cpu") 309 | 310 | # dsets = get_data(configs[0]) 311 | step_info = "" 312 | st.write(step_info[:-2]) 313 | 314 | batch_size = st.number_input("Batch size", min_value=1, value=4) 315 | conditional = models[0].conditioner is not None 316 | 317 | 318 | user_conditioning = None 319 | if conditional: 320 | st.info("Detected a conditional model.") 321 | user_inputs = [] 322 | conditioner_key = models[0].conditioner.key 323 | 324 | if conditioner_key == "caption": 325 | for n in range(batch_size): 326 | user_input = st.text_input(f"user caption {n}", value=f"Example caption {n}") 327 | user_inputs.append(user_input) 328 | 329 | #example["caption"] = [user_inputs] 330 | user_conditioning = user_inputs 331 | 332 | st.write(f"Selected text-prompts are {user_conditioning}") 333 | elif conditioner_key == "class_label": 334 | 335 | cfd = os.path.dirname(os.path.abspath(__file__)) 336 | integer2human = OmegaConf.load(os.path.join(cfd,'../data/imagenet_ids2labels.yaml')) 337 | 338 | format_fn = lambda x: integer2human[x] 339 | for n in range(batch_size): 340 | user_input = st.selectbox(f"user class label {n}", index=144, 341 | options=list(integer2human.keys()), 342 | format_func=format_fn) 343 | user_inputs.append(int(user_input)) 344 | 345 | user_conditioning = torch.tensor(user_inputs) 346 | 347 | human_labels = [integer2human[str(l)] for l in user_inputs] 348 | st.write(f"Selected class labels are {human_labels}") 349 | else: 350 | raise NotImplementedError(f"Model with conditoner key {conditioner_key} not yet implemented.") 351 | 352 | run(models, user_conditioning, batch_size, device=device,conditional=conditional) 353 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='imagebart', 5 | version='0.0.1', 6 | description='autoregressive image modification via multinomial diffusion', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) 14 | --------------------------------------------------------------------------------