├── README.md
├── assets
├── foxchain.png
├── image_editing.png
├── imagebart_poster.pdf
├── modelfigure.png
└── sample-001.jpg
├── configs
├── ffhq
│ ├── 2_scales
│ │ ├── ffhq-custom-scale1.yaml
│ │ └── ffhq-custom-scale2.yaml
│ ├── 4_scales
│ │ ├── ffhq-geometric-scale1.yaml
│ │ ├── ffhq-geometric-scale2.yaml
│ │ ├── ffhq-geometric-scale3.yaml
│ │ └── ffhq-geometric-scale4.yaml
│ └── ffhq_4-scales_joint-training.yaml
├── imagenet
│ ├── 4_scales
│ │ ├── imagenet_geometric_scale1.yaml
│ │ ├── imagenet_geometric_scale2.yaml
│ │ ├── imagenet_geometric_scale3.yaml
│ │ └── imagenet_geometric_scale4.yaml
│ └── 5_scales
│ │ ├── imagenet-custom-scale1.yaml
│ │ ├── imagenet-custom-scale2.yaml
│ │ ├── imagenet-custom-scale3.yaml
│ │ ├── imagenet-custom-scale4.yaml
│ │ └── imagenet-custom-scale5.yaml
├── lsun-beds
│ ├── lsun-beds-scale1.yaml
│ ├── lsun-beds-scale2.yaml
│ └── lsun-beds-scale3.yaml
├── lsun-cats
│ ├── lsun-cats-scale1.yaml
│ ├── lsun-cats-scale2.yaml
│ └── lsun-cats-scale3.yaml
├── lsun-churches
│ ├── lsun-churches-scale1.yaml
│ ├── lsun-churches-scale2.yaml
│ └── lsun-churches-scale3.yaml
└── sampling
│ ├── ffhq
│ ├── ffhq_2_scales_custom.yaml
│ └── ffhq_4_scales_geometric.yaml
│ ├── imagenet
│ ├── imagenet_4_scales_geometric.yaml
│ └── imagenet_5_scales_custom.yaml
│ └── lsun
│ ├── beds_3_scales.yaml
│ ├── cats_3_scales.yaml
│ └── churches_3_scales.yaml
├── data
├── DejaVuSans.ttf
├── ffhq_schedule_vs_metric.p
├── ffhqtrain.txt
├── ffhqvalidation.txt
├── imagenet_ids2labels.yaml
├── in_schedule_vs_metric.p
└── vqgan_indices
│ ├── bedroom_indices.npy
│ ├── cat_indices.npy
│ ├── church_indices.npy
│ ├── ffhq_indices.npy
│ └── imagenet_indices.npy
├── environment.yaml
├── imagebart
├── __init__.py
├── data
│ ├── __init__.py
│ ├── base.py
│ └── lsun.py
├── lr_scheduler.py
├── models
│ ├── __init__.py
│ ├── diffusion.py
│ └── vqgan.py
├── modules
│ ├── __init__.py
│ ├── betas.py
│ ├── ema.py
│ ├── transformer
│ │ ├── __init__.py
│ │ ├── mingpt.py
│ │ ├── vit.py
│ │ └── warper.py
│ └── xtransformers
│ │ ├── __init__.py
│ │ ├── autoregressive_wrapper.py
│ │ ├── positional_embeddings.py
│ │ └── x_transformer.py
└── util.py
├── main.py
├── scripts
├── inpaint_imagebart.py
└── sample_imagebart.py
└── setup.py
/README.md:
--------------------------------------------------------------------------------
1 | # ImageBART
2 | #### [NeurIPS 2021](https://nips.cc/)
3 |
4 | 
5 |
6 | [Patrick Esser](https://github.com/pesser)\*,
7 | [Robin Rombach](https://github.com/rromb)\*,
8 | [Andreas Blattmann](https://github.com/ablattmann)\*,
9 | [Björn Ommer](https://ommer-lab.com/)
10 | \* equal contribution
11 |
12 | [arXiv](https://arxiv.org/abs/2108.08827) | [BibTeX](#bibtex) | [Poster](assets/imagebart_poster.pdf)
13 |
14 | ## Requirements
15 | A suitable [conda](https://conda.io/) environment named `imagebart` can be created
16 | and activated with:
17 |
18 | ```
19 | conda env create -f environment.yaml
20 | conda activate imagebart
21 | ```
22 |
23 | ## Get the Models
24 |
25 | We provide pretrained weights and hyperparameters for models trained on the following datasets:
26 |
27 | * FFHQ:
28 | * [4 scales, geometric noise schedule](https://ommer-lab.com/files/ffhq_4_scales_geometric.zip): `wget -c https://ommer-lab.com/files/ffhq_4_scales_geometric.zip`
29 | * [2 scales, custom noise schedule](https://ommer-lab.com/files/ffhq_2_scales_custom.zip): `wget -c https://ommer-lab.com/files/ffhq_2_scales_custom.zip`
30 | * LSUN, 3 scales, custom noise schedules:
31 | * [Churches](https://ommer-lab.com/files/churches_3_scales.zip): `wget -c https://ommer-lab.com/files/churches_3_scales.zip`
32 | * [Bedrooms](https://ommer-lab.com/files/bedrooms_3_scales.zip): `wget -c https://ommer-lab.com/files/bedrooms_3_scales.zip`
33 | * [Cats](https://ommer-lab.com/files/cats_3_scales.zip): `wget -c https://ommer-lab.com/files/cats_3_scales.zip`
34 | * Class-conditional ImageNet:
35 | * [5 scales, custom noise schedule](https://ommer-lab.com/files/cin_5_scales_custom.zip): `wget -c https://ommer-lab.com/files/cin_5_scales_custom.zip`
36 | * [4 scales, geometric noise schedule](https://ommer-lab.com/files/cin_4_scales_geometric.zip): `wget -c https://ommer-lab.com/files/cin_4_scales_geometric.zip`
37 |
38 | Download the respective files and extract their contents to a directory `./models/`.
39 |
40 | Moreover, we provide all the required VQGANs as a .zip at [https://ommer-lab.com/files/vqgan.zip](https://ommer-lab.com/files/vqgan.zip),
41 | which contents have to be extracted to `./vqgan/`.
42 |
43 | ## Get the Data
44 | Running the training configs or the [inpainting script](scripts/inpaint_imagebart.py) requires
45 | a dataset available locally. For ImageNet and FFHQ, see this repo's parent directory [taming-transformers](https://github.com/CompVis/taming-transformers).
46 | The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun).
47 | We performed a custom split into training and validation images, and provide the corresponding filenames
48 | at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip).
49 | After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should
50 | also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively.
51 |
52 | ## Inference
53 |
54 | ### Unconditional Sampling
55 | We provide a script for sampling from unconditional models trained on the LSUN-{bedrooms,bedrooms,cats}- and FFHQ-datasets.
56 |
57 | #### FFHQ
58 |
59 | On the FFHQ dataset, we provide two distinct pretrained models, one with a chain of length 4 and a geometric noise schedule as proposed by Sohl-Dickstein et al. [[1]](##References) , and another one with a chain of length 2 and a custom schedule.
60 | These models can be started with
61 | ```shell script
62 | CUDA_VISIBLE_DEVICES= streamlit run scripts/sample_imagebart.py configs/sampling/ffhq/
63 | ```
64 |
65 | #### LSUN
66 | For the models trained on the LSUN-datasets, use
67 | ```shell script
68 | CUDA_VISIBLE_DEVICES= streamlit run scripts/sample_imagebart.py configs/sampling/lsun/
69 | ```
70 |
71 | ### Class Conditional Sampling on ImageNet
72 |
73 |
74 | To sample from class-conditional ImageNet models, use
75 | ```shell script
76 | CUDA_VISIBLE_DEVICES= streamlit run scripts/sample_imagebart.py configs/sampling/imagenet/
77 | ```
78 |
79 | ### Image Editing with Unconditional Models
80 |
81 | We also provide a script for image editing with our unconditional models. For our FFHQ-model with geometric schedule this can be started with
82 | ```shell script
83 | CUDA_VISIBLE_DEVICES= streamlit run scripts/inpaint_imagebart.py configs/sampling/ffhq/ffhq_4scales_geometric.yaml
84 | ```
85 | resulting in samples similar to the following.
86 | 
87 |
88 |
89 | ## Training
90 | In general, there are two options for training the autoregressive transition probabilities of the
91 | reverse Markov chain: (i) train them jointly, taking into account a weighting of the
92 | individual scale contributions, or (ii) train them independently, which means that each
93 | training process optimizes a single transition and the scales must be stacked after training.
94 | We conduct most of our experiments using the latter option, but provide configurations for both cases.
95 |
96 | ### Training Scales Independently
97 | For training scales independently, each transition requires a seperate optimization process, which can
98 | started via
99 |
100 | ```
101 | CUDA_VISIBLE_DEVICES= python main.py --base configs//.yaml -t --gpus 0,
102 | ```
103 |
104 | We provide training configs for a four scale training of FFHQ using a geometric schedule,
105 | a four scale geometric training on ImageNet and various three-scale experiments on LSUN.
106 | See also the overview of our [pretrained models](#get-the-models).
107 |
108 |
109 | ### Training Scales Jointly
110 |
111 | For completeness, we also provide a config to run a joint training with 4 scales on FFHQ.
112 | Training can be started by running
113 |
114 | ```
115 | CUDA_VISIBLE_DEVICES= python main.py --base configs/ffhq/ffhq_4_scales_joint-training.yaml -t --gpus 0,
116 | ```
117 |
118 |
119 | ## Shout-Outs
120 | Many thanks to all who make their work and implementations publicly available.
121 | For this work, these were in particular:
122 |
123 | - The extremely clear and extensible encoder-decoder transformer implementations by [lucidrains](https://github.com/lucidrains):
124 | https://github.com/lucidrains/x-transformers
125 | - Emiel Hoogeboom et al's paper on multinomial diffusion and argmax flows: https://arxiv.org/abs/2102.05379
126 |
127 |
128 | 
129 |
130 | ## References
131 |
132 | [1] Sohl-Dickstein, J., Weiss, E., Maheswaranathan, N. & Ganguli, S.. (2015). Deep Unsupervised Learning using Nonequilibrium Thermodynamics. Proceedings of the 32nd International Conference on Machine Learning
133 |
134 | ## Bibtex
135 |
136 | ```
137 | @article{DBLP:journals/corr/abs-2108-08827,
138 | author = {Patrick Esser and
139 | Robin Rombach and
140 | Andreas Blattmann and
141 | Bj{\"{o}}rn Ommer},
142 | title = {ImageBART: Bidirectional Context with Multinomial Diffusion for Autoregressive
143 | Image Synthesis},
144 | journal = {CoRR},
145 | volume = {abs/2108.08827},
146 | year = {2021}
147 | }
148 | ```
--------------------------------------------------------------------------------
/assets/foxchain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/assets/foxchain.png
--------------------------------------------------------------------------------
/assets/image_editing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/assets/image_editing.png
--------------------------------------------------------------------------------
/assets/imagebart_poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/assets/imagebart_poster.pdf
--------------------------------------------------------------------------------
/assets/modelfigure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/assets/modelfigure.png
--------------------------------------------------------------------------------
/assets/sample-001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/assets/sample-001.jpg
--------------------------------------------------------------------------------
/configs/ffhq/2_scales/ffhq-custom-scale1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 2
8 | single_scale: 1
9 | top_k: 548
10 | alpha: 0.0
11 | redraw_prob: ffhq_bernoulli_PSIM
12 | use_ema: true
13 |
14 | scheduler_config:
15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
16 | params:
17 | verbosity_interval: 0
18 | warm_up_steps: 10000
19 | max_decay_steps: 1500001
20 | lr_start: 2.5e-06
21 | lr_max: 0.0001
22 | lr_min: 1.0e-08
23 | transformer_config:
24 | target: imagebart.modules.xtransformers.x_transformer.XTransformer
25 | params:
26 | wrap_decoder: false
27 | dim: 1152
28 | enc_num_tokens: 548
29 | enc_depth: 32
30 | enc_heads: 16
31 | enc_max_seq_len: 257
32 | dec_num_tokens: 548
33 | dec_depth: 6
34 | dec_heads: 16
35 | tie_token_emb: false
36 | dec_max_seq_len: 256
37 | first_stage_config:
38 | target: imagebart.models.vqgan.VQGANWrapper
39 | params:
40 | ckpt_path: vqgan/vqgan-ffhq.ckpt
41 | remap: data/vqgan_indices/ffhq_indices.npy
42 | sane_index_shape: true
43 | embed_dim: 256
44 | n_embed: 1024
45 | ddconfig:
46 | double_z: false
47 | z_channels: 256
48 | resolution: 256
49 | in_channels: 3
50 | out_ch: 3
51 | ch: 128
52 | ch_mult:
53 | - 1
54 | - 1
55 | - 2
56 | - 2
57 | - 4
58 | num_res_blocks: 2
59 | attn_resolutions:
60 | - 16
61 | dropout: 0.0
62 | lossconfig:
63 | target: taming.modules.losses.vqperceptual.DummyLoss
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 18
69 | num_workers: 32
70 | wrap: false
71 | train:
72 | target: taming.data.faceshq.FFHQTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: taming.data.faceshq.FFHQValidation
77 | params:
78 | size: 256
79 |
--------------------------------------------------------------------------------
/configs/ffhq/2_scales/ffhq-custom-scale2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 2
8 | single_scale: 2
9 | top_k: 548
10 | alpha: 1.0
11 | redraw_prob: ffhq_bernoulli_PSIM
12 | use_ema: true
13 | scheduler_config:
14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
15 | params:
16 | verbosity_interval: 0
17 | warm_up_steps: 10000
18 | max_decay_steps: 1500001
19 | lr_start: 2.5e-06
20 | lr_max: 0.0001
21 | lr_min: 1.0e-08
22 | transformer_config:
23 | target: imagebart.modules.transformer.mingpt.GPT
24 | params:
25 | vocab_size: 548
26 | block_size: 256
27 | n_layer: 36
28 | n_head: 16
29 | n_embd: 1216
30 | first_stage_config:
31 | target: imagebart.models.vqgan.VQGANWrapper
32 | params:
33 | ckpt_path: vqgan/vqgan-ffhq.ckpt
34 | remap: data/vqgan_indices/ffhq_indices.npy
35 | sane_index_shape: true
36 | embed_dim: 256
37 | n_embed: 1024
38 | ddconfig:
39 | double_z: false
40 | z_channels: 256
41 | resolution: 256
42 | in_channels: 3
43 | out_ch: 3
44 | ch: 128
45 | ch_mult:
46 | - 1
47 | - 1
48 | - 2
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions:
53 | - 16
54 | dropout: 0.0
55 | lossconfig:
56 | target: taming.modules.losses.vqperceptual.DummyLoss
57 |
58 | data:
59 | target: main.DataModuleFromConfig
60 | params:
61 | batch_size: 18
62 | wrap: false
63 | train:
64 | target: taming.data.faceshq.FFHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: taming.data.faceshq.FFHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/configs/ffhq/4_scales/ffhq-geometric-scale1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: "image"
6 | monitor: "val/loss"
7 | n_scales: 4
8 | single_scale: 1
9 | top_k: 548
10 | alpha: 0.0
11 | redraw_prob: geometric
12 | use_ema: True
13 |
14 | scheduler_config:
15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
16 | params:
17 | verbosity_interval: 0 # 0 or negative to disable
18 | warm_up_steps: 10000
19 | max_decay_steps: 1500001
20 | lr_start: 2.5e-6
21 | lr_max: 1.0e-4
22 | lr_min: 1.0e-8
23 |
24 | transformer_config:
25 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer
26 | params:
27 | scale_pos: 0
28 | n_scales: 4
29 | xt_start: 1
30 | xt_size: 256 # predict x_{t-1}
31 | wrap_decoder: False
32 | dim: 752
33 | enc_num_tokens: 548
34 | enc_depth: 18
35 | enc_heads: 16
36 | enc_max_seq_len: 257
37 | dec_num_tokens: 548
38 | dec_depth: 6
39 | dec_heads: 16
40 | tie_token_emb: False
41 | dec_max_seq_len: 256
42 |
43 | first_stage_config:
44 | target: imagebart.models.vqgan.VQGANWrapper
45 | params:
46 | ckpt_path: vqgan/vqgan-ffhq.ckpt
47 | remap: "data/vqgan_indices/ffhq_indices.npy"
48 | sane_index_shape: True
49 | embed_dim: 256
50 | n_embed: 1024
51 | ddconfig:
52 | double_z: false
53 | z_channels: 256
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
59 | num_res_blocks: 2
60 | attn_resolutions: [ 16 ]
61 | dropout: 0.0
62 | lossconfig:
63 | target: taming.modules.losses.vqperceptual.DummyLoss
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 16
69 | num_workers: 32
70 | wrap: False
71 | train:
72 | target: taming.data.faceshq.FFHQTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: taming.data.faceshq.FFHQValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 1000
86 | max_images: 4
87 | increase_log_steps: False
88 | trainer:
89 | benchmark: True
90 |
--------------------------------------------------------------------------------
/configs/ffhq/4_scales/ffhq-geometric-scale2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: "image"
6 | monitor: "val/loss"
7 | n_scales: 4
8 | single_scale: 2
9 | top_k: 548
10 | alpha: 0.0
11 | redraw_prob: geometric
12 | use_ema: True
13 |
14 | scheduler_config:
15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
16 | params:
17 | verbosity_interval: 0 # 0 or negative to disable
18 | warm_up_steps: 10000
19 | max_decay_steps: 1500001
20 | lr_start: 2.5e-6
21 | lr_max: 1.0e-4
22 | lr_min: 1.0e-8
23 |
24 | transformer_config:
25 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer
26 | params:
27 | scale_pos: 0
28 | n_scales: 4
29 | xt_start: 1
30 | xt_size: 256 # predict x_{t-1}
31 | wrap_decoder: False
32 | dim: 752
33 | enc_num_tokens: 548
34 | enc_depth: 18
35 | enc_heads: 16
36 | enc_max_seq_len: 257
37 | dec_num_tokens: 548
38 | dec_depth: 6
39 | dec_heads: 16
40 | tie_token_emb: False
41 | dec_max_seq_len: 256
42 |
43 | first_stage_config:
44 | target: imagebart.models.vqgan.VQGANWrapper
45 | params:
46 | ckpt_path: vqgan/vqgan-ffhq.ckpt
47 | remap: "data/vqgan_indices/ffhq_indices.npy"
48 | sane_index_shape: True
49 | embed_dim: 256
50 | n_embed: 1024
51 | ddconfig:
52 | double_z: false
53 | z_channels: 256
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
59 | num_res_blocks: 2
60 | attn_resolutions: [ 16 ]
61 | dropout: 0.0
62 | lossconfig:
63 | target: taming.modules.losses.vqperceptual.DummyLoss
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 16
69 | num_workers: 32
70 | wrap: False
71 | train:
72 | target: taming.data.faceshq.FFHQTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: taming.data.faceshq.FFHQValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 1000
86 | max_images: 4
87 | increase_log_steps: False
88 | trainer:
89 | benchmark: True
90 |
--------------------------------------------------------------------------------
/configs/ffhq/4_scales/ffhq-geometric-scale3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: "image"
6 | monitor: "val/loss"
7 | n_scales: 4
8 | single_scale: 3
9 | top_k: 548
10 | alpha: 0.0
11 | redraw_prob: geometric
12 | use_ema: True
13 |
14 | scheduler_config:
15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
16 | params:
17 | verbosity_interval: 0 # 0 or negative to disable
18 | warm_up_steps: 10000
19 | max_decay_steps: 1500001
20 | lr_start: 2.5e-6
21 | lr_max: 1.0e-4
22 | lr_min: 1.0e-8
23 |
24 | transformer_config:
25 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer
26 | params:
27 | scale_pos: 0
28 | n_scales: 4
29 | xt_start: 1
30 | xt_size: 256 # predict x_{t-1}
31 | wrap_decoder: False
32 | dim: 752
33 | enc_num_tokens: 548
34 | enc_depth: 18
35 | enc_heads: 16
36 | enc_max_seq_len: 257
37 | dec_num_tokens: 548
38 | dec_depth: 6
39 | dec_heads: 16
40 | tie_token_emb: False
41 | dec_max_seq_len: 256
42 |
43 | first_stage_config:
44 | target: imagebart.models.vqgan.VQGANWrapper
45 | params:
46 | ckpt_path: vqgan/vqgan-ffhq.ckpt
47 | remap: "data/vqgan_indices/ffhq_indices.npy"
48 | sane_index_shape: True
49 | embed_dim: 256
50 | n_embed: 1024
51 | ddconfig:
52 | double_z: false
53 | z_channels: 256
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
59 | num_res_blocks: 2
60 | attn_resolutions: [ 16 ]
61 | dropout: 0.0
62 | lossconfig:
63 | target: taming.modules.losses.vqperceptual.DummyLoss
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 16
69 | num_workers: 32
70 | wrap: False
71 | train:
72 | target: taming.data.faceshq.FFHQTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: taming.data.faceshq.FFHQValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 1000
86 | max_images: 4
87 | increase_log_steps: False
88 | trainer:
89 | benchmark: True
90 |
--------------------------------------------------------------------------------
/configs/ffhq/4_scales/ffhq-geometric-scale4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser
4 | params:
5 | first_stage_key: "image"
6 | monitor: "val/loss"
7 | n_scales: 4
8 | single_scale: 4
9 | top_k: 548
10 | alpha: 1.0 # soft only
11 | redraw_prob: geometric
12 | use_ema: True
13 |
14 | scheduler_config:
15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
16 | params:
17 | verbosity_interval: 0 # 0 or negative to disable
18 | warm_up_steps: 10000
19 | max_decay_steps: 1500001
20 | lr_start: 2.5e-6
21 | lr_max: 1.0e-4
22 | lr_min: 1.0e-8
23 |
24 | transformer_config:
25 | target: imagebart.modules.transformer.mingpt.GPT
26 | params:
27 | vocab_size: 548
28 | block_size: 256 # 256 data tokens - 1 + 1 scale token
29 | n_layer: 26
30 | n_head: 16
31 | n_embd: 800
32 |
33 | first_stage_config:
34 | target: imagebart.models.vqgan.VQGANWrapper
35 | params:
36 | ckpt_path: vqgan/vqgan-ffhq.ckpt
37 | remap: "data/vqgan_indices/ffhq_indices.npy"
38 | sane_index_shape: True
39 | embed_dim: 256
40 | n_embed: 1024
41 | ddconfig:
42 | double_z: False
43 | z_channels: 256
44 | resolution: 256
45 | in_channels: 3
46 | out_ch: 3
47 | ch: 128
48 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
49 | num_res_blocks: 2
50 | attn_resolutions: [ 16 ]
51 | dropout: 0.0
52 | lossconfig:
53 | target: taming.modules.losses.vqperceptual.DummyLoss
54 |
55 | data:
56 | target: main.DataModuleFromConfig
57 | params:
58 | batch_size: 16
59 | wrap: False
60 | train:
61 | target: taming.data.faceshq.FFHQTrain
62 | params:
63 | size: 256
64 | validation:
65 | target: taming.data.faceshq.FFHQValidation
66 | params:
67 | size: 256
68 |
69 | lightning:
70 | callbacks:
71 | image_logger:
72 | target: main.ImageLogger
73 | params:
74 | batch_frequency: 1000
75 | max_images: 4
76 | increase_log_steps: False
77 | trainer:
78 | benchmark: True
79 |
--------------------------------------------------------------------------------
/configs/ffhq/ffhq_4-scales_joint-training.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: "image"
6 | monitor: "val/loss"
7 | n_scales: 4
8 | top_k: 548
9 | alpha: 1.0
10 | redraw_prob: geometric
11 | use_ema: True
12 |
13 | scheduler_config:
14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
15 | params:
16 | verbosity_interval: 0 # 0 or negative to disable
17 | warm_up_steps: 10000
18 | max_decay_steps: 1500001
19 | lr_start: 2.5e-6
20 | lr_max: 1.0e-4
21 | lr_min: 1.0e-8
22 |
23 | transformer_config:
24 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer
25 | params:
26 | scale_pos: 0
27 | n_scales: 4
28 | xt_start: 1
29 | xt_size: 256 # predict x_{t-1}
30 | wrap_decoder: False
31 | dim: 752
32 | enc_num_tokens: 548
33 | enc_depth: 18
34 | enc_heads: 16
35 | enc_max_seq_len: 257
36 | dec_num_tokens: 548
37 | dec_depth: 6
38 | dec_heads: 16
39 | tie_token_emb: False
40 | dec_max_seq_len: 256
41 |
42 | first_stage_config:
43 | target: imagebart.models.vqgan.VQGANWrapper
44 | params:
45 | ckpt_path: vqgan/vqgan-ffhq.ckpt
46 | remap: "data/vqgan_indices/ffhq_indices.npy"
47 | sane_index_shape: True
48 | embed_dim: 256
49 | n_embed: 1024
50 | ddconfig:
51 | double_z: false
52 | z_channels: 256
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
58 | num_res_blocks: 2
59 | attn_resolutions: [ 16 ]
60 | dropout: 0.0
61 | lossconfig:
62 | target: taming.modules.losses.vqperceptual.DummyLoss
63 |
64 | data:
65 | target: main.DataModuleFromConfig
66 | params:
67 | batch_size: 16
68 | num_workers: 32
69 | wrap: False
70 | train:
71 | target: taming.data.faceshq.FFHQTrain
72 | params:
73 | size: 256
74 | validation:
75 | target: taming.data.faceshq.FFHQValidation
76 | params:
77 | size: 256
78 |
79 | lightning:
80 | callbacks:
81 | image_logger:
82 | target: main.ImageLogger
83 | params:
84 | batch_frequency: 1000
85 | max_images: 4
86 | increase_log_steps: False
87 | log_images_kwargs:
88 | sample_full: True
89 | sample_half: True
90 |
91 | trainer:
92 | benchmark: True
93 |
--------------------------------------------------------------------------------
/configs/imagenet/4_scales/imagenet_geometric_scale1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 4
8 | single_scale: 1
9 | top_k: 973
10 | alpha: 0.0
11 | redraw_prob: geometric
12 | use_ema: true
13 | conditioner_config:
14 | target: imagebart.util.ClassProvider
15 | params:
16 | key: class_label
17 | scheduler_config:
18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
19 | params:
20 | verbosity_interval: 0
21 | warm_up_steps: 10000
22 | max_decay_steps: 10000000
23 | lr_start: 2.5e-06
24 | lr_max: 0.0001
25 | lr_min: 1.0e-08
26 | transformer_config:
27 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer
28 | params:
29 | scale_pos: 0
30 | n_scales: 4
31 | xt_start: 2
32 | xt_size: 256
33 | wrap_decoder: false
34 | dim: 1152
35 | enc_num_tokens: 1000
36 | enc_depth: 32
37 | enc_heads: 16
38 | enc_max_seq_len: 258
39 | dec_num_tokens: 973
40 | dec_depth: 6
41 | dec_heads: 16
42 | tie_token_emb: false
43 | dec_max_seq_len: 256
44 | first_stage_config:
45 | target: imagebart.models.vqgan.VQGANWrapper
46 | params:
47 | ckpt_path: vqgan/vqgan-imagenet.ckpt
48 | remap: data/vqgan_indices/imagenet_indices.npy
49 | sane_index_shape: true
50 | embed_dim: 256
51 | n_embed: 16384
52 | ddconfig:
53 | double_z: false
54 | z_channels: 256
55 | resolution: 256
56 | in_channels: 3
57 | out_ch: 3
58 | ch: 128
59 | ch_mult:
60 | - 1
61 | - 1
62 | - 2
63 | - 2
64 | - 4
65 | num_res_blocks: 2
66 | attn_resolutions:
67 | - 16
68 | dropout: 0.0
69 | tanh_out: true
70 | lossconfig:
71 | target: taming.modules.losses.vqperceptual.DummyLoss
72 |
73 | data:
74 | target: main.DataModuleFromConfig
75 | params:
76 | batch_size: 16
77 | wrap: false
78 | train:
79 | target: taming.data.imagenet.ImageNetTrain
80 | params:
81 | config:
82 | size: 256
83 | validation:
84 | target: taming.data.imagenet.ImageNetValidation
85 | params:
86 | config:
87 | size: 256
88 |
89 | lightning:
90 | callbacks:
91 | image_logger:
92 | target: main.ImageLogger
93 | params:
94 | batch_frequency: 1000
95 | max_images: 4
96 | increase_log_steps: false
97 | trainer:
98 | benchmark: true
99 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/imagenet/4_scales/imagenet_geometric_scale2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 4
8 | single_scale: 2
9 | top_k: 973
10 | alpha: 1.0
11 | redraw_prob: geometric
12 | use_ema: true
13 | conditioner_config:
14 | target: imagebart.util.ClassProvider
15 | params:
16 | key: class_label
17 | scheduler_config:
18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
19 | params:
20 | verbosity_interval: 0
21 | warm_up_steps: 10000
22 | max_decay_steps: 10000000
23 | lr_start: 2.5e-06
24 | lr_max: 0.0001
25 | lr_min: 1.0e-08
26 | transformer_config:
27 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer
28 | params:
29 | scale_pos: 0
30 | n_scales: 4
31 | xt_start: 2
32 | xt_size: 256
33 | wrap_decoder: false
34 | dim: 1152
35 | enc_num_tokens: 1000
36 | enc_depth: 32
37 | enc_heads: 16
38 | enc_max_seq_len: 258
39 | dec_num_tokens: 973
40 | dec_depth: 6
41 | dec_heads: 16
42 | tie_token_emb: false
43 | dec_max_seq_len: 256
44 | first_stage_config:
45 | target: imagebart.models.vqgan.VQGANWrapper
46 | params:
47 | ckpt_path: vqgan/vqgan-imagenet.ckpt
48 | remap: data/vqgan_indices/imagenet_indices.npy
49 | sane_index_shape: true
50 | embed_dim: 256
51 | n_embed: 16384
52 | ddconfig:
53 | double_z: false
54 | z_channels: 256
55 | resolution: 256
56 | in_channels: 3
57 | out_ch: 3
58 | ch: 128
59 | ch_mult:
60 | - 1
61 | - 1
62 | - 2
63 | - 2
64 | - 4
65 | num_res_blocks: 2
66 | attn_resolutions:
67 | - 16
68 | dropout: 0.0
69 | tanh_out: true
70 | lossconfig:
71 | target: taming.modules.losses.vqperceptual.DummyLoss
72 |
73 | data:
74 | target: main.DataModuleFromConfig
75 | params:
76 | batch_size: 16
77 | wrap: false
78 | train:
79 | target: taming.data.imagenet.ImageNetTrain
80 | params:
81 | config:
82 | size: 256
83 | validation:
84 | target: taming.data.imagenet.ImageNetValidation
85 | params:
86 | config:
87 | size: 256
88 |
89 | lightning:
90 | callbacks:
91 | image_logger:
92 | target: main.ImageLogger
93 | params:
94 | batch_frequency: 1000
95 | max_images: 4
96 | increase_log_steps: false
97 | trainer:
98 | benchmark: true
99 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/imagenet/4_scales/imagenet_geometric_scale3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 4
8 | single_scale: 3
9 | top_k: 973
10 | alpha: 1.0
11 | redraw_prob: geometric
12 | use_ema: true
13 | conditioner_config:
14 | target: imagebart.util.ClassProvider
15 | params:
16 | key: class_label
17 | scheduler_config:
18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
19 | params:
20 | verbosity_interval: 0
21 | warm_up_steps: 10000
22 | max_decay_steps: 10000000
23 | lr_start: 2.5e-06
24 | lr_max: 0.0001
25 | lr_min: 1.0e-08
26 | transformer_config:
27 | target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer
28 | params:
29 | scale_pos: 0
30 | n_scales: 4
31 | xt_start: 2
32 | xt_size: 256
33 | wrap_decoder: false
34 | dim: 1152
35 | enc_num_tokens: 1000
36 | enc_depth: 32
37 | enc_heads: 16
38 | enc_max_seq_len: 258
39 | dec_num_tokens: 973
40 | dec_depth: 6
41 | dec_heads: 16
42 | tie_token_emb: false
43 | dec_max_seq_len: 256
44 | first_stage_config:
45 | target: imagebart.models.vqgan.VQGANWrapper
46 | params:
47 | ckpt_path: vqgan/vqgan-imagenet.ckpt
48 | remap: data/vqgan_indices/imagenet_indices.npy
49 | sane_index_shape: true
50 | embed_dim: 256
51 | n_embed: 16384
52 | ddconfig:
53 | double_z: false
54 | z_channels: 256
55 | resolution: 256
56 | in_channels: 3
57 | out_ch: 3
58 | ch: 128
59 | ch_mult:
60 | - 1
61 | - 1
62 | - 2
63 | - 2
64 | - 4
65 | num_res_blocks: 2
66 | attn_resolutions:
67 | - 16
68 | dropout: 0.0
69 | tanh_out: true
70 | lossconfig:
71 | target: taming.modules.losses.vqperceptual.DummyLoss
72 |
73 | data:
74 | target: main.DataModuleFromConfig
75 | params:
76 | batch_size: 16
77 | wrap: false
78 | train:
79 | target: taming.data.imagenet.ImageNetTrain
80 | params:
81 | config:
82 | size: 256
83 | validation:
84 | target: taming.data.imagenet.ImageNetValidation
85 | params:
86 | config:
87 | size: 256
88 |
89 | lightning:
90 | callbacks:
91 | image_logger:
92 | target: main.ImageLogger
93 | params:
94 | batch_frequency: 1000
95 | max_images: 4
96 | increase_log_steps: false
97 | trainer:
98 | benchmark: true
99 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/imagenet/4_scales/imagenet_geometric_scale4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 4
8 | single_scale: 4
9 | top_k: 973
10 | alpha: 1.0
11 | redraw_prob: geometric
12 | use_ema: true
13 | conditioner_config:
14 | target: imagebart.util.ClassProvider
15 | params:
16 | key: class_label
17 | scheduler_config:
18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
19 | params:
20 | verbosity_interval: 0
21 | warm_up_steps: 10000
22 | max_decay_steps: 10000000
23 | lr_start: 2.5e-06
24 | lr_max: 0.0001
25 | lr_min: 1.0e-08
26 | transformer_config:
27 | target: imagebart.modules.transformer.mingpt.GPT
28 | params:
29 | input_vocab_size: 1000
30 | vocab_size: 973
31 | block_size: 258
32 | n_layer: 36
33 | n_head: 16
34 | n_embd: 1216
35 | n_unmasked: 2
36 | first_stage_config:
37 | target: imagebart.models.vqgan.VQGANWrapper
38 | params:
39 | ckpt_path: vqgan/vqgan-imagenet.ckpt
40 | remap: data/vqgan_indices/imagenet_indices.npy
41 | sane_index_shape: true
42 | embed_dim: 256
43 | n_embed: 16384
44 | ddconfig:
45 | double_z: false
46 | z_channels: 256
47 | resolution: 256
48 | in_channels: 3
49 | out_ch: 3
50 | ch: 128
51 | ch_mult:
52 | - 1
53 | - 1
54 | - 2
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions:
59 | - 16
60 | dropout: 0.0
61 | lossconfig:
62 | target: taming.modules.losses.vqperceptual.DummyLoss
63 |
64 | data:
65 | target: main.DataModuleFromConfig
66 | params:
67 | batch_size: 16
68 | wrap: false
69 | train:
70 | target: taming.data.imagenet.ImageNetTrain
71 | params:
72 | config:
73 | size: 256
74 | validation:
75 | target: taming.data.imagenet.ImageNetValidation
76 | params:
77 | config:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 1000
86 | max_images: 4
87 | increase_log_steps: false
88 | trainer:
89 | benchmark: true
90 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/imagenet/5_scales/imagenet-custom-scale1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 5
8 | single_scale: 1
9 | top_k: 973
10 | alpha: 0.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | conditioner_config:
14 | target: imagebart.util.ClassProvider
15 | params:
16 | key: class_label
17 | scheduler_config:
18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
19 | params:
20 | verbosity_interval: 0 # 0 or negative to disable
21 | warm_up_steps: 10000
22 | max_decay_steps: 10000000
23 | lr_start: 2.5e-6
24 | lr_max: 1.0e-4
25 | lr_min: 1.0e-8
26 | transformer_config:
27 | target: imagebart.modules.xtransformers.x_transformer.XTransformer
28 | params:
29 | wrap_decoder: false
30 | dim: 1152
31 | enc_num_tokens: 1000
32 | enc_depth: 32
33 | enc_heads: 16
34 | enc_max_seq_len: 258
35 | dec_num_tokens: 973
36 | dec_depth: 6
37 | dec_heads: 16
38 | tie_token_emb: false
39 | dec_max_seq_len: 256
40 | first_stage_config:
41 | target: imagebart.models.vqgan.VQGANWrapper
42 | params:
43 | ckpt_path: vqgan/vqgan-imagenet.ckpt
44 | remap: data/vqgan_indices/imagenet_indices.npy
45 | sane_index_shape: true
46 | embed_dim: 256
47 | n_embed: 16384
48 | ddconfig:
49 | double_z: false
50 | z_channels: 256
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: taming.modules.losses.vqperceptual.DummyLoss
67 | data:
68 | target: main.DataModuleFromConfig
69 | params:
70 | batch_size: 16
71 | wrap: false
72 | train:
73 | target: taming.data.imagenet.ImageNetTrain
74 | params:
75 | config:
76 | size: 256
77 | validation:
78 | target: taming.data.imagenet.ImageNetValidation
79 | params:
80 | config:
81 | size: 256
82 |
83 | lightning:
84 | callbacks:
85 | image_logger:
86 | target: main.ImageLogger
87 | params:
88 | batch_frequency: 1000
89 | max_images: 4
90 | increase_log_steps: false
91 | trainer:
92 | benchmark: true
93 | accumulate_grad_batches: 2
94 |
--------------------------------------------------------------------------------
/configs/imagenet/5_scales/imagenet-custom-scale2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 5
8 | single_scale: 2
9 | top_k: 973
10 | alpha: 1.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | conditioner_config:
14 | target: imagebart.util.ClassProvider
15 | params:
16 | key: class_label
17 | scheduler_config:
18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
19 | params:
20 | verbosity_interval: 0 # 0 or negative to disable
21 | warm_up_steps: 10000
22 | max_decay_steps: 10000000
23 | lr_start: 2.5e-6
24 | lr_max: 1.0e-4
25 | lr_min: 1.0e-8
26 | transformer_config:
27 | target: imagebart.modules.xtransformers.x_transformer.XTransformer
28 | params:
29 | wrap_decoder: false
30 | dim: 1152
31 | enc_num_tokens: 1000
32 | enc_depth: 32
33 | enc_heads: 16
34 | enc_max_seq_len: 258
35 | dec_num_tokens: 973
36 | dec_depth: 6
37 | dec_heads: 16
38 | tie_token_emb: false
39 | dec_max_seq_len: 256
40 | first_stage_config:
41 | target: imagebart.models.vqgan.VQGANWrapper
42 | params:
43 | ckpt_path: vqgan/vqgan-imagenet.ckpt
44 | remap: data/vqgan_indices/imagenet_indices.npy
45 | sane_index_shape: true
46 | embed_dim: 256
47 | n_embed: 16384
48 | ddconfig:
49 | double_z: false
50 | z_channels: 256
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: taming.modules.losses.vqperceptual.DummyLoss
67 | data:
68 | target: main.DataModuleFromConfig
69 | params:
70 | batch_size: 16
71 | wrap: false
72 | train:
73 | target: taming.data.imagenet.ImageNetTrain
74 | params:
75 | config:
76 | size: 256
77 | validation:
78 | target: taming.data.imagenet.ImageNetValidation
79 | params:
80 | config:
81 | size: 256
82 |
83 | lightning:
84 | callbacks:
85 | image_logger:
86 | target: main.ImageLogger
87 | params:
88 | batch_frequency: 1000
89 | max_images: 4
90 | increase_log_steps: false
91 | trainer:
92 | benchmark: true
93 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/imagenet/5_scales/imagenet-custom-scale3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 5
8 | single_scale: 3
9 | top_k: 973
10 | alpha: 1.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | conditioner_config:
14 | target: imagebart.util.ClassProvider
15 | params:
16 | key: class_label
17 | scheduler_config:
18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
19 | params:
20 | verbosity_interval: 0 # 0 or negative to disable
21 | warm_up_steps: 10000
22 | max_decay_steps: 10000000
23 | lr_start: 2.5e-6
24 | lr_max: 1.0e-4
25 | lr_min: 1.0e-8
26 | transformer_config:
27 | target: imagebart.modules.xtransformers.x_transformer.XTransformer
28 | params:
29 | wrap_decoder: false
30 | dim: 1152
31 | enc_num_tokens: 1000
32 | enc_depth: 32
33 | enc_heads: 16
34 | enc_max_seq_len: 258
35 | dec_num_tokens: 973
36 | dec_depth: 6
37 | dec_heads: 16
38 | tie_token_emb: false
39 | dec_max_seq_len: 256
40 | first_stage_config:
41 | target: imagebart.models.vqgan.VQGANWrapper
42 | params:
43 | ckpt_path: vqgan/vqgan-imagenet.ckpt
44 | remap: data/vqgan_indices/imagenet_indices.npy
45 | sane_index_shape: true
46 | embed_dim: 256
47 | n_embed: 16384
48 | ddconfig:
49 | double_z: false
50 | z_channels: 256
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: taming.modules.losses.vqperceptual.DummyLoss
67 | data:
68 | target: main.DataModuleFromConfig
69 | params:
70 | batch_size: 16
71 | wrap: false
72 | train:
73 | target: taming.data.imagenet.ImageNetTrain
74 | params:
75 | config:
76 | size: 256
77 | validation:
78 | target: taming.data.imagenet.ImageNetValidation
79 | params:
80 | config:
81 | size: 256
82 |
83 | lightning:
84 | callbacks:
85 | image_logger:
86 | target: main.ImageLogger
87 | params:
88 | batch_frequency: 1000
89 | max_images: 4
90 | increase_log_steps: false
91 | trainer:
92 | benchmark: true
93 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/imagenet/5_scales/imagenet-custom-scale4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 5
8 | single_scale: 4
9 | top_k: 973
10 | alpha: 1.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | conditioner_config:
14 | target: imagebart.util.ClassProvider
15 | params:
16 | key: class_label
17 | scheduler_config:
18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
19 | params:
20 | verbosity_interval: 0 # 0 or negative to disable
21 | warm_up_steps: 10000
22 | max_decay_steps: 10000000
23 | lr_start: 2.5e-6
24 | lr_max: 1.0e-4
25 | lr_min: 1.0e-8
26 | transformer_config:
27 | target: imagebart.modules.xtransformers.x_transformer.XTransformer
28 | params:
29 | wrap_decoder: false
30 | dim: 1152
31 | enc_num_tokens: 1000
32 | enc_depth: 32
33 | enc_heads: 16
34 | enc_max_seq_len: 258
35 | dec_num_tokens: 973
36 | dec_depth: 6
37 | dec_heads: 16
38 | tie_token_emb: false
39 | dec_max_seq_len: 256
40 | first_stage_config:
41 | target: imagebart.models.vqgan.VQGANWrapper
42 | params:
43 | ckpt_path: vqgan/vqgan-imagenet.ckpt
44 | remap: data/vqgan_indices/imagenet_indices.npy
45 | sane_index_shape: true
46 | embed_dim: 256
47 | n_embed: 16384
48 | ddconfig:
49 | double_z: false
50 | z_channels: 256
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: taming.modules.losses.vqperceptual.DummyLoss
67 | data:
68 | target: main.DataModuleFromConfig
69 | params:
70 | batch_size: 16
71 | wrap: false
72 | train:
73 | target: taming.data.imagenet.ImageNetTrain
74 | params:
75 | config:
76 | size: 256
77 | validation:
78 | target: taming.data.imagenet.ImageNetValidation
79 | params:
80 | config:
81 | size: 256
82 |
83 | lightning:
84 | callbacks:
85 | image_logger:
86 | target: main.ImageLogger
87 | params:
88 | batch_frequency: 1000
89 | max_images: 4
90 | increase_log_steps: false
91 | trainer:
92 | benchmark: true
93 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/imagenet/5_scales/imagenet-custom-scale5.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 5
8 | single_scale: 5
9 | top_k: 973
10 | alpha: 1.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | conditioner_config:
14 | target: imagebart.util.ClassProvider
15 | params:
16 | key: class_label
17 | scheduler_config:
18 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
19 | params:
20 | verbosity_interval: 0 # 0 or negative to disable
21 | warm_up_steps: 10000
22 | max_decay_steps: 10000000
23 | lr_start: 2.5e-6
24 | lr_max: 1.0e-4
25 | lr_min: 1.0e-8
26 | transformer_config:
27 | target: imagebart.modules.transformer.mingpt.GPT
28 | params:
29 | input_vocab_size: 1000
30 | vocab_size: 973
31 | block_size: 258
32 | n_layer: 36
33 | n_head: 16
34 | n_embd: 1216
35 | n_unmasked: 2
36 | first_stage_config:
37 | target: imagebart.models.vqgan.VQGANWrapper
38 | params:
39 | ckpt_path: vqgan/vqgan-imagenet.ckpt
40 | remap: data/vqgan_indices/imagenet_indices.npy
41 | sane_index_shape: true
42 | embed_dim: 256
43 | n_embed: 16384
44 | ddconfig:
45 | double_z: false
46 | z_channels: 256
47 | resolution: 256
48 | in_channels: 3
49 | out_ch: 3
50 | ch: 128
51 | ch_mult:
52 | - 1
53 | - 1
54 | - 2
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions:
59 | - 16
60 | dropout: 0.0
61 | lossconfig:
62 | target: taming.modules.losses.vqperceptual.DummyLoss
63 | data:
64 | target: main.DataModuleFromConfig
65 | params:
66 | batch_size: 16
67 | wrap: false
68 | train:
69 | target: taming.data.imagenet.ImageNetTrain
70 | params:
71 | config:
72 | size: 256
73 | validation:
74 | target: taming.data.imagenet.ImageNetValidation
75 | params:
76 | config:
77 | size: 256
78 |
79 | lightning:
80 | callbacks:
81 | image_logger:
82 | target: main.ImageLogger
83 | params:
84 | batch_frequency: 1000
85 | max_images: 4
86 | increase_log_steps: false
87 | trainer:
88 | benchmark: true
89 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/lsun-beds/lsun-beds-scale1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 3
8 | single_scale: 1
9 | top_k: 1017
10 | alpha: 0.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | scheduler_config:
14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
15 | params:
16 | verbosity_interval: 0
17 | warm_up_steps: 10000
18 | max_decay_steps: 5000000
19 | lr_start: 2.5e-06
20 | lr_max: 0.0001
21 | lr_min: 1.0e-08
22 |
23 | transformer_config:
24 | target: imagebart.modules.xtransformers.x_transformer.XTransformer
25 | params:
26 | wrap_decoder: false
27 | dim: 1152
28 | enc_num_tokens: 1017
29 | enc_depth: 32
30 | enc_heads: 16
31 | enc_max_seq_len: 257
32 | dec_num_tokens: 1017
33 | dec_depth: 6
34 | dec_heads: 16
35 | tie_token_emb: false
36 | dec_max_seq_len: 256
37 | first_stage_config:
38 | target: imagebart.models.vqgan.VQGANWrapper
39 | params:
40 | ckpt_path: vqgan/vqgan-bedrooms.ckpt
41 | remap: data/vqgan_indices/bedroom_indices.npy
42 | sane_index_shape: true
43 | embed_dim: 256
44 | n_embed: 16384
45 | ddconfig:
46 | double_z: false
47 | z_channels: 256
48 | resolution: 256
49 | in_channels: 3
50 | out_ch: 3
51 | ch: 128
52 | ch_mult:
53 | - 1
54 | - 1
55 | - 2
56 | - 2
57 | - 4
58 | num_res_blocks: 2
59 | attn_resolutions:
60 | - 16
61 | dropout: 0.0
62 | tanh_out: true
63 | lossconfig:
64 | target: taming.modules.losses.vqperceptual.DummyLoss
65 |
66 | data:
67 | target: main.DataModuleFromConfig
68 | params:
69 | batch_size: 16
70 | wrap: false
71 | train:
72 | target: imagebart.data.lsun.LSUNBedroomsTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: imagebart.data.lsun.LSUNBedroomsValidation
77 | params:
78 | size: 256
79 |
80 |
81 | lightning:
82 | callbacks:
83 | image_logger:
84 | target: main.ImageLogger
85 | params:
86 | batch_frequency: 1000
87 | max_images: 4
88 | increase_log_steps: false
89 | trainer:
90 | benchmark: true
91 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/lsun-beds/lsun-beds-scale2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 3
8 | single_scale: 2
9 | top_k: 1017
10 | alpha: 1.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | scheduler_config:
14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
15 | params:
16 | verbosity_interval: 0
17 | warm_up_steps: 10000
18 | max_decay_steps: 5000000
19 | lr_start: 2.5e-06
20 | lr_max: 0.0001
21 | lr_min: 1.0e-08
22 |
23 | transformer_config:
24 | target: imagebart.modules.xtransformers.x_transformer.XTransformer
25 | params:
26 | wrap_decoder: false
27 | dim: 1152
28 | enc_num_tokens: 1017
29 | enc_depth: 32
30 | enc_heads: 16
31 | enc_max_seq_len: 257
32 | dec_num_tokens: 1017
33 | dec_depth: 6
34 | dec_heads: 16
35 | tie_token_emb: false
36 | dec_max_seq_len: 256
37 | first_stage_config:
38 | target: imagebart.models.vqgan.VQGANWrapper
39 | params:
40 | ckpt_path: vqgan/vqgan-bedrooms.ckpt
41 | remap: data/vqgan_indices/bedroom_indices.npy
42 | sane_index_shape: true
43 | embed_dim: 256
44 | n_embed: 16384
45 | ddconfig:
46 | double_z: false
47 | z_channels: 256
48 | resolution: 256
49 | in_channels: 3
50 | out_ch: 3
51 | ch: 128
52 | ch_mult:
53 | - 1
54 | - 1
55 | - 2
56 | - 2
57 | - 4
58 | num_res_blocks: 2
59 | attn_resolutions:
60 | - 16
61 | dropout: 0.0
62 | tanh_out: true
63 | lossconfig:
64 | target: taming.modules.losses.vqperceptual.DummyLoss
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 16
69 | wrap: false
70 | train:
71 | target: imagebart.data.lsun.LSUNBedroomsTrain
72 | params:
73 | size: 256
74 | validation:
75 | target: imagebart.data.lsun.LSUNBedroomsValidation
76 | params:
77 | size: 256
78 |
79 | lightning:
80 | callbacks:
81 | image_logger:
82 | target: main.ImageLogger
83 | params:
84 | batch_frequency: 1000
85 | max_images: 4
86 | increase_log_steps: false
87 | trainer:
88 | benchmark: true
89 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/lsun-beds/lsun-beds-scale3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 3
8 | single_scale: 3
9 | top_k: 1017
10 | alpha: 1.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | scheduler_config:
14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
15 | params:
16 | verbosity_interval: 0
17 | warm_up_steps: 10000
18 | max_decay_steps: 5000000
19 | lr_start: 2.5e-06
20 | lr_max: 0.0001
21 | lr_min: 1.0e-08
22 | transformer_config:
23 | target: imagebart.modules.transformer.mingpt.GPT
24 | params:
25 | input_vocab_size: 1017
26 | vocab_size: 1017
27 | block_size: 256
28 | n_layer: 36
29 | n_head: 16
30 | n_embd: 1216
31 | first_stage_config:
32 | target: imagebart.models.vqgan.VQGANWrapper
33 | params:
34 | ckpt_path: vqgan/vqgan-bedrooms.ckpt
35 | remap: data/vqgan_indices/bedroom_indices.npy
36 | sane_index_shape: true
37 | embed_dim: 256
38 | n_embed: 16384
39 | ddconfig:
40 | double_z: false
41 | z_channels: 256
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 1
49 | - 2
50 | - 2
51 | - 4
52 | num_res_blocks: 2
53 | attn_resolutions:
54 | - 16
55 | dropout: 0.0
56 | tanh_out: true
57 | lossconfig:
58 | target: taming.modules.losses.vqperceptual.DummyLoss
59 |
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 16
64 | wrap: false
65 | train:
66 | target: imagebart.data.lsun.LSUNBedroomsTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: imagebart.data.lsun.LSUNBedroomsValidation
71 | params:
72 | size: 256
73 |
74 | lightning:
75 | callbacks:
76 | image_logger:
77 | target: main.ImageLogger
78 | params:
79 | batch_frequency: 1000
80 | max_images: 4
81 | increase_log_steps: false
82 | trainer:
83 | benchmark: true
84 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/lsun-cats/lsun-cats-scale1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 3
8 | single_scale: 1
9 | top_k: 1014
10 | alpha: 0.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | scheduler_config:
14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
15 | params:
16 | verbosity_interval: 0
17 | warm_up_steps: 10000
18 | max_decay_steps: 5000000
19 | lr_start: 2.5e-06
20 | lr_max: 0.0001
21 | lr_min: 1.0e-08
22 |
23 | transformer_config:
24 | target: imagebart.modules.xtransformers.x_transformer.XTransformer
25 | params:
26 | wrap_decoder: false
27 | dim: 1152
28 | enc_num_tokens: 1014
29 | enc_depth: 32
30 | enc_heads: 16
31 | enc_max_seq_len: 257
32 | dec_num_tokens: 1014
33 | dec_depth: 6
34 | dec_heads: 16
35 | tie_token_emb: false
36 | dec_max_seq_len: 256
37 | first_stage_config:
38 | target: imagebart.models.vqgan.VQGANWrapper
39 | params:
40 | ckpt_path: vqgan/vqgan-cats.ckpt
41 | remap: data/vqgan_indices/cat_indices.npy
42 | sane_index_shape: true
43 | embed_dim: 256
44 | n_embed: 16384
45 | ddconfig:
46 | double_z: false
47 | z_channels: 256
48 | resolution: 256
49 | in_channels: 3
50 | out_ch: 3
51 | ch: 128
52 | ch_mult:
53 | - 1
54 | - 1
55 | - 2
56 | - 2
57 | - 4
58 | num_res_blocks: 2
59 | attn_resolutions:
60 | - 16
61 | dropout: 0.0
62 | tanh_out: true
63 | lossconfig:
64 | target: taming.modules.losses.vqperceptual.DummyLoss
65 |
66 | data:
67 | target: main.DataModuleFromConfig
68 | params:
69 | batch_size: 16
70 | wrap: false
71 | train:
72 | target: imagebart.data.lsun.LSUNCatsTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: imagebart.data.lsun.LSUNCatsValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 1000
86 | max_images: 4
87 | increase_log_steps: false
88 | trainer:
89 | benchmark: true
90 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/lsun-cats/lsun-cats-scale2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 3
8 | single_scale: 2
9 | top_k: 1014
10 | alpha: 1.0
11 | redraw_prob: bernoulli_PSIM
12 |
13 | use_ema: true
14 | scheduler_config:
15 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
16 | params:
17 | verbosity_interval: 0
18 | warm_up_steps: 10000
19 | max_decay_steps: 5000000
20 | lr_start: 2.5e-06
21 | lr_max: 0.0001
22 | lr_min: 1.0e-08
23 | transformer_config:
24 | target: imagebart.modules.xtransformers.x_transformer.XTransformer
25 | params:
26 | wrap_decoder: false
27 | dim: 1152
28 | enc_num_tokens: 1014
29 | enc_depth: 32
30 | enc_heads: 16
31 | enc_max_seq_len: 257
32 | dec_num_tokens: 1014
33 | dec_depth: 6
34 | dec_heads: 16
35 | tie_token_emb: false
36 | dec_max_seq_len: 256
37 | first_stage_config:
38 | target: imagebart.models.vqgan.VQGANWrapper
39 | params:
40 | ckpt_path: vqgan/vqgan-cats.ckpt
41 | remap: data/vqgan_indices/cat_indices.npy
42 | sane_index_shape: true
43 | embed_dim: 256
44 | n_embed: 16384
45 | ddconfig:
46 | double_z: false
47 | z_channels: 256
48 | resolution: 256
49 | in_channels: 3
50 | out_ch: 3
51 | ch: 128
52 | ch_mult:
53 | - 1
54 | - 1
55 | - 2
56 | - 2
57 | - 4
58 | num_res_blocks: 2
59 | attn_resolutions:
60 | - 16
61 | dropout: 0.0
62 | tanh_out: true
63 | lossconfig:
64 | target: taming.modules.losses.vqperceptual.DummyLoss
65 |
66 | data:
67 | target: main.DataModuleFromConfig
68 | params:
69 | batch_size: 16
70 | wrap: false
71 | train:
72 | target: imagebart.data.lsun.LSUNCatsTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: imagebart.data.lsun.LSUNCatsValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 1000
86 | max_images: 4
87 | increase_log_steps: false
88 | trainer:
89 | benchmark: true
90 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/lsun-cats/lsun-cats-scale3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 3
8 | single_scale: 3
9 | top_k: 1014
10 | alpha: 1.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | scheduler_config:
14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
15 | params:
16 | verbosity_interval: 0
17 | warm_up_steps: 10000
18 | max_decay_steps: 5000000
19 | lr_start: 2.5e-06
20 | lr_max: 0.0001
21 | lr_min: 1.0e-08
22 |
23 | transformer_config:
24 | target: imagebart.modules.transformer.mingpt.GPT
25 | params:
26 | input_vocab_size: 1014
27 | vocab_size: 1014
28 | block_size: 256
29 | n_layer: 36
30 | n_head: 16
31 | n_embd: 1216
32 |
33 | first_stage_config:
34 | target: imagebart.models.vqgan.VQGANWrapper
35 | params:
36 | ckpt_path: vqgan/vqgan-cats.ckpt
37 | remap: data/vqgan_indices/cat_indices.npy
38 | sane_index_shape: true
39 | embed_dim: 256
40 | n_embed: 16384
41 | ddconfig:
42 | double_z: false
43 | z_channels: 256
44 | resolution: 256
45 | in_channels: 3
46 | out_ch: 3
47 | ch: 128
48 | ch_mult:
49 | - 1
50 | - 1
51 | - 2
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions:
56 | - 16
57 | dropout: 0.0
58 | tanh_out: true
59 | lossconfig:
60 | target: taming.modules.losses.vqperceptual.DummyLoss
61 |
62 | data:
63 | target: main.DataModuleFromConfig
64 | params:
65 | batch_size: 16
66 | wrap: false
67 | train:
68 | target: imagebart.data.lsun.LSUNCatsTrain
69 | params:
70 | size: 256
71 | validation:
72 | target: imagebart.data.lsun.LSUNCatsValidation
73 | params:
74 | size: 256
75 |
76 |
77 | lightning:
78 | callbacks:
79 | image_logger:
80 | target: main.ImageLogger
81 | params:
82 | batch_frequency: 1000
83 | max_images: 4
84 | increase_log_steps: false
85 | trainer:
86 | benchmark: true
87 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/lsun-churches/lsun-churches-scale1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 3
8 | single_scale: 1
9 | top_k: 1022
10 | alpha: 0.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | scheduler_config:
14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
15 | params:
16 | verbosity_interval: 0
17 | warm_up_steps: 10000
18 | max_decay_steps: 1300001
19 | lr_start: 2.5e-06
20 | lr_max: 0.0001
21 | lr_min: 1.0e-08
22 | transformer_config:
23 | target: imagebart.modules.xtransformers.x_transformer.XTransformer
24 | params:
25 | wrap_decoder: false
26 | dim: 1152
27 | enc_num_tokens: 1022
28 | enc_depth: 32
29 | enc_heads: 16
30 | enc_max_seq_len: 258
31 | dec_num_tokens: 1022
32 | dec_depth: 6
33 | dec_heads: 16
34 | tie_token_emb: false
35 | dec_max_seq_len: 256
36 | first_stage_config:
37 | target: imagebart.models.vqgan.VQGANWrapper
38 | params:
39 | ckpt_path: vqgan/vqgan-churches.ckpt
40 | remap: data/vqgan_indices/church_indices.npy
41 | sane_index_shape: true
42 | embed_dim: 256
43 | n_embed: 16384
44 | ddconfig:
45 | double_z: false
46 | z_channels: 256
47 | resolution: 256
48 | in_channels: 3
49 | out_ch: 3
50 | ch: 128
51 | ch_mult:
52 | - 1
53 | - 1
54 | - 2
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions:
59 | - 16
60 | dropout: 0.0
61 | tanh_out: true
62 | lossconfig:
63 | target: taming.modules.losses.vqperceptual.DummyLoss
64 | data:
65 | target: main.DataModuleFromConfig
66 | params:
67 | batch_size: 16
68 | wrap: false
69 | train:
70 | target: imagebart.data.lsun.LSUNChurchesTrain
71 | params:
72 | size: 256
73 | validation:
74 | target: imagebart.data.lsun.LSUNChurchesValidation
75 | params:
76 | size: 256
77 |
78 | lightning:
79 | callbacks:
80 | image_logger:
81 | target: main.ImageLogger
82 | params:
83 | batch_frequency: 1000
84 | max_images: 4
85 | increase_log_steps: false
86 | trainer:
87 | benchmark: true
88 | accumulate_grad_batches: 2
89 |
--------------------------------------------------------------------------------
/configs/lsun-churches/lsun-churches-scale2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DenoisingXTransformer
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 3
8 | single_scale: 2
9 | top_k: 1022
10 | alpha: 1.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | scheduler_config:
14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
15 | params:
16 | verbosity_interval: 0
17 | warm_up_steps: 10000
18 | max_decay_steps: 1300001
19 | lr_start: 2.5e-06
20 | lr_max: 0.0001
21 | lr_min: 1.0e-08
22 | transformer_config:
23 | target: imagebart.modules.xtransformers.x_transformer.XTransformer
24 | params:
25 | wrap_decoder: false
26 | dim: 1152
27 | enc_num_tokens: 1022
28 | enc_depth: 32
29 | enc_heads: 16
30 | enc_max_seq_len: 258
31 | dec_num_tokens: 1022
32 | dec_depth: 6
33 | dec_heads: 16
34 | tie_token_emb: false
35 | dec_max_seq_len: 256
36 | first_stage_config:
37 | target: imagebart.models.vqgan.VQGANWrapper
38 | params:
39 | ckpt_path: vqgan/vqgan-churches.ckpt
40 | remap: data/vqgan_indices/church_indices.npy
41 | sane_index_shape: true
42 | embed_dim: 256
43 | n_embed: 16384
44 | ddconfig:
45 | double_z: false
46 | z_channels: 256
47 | resolution: 256
48 | in_channels: 3
49 | out_ch: 3
50 | ch: 128
51 | ch_mult:
52 | - 1
53 | - 1
54 | - 2
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions:
59 | - 16
60 | dropout: 0.0
61 | tanh_out: true
62 | lossconfig:
63 | target: taming.modules.losses.vqperceptual.DummyLoss
64 | data:
65 | target: main.DataModuleFromConfig
66 | params:
67 | batch_size: 16
68 | wrap: false
69 | train:
70 | target: imagebart.data.lsun.LSUNChurchesTrain
71 | params:
72 | size: 256
73 | validation:
74 | target: imagebart.data.lsun.LSUNChurchesValidation
75 | params:
76 | size: 256
77 |
78 | lightning:
79 | callbacks:
80 | image_logger:
81 | target: main.ImageLogger
82 | params:
83 | batch_frequency: 1000
84 | max_images: 4
85 | increase_log_steps: false
86 | trainer:
87 | benchmark: true
88 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/lsun-churches/lsun-churches-scale3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: imagebart.models.diffusion.DecoderOnlyDenoiser
4 | params:
5 | first_stage_key: image
6 | monitor: val/loss
7 | n_scales: 3
8 | single_scale: 3
9 | top_k: 1022
10 | alpha: 1.0
11 | redraw_prob: bernoulli_PSIM
12 | use_ema: true
13 | scheduler_config:
14 | target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
15 | params:
16 | verbosity_interval: 0
17 | warm_up_steps: 10000
18 | max_decay_steps: 1300001
19 | lr_start: 2.5e-06
20 | lr_max: 0.0001
21 | lr_min: 1.0e-08
22 | transformer_config:
23 | target: imagebart.modules.transformer.mingpt.GPT
24 | params:
25 | input_vocab_size: 1022
26 | vocab_size: 1022
27 | block_size: 256
28 | n_layer: 36
29 | n_head: 16
30 | n_embd: 1216
31 | first_stage_config:
32 | target: imagebart.models.vqgan.VQGANWrapper
33 | params:
34 | ckpt_path: vqgan/vqgan-churches.ckpt
35 | remap: data/vqgan_indices/church_indices.npy
36 | sane_index_shape: true
37 | embed_dim: 256
38 | n_embed: 16384
39 | ddconfig:
40 | double_z: false
41 | z_channels: 256
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 1
49 | - 2
50 | - 2
51 | - 4
52 | num_res_blocks: 2
53 | attn_resolutions:
54 | - 16
55 | dropout: 0.0
56 | tanh_out: true
57 | lossconfig:
58 | target: taming.modules.losses.vqperceptual.DummyLoss
59 |
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 16
64 | wrap: false
65 | train:
66 | target: imagebart.data.lsun.LSUNChurchesTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: imagebart.data.lsun.LSUNChurchesValidation
71 | params:
72 | size: 256
73 |
74 | lightning:
75 | callbacks:
76 | image_logger:
77 | target: main.ImageLogger
78 | params:
79 | batch_frequency: 1000
80 | max_images: 4
81 | increase_log_steps: false
82 | trainer:
83 | benchmark: true
84 | accumulate_grad_batches: 2
--------------------------------------------------------------------------------
/configs/sampling/ffhq/ffhq_2_scales_custom.yaml:
--------------------------------------------------------------------------------
1 | checkpoints:
2 | - "models/ffhq_2_scales_custom/scale1.ckpt"
3 | - "models/ffhq_2_scales_custom/scale2.ckpt"
4 |
5 | configs:
6 | - "configs/ffhq/2_scales/ffhq-custom-scale1.yaml"
7 | - "configs/ffhq/2_scales/ffhq-custom-scale2.yaml"
--------------------------------------------------------------------------------
/configs/sampling/ffhq/ffhq_4_scales_geometric.yaml:
--------------------------------------------------------------------------------
1 | checkpoints:
2 | - "models/ffhq_4_scales_geometric/scale1.ckpt"
3 | - "models/ffhq_4_scales_geometric/scale2.ckpt"
4 | - "models/ffhq_4_scales_geometric/scale3.ckpt"
5 | - "models/ffhq_4_scales_geometric/scale4.ckpt"
6 |
7 | configs:
8 | - "configs/ffhq/4_scales/ffhq-geometric-scale1.yaml"
9 | - "configs/ffhq/4_scales/ffhq-geometric-scale2.yaml"
10 | - "configs/ffhq/4_scales/ffhq-geometric-scale3.yaml"
11 | - "configs/ffhq/4_scales/ffhq-geometric-scale4.yaml"
12 |
--------------------------------------------------------------------------------
/configs/sampling/imagenet/imagenet_4_scales_geometric.yaml:
--------------------------------------------------------------------------------
1 | checkpoints:
2 | - "models/cin_4_scales_geometric/scale1.ckpt"
3 | - "models/cin_4_scales_geometric/scale2.ckpt"
4 | - "models/cin_4_scales_geometric/scale3.ckpt"
5 | - "models/cin_4_scales_geometric/scale4.ckpt"
6 |
7 | configs:
8 | - "configs/imagenet/4_scales/imagenet_geometric_scale1.yaml"
9 | - "configs/imagenet/4_scales/imagenet_geometric_scale2.yaml"
10 | - "configs/imagenet/4_scales/imagenet_geometric_scale3.yaml"
11 | - "configs/imagenet/4_scales/imagenet_geometric_scale4.yaml"
12 |
--------------------------------------------------------------------------------
/configs/sampling/imagenet/imagenet_5_scales_custom.yaml:
--------------------------------------------------------------------------------
1 | checkpoints:
2 | - "models/cin_5_scales_custom/scale1.ckpt"
3 | - "models/cin_5_scales_custom/scale2.ckpt"
4 | - "models/cin_5_scales_custom/scale3.ckpt"
5 | - "models/cin_5_scales_custom/scale4.ckpt"
6 | - "models/cin_5_scales_custom/scale5.ckpt"
7 |
8 | configs:
9 | - "configs/imagenet/5_scales/imagenet-custom-scale1.yaml"
10 | - "configs/imagenet/5_scales/imagenet-custom-scale2.yaml"
11 | - "configs/imagenet/5_scales/imagenet-custom-scale3.yaml"
12 | - "configs/imagenet/5_scales/imagenet-custom-scale4.yaml"
13 | - "configs/imagenet/5_scales/imagenet-custom-scale5.yaml"
14 | >
15 |
--------------------------------------------------------------------------------
/configs/sampling/lsun/beds_3_scales.yaml:
--------------------------------------------------------------------------------
1 | checkpoints:
2 | - "models/bedrooms_3_scales/scale1.ckpt"
3 | - "models/bedrooms_3_scales/scale2.ckpt"
4 | - "models/bedrooms_3_scales/scale3.ckpt"
5 |
6 | configs:
7 | - "configs/lsun-beds/lsun-beds-scale1.yaml"
8 | - "configs/lsun-beds/lsun-beds-scale2.yaml"
9 | - "configs/lsun-beds/lsun-beds-scale3.yaml"
10 |
--------------------------------------------------------------------------------
/configs/sampling/lsun/cats_3_scales.yaml:
--------------------------------------------------------------------------------
1 | checkpoints:
2 | - "models/cats_3_scales/scale1.ckpt"
3 | - "models/cats_3_scales/scale2.ckpt"
4 | - "models/cats_3_scales/scale3.ckpt"
5 |
6 | configs:
7 | - "configs/lsun-cats/lsun-cats-scale1.yaml"
8 | - "configs/lsun-cats/lsun-cats-scale2.yaml"
9 | - "configs/lsun-cats/lsun-cats-scale3.yaml"
10 |
--------------------------------------------------------------------------------
/configs/sampling/lsun/churches_3_scales.yaml:
--------------------------------------------------------------------------------
1 | checkpoints:
2 | - "models/churches_3_scales/scale1.ckpt"
3 | - "models/churches_3_scales/scale2.ckpt"
4 | - "models/churches_3_scales/scale3.ckpt"
5 |
6 | configs:
7 | - "configs/lsun-churches/lsun-churches-scale1.yaml"
8 | - "configs/lsun-churches/lsun-churches-scale2.yaml"
9 | - "configs/lsun-churches/lsun-churches-scale3.yaml"
10 |
--------------------------------------------------------------------------------
/data/DejaVuSans.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/DejaVuSans.ttf
--------------------------------------------------------------------------------
/data/ffhq_schedule_vs_metric.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/ffhq_schedule_vs_metric.p
--------------------------------------------------------------------------------
/data/in_schedule_vs_metric.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/in_schedule_vs_metric.p
--------------------------------------------------------------------------------
/data/vqgan_indices/bedroom_indices.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/vqgan_indices/bedroom_indices.npy
--------------------------------------------------------------------------------
/data/vqgan_indices/cat_indices.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/vqgan_indices/cat_indices.npy
--------------------------------------------------------------------------------
/data/vqgan_indices/church_indices.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/vqgan_indices/church_indices.npy
--------------------------------------------------------------------------------
/data/vqgan_indices/ffhq_indices.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/vqgan_indices/ffhq_indices.npy
--------------------------------------------------------------------------------
/data/vqgan_indices/imagenet_indices.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/data/vqgan_indices/imagenet_indices.npy
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: imagebart
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.0
9 | - pytorch=1.7.0
10 | - torchvision=0.8.1
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - opencv-python==4.1.2.30
15 | - pudb==2019.2
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - pytorch-lightning==1.4.2
19 | - omegaconf==2.1.1
20 | - test-tube>=0.7.5
21 | - streamlit>=0.73.1
22 | - einops==0.3.0
23 | - transformers==4.3.1
24 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
25 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
26 | - -e .
27 |
--------------------------------------------------------------------------------
/imagebart/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/__init__.py
--------------------------------------------------------------------------------
/imagebart/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/data/__init__.py
--------------------------------------------------------------------------------
/imagebart/data/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
5 | import albumentations
6 | import bisect
7 | from abc import abstractmethod
8 |
9 | class Txt2ImgIterableBaseDataset(IterableDataset):
10 | '''
11 | Define an interface to make the IterableDatasets for text2img data chainable
12 | '''
13 | def __init__(self, num_records=0, valid_ids=None, size=256):
14 | super().__init__()
15 | self.num_records = num_records
16 | self.valid_ids = valid_ids
17 | self.sample_ids = valid_ids
18 | self.size = size
19 |
20 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
21 |
22 | def __len__(self):
23 | return self.num_records
24 |
25 | @abstractmethod
26 | def __iter__(self):
27 | pass
--------------------------------------------------------------------------------
/imagebart/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/imagebart/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # https://pytorch-lightning-bolts.readthedocs.io/en/0.2.1/api/pl_bolts.optimizers.lr_scheduler.html
2 |
3 | import math
4 | import warnings
5 | from typing import List
6 |
7 | import torch.nn as nn
8 | from torch.optim import Optimizer, Adam
9 | from torch.optim.lr_scheduler import _LRScheduler
10 | from torch._six import inf
11 | import numpy as np
12 |
13 |
14 | class LinearWarmupCosineAnnealingLR(_LRScheduler):
15 | """
16 | Sets the learning rate of each parameter group to follow a linear warmup schedule
17 | between warmup_start_lr and base_lr followed by a cosine annealing schedule between
18 | base_lr and eta_min.
19 | .. warning::
20 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
21 | after each iteration as calling it after each epoch will keep the starting lr at
22 | warmup_start_lr for the first epoch which is 0 in most cases.
23 | .. warning::
24 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
25 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
26 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
27 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
28 | train and validation methods.
29 | Args:
30 | optimizer (Optimizer): Wrapped optimizer.
31 | warmup_epochs (int): Maximum number of iterations for linear warmup
32 | max_epochs (int): Maximum number of iterations
33 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
34 | eta_min (float): Minimum learning rate. Default: 0.
35 | last_epoch (int): The index of last epoch. Default: -1.
36 | Example:
37 | >>> layer = nn.Linear(10, 1)
38 | >>> optimizer = Adam(layer.parameters(), lr=0.02)
39 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
40 | >>> #
41 | >>> # the default case
42 | >>> for epoch in range(40):
43 | ... # train(...)
44 | ... # validate(...)
45 | ... scheduler.step()
46 | >>> #
47 | >>> # passing epoch param case
48 | >>> for epoch in range(40):
49 | ... scheduler.step(epoch)
50 | ... # train(...)
51 | ... # validate(...)
52 | """
53 |
54 | def __init__(
55 | self,
56 | optimizer: Optimizer,
57 | warmup_epochs: int,
58 | max_epochs: int,
59 | warmup_start_lr: float = 0.0,
60 | eta_min: float = 0.0,
61 | last_epoch: int = -1,
62 | ) -> None:
63 |
64 | self.warmup_epochs = warmup_epochs
65 | self.max_epochs = max_epochs
66 | self.warmup_start_lr = warmup_start_lr
67 | self.eta_min = eta_min
68 |
69 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
70 |
71 | def get_lr(self) -> List[float]:
72 | """
73 | Compute learning rate using chainable form of the scheduler
74 | """
75 | if not self._get_lr_called_within_step:
76 | warnings.warn(
77 | "To get the last learning rate computed by the scheduler, "
78 | "please use `get_last_lr()`.",
79 | UserWarning,
80 | )
81 |
82 | if self.last_epoch == 0:
83 | return [self.warmup_start_lr] * len(self.base_lrs)
84 | elif self.last_epoch < self.warmup_epochs:
85 | return [
86 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
87 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
88 | ]
89 | elif self.last_epoch == self.warmup_epochs:
90 | return self.base_lrs
91 | elif (self.last_epoch - 1 - self.max_epochs) % (
92 | 2 * (self.max_epochs - self.warmup_epochs)
93 | ) == 0:
94 | return [
95 | group["lr"] + (base_lr - self.eta_min) * (
96 | 1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))
97 | ) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
98 | ]
99 |
100 | return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) /
101 | (self.max_epochs - self.warmup_epochs))) /
102 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) /
103 | (self.max_epochs - self.warmup_epochs))) *
104 | (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups
105 | ]
106 |
107 | def _get_closed_form_lr(self) -> List[float]:
108 | """
109 | Called when epoch is passed as a param to the `step` function of the scheduler.
110 | """
111 | if self.last_epoch < self.warmup_epochs:
112 | return [
113 | self.warmup_start_lr + self.last_epoch * (
114 | base_lr - self.warmup_start_lr
115 | ) / (self.warmup_epochs - 1) for base_lr in self.base_lrs
116 | ]
117 |
118 | return [
119 | self.eta_min + 0.5 * (base_lr - self.eta_min) * (
120 | 1 + math.cos(
121 | math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)
122 | )
123 | ) for base_lr in self.base_lrs
124 | ]
125 |
126 |
127 | class ReduceLROnLossPlateau:
128 |
129 | def __init__(self, lr_init, mode='min', factor=0.1, patience=10,
130 | threshold=1e-4, threshold_mode='rel', cooldown=0,
131 | min_lr=0, eps=1e-8, verbose=False):
132 | if factor >= 1.0:
133 | raise ValueError('Factor should be < 1.0.')
134 | self.factor = factor
135 | # Attach optimize
136 |
137 | self.current_factor = 1.
138 | self.min_lr = min_lr
139 | self.current_lr = lr_init
140 |
141 | self.patience = patience
142 | self.verbose = verbose
143 | self.cooldown = cooldown
144 | self.cooldown_counter = 0
145 | self.mode = mode
146 | self.threshold = threshold
147 | self.threshold_mode = threshold_mode
148 | self.best = None
149 | self.num_bad_epochs = None
150 | self.mode_worse = None # the worse value for the chosen mode
151 | self.eps = eps
152 | self.last_it = 0
153 | self._init_is_better(mode=mode, threshold=threshold,
154 | threshold_mode=threshold_mode)
155 |
156 | self._reset()
157 |
158 | def set_factor(self, f):
159 | self.current_factor = f
160 |
161 | def _reset(self):
162 | """Resets num_bad_epochs counter and cooldown counter."""
163 | self.best = self.mode_worse
164 | self.cooldown_counter = 0
165 | self.num_bad_it = 0
166 |
167 | def schedule(self, metrics):
168 | # convert `metrics` to float, in case it's a zero-dim Tensor
169 | current = float(metrics)
170 |
171 | # it = self.last_it + 1
172 |
173 | # self.it = it
174 |
175 | if self.is_better(current, self.best):
176 | self.best = current
177 | self.num_bad_it = 0
178 | else:
179 | self.num_bad_it += 1
180 |
181 | if self.in_cooldown:
182 | self.cooldown_counter -= 1
183 | self.num_bad_it = 0 # ignore any bad epochs in cooldown
184 |
185 | if self.num_bad_it > self.patience and self.current_lr >= self.min_lr:
186 | # reduce lr by factor
187 | new_f = self.current_factor * self.factor
188 | self.set_factor(new_f)
189 | self.current_lr = self.current_lr * self.current_factor
190 |
191 | self.set_factor(new_f)
192 | self.cooldown_counter = self.cooldown
193 | self.num_bad_it = 0
194 |
195 | # self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
196 | return self.current_factor
197 |
198 | @property
199 | def in_cooldown(self):
200 | return self.cooldown_counter > 0
201 |
202 | def is_better(self, a, best):
203 | if self.mode == 'min' and self.threshold_mode == 'rel':
204 | rel_epsilon = 1. - self.threshold
205 | return a < best * rel_epsilon
206 |
207 | elif self.mode == 'min' and self.threshold_mode == 'abs':
208 | return a < best - self.threshold
209 |
210 | elif self.mode == 'max' and self.threshold_mode == 'rel':
211 | rel_epsilon = self.threshold + 1.
212 | return a > best * rel_epsilon
213 |
214 | else: # mode == 'max' and epsilon_mode == 'abs':
215 | return a > best + self.threshold
216 |
217 | def _init_is_better(self, mode, threshold, threshold_mode):
218 | if mode not in {'min', 'max'}:
219 | raise ValueError('mode ' + mode + ' is unknown!')
220 | if threshold_mode not in {'rel', 'abs'}:
221 | raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
222 |
223 | if mode == 'min':
224 | self.mode_worse = inf
225 | else: # mode == 'max':
226 | self.mode_worse = -inf
227 |
228 | self.mode = mode
229 | self.threshold = threshold
230 | self.threshold_mode = threshold_mode
231 |
232 | def __call__(self, metrics):
233 | return self.schedule(metrics)
234 |
235 |
236 | class LambdaWarmUpCosineScheduler:
237 | """
238 | note: use with a base_lr of 1.0
239 | """
240 |
241 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
242 | self.lr_warm_up_steps = warm_up_steps
243 | self.lr_start = lr_start
244 | self.lr_min = lr_min
245 | self.lr_max = lr_max
246 | self.lr_max_decay_steps = max_decay_steps
247 | self.last_lr = 0.
248 | self.verbosity_interval = verbosity_interval
249 |
250 | def schedule(self, n, **kwargs):
251 | if self.verbosity_interval > 0:
252 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
253 | if n < self.lr_warm_up_steps:
254 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
255 | self.last_lr = lr
256 | return lr
257 | else:
258 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
259 | t = min(t, 1.0)
260 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
261 | 1 + np.cos(t * np.pi))
262 | self.last_lr = lr
263 | return lr
264 |
265 | def __call__(self, n, **kwargs):
266 | return self.schedule(n, **kwargs)
267 |
268 |
269 | class LambdaWarmUpCosineScheduler2:
270 | """
271 | supports repeated iterations, configurable via lists
272 | note: use with a base_lr of 1.0.
273 | """
274 |
275 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
276 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
277 | self.lr_warm_up_steps = warm_up_steps
278 | self.f_start = f_start
279 | self.f_min = f_min
280 | self.f_max = f_max
281 | self.cycle_lengths = cycle_lengths
282 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
283 | self.last_f = 0.
284 | self.verbosity_interval = verbosity_interval
285 |
286 | def find_in_interval(self, n):
287 | interval = 0
288 | for cl in self.cum_cycles[1:]:
289 | if n <= cl:
290 | return interval
291 | interval += 1
292 |
293 | def schedule(self, n, **kwargs):
294 | cycle = self.find_in_interval(n)
295 | n = n - self.cum_cycles[cycle]
296 | if self.verbosity_interval > 0:
297 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
298 | f"current cycle {cycle}")
299 | if n < self.lr_warm_up_steps[cycle]:
300 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
301 | self.last_f = f
302 | return f
303 | else:
304 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
305 | t = min(t, 1.0)
306 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
307 | 1 + np.cos(t * np.pi))
308 | self.last_f = f
309 | return f
310 |
311 | def __call__(self, n, **kwargs):
312 | return self.schedule(n, **kwargs)
313 |
314 |
315 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
316 |
317 | def schedule(self, n, **kwargs):
318 | cycle = self.find_in_interval(n)
319 | n = n - self.cum_cycles[cycle]
320 | if self.verbosity_interval > 0:
321 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
322 | f"current cycle {cycle}")
323 |
324 | if n < self.lr_warm_up_steps[cycle]:
325 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
326 | self.last_f = f
327 | return f
328 | else:
329 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (
330 | self.cycle_lengths[cycle])
331 | self.last_f = f
332 | return f
333 |
334 |
335 | if __name__ == "__main__":
336 | from tqdm import trange
337 | import matplotlib.pyplot as plt
338 |
339 | warm_up_steps = [1000, 500, 500]
340 | f_min = [1., 0., 0.]
341 | f_max = [10, 7.5, 5.]
342 | f_start = [0., 1., 0.]
343 | cycle_lengths = [10000, 5000, 4000]
344 | scheduler = LambdaWarmUpCosineScheduler2(warm_up_steps=warm_up_steps, f_min=f_min, f_max=f_max, f_start=f_start,
345 | cycle_lengths=cycle_lengths, verbosity_interval=100)
346 |
347 | schedule = []
348 | for n in trange(int(sum(cycle_lengths)), desc="Iter"):
349 | schedule.append(scheduler(n))
350 |
351 | plt.figure()
352 | plt.plot(schedule)
353 | plt.xlabel("global step")
354 | plt.savefig("scheduler_test.png")
355 | print("done.")
356 |
357 | """
358 | layer = nn.Linear(10, 1)
359 | optimizer = Adam(layer.parameters(), lr=0.02)
360 | scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
361 | #
362 | # the default case
363 | for epoch in range(40):
364 | # train(...)
365 | # validate(...)
366 | scheduler.step()
367 | #
368 | # passing epoch param case
369 | for epoch in range(40):
370 | scheduler.step(epoch)
371 | # train(...)
372 | # validate(...)
373 | """
374 |
--------------------------------------------------------------------------------
/imagebart/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/models/__init__.py
--------------------------------------------------------------------------------
/imagebart/models/vqgan.py:
--------------------------------------------------------------------------------
1 | from taming.models.vqgan import VQModel
2 |
3 |
4 | class VQGANWrapper(VQModel):
5 |
6 | def __init__(self,embed_dim,*args,**kwargs):
7 | super().__init__(embed_dim=embed_dim,*args,**kwargs)
8 | self.embed_dim = embed_dim
9 |
10 | def encode_to_prequant(self, x):
11 | h = self.encoder(x)
12 | h = self.quant_conv(h)
13 | return h
--------------------------------------------------------------------------------
/imagebart/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/modules/__init__.py
--------------------------------------------------------------------------------
/imagebart/modules/betas.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 |
4 |
5 | def find_roots(x,y):
6 | # https://stackoverflow.com/questions/46909373/how-to-find-the-exact-intersection-of-a-curve-as-np-array-with-y-0/46911822#46911822
7 | s = np.abs(np.diff(np.sign(y))).astype(bool)
8 | return x[:-1][s] + np.diff(x)[s]/(np.abs(y[1:][s]/y[:-1][s])+1)
9 |
10 |
11 | def load_betas(path, dist, metric, n_steps):
12 | with open(path, "rb") as f:
13 | data = pickle.load(f)
14 | alphacum = np.array(data[dist]["alphacum"])
15 | values = np.array(data[dist][metric])[:,0]
16 |
17 | equi_alphacum = list()
18 | for val in np.linspace(values.min(), values.max(), n_steps):
19 | equi_alphacum.append(find_roots(alphacum, values-val)[0])
20 | equi_alphacum = np.array(equi_alphacum)
21 | beta_1toT = 1-equi_alphacum[1:]/equi_alphacum[:-1]
22 | beta_0toT = np.concatenate((np.array([0.0]), beta_1toT))
23 | return beta_0toT.astype(np.float32)
24 |
25 |
26 | if __name__ == "__main__":
27 | import sys
28 | betas = load_betas(sys.argv[1], "bernoulli", "FID", int(sys.argv[2]))
29 | print(betas)
30 | print(np.cumprod(1-betas, 0))
31 |
--------------------------------------------------------------------------------
/imagebart/modules/ema.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class LitEma(nn.Module):
7 | def __init__(self, model, decay=0.9999, use_num_upates=True):
8 | super().__init__()
9 | if decay < 0.0 or decay > 1.0:
10 | raise ValueError('Decay must be between 0 and 1')
11 |
12 | self.m_name2s_name = {}
13 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
14 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
15 | else torch.tensor(-1,dtype=torch.int))
16 |
17 | for name, p in model.named_parameters():
18 | if p.requires_grad:
19 | #remove as '.'-character is not allowed in buffers
20 | s_name = name.replace('.','')
21 | self.m_name2s_name.update({name:s_name})
22 | self.register_buffer(s_name,p.clone().detach().data)
23 |
24 | self.collected_params = []
25 |
26 | def forward(self,model):
27 | decay = self.decay
28 |
29 | if self.num_updates >= 0:
30 | self.num_updates += 1
31 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
32 |
33 | one_minus_decay = 1.0 - decay
34 |
35 | with torch.no_grad():
36 | m_param = dict(model.named_parameters())
37 | shadow_params = dict(self.named_buffers())
38 |
39 | for key in m_param:
40 | if m_param[key].requires_grad:
41 | sname = self.m_name2s_name[key]
42 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
43 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
44 | else:
45 | assert not key in self.m_name2s_name
46 |
47 | def copy_to(self, model):
48 | m_param = dict(model.named_parameters())
49 | shadow_params = dict(self.named_buffers())
50 | for key in m_param:
51 | if m_param[key].requires_grad:
52 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
53 | else:
54 | assert not key in self.m_name2s_name
55 |
56 | def store(self, parameters):
57 | """
58 | Save the current parameters for restoring later.
59 | Args:
60 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
61 | temporarily stored.
62 | """
63 | self.collected_params = [param.clone() for param in parameters]
64 |
65 | def restore(self, parameters):
66 | """
67 | Restore the parameters stored with the `store` method.
68 | Useful to validate the model with EMA parameters without affecting the
69 | original optimization process. Store the parameters before the
70 | `copy_to` method. After validation (or model saving), use this to
71 | restore the former parameters.
72 | Args:
73 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
74 | updated with the stored parameters.
75 | """
76 | for c_param, param in zip(self.collected_params, parameters):
77 | param.data.copy_(c_param.data)
78 |
--------------------------------------------------------------------------------
/imagebart/modules/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/modules/transformer/__init__.py
--------------------------------------------------------------------------------
/imagebart/modules/transformer/mingpt.py:
--------------------------------------------------------------------------------
1 | """
2 | source: https://github.com/karpathy/minGPT/
3 | GPT model:
4 | - the initial stem consists of a combination of token encoding and a positional encoding
5 | - the meat of it is a uniform sequence of Transformer blocks
6 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
7 | - all blocks feed into a central residual pathway similar to resnets
8 | - the final decoder is a linear projection into a vanilla Softmax classifier
9 | """
10 |
11 | import math
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | from torch.nn import functional as F
16 | from transformers import top_k_top_p_filtering
17 |
18 | from imagebart.modules.xtransformers.positional_embeddings import apply_rotary_pos_emb
19 |
20 |
21 | def layers_from_width(d, a=5.039, b=5.55e-2):
22 | """as in https://arxiv.org/pdf/2006.12467.pdf"""
23 | return (math.log(d)-a)/b
24 |
25 |
26 | class GPTConfig:
27 | """ base GPT config, params common to all GPT versions """
28 | embd_pdrop = 0.1
29 | resid_pdrop = 0.1
30 | attn_pdrop = 0.1
31 |
32 | def __init__(self, vocab_size, block_size, **kwargs):
33 | self.vocab_size = vocab_size
34 | self.block_size = block_size
35 | for k,v in kwargs.items():
36 | setattr(self, k, v)
37 |
38 |
39 | class CausalSelfAttention(nn.Module):
40 | """
41 | A vanilla multi-head masked self-attention layer with a projection at the end.
42 | It is possible to use torch.nn.MultiheadAttention here but I am including an
43 | explicit implementation here to show that there is nothing too scary here.
44 | """
45 |
46 | def __init__(self, config):
47 | super().__init__()
48 | assert config.n_embd % config.n_head == 0, f"n_embd is {config.n_embd} but n_head is {config.n_head}."
49 | # key, query, value projections for all heads
50 | self.key = nn.Linear(config.n_embd, config.n_embd)
51 | self.query = nn.Linear(config.n_embd, config.n_embd)
52 | self.value = nn.Linear(config.n_embd, config.n_embd)
53 | # regularization
54 | self.attn_drop = nn.Dropout(config.attn_pdrop)
55 | self.resid_drop = nn.Dropout(config.resid_pdrop)
56 | # output projection
57 | self.proj = nn.Linear(config.n_embd, config.n_embd)
58 | # causal mask to ensure that attention is only applied to the left in the input sequence
59 | mask = torch.tril(torch.ones(config.block_size,
60 | config.block_size))
61 | if hasattr(config, "n_unmasked"):
62 | mask[:config.n_unmasked, :config.n_unmasked] = 1
63 | self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
64 | self.n_head = config.n_head
65 |
66 | def forward(self, x, layer_past=None, rotary_pos_emb=None):
67 | B, T, C = x.size()
68 |
69 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
70 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
71 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
72 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
73 |
74 | if rotary_pos_emb is not None:
75 | l = rotary_pos_emb.shape[-1]
76 | (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
77 | ql, kl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl))
78 | q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
79 |
80 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
81 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
82 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
83 | att = F.softmax(att, dim=-1)
84 | att = self.attn_drop(att)
85 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
86 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
87 |
88 | # output projection
89 | y = self.resid_drop(self.proj(y))
90 | return y
91 |
92 | def forward_with_past(self, x, layer_past=None, rotary_pos_emb=None):
93 | assert rotary_pos_emb is None, 'just for debugging'
94 | B, T, C = x.size()
95 |
96 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
97 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
98 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
99 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
100 |
101 | if rotary_pos_emb is not None:
102 | l = rotary_pos_emb.shape[-1]
103 | (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
104 | ql, kl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl))
105 | q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
106 |
107 | present = torch.stack((k, v))
108 | #present = torch.stack((k.clone(), v.clone())) # (2, B, nh, 1, hs)
109 |
110 | if layer_past is not None:
111 | past_key, past_value = layer_past
112 | k = torch.cat((past_key, k), dim=-2)
113 | v = torch.cat((past_value, v), dim=-2)
114 |
115 | # causal self-attention; Self-attend: (B, nh, Tq, hs) x (B, nh, hs, Tk) -> (B, nh, Tq, Tk)
116 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
117 | if layer_past is None:
118 | pass
119 |
120 | att = F.softmax(att, dim=-1)
121 | att = self.attn_drop(att)
122 | y = att @ v # (B, nh, Tq, Tk) x (B, nh, Tk, hs) -> (B, nh, Tq, hs)
123 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
124 |
125 | # output projection
126 | y = self.resid_drop(self.proj(y))
127 | return y, present
128 |
129 |
130 | class CausalCrossAttention(nn.Module):
131 | """
132 | A vanilla multi-head masked self-attention layer with a projection at the end.
133 | It is possible to use torch.nn.MultiheadAttention here but I am including an
134 | explicit implementation here to show that there is nothing too scary here.
135 | """
136 |
137 | def __init__(self, config):
138 | super().__init__()
139 | assert config.n_embd % config.n_head == 0, f"n_embd is {config.n_embd} but n_head is {config.n_head}."
140 | # key, query, value projections for all heads
141 | self.key = nn.Linear(config.n_embd, config.n_embd)
142 | self.query = nn.Linear(config.n_embd, config.n_embd)
143 | self.value = nn.Linear(config.n_embd, config.n_embd)
144 | # regularization
145 | self.attn_drop = nn.Dropout(config.attn_pdrop)
146 | self.resid_drop = nn.Dropout(config.resid_pdrop)
147 | # output projection
148 | self.proj = nn.Linear(config.n_embd, config.n_embd)
149 | # causal mask to ensure that attention is only applied to the left in the input sequence
150 | block_size = config.block_size
151 | cond_length = config.n_unmasked
152 | data_length = block_size-cond_length+1
153 | mask = np.zeros((data_length, block_size), dtype=np.float32)
154 | mask[:,:cond_length]=1 # make conditioning visible
155 | submask=np.tril(np.ones((data_length-1,data_length-1), dtype=np.float32)) # causal submask
156 | mask[1:,cond_length:] = submask
157 | mask = torch.tensor(mask)
158 |
159 | self.register_buffer("mask", mask.view(1, 1, data_length, block_size))
160 | self.n_head = config.n_head
161 |
162 | def forward(self, x_q, x_kv, layer_past=None, rotary_pos_emb=None):
163 | B, T_q, C = x_q.size()
164 | _B, T_kv, _C = x_kv.size()
165 | assert B==_B and C==_C
166 |
167 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
168 | q = self.query(x_q).view(B, T_q, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
169 | k = self.key(x_kv).view(B, T_kv, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
170 | v = self.value(x_kv).view(B, T_kv, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
171 |
172 | if rotary_pos_emb is not None:
173 | l = rotary_pos_emb.shape[-1]
174 | (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
175 | ql, kl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl))
176 | q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
177 |
178 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
179 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
180 | att = att.masked_fill(self.mask[:,:,:T_q,:T_kv] == 0, float('-inf'))
181 | att = F.softmax(att, dim=-1)
182 | att = self.attn_drop(att)
183 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
184 | y = y.transpose(1, 2).contiguous().view(B, T_q, C) # re-assemble all head outputs side by side
185 |
186 | # output projection
187 | y = self.resid_drop(self.proj(y))
188 | return y
189 |
190 | def forward_with_past(self, x, layer_past=None, rotary_pos_emb=None):
191 | raise NotImplementedError(":(")
192 |
193 |
194 | class Block(nn.Module):
195 | """ an unassuming Transformer block """
196 | def __init__(self, config):
197 | super().__init__()
198 | self.ln1 = nn.LayerNorm(config.n_embd)
199 | self.ln2 = nn.LayerNorm(config.n_embd)
200 | self.attn = CausalSelfAttention(config)
201 | self.mlp = nn.Sequential(
202 | nn.Linear(config.n_embd, 4 * config.n_embd),
203 | nn.GELU(), # nice
204 | nn.Linear(4 * config.n_embd, config.n_embd),
205 | nn.Dropout(config.resid_pdrop),
206 | )
207 |
208 | def forward(self, x, rotary_pos_emb=None):
209 | x = x + self.attn(self.ln1(x), rotary_pos_emb=rotary_pos_emb)
210 | x = x + self.mlp(self.ln2(x))
211 | return x
212 |
213 | def forward_with_past(self, x, rotary_pos_emb=None, layer_past=None):
214 | assert rotary_pos_emb is None, 'just for debugging'
215 | attn, present = self.attn.forward_with_past(self.ln1(x), rotary_pos_emb=rotary_pos_emb, layer_past=layer_past)
216 | # layer past: tuple of length two with B, nh, T, hs
217 | x = x + attn
218 | x = x + self.mlp(self.ln2(x))
219 | return x, present
220 |
221 |
222 | class CrossBlock(nn.Module):
223 | """ an unassuming Transformer block """
224 | def __init__(self, config):
225 | super().__init__()
226 | self.lnq = nn.LayerNorm(config.n_embd)
227 | self.ln1 = nn.LayerNorm(config.n_embd)
228 | self.ln2 = nn.LayerNorm(config.n_embd)
229 | self.attn = CausalCrossAttention(config)
230 | self.mlp = nn.Sequential(
231 | nn.Linear(config.n_embd, 4 * config.n_embd),
232 | nn.GELU(), # nice
233 | nn.Linear(4 * config.n_embd, config.n_embd),
234 | nn.Dropout(config.resid_pdrop),
235 | )
236 |
237 | def forward(self, x_q, x_kv, rotary_pos_emb=None):
238 | x = x_q + self.attn(x_q=self.lnq(x_q),
239 | x_kv=self.ln1(x_kv),
240 | rotary_pos_emb=rotary_pos_emb)
241 | x = x + self.mlp(self.ln2(x))
242 | return x
243 |
244 | def forward_with_past(self, x, rotary_pos_emb=None, layer_past=None):
245 | raise NotImplementedError()
246 |
247 |
248 | class GPT(nn.Module):
249 | """ the full GPT language model, with a context size of block_size """
250 | def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
251 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0,
252 | input_vocab_size=None, autoset_layers=False):
253 | super().__init__()
254 | recommended_layers = int(np.around(layers_from_width(n_embd)))
255 | if autoset_layers:
256 | n_layer = recommended_layers
257 | print(f"Training with a width of n_embed = {n_embd} and L = {n_layer} layers. "
258 | f"https://arxiv.org/pdf/2006.12467.pdf suggest that one should use {recommended_layers} layers.")
259 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
260 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
261 | n_layer=n_layer, n_head=n_head, n_embd=n_embd,
262 | n_unmasked=n_unmasked)
263 |
264 | # input embedding stem
265 | in_vocab_size = vocab_size if not input_vocab_size else input_vocab_size
266 | self.tok_emb = nn.Embedding(in_vocab_size, config.n_embd)
267 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
268 | self.drop = nn.Dropout(config.embd_pdrop)
269 | # transformer
270 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
271 | # decoder head
272 | self.ln_f = nn.LayerNorm(config.n_embd)
273 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
274 | self.block_size = config.block_size
275 | self.apply(self._init_weights)
276 | self.config = config
277 |
278 | def get_block_size(self):
279 | return self.block_size
280 |
281 | def _init_weights(self, module):
282 | if isinstance(module, (nn.Linear, nn.Embedding)):
283 | module.weight.data.normal_(mean=0.0, std=0.02)
284 | if isinstance(module, nn.Linear) and module.bias is not None:
285 | module.bias.data.zero_()
286 | elif isinstance(module, nn.LayerNorm):
287 | module.bias.data.zero_()
288 | module.weight.data.fill_(1.0)
289 |
290 | def forward(self, idx, embeddings=None, targets=None, return_layers=False, token_embeddings=None):
291 | # forward the GPT model
292 | if token_embeddings is None:
293 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
294 | if embeddings is not None: # prepend explicit embeddings
295 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
296 |
297 | t = token_embeddings.shape[1]
298 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
299 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
300 | x = self.drop(token_embeddings + position_embeddings)
301 |
302 | if return_layers:
303 | layers = [x]
304 | for block in self.blocks:
305 | x = block(x)
306 | layers.append(x)
307 | return layers
308 |
309 | x = self.blocks(x)
310 | x = self.ln_f(x)
311 | logits = self.head(x)
312 |
313 | # if we are given some desired targets also calculate the loss
314 | loss = None
315 | if targets is not None:
316 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
317 | return logits, loss
318 |
319 | def forward_with_past(self, idx, embeddings=None, targets=None, token_embeddings=None,
320 | past=None, past_length=None):
321 |
322 | if token_embeddings is None:
323 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
324 | if embeddings is not None: # prepend explicit embeddings
325 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
326 | assert not self.training
327 | if past is not None:
328 | assert past_length is not None
329 | past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
330 | past_shape = list(past.shape)
331 | expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
332 | assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
333 | position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
334 | else:
335 | position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
336 |
337 | x = self.drop(token_embeddings + position_embeddings)
338 |
339 | presents = [] # accumulate over layers
340 | for i, block in enumerate(self.blocks):
341 | x, present = block.forward_with_past(x, layer_past=past[i, ...] if past is not None else None) # take from layer
342 | presents.append(present)
343 |
344 | x = self.ln_f(x)
345 | logits = self.head(x)
346 | # if we are given some desired targets also calculate the loss
347 | loss = None
348 | if targets is not None:
349 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
350 |
351 | return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
352 |
353 |
354 | def top_k_logits(logits, k):
355 | v, ix = torch.topk(logits, k)
356 | out = logits.clone()
357 | out[out < v[:, [-1]]] = -float('Inf')
358 | return out
359 |
360 | @torch.no_grad()
361 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
362 | """
363 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
364 | the sequence, feeding the predictions back into the model each time. Clearly the sampling
365 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window
366 | of block_size, unlike an RNN that has an infinite context window.
367 | """
368 | block_size = model.get_block_size()
369 | model.eval()
370 | for k in range(steps):
371 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
372 | logits, _ = model(x_cond)
373 | # pluck the logits at the final step and scale by temperature
374 | logits = logits[:, -1, :] / temperature
375 | # optionally crop probabilities to only the top k options
376 | if top_k is not None:
377 | logits = top_k_logits(logits, top_k)
378 | # apply softmax to convert to probabilities
379 | probs = F.softmax(logits, dim=-1)
380 | # sample from the distribution or take the most likely
381 | if sample:
382 | ix = torch.multinomial(probs, num_samples=1)
383 | else:
384 | _, ix = torch.topk(probs, k=1, dim=-1)
385 | # append to the sequence and continue
386 | x = torch.cat((x, ix), dim=1)
387 |
388 | return x
389 |
390 |
391 | @torch.no_grad()
392 | def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
393 | top_k=None, callback=None, guide=None, top_p=None,
394 | embeddings=None):
395 | # x is conditioning
396 | sample = x
397 | cond_len = x.shape[1]
398 | if embeddings is not None:
399 | cond_len += embeddings.shape[1]
400 | past = None
401 | for n in range(steps):
402 | if callback is not None:
403 | callback(n)
404 | logits, _, present = model.forward_with_past(x, embeddings=embeddings, past=past, past_length=(n+cond_len-1))
405 | embeddings = None # only pass in first time
406 | if past is None:
407 | past = [present]
408 | else:
409 | past.append(present)
410 | if guide is not None:
411 | logits = logits + guide[:, [n]]
412 | logits = logits[:, -1, :] / temperature
413 | if top_k is not None:
414 | if top_p is not None:
415 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
416 | else:
417 | logits = top_k_logits(logits, top_k)
418 |
419 | probs = F.softmax(logits, dim=-1)
420 | if not sample_logits:
421 | _, x = torch.topk(probs, k=1, dim=-1)
422 | else:
423 | x = torch.multinomial(probs, num_samples=1)
424 | # append to the sequence and continue
425 | sample = torch.cat((sample, x), dim=1)
426 | del past
427 | sample = sample[:, -steps:] # cut conditioning off
428 | return sample
429 |
430 |
431 | @torch.no_grad()
432 | def sample_vanilla(x, model, steps, temperature=1.0, sample_logits=False,
433 | top_k=None, embeddings=None):
434 | block_size = model.get_block_size()
435 | model.eval()
436 | for k in range(steps):
437 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
438 | logits, _ = model(x_cond, embeddings=embeddings)
439 | # pluck the logits at the final step and scale by temperature
440 | logits = logits[:, -1, :] / temperature
441 | # optionally crop probabilities to only the top k options
442 | if top_k is not None:
443 | logits = top_k_logits(logits, top_k)
444 | # apply softmax to convert to probabilities
445 | probs = F.softmax(logits, dim=-1)
446 | # sample from the distribution or take the most likely
447 | if sample_logits:
448 | ix = torch.multinomial(probs, num_samples=1)
449 | else:
450 | _, ix = torch.topk(probs, k=1, dim=-1)
451 | # append to the sequence and continue
452 | x = torch.cat((x, ix), dim=1)
453 | x = x[:, -steps:]
454 | return x
455 |
456 |
457 | def seed(seed=42):
458 | np.random.seed(seed)
459 | torch.manual_seed(seed)
460 | torch.cuda.manual_seed(seed)
461 | torch.cuda.manual_seed_all(seed)
462 | torch.backends.cudnn.deterministic = True
463 | torch.backends.cudnn.benchmark = False
464 |
465 |
466 | def test1():
467 | import time
468 | #torch.use_deterministic_algorithms(True)
469 | SEED = 142
470 | h = w = 4
471 | b = 1
472 | clen = 2
473 | cb = lambda n: print(n)
474 | cb = None
475 | elen = 3
476 |
477 | SAMPLE = True
478 | DEVICE = "cpu"
479 |
480 | # test past
481 | device = torch.device(DEVICE)
482 | model = GPT(vocab_size=1024, block_size=h*w+(elen+clen-1), n_embd=32).to(device).eval()
483 | x = torch.randint(0, 1024, size=(b, clen)).to(device) # start
484 | emb = torch.randn(b, elen, 32)
485 |
486 | print(f"in goes: {x}")
487 |
488 | # with past
489 | seed(SEED)
490 | t0 = time.time()
491 | s0 = sample_with_past(x, model, embeddings=emb, steps=h * w, sample_logits=SAMPLE, callback=cb)
492 | t1 = time.time()
493 |
494 | # without past
495 | seed(SEED)
496 | s1 = sample_vanilla(x, model, embeddings=emb, steps=h*w, sample_logits=SAMPLE)
497 | t2 = time.time()
498 |
499 | print(f"s0 (with past): time = {t1-t0:.2f}s")
500 | print(s0)
501 | print(f"s1 (no past): time = {t2-t1:.2f}s")
502 | print(s1)
503 | print("are equal:", torch.equal(s0, s1))
504 | print("done.")
505 |
506 |
507 | if __name__ == "__main__":
508 | test1()
509 |
--------------------------------------------------------------------------------
/imagebart/modules/transformer/vit.py:
--------------------------------------------------------------------------------
1 | # source: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_pytorch.py
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from einops import rearrange, repeat
6 | from torch import nn
7 |
8 | MIN_NUM_PATCHES = 16
9 |
10 |
11 | class Residual(nn.Module):
12 | def __init__(self, fn):
13 | super().__init__()
14 | self.fn = fn
15 | def forward(self, x, **kwargs):
16 | return self.fn(x, **kwargs) + x
17 |
18 |
19 | class PreNorm(nn.Module):
20 | def __init__(self, dim, fn):
21 | super().__init__()
22 | self.norm = nn.LayerNorm(dim)
23 | self.fn = fn
24 | def forward(self, x, **kwargs):
25 | return self.fn(self.norm(x), **kwargs)
26 |
27 |
28 | class FeedForward(nn.Module):
29 | def __init__(self, dim, hidden_dim, dropout = 0.):
30 | super().__init__()
31 | self.net = nn.Sequential(
32 | nn.Linear(dim, hidden_dim),
33 | nn.GELU(),
34 | nn.Dropout(dropout),
35 | nn.Linear(hidden_dim, dim),
36 | nn.Dropout(dropout)
37 | )
38 | def forward(self, x):
39 | return self.net(x)
40 |
41 |
42 | class Attention(nn.Module):
43 | def __init__(self, dim, heads = 8, dropout = 0.):
44 | super().__init__()
45 | self.heads = heads
46 | self.scale = dim ** -0.5
47 |
48 | self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
49 | self.to_out = nn.Sequential(
50 | nn.Linear(dim, dim),
51 | nn.Dropout(dropout)
52 | )
53 |
54 | def forward(self, x, mask = None):
55 | b, n, _, h = *x.shape, self.heads
56 | qkv = self.to_qkv(x).chunk(3, dim = -1)
57 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
58 |
59 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
60 | mask_value = -torch.finfo(dots.dtype).max
61 |
62 | if mask is not None:
63 | mask = F.pad(mask.flatten(1), (1, 0), value = True)
64 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
65 | mask = mask[:, None, :] * mask[:, :, None]
66 | dots.masked_fill_(~mask, mask_value)
67 | del mask
68 |
69 | attn = dots.softmax(dim=-1)
70 |
71 | out = torch.einsum('bhij,bhjd->bhid', attn, v)
72 | out = rearrange(out, 'b h n d -> b n (h d)')
73 | out = self.to_out(out)
74 | return out
75 |
76 |
77 | class Transformer(nn.Module):
78 | def __init__(self, dim, depth, heads, mlp_dim, dropout):
79 | super().__init__()
80 | self.layers = nn.ModuleList([])
81 | for _ in range(depth):
82 | self.layers.append(nn.ModuleList([
83 | Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
84 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
85 | ]))
86 | def forward(self, x, mask = None):
87 | for attn, ff in self.layers:
88 | x = attn(x, mask = mask)
89 | x = ff(x)
90 | return x
91 |
92 |
93 | class ViT(nn.Module):
94 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
95 | super().__init__()
96 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
97 | num_patches = (image_size // patch_size) ** 2
98 | patch_dim = channels * patch_size ** 2
99 | assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
100 |
101 | self.patch_size = patch_size
102 |
103 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
104 | self.patch_to_embedding = nn.Linear(patch_dim, dim)
105 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
106 | self.dropout = nn.Dropout(emb_dropout)
107 |
108 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
109 | self.to_cls_token = nn.Identity()
110 |
111 | self.mlp_head = nn.Sequential(
112 | nn.LayerNorm(dim),
113 | nn.Linear(dim, mlp_dim),
114 | nn.GELU(),
115 | nn.Dropout(dropout),
116 | nn.Linear(mlp_dim, num_classes)
117 | )
118 |
119 | def forward(self, img, mask = None):
120 | p = self.patch_size
121 |
122 | x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
123 | x = self.patch_to_embedding(x)
124 | b, n, _ = x.shape
125 |
126 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
127 | x = torch.cat((cls_tokens, x), dim=1)
128 | x += self.pos_embedding[:, :(n + 1)]
129 | x = self.dropout(x)
130 |
131 | x = self.transformer(x, mask)
132 |
133 | x = self.to_cls_token(x[:, 0])
134 | return self.mlp_head(x)
135 |
136 |
137 | if __name__ == "__main__":
138 | from torchsummary import summary
139 | v = ViT(
140 | image_size=256,
141 | patch_size=32,
142 | num_classes=1000,
143 | dim=1024,
144 | depth=6,
145 | heads=8,
146 | mlp_dim=2048,
147 | dropout=0.1,
148 | emb_dropout=0.1
149 | ).to("cuda")
150 |
151 | img = torch.randn(1, 3, 256, 256).to("cuda")
152 |
153 | summary(v, img.shape[1:])
154 | mask = torch.ones(1, 8, 8).bool().to("cuda") # optional mask, designating which patch to attend to
155 | preds = v(img, mask=mask) # (1, 1000)
156 |
--------------------------------------------------------------------------------
/imagebart/modules/transformer/warper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from einops import rearrange
5 |
6 | from braket.modules.warp.midas import Midas
7 | from braket.modules.transformer.util import Block as TransformerBlock
8 | from braket.modules.transformer.util import TinyTransformer
9 |
10 |
11 | def disabled_train(self, mode=True):
12 | """Overwrite model.train with this function to make sure train/eval mode
13 | does not change anymore."""
14 | return self
15 |
16 |
17 | class AbstractWarper(nn.Module):
18 | def __init__(self, *args, **kwargs):
19 | super().__init__()
20 | self._midas = Midas()
21 | self._midas.eval()
22 | self._midas.train = disabled_train
23 | for param in self._midas.parameters():
24 | param.requires_grad = False
25 |
26 | self.n_unmasked = kwargs["n_unmasked"] # length of conditioning
27 | self.n_embd = kwargs["n_embd"]
28 | self.block_size = kwargs["block_size"]
29 | self.size = kwargs["size"] # h, w tuple
30 | self.start_idx = kwargs.get("start_idx", 0) # hint to not modify parts
31 |
32 | self._use_cache = False
33 | self.new_emb = None # cache
34 | self.new_pos = None # cache
35 |
36 | def set_cache(self, value):
37 | self._use_cache = value
38 |
39 | def get_embeddings(self, token_embeddings, position_embeddings, warpkwargs):
40 | if self._use_cache:
41 | assert not self.training, "Do you really want to use caching during training?"
42 | assert self.new_emb is not None
43 | assert self.new_pos is not None
44 | return self.new_emb, self.new_pos
45 | self.new_emb, self.new_pos = self._get_embeddings(token_embeddings,
46 | position_embeddings,
47 | warpkwargs)
48 | return self.new_emb, self.new_pos
49 |
50 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs):
51 | raise NotImplementedError()
52 |
53 | def forward(self, token_embeddings, position_embeddings, warpkwargs):
54 | new_emb, new_pos = self.get_embeddings(token_embeddings,
55 | position_embeddings,
56 | warpkwargs)
57 |
58 | new_emb = torch.cat([new_emb, token_embeddings[:,self.n_unmasked:,:]],
59 | dim=1)
60 | b = new_pos.shape[0]
61 | new_pos = torch.cat([new_pos, position_embeddings[:,self.n_unmasked:,:][b*[0],...]],
62 | dim=1)
63 |
64 | return new_emb, new_pos
65 |
66 | def _to_sequence(self, x):
67 | x = rearrange(x, 'b c h w -> b (h w) c')
68 | return x
69 |
70 | def _to_imglike(self, x):
71 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.size[0])
72 | return x
73 |
74 |
75 | class AbstractWarperWithCustomEmbedding(AbstractWarper):
76 | def __init__(self, *args, **kwargs):
77 | super().__init__()
78 | self.pos_emb = nn.Parameter(torch.zeros(1, self.block_size, self.n_embd))
79 |
80 |
81 | class NoSourceWarper(AbstractWarper):
82 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs):
83 | cond_emb = token_embeddings[:,:self.n_unmasked,:]
84 | cond_pos = position_embeddings[:,:self.n_unmasked,:]
85 |
86 | b, seq_length, chn = cond_emb.shape
87 | cond_emb = self._to_imglike(cond_emb)
88 |
89 | cond_pos = self._to_imglike(cond_pos)
90 | cond_pos = cond_pos[b*[0],...]
91 |
92 | new_emb, _ = self._midas.warp_features(f=cond_emb, no_depth_grad=True,
93 | boltzmann_factor=0.0,
94 | **warpkwargs)
95 | new_pos, _ = self._midas.warp_features(f=cond_pos, no_depth_grad=True,
96 | boltzmann_factor=0.0,
97 | **warpkwargs)
98 | new_emb = self._filter_nans(new_emb)
99 | new_pos = self._filter_nans(new_pos)
100 |
101 | new_emb = self._to_sequence(new_emb)
102 | new_pos = self._to_sequence(new_pos)
103 | return new_emb, new_pos
104 |
105 | def _filter_nans(self, x):
106 | x[torch.isnan(x)] = 0.
107 | return x
108 |
109 |
110 | class ConvWarper(AbstractWarper):
111 | def __init__(self, *args, **kwargs):
112 | super().__init__(*args, **kwargs)
113 | self.emb_conv = nn.Conv2d(2*self.n_embd, self.n_embd,
114 | kernel_size=1,
115 | padding=0,
116 | bias=False)
117 | self.pos_conv = nn.Conv2d(2*self.n_embd, self.n_embd,
118 | kernel_size=1,
119 | padding=0,
120 | bias=False)
121 |
122 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs):
123 | cond_emb = token_embeddings[:,self.start_idx:self.n_unmasked,:]
124 | cond_pos = position_embeddings[:,self.start_idx:self.n_unmasked,:]
125 |
126 | b, seq_length, chn = cond_emb.shape
127 | cond_emb = cond_emb.reshape(b, self.size[0], self.size[1], chn)
128 | cond_emb = cond_emb.permute(0,3,1,2)
129 |
130 | cond_pos = cond_pos.reshape(1, self.size[0], self.size[1], chn)
131 | cond_pos = cond_pos.permute(0,3,1,2)
132 | cond_pos = cond_pos[b*[0],...]
133 |
134 | with torch.no_grad():
135 | warp_emb, _ = self._midas.warp_features(f=cond_emb, no_depth_grad=True, **warpkwargs)
136 | warp_pos, _ = self._midas.warp_features(f=cond_pos, no_depth_grad=True, **warpkwargs)
137 |
138 | new_emb = self.emb_conv(torch.cat([cond_emb, warp_emb], dim=1))
139 | new_pos = self.pos_conv(torch.cat([cond_pos, warp_pos], dim=1))
140 |
141 | new_emb = new_emb.permute(0,2,3,1)
142 | new_emb = new_emb.reshape(b,seq_length,chn)
143 |
144 | new_pos = new_pos.permute(0,2,3,1)
145 | new_pos = new_pos.reshape(b,seq_length,chn)
146 |
147 | # prepend unmodified ones again
148 | new_emb = torch.cat((token_embeddings[:,:self.start_idx,:], new_emb),
149 | dim=1)
150 | new_pos = torch.cat((position_embeddings[:,:self.start_idx,:][b*[0],...], new_pos),
151 | dim=1)
152 |
153 | return new_emb, new_pos
154 |
155 |
156 | class ParallelAttentionWarper(AbstractWarper):
157 | def __init__(self, *args, **kwargs):
158 | super().__init__(*args, **kwargs)
159 | self.pos_block = nn.ModuleList([TinyTransformer(block_size=self.block_size,
160 | n_layer=2,
161 | n_head=16,
162 | n_embd=self.n_embd,
163 | use_head=False) for _ in range(3)])
164 | self.emb_block = nn.ModuleList([TinyTransformer(block_size=self.block_size,
165 | n_layer=2,
166 | n_head=16,
167 | n_embd=self.n_embd,
168 | use_head=False) for _ in range(3)])
169 |
170 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs):
171 | cond_emb = token_embeddings[:,:self.n_unmasked,:]
172 | cond_pos = position_embeddings[:,:self.n_unmasked,:]
173 |
174 | b, seq_length, chn = cond_emb.shape
175 | cond_emb = self._to_imglike(cond_emb)
176 |
177 | cond_pos = self._to_imglike(cond_pos)
178 | cond_pos = cond_pos[b*[0],...]
179 |
180 | with torch.no_grad():
181 | warp_emb, _ = self._midas.warp_features(f=cond_emb, no_depth_grad=True, **warpkwargs)
182 | warp_pos, _ = self._midas.warp_features(f=cond_pos, no_depth_grad=True, **warpkwargs)
183 |
184 | warp_emb = self._to_sequence(warp_emb)
185 | warp_pos = self._to_sequence(warp_pos)
186 |
187 | cond_emb = self._to_sequence(cond_emb)
188 | cond_pos = self._to_sequence(cond_pos)
189 |
190 | cond_emb = self.emb_block[0](idx=None, token_embeddings=cond_emb) + cond_emb
191 | warp_emb = self.emb_block[1](idx=None, token_embeddings=warp_emb) + warp_emb
192 | new_emb = self.emb_block[2](idx=None, token_embeddings=warp_emb+cond_emb)
193 |
194 | cond_pos = self.pos_block[0](idx=None, token_embeddings=cond_pos) + cond_pos
195 | warp_pos = self.pos_block[1](idx=None, token_embeddings=warp_pos) + warp_pos
196 | new_pos = self.pos_block[2](idx=None, token_embeddings=warp_pos+cond_pos)
197 |
198 | return new_emb, new_pos
199 |
--------------------------------------------------------------------------------
/imagebart/modules/xtransformers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/imagebart/c11ea830ebf74dd0f9c1307462c9de24f536a9e7/imagebart/modules/xtransformers/__init__.py
--------------------------------------------------------------------------------
/imagebart/modules/xtransformers/autoregressive_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | try:
6 | from entmax import entmax_bisect
7 | except ModuleNotFoundError:
8 | print("'entmax' not installed. skipping.")
9 | entmax_bisect = None
10 |
11 |
12 | # nucleus
13 |
14 | def top_p(logits, thres=0.9):
15 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
16 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
17 |
18 | sorted_indices_to_remove = cum_probs > (1 - thres)
19 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
20 | sorted_indices_to_remove[:, 0] = 0
21 |
22 | sorted_logits[sorted_indices_to_remove] = float('-inf')
23 | return sorted_logits.scatter(1, sorted_indices, sorted_logits)
24 |
25 |
26 | # topk
27 |
28 | def top_k(logits, thres=0.9):
29 | k = int((1 - thres) * logits.shape[-1])
30 | val, ind = torch.topk(logits, k)
31 | probs = torch.full_like(logits, float('-inf'))
32 | probs.scatter_(1, ind, val)
33 | return probs
34 |
35 |
36 | # entmax
37 |
38 | ENTMAX_ALPHA = 1.3
39 | entmax = entmax_bisect
40 |
41 |
42 | class AutoregressiveWrapper(nn.Module):
43 | def __init__(self, net, ignore_index=-100, pad_value=0):
44 | super().__init__()
45 | self.pad_value = pad_value
46 | self.ignore_index = ignore_index
47 |
48 | self.net = net
49 | self.max_seq_len = net.max_seq_len
50 |
51 | @torch.no_grad()
52 | def generate(self, start_tokens, seq_len, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9,
53 | **kwargs):
54 | device = start_tokens.device
55 | was_training = self.net.training
56 | num_dims = len(start_tokens.shape)
57 |
58 | if num_dims == 1:
59 | start_tokens = start_tokens[None, :]
60 |
61 | b, t = start_tokens.shape
62 |
63 | self.net.eval()
64 | out = start_tokens
65 | mask = kwargs.pop('mask', None)
66 |
67 | if mask is None:
68 | mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
69 |
70 | for _ in range(seq_len):
71 | x = out[:, -self.max_seq_len:]
72 | mask = mask[:, -self.max_seq_len:]
73 |
74 | logits = self.net(x, mask=mask, **kwargs)[:, -1, :]
75 |
76 | if filter_logits_fn in {top_k, top_p}:
77 | filtered_logits = filter_logits_fn(logits, thres=filter_thres)
78 | probs = F.softmax(filtered_logits / temperature, dim=-1)
79 |
80 | elif filter_logits_fn is entmax:
81 | probs = entmax(logits / temperature, alpha=ENTMAX_ALPHA, dim=-1)
82 |
83 | sample = torch.multinomial(probs, 1)
84 |
85 | out = torch.cat((out, sample), dim=-1)
86 | mask = F.pad(mask, (0, 1), value=True)
87 |
88 | if eos_token is not None and (sample == eos_token).all():
89 | break
90 |
91 | out = out[:, t:]
92 |
93 | if num_dims == 1:
94 | out = out.squeeze(0)
95 |
96 | self.net.train(was_training)
97 | return out
98 |
99 | def forward(self, x, return_logits=True, **kwargs):
100 | xi = x[:, :-1]
101 | xo = x[:, 1:]
102 |
103 | # help auto-solve a frequent area of confusion around input masks in auto-regressive
104 | # if user supplies a mask that is only off by one from the source sequence, resolve it for them
105 | mask = kwargs.get('mask', None)
106 | if mask is not None and mask.shape[1] == x.shape[1]:
107 | mask = mask[:, :-1]
108 | kwargs['mask'] = mask
109 |
110 | out = self.net(xi, **kwargs)
111 | loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index=self.ignore_index)
112 | if return_logits:
113 | return out, loss
114 | return loss
115 |
--------------------------------------------------------------------------------
/imagebart/modules/xtransformers/positional_embeddings.py:
--------------------------------------------------------------------------------
1 | # https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | from einops import rearrange
7 | from inspect import isfunction
8 |
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 |
14 | def default(val, d):
15 | if exists(val):
16 | return val
17 | return d() if isfunction(d) else d
18 |
19 |
20 | # positional embeddings
21 | class DepthWiseConv1d(nn.Module):
22 | def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True, groups=False):
23 | super().__init__()
24 | groups = default(groups, dim_in)
25 | self.net = nn.Sequential(
26 | nn.Conv1d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride,
27 | bias=bias),
28 | nn.Conv1d(dim_in, dim_out, 1)
29 | )
30 |
31 | def forward(self, x):
32 | return self.net(x)
33 |
34 |
35 | class AbsolutePositionalEmbedding(nn.Module):
36 | def __init__(self, dim, max_seq_len):
37 | super().__init__()
38 | self.emb = nn.Embedding(max_seq_len, dim)
39 | self.init_()
40 |
41 | def init_(self):
42 | nn.init.normal_(self.emb.weight, std=0.02)
43 |
44 | def forward(self, x):
45 | n = torch.arange(x.shape[1], device=x.device)
46 | return self.emb(n)[None, :, :]
47 |
48 |
49 | class FixedPositionalEmbedding(nn.Module):
50 | def __init__(self, dim):
51 | super().__init__()
52 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
53 | self.register_buffer('inv_freq', inv_freq)
54 |
55 | def forward(self, x, seq_dim=1, offset=0):
56 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
57 | sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
58 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
59 | return emb[None, :, :]
60 |
61 |
62 | class RelativePositionBias(nn.Module):
63 | def __init__(self, causal=False, num_buckets=32, max_distance=128, heads=8):
64 | super().__init__()
65 | self.causal = causal
66 | self.num_buckets = num_buckets
67 | self.max_distance = max_distance
68 | self.relative_attention_bias = nn.Embedding(num_buckets, heads)
69 |
70 | @staticmethod
71 | def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
72 | ret = 0
73 | n = -relative_position
74 | if not causal:
75 | num_buckets //= 2
76 | ret += (n < 0).long() * num_buckets
77 | n = torch.abs(n)
78 | else:
79 | n = torch.max(n, torch.zeros_like(n))
80 |
81 | max_exact = num_buckets // 2
82 | is_small = n < max_exact
83 |
84 | val_if_large = max_exact + (
85 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
86 | ).long()
87 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
88 |
89 | ret += torch.where(is_small, n, val_if_large)
90 | return ret
91 |
92 | def forward(self, qk_dots):
93 | i, j, device = *qk_dots.shape[-2:], qk_dots.device
94 | q_pos = torch.arange(i, dtype=torch.long, device=device)
95 | k_pos = torch.arange(j, dtype=torch.long, device=device)
96 | rel_pos = k_pos[None, :] - q_pos[:, None]
97 | rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
98 | max_distance=self.max_distance)
99 | values = self.relative_attention_bias(rp_bucket)
100 | bias = rearrange(values, 'i j h -> () h i j')
101 | return qk_dots + bias
102 |
103 |
104 | class RotaryEmbedding(nn.Module):
105 | def __init__(self, dim):
106 | super().__init__()
107 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
108 | self.register_buffer('inv_freq', inv_freq)
109 |
110 | def forward(self, x, seq_dim=1):
111 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
112 | freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
113 | emb = torch.cat((freqs, freqs), dim=-1)
114 | return emb[None, :, :]
115 |
116 |
117 | def rotate_half(x):
118 | x = rearrange(x, '... (j d) -> ... j d', j=2)
119 | x1, x2 = x.unbind(dim=-2)
120 | return torch.cat((-x2, x1), dim=-1)
121 |
122 |
123 | def apply_rotary_pos_emb(t, freqs):
124 | seq_len = t.shape[-2]
125 | freqs = freqs[:, :, -seq_len:]
126 | return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
--------------------------------------------------------------------------------
/imagebart/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image, ImageDraw, ImageFont
4 |
5 |
6 | class ClassProvider(torch.nn.Module):
7 | def __init__(self, key="class"):
8 | super().__init__()
9 | self.key = key
10 |
11 | def forward(self, batch):
12 | c = batch[self.key][:, None]
13 | return c
14 |
15 |
16 | class BasicTokenizer(torch.nn.Module):
17 | """
18 | Uses the 'simple tokenizer' of the CLIP model
19 | https://github.com/openai/CLIP
20 | """
21 | def __init__(self, device="cuda", key="caption"):
22 | super().__init__()
23 | from clip import tokenize
24 | self.tknz_fn = tokenize
25 | self.device = device
26 | self.key = key
27 |
28 | def forward(self, batch):
29 | text = batch[self.key]
30 | tokens = self.tknz_fn(text).to(self.device)
31 | return tokens
32 |
33 |
34 | class KeyNotFoundError(Exception):
35 | def __init__(self, cause, keys=None, visited=None):
36 | self.cause = cause
37 | self.keys = keys
38 | self.visited = visited
39 | messages = list()
40 | if keys is not None:
41 | messages.append("Key not found: {}".format(keys))
42 | if visited is not None:
43 | messages.append("Visited: {}".format(visited))
44 | messages.append("Cause:\n{}".format(cause))
45 | message = "\n".join(messages)
46 | super().__init__(message)
47 |
48 |
49 | def log_txt_as_img(wh, xc, size=10):
50 | # wh a tuple of (width, height)
51 | # xc a list of captions to plot
52 | b = len(xc)
53 | txts = list()
54 | for bi in range(b):
55 | txt = Image.new("RGB", wh, color="white")
56 | draw = ImageDraw.Draw(txt)
57 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
58 | nc = int(40 * (wh[0]/256))
59 | lines = "\n".join(xc[bi][start:start+nc] for start in range(0, len(xc[bi]), nc))
60 |
61 | try:
62 | draw.text((0,0), lines, fill="black", font=font)
63 | except UnicodeEncodeError:
64 | print("Cant encode string for logging. Skipping.")
65 |
66 | txt = np.array(txt).transpose(2,0,1)/127.5-1.0
67 | txts.append(txt)
68 | txts = np.stack(txts)
69 | txts = torch.tensor(txts)
70 | return txts
71 |
72 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, datetime, glob, importlib, csv
2 | import numpy as np
3 | import time
4 | import torch
5 | import torchvision
6 | import pytorch_lightning as pl
7 |
8 | from packaging import version
9 | from omegaconf import OmegaConf
10 | from torch.utils.data import random_split, DataLoader, Dataset, Subset
11 | from functools import partial
12 | from PIL import Image
13 |
14 | from pytorch_lightning import seed_everything
15 | from pytorch_lightning.trainer import Trainer
16 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
17 | from pytorch_lightning.utilities.distributed import rank_zero_only
18 | from pytorch_lightning.utilities import rank_zero_info
19 |
20 | from imagebart.data.base import Txt2ImgIterableBaseDataset
21 |
22 |
23 | def get_obj_from_str(string, reload=False):
24 | module, cls = string.rsplit(".", 1)
25 | if reload:
26 | module_imp = importlib.import_module(module)
27 | importlib.reload(module_imp)
28 | return getattr(importlib.import_module(module, package=None), cls)
29 |
30 |
31 | def get_parser(**parser_kwargs):
32 | def str2bool(v):
33 | if isinstance(v, bool):
34 | return v
35 | if v.lower() in ("yes", "true", "t", "y", "1"):
36 | return True
37 | elif v.lower() in ("no", "false", "f", "n", "0"):
38 | return False
39 | else:
40 | raise argparse.ArgumentTypeError("Boolean value expected.")
41 |
42 | parser = argparse.ArgumentParser(**parser_kwargs)
43 | parser.add_argument(
44 | "-n",
45 | "--name",
46 | type=str,
47 | const=True,
48 | default="",
49 | nargs="?",
50 | help="postfix for logdir",
51 | )
52 | parser.add_argument(
53 | "-r",
54 | "--resume",
55 | type=str,
56 | const=True,
57 | default="",
58 | nargs="?",
59 | help="resume from logdir or checkpoint in logdir",
60 | )
61 | parser.add_argument(
62 | "-b",
63 | "--base",
64 | nargs="*",
65 | metavar="base_config.yaml",
66 | help="paths to base configs. Loaded from left-to-right. "
67 | "Parameters can be overwritten or added with command-line options of the form `--key value`.",
68 | default=list(),
69 | )
70 | parser.add_argument(
71 | "-t",
72 | "--train",
73 | type=str2bool,
74 | const=True,
75 | default=False,
76 | nargs="?",
77 | help="train",
78 | )
79 | parser.add_argument(
80 | "--no-test",
81 | type=str2bool,
82 | const=True,
83 | default=False,
84 | nargs="?",
85 | help="disable test",
86 | )
87 | parser.add_argument("-p", "--project", help="name of new or path to existing project")
88 | parser.add_argument(
89 | "-d",
90 | "--debug",
91 | type=str2bool,
92 | nargs="?",
93 | const=True,
94 | default=False,
95 | help="enable post-mortem debugging",
96 | )
97 | parser.add_argument(
98 | "-s",
99 | "--seed",
100 | type=int,
101 | default=23,
102 | help="seed for seed_everything",
103 | )
104 | parser.add_argument(
105 | "-f",
106 | "--postfix",
107 | type=str,
108 | default="",
109 | help="post-postfix for default name",
110 | )
111 | parser.add_argument(
112 | "-l",
113 | "--logdir",
114 | type=str,
115 | default="logs",
116 | help="directory for logging dat shit",
117 | )
118 | parser.add_argument(
119 | "--scale_lr",
120 | type=str2bool,
121 | nargs="?",
122 | const=True,
123 | default=True,
124 | help="scale base-lr by ngpu * batch_size * n_accumulate",
125 | )
126 | return parser
127 |
128 |
129 | def nondefault_trainer_args(opt):
130 | parser = argparse.ArgumentParser()
131 | parser = Trainer.add_argparse_args(parser)
132 | args = parser.parse_args([])
133 | return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
134 |
135 |
136 | def instantiate_from_config(config):
137 | if not "target" in config:
138 | if config == '__is_first_stage__':
139 | return None
140 | elif config == "__is_unconditional__":
141 | return None
142 | raise KeyError("Expected key `target` to instantiate.")
143 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
144 |
145 |
146 | class WrappedDataset(Dataset):
147 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
148 |
149 | def __init__(self, dataset):
150 | self.data = dataset
151 |
152 | def __len__(self):
153 | return len(self.data)
154 |
155 | def __getitem__(self, idx):
156 | return self.data[idx]
157 |
158 |
159 | def worker_init_fn(_):
160 | worker_info = torch.utils.data.get_worker_info()
161 |
162 | dataset = worker_info.dataset
163 | worker_id = worker_info.id
164 |
165 | if isinstance(dataset, Txt2ImgIterableBaseDataset):
166 | split_size = dataset.num_records // worker_info.num_workers
167 | # reset num_records to the true number to retain reliable length information
168 | dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
169 | current_id = np.random.choice(len(np.random.get_state()[1]), 1)
170 | return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
171 | else:
172 | return np.random.seed(np.random.get_state()[1][0] + worker_id)
173 |
174 |
175 | class DataModuleFromConfig(pl.LightningDataModule):
176 | def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
177 | wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
178 | shuffle_val_dataloader=False):
179 | super().__init__()
180 | self.batch_size = batch_size
181 | self.dataset_configs = dict()
182 | self.num_workers = num_workers if num_workers is not None else batch_size * 2
183 | self.use_worker_init_fn = use_worker_init_fn
184 | if train is not None:
185 | self.dataset_configs["train"] = train
186 | self.train_dataloader = self._train_dataloader
187 | if validation is not None:
188 | self.dataset_configs["validation"] = validation
189 | self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
190 | if test is not None:
191 | self.dataset_configs["test"] = test
192 | self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
193 | if predict is not None:
194 | self.dataset_configs["predict"] = predict
195 | self.predict_dataloader = self._predict_dataloader
196 | self.wrap = wrap
197 |
198 | def prepare_data(self):
199 | for data_cfg in self.dataset_configs.values():
200 | instantiate_from_config(data_cfg)
201 |
202 | def setup(self, stage=None):
203 | self.datasets = dict(
204 | (k, instantiate_from_config(self.dataset_configs[k]))
205 | for k in self.dataset_configs)
206 | if self.wrap:
207 | for k in self.datasets:
208 | self.datasets[k] = WrappedDataset(self.datasets[k])
209 |
210 | def _train_dataloader(self):
211 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
212 | if is_iterable_dataset or self.use_worker_init_fn:
213 | init_fn = worker_init_fn
214 | else:
215 | init_fn = None
216 | return DataLoader(self.datasets["train"], batch_size=self.batch_size,
217 | num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
218 | worker_init_fn=init_fn)
219 |
220 | def _val_dataloader(self, shuffle=False):
221 | if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
222 | init_fn = worker_init_fn
223 | else:
224 | init_fn = None
225 | return DataLoader(self.datasets["validation"],
226 | batch_size=self.batch_size,
227 | num_workers=self.num_workers,
228 | worker_init_fn=init_fn,
229 | shuffle=shuffle)
230 |
231 | def _test_dataloader(self, shuffle=False):
232 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
233 | if is_iterable_dataset or self.use_worker_init_fn:
234 | init_fn = worker_init_fn
235 | else:
236 | init_fn = None
237 |
238 | # do not shuffle dataloader for iterable dataset
239 | shuffle = shuffle and (not is_iterable_dataset)
240 |
241 | return DataLoader(self.datasets["test"], batch_size=self.batch_size,
242 | num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
243 |
244 | def _predict_dataloader(self, shuffle=False):
245 | if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
246 | init_fn = worker_init_fn
247 | else:
248 | init_fn = None
249 | return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
250 | num_workers=self.num_workers, worker_init_fn=init_fn)
251 |
252 |
253 | class SetupCallback(Callback):
254 | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
255 | super().__init__()
256 | self.resume = resume
257 | self.now = now
258 | self.logdir = logdir
259 | self.ckptdir = ckptdir
260 | self.cfgdir = cfgdir
261 | self.config = config
262 | self.lightning_config = lightning_config
263 |
264 | def on_keyboard_interrupt(self, trainer, pl_module):
265 | if trainer.global_rank == 0:
266 | print("Summoning checkpoint.")
267 | ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
268 | trainer.save_checkpoint(ckpt_path)
269 |
270 | def on_pretrain_routine_start(self, trainer, pl_module):
271 | if trainer.global_rank == 0:
272 | # Create logdirs and save configs
273 | os.makedirs(self.logdir, exist_ok=True)
274 | os.makedirs(self.ckptdir, exist_ok=True)
275 | os.makedirs(self.cfgdir, exist_ok=True)
276 |
277 | if "callbacks" in self.lightning_config:
278 | if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
279 | os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
280 | print("Project config")
281 | print(OmegaConf.to_yaml(self.config))
282 | OmegaConf.save(self.config,
283 | os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
284 |
285 | print("Lightning config")
286 | print(OmegaConf.to_yaml(self.lightning_config))
287 | OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
288 | os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
289 |
290 | else:
291 | # ModelCheckpoint callback created log directory --- remove it
292 | if not self.resume and os.path.exists(self.logdir):
293 | dst, name = os.path.split(self.logdir)
294 | dst = os.path.join(dst, "child_runs", name)
295 | os.makedirs(os.path.split(dst)[0], exist_ok=True)
296 | try:
297 | os.rename(self.logdir, dst)
298 | except FileNotFoundError:
299 | pass
300 |
301 |
302 | class ImageLogger(Callback):
303 | def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
304 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
305 | log_images_kwargs=None):
306 | super().__init__()
307 | self.rescale = rescale
308 | self.batch_freq = batch_frequency
309 | self.max_images = max_images
310 | self.logger_log_images = {
311 | pl.loggers.TestTubeLogger: self._testtube,
312 | }
313 | self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
314 | if not increase_log_steps:
315 | self.log_steps = [self.batch_freq]
316 | self.clamp = clamp
317 | self.disabled = disabled
318 | self.log_on_batch_idx = log_on_batch_idx
319 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
320 | self.log_first_step = log_first_step
321 |
322 | @rank_zero_only
323 | def _testtube(self, pl_module, images, batch_idx, split):
324 | for k in images:
325 | grid = torchvision.utils.make_grid(images[k])
326 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
327 |
328 | tag = f"{split}/{k}"
329 | pl_module.logger.experiment.add_image(
330 | tag, grid,
331 | global_step=pl_module.global_step)
332 |
333 | @rank_zero_only
334 | def log_local(self, save_dir, split, images,
335 | global_step, current_epoch, batch_idx):
336 | root = os.path.join(save_dir, "images", split)
337 | for k in images:
338 | grid = torchvision.utils.make_grid(images[k], nrow=4)
339 | if self.rescale:
340 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
341 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
342 | grid = grid.numpy()
343 | grid = (grid * 255).astype(np.uint8)
344 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
345 | k,
346 | global_step,
347 | current_epoch,
348 | batch_idx)
349 | path = os.path.join(root, filename)
350 | os.makedirs(os.path.split(path)[0], exist_ok=True)
351 | Image.fromarray(grid).save(path)
352 |
353 | def log_img(self, pl_module, batch, batch_idx, split="train"):
354 | check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
355 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
356 | hasattr(pl_module, "log_images") and
357 | callable(pl_module.log_images) and
358 | self.max_images > 0):
359 | logger = type(pl_module.logger)
360 |
361 | is_train = pl_module.training
362 | if is_train:
363 | pl_module.eval()
364 |
365 | with torch.no_grad():
366 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
367 |
368 | for k in images:
369 | N = min(images[k].shape[0], self.max_images)
370 | images[k] = images[k][:N]
371 | if isinstance(images[k], torch.Tensor):
372 | images[k] = images[k].detach().cpu()
373 | if self.clamp:
374 | images[k] = torch.clamp(images[k], -1., 1.)
375 |
376 | self.log_local(pl_module.logger.save_dir, split, images,
377 | pl_module.global_step, pl_module.current_epoch, batch_idx)
378 |
379 | logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
380 | logger_log_images(pl_module, images, pl_module.global_step, split)
381 |
382 | if is_train:
383 | pl_module.train()
384 |
385 | def check_frequency(self, check_idx):
386 | if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
387 | check_idx > 0 or self.log_first_step):
388 | try:
389 | self.log_steps.pop(0)
390 | except IndexError as e:
391 | print(e)
392 | pass
393 | return True
394 | return False
395 |
396 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
397 | if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
398 | self.log_img(pl_module, batch, batch_idx, split="train")
399 |
400 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
401 | if not self.disabled and pl_module.global_step > 0:
402 | self.log_img(pl_module, batch, batch_idx, split="val")
403 | if hasattr(pl_module, 'calibrate_grad_norm'):
404 | if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
405 | self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
406 |
407 |
408 | class CUDACallback(Callback):
409 | # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
410 | def on_train_epoch_start(self, trainer, pl_module):
411 | # Reset the memory use counter
412 | torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
413 | torch.cuda.synchronize(trainer.root_gpu)
414 | self.start_time = time.time()
415 |
416 | def on_train_epoch_end(self, trainer, pl_module, outputs):
417 | torch.cuda.synchronize(trainer.root_gpu)
418 | max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
419 | epoch_time = time.time() - self.start_time
420 |
421 | try:
422 | max_memory = trainer.training_type_plugin.reduce(max_memory)
423 | epoch_time = trainer.training_type_plugin.reduce(epoch_time)
424 |
425 | rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
426 | rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
427 | except AttributeError:
428 | pass
429 |
430 |
431 | if __name__ == "__main__":
432 | # custom parser to specify config files, train, test and debug mode,
433 | # postfix, resume.
434 | # `--key value` arguments are interpreted as arguments to the trainer.
435 | # `nested.key=value` arguments are interpreted as config parameters.
436 | # configs are merged from left-to-right followed by command line parameters.
437 |
438 | # model:
439 | # base_learning_rate: float
440 | # target: path to lightning module
441 | # params:
442 | # key: value
443 | # data:
444 | # target: main.DataModuleFromConfig
445 | # params:
446 | # batch_size: int
447 | # wrap: bool
448 | # train:
449 | # target: path to train dataset
450 | # params:
451 | # key: value
452 | # validation:
453 | # target: path to validation dataset
454 | # params:
455 | # key: value
456 | # test:
457 | # target: path to test dataset
458 | # params:
459 | # key: value
460 | # lightning: (optional, has sane defaults and can be specified on cmdline)
461 | # trainer:
462 | # additional arguments to trainer
463 | # logger:
464 | # logger to instantiate
465 | # modelcheckpoint:
466 | # modelcheckpoint to instantiate
467 | # callbacks:
468 | # callback1:
469 | # target: importpath
470 | # params:
471 | # key: value
472 |
473 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
474 |
475 | # add cwd for convenience and to make classes in this file available when
476 | # running as `python main.py`
477 | # (in particular `main.DataModuleFromConfig`)
478 | sys.path.append(os.getcwd())
479 |
480 | parser = get_parser()
481 | parser = Trainer.add_argparse_args(parser)
482 |
483 | opt, unknown = parser.parse_known_args()
484 | if opt.name and opt.resume:
485 | raise ValueError(
486 | "-n/--name and -r/--resume cannot be specified both."
487 | "If you want to resume training in a new log folder, "
488 | "use -n/--name in combination with --resume_from_checkpoint"
489 | )
490 | if opt.resume:
491 | if not os.path.exists(opt.resume):
492 | raise ValueError("Cannot find {}".format(opt.resume))
493 | if os.path.isfile(opt.resume):
494 | paths = opt.resume.split("/")
495 | # idx = len(paths)-paths[::-1].index("logs")+1
496 | # logdir = "/".join(paths[:idx])
497 | logdir = "/".join(paths[:-2])
498 | ckpt = opt.resume
499 | else:
500 | assert os.path.isdir(opt.resume), opt.resume
501 | logdir = opt.resume.rstrip("/")
502 | ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
503 |
504 | opt.resume_from_checkpoint = ckpt
505 | base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
506 | opt.base = base_configs + opt.base
507 | _tmp = logdir.split("/")
508 | nowname = _tmp[-1]
509 | else:
510 | if opt.name:
511 | name = "_" + opt.name
512 | elif opt.base:
513 | cfg_fname = os.path.split(opt.base[0])[-1]
514 | cfg_name = os.path.splitext(cfg_fname)[0]
515 | name = "_" + cfg_name
516 | else:
517 | name = ""
518 | nowname = now + name + opt.postfix
519 | logdir = os.path.join(opt.logdir, nowname)
520 |
521 | ckptdir = os.path.join(logdir, "checkpoints")
522 | cfgdir = os.path.join(logdir, "configs")
523 | seed_everything(opt.seed)
524 |
525 | try:
526 | # init and save configs
527 | configs = [OmegaConf.load(cfg) for cfg in opt.base]
528 | cli = OmegaConf.from_dotlist(unknown)
529 | config = OmegaConf.merge(*configs, cli)
530 | lightning_config = config.pop("lightning", OmegaConf.create())
531 | # merge trainer cli with config
532 | trainer_config = lightning_config.get("trainer", OmegaConf.create())
533 | # default to ddp
534 | trainer_config["accelerator"] = "ddp"
535 | for k in nondefault_trainer_args(opt):
536 | trainer_config[k] = getattr(opt, k)
537 | if not "gpus" in trainer_config:
538 | del trainer_config["accelerator"]
539 | cpu = True
540 | else:
541 | gpuinfo = trainer_config["gpus"]
542 | print(f"Running on GPUs {gpuinfo}")
543 | cpu = False
544 | trainer_opt = argparse.Namespace(**trainer_config)
545 | lightning_config.trainer = trainer_config
546 |
547 | # model
548 | model = instantiate_from_config(config.model)
549 |
550 | # trainer and callbacks
551 | trainer_kwargs = dict()
552 |
553 | # default logger configs
554 | # NOTE wandb < 0.10.0 interferes with shutdown
555 | # wandb >= 0.10.0 seems to fix it but still interferes with pudb
556 | # debugging (wrongly sized pudb ui)
557 | # thus prefer testtube for now
558 | default_logger_cfgs = {
559 | "wandb": {
560 | "target": "pytorch_lightning.loggers.WandbLogger",
561 | "params": {
562 | "name": nowname,
563 | "save_dir": logdir,
564 | "offline": opt.debug,
565 | "id": nowname,
566 | }
567 | },
568 | "testtube": {
569 | "target": "pytorch_lightning.loggers.TestTubeLogger",
570 | "params": {
571 | "name": "testtube",
572 | "save_dir": logdir,
573 | }
574 | },
575 | }
576 | default_logger_cfg = default_logger_cfgs["testtube"]
577 | if 'logger' in lightning_config:
578 | logger_cfg = lightning_config.logger
579 | else:
580 | logger_cfg = OmegaConf.create()
581 | logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
582 | trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
583 |
584 | # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
585 | # specify which metric is used to determine best models
586 | default_modelckpt_cfg = {
587 | "target": "pytorch_lightning.callbacks.ModelCheckpoint",
588 | "params": {
589 | "dirpath": ckptdir,
590 | "filename": "{epoch:06}",
591 | "verbose": True,
592 | "save_last": True,
593 | }
594 | }
595 | if hasattr(model, "monitor"):
596 | print(f"Monitoring {model.monitor} as checkpoint metric.")
597 | default_modelckpt_cfg["params"]["monitor"] = model.monitor
598 | default_modelckpt_cfg["params"]["save_top_k"] = 3
599 |
600 | if 'modelcheckpoint' in lightning_config:
601 | modelckpt_cfg = lightning_config.modelcheckpoint
602 | else:
603 | modelckpt_cfg = OmegaConf.create()
604 | modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
605 | print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
606 | if version.parse(pl.__version__) < version.parse('1.4.0'):
607 | trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
608 |
609 | # add callback which sets up log directory
610 | default_callbacks_cfg = {
611 | "setup_callback": {
612 | "target": "main.SetupCallback",
613 | "params": {
614 | "resume": opt.resume,
615 | "now": now,
616 | "logdir": logdir,
617 | "ckptdir": ckptdir,
618 | "cfgdir": cfgdir,
619 | "config": config,
620 | "lightning_config": lightning_config,
621 | }
622 | },
623 | "image_logger": {
624 | "target": "main.ImageLogger",
625 | "params": {
626 | "batch_frequency": 750,
627 | "max_images": 4,
628 | "clamp": True
629 | }
630 | },
631 | "learning_rate_logger": {
632 | "target": "main.LearningRateMonitor",
633 | "params": {
634 | "logging_interval": "step",
635 | # "log_momentum": True
636 | }
637 | },
638 | "cuda_callback": {
639 | "target": "main.CUDACallback"
640 | },
641 | }
642 | if version.parse(pl.__version__) >= version.parse('1.4.0'):
643 | default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
644 |
645 | if 'callbacks' in lightning_config:
646 | callbacks_cfg = lightning_config.callbacks
647 | else:
648 | callbacks_cfg = OmegaConf.create()
649 |
650 | if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
651 | print(
652 | 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
653 | default_metrics_over_trainsteps_ckpt_dict = {
654 | 'metrics_over_trainsteps_checkpoint':
655 | {"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
656 | 'params': {
657 | "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
658 | "filename": "{epoch:06}-{step:09}",
659 | "verbose": True,
660 | 'save_top_k': -1,
661 | 'every_n_train_steps': 10000,
662 | 'save_weights_only': True
663 | }
664 | }
665 | }
666 | default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
667 |
668 | callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
669 | if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
670 | callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
671 | elif 'ignore_keys_callback' in callbacks_cfg:
672 | del callbacks_cfg['ignore_keys_callback']
673 |
674 | trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
675 |
676 | trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
677 | trainer.logdir = logdir ###
678 |
679 | # data
680 | data = instantiate_from_config(config.data)
681 | # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
682 | # calling these ourselves should not be necessary but it is.
683 | # lightning still takes care of proper multiprocessing though
684 | data.prepare_data()
685 | data.setup()
686 | print("#### Data #####")
687 | for k in data.datasets:
688 | print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
689 |
690 | # configure learning rate
691 | bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
692 | if not cpu:
693 | ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
694 | else:
695 | ngpu = 1
696 | accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
697 | print(f"accumulate_grad_batches = {accumulate_grad_batches}")
698 | lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
699 | if opt.scale_lr:
700 | model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
701 | print(
702 | "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
703 | model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
704 | else:
705 | model.learning_rate = base_lr
706 | print("++++ NOT USING LR SCALING ++++")
707 | print(f"Setting learning rate to {model.learning_rate:.2e}")
708 |
709 |
710 | # allow checkpointing via USR1
711 | def melk(*args, **kwargs):
712 | # run all checkpoint hooks
713 | if trainer.global_rank == 0:
714 | print("Summoning checkpoint.")
715 | ckpt_path = os.path.join(ckptdir, "last.ckpt")
716 | trainer.save_checkpoint(ckpt_path)
717 |
718 |
719 | def divein(*args, **kwargs):
720 | if trainer.global_rank == 0:
721 | import pudb;
722 | pudb.set_trace()
723 |
724 |
725 | import signal
726 |
727 | signal.signal(signal.SIGUSR1, melk)
728 | signal.signal(signal.SIGUSR2, divein)
729 |
730 | # run
731 | if opt.train:
732 | try:
733 | trainer.fit(model, data)
734 | except Exception:
735 | melk()
736 | raise
737 | if not opt.no_test and not trainer.interrupted:
738 | trainer.test(model, data)
739 | except Exception:
740 | if opt.debug and trainer.global_rank == 0:
741 | try:
742 | import pudb as debugger
743 | except ImportError:
744 | import pdb as debugger
745 | debugger.post_mortem()
746 | raise
747 | finally:
748 | # move newly created debug project to debug_runs
749 | if opt.debug and not opt.resume and trainer.global_rank == 0:
750 | dst, name = os.path.split(logdir)
751 | dst = os.path.join(dst, "debug_runs", name)
752 | os.makedirs(os.path.split(dst)[0], exist_ok=True)
753 | os.rename(logdir, dst)
754 | if trainer.global_rank == 0:
755 | print(trainer.profiler.summary())
756 |
--------------------------------------------------------------------------------
/scripts/inpaint_imagebart.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import numpy as np
4 | from omegaconf import OmegaConf
5 | import streamlit as st
6 | from streamlit import caching
7 | from PIL import Image
8 | from main import instantiate_from_config
9 | from torch.utils.data.dataloader import default_collate
10 | from torchvision.utils import make_grid
11 | from tqdm import trange
12 | from einops import repeat
13 | from contextlib import contextmanager
14 |
15 |
16 | from scripts.sample_imagebart import get_top_k_schedule, get_temperature_schedule
17 |
18 |
19 | @torch.no_grad()
20 | def sample_unconditional(models, batch_size, chain_schedule, temperature_schedule=None, top_k_schedule=None,
21 | dim_z=256, h=16, w=16,
22 | example=None, mask=None):
23 | ########### opts
24 | use_ema = True
25 | device = torch.device("cuda")
26 | model_index = len(models) - 1
27 | steps = len(models) * [h * w]
28 | #################
29 |
30 | # mask and input
31 | mask = mask.to(device=device)
32 | unmasked_input = torch.tensor(example["image"]).permute(0, 3, 1, 2).to(device=device)
33 | masked_inputs = mask * unmasked_input
34 |
35 | ##### clean for guidance
36 | with on_gpu(models[0]):
37 | pre_quant = models[0].get_h(unmasked_input)
38 | _, _, x_clean, _, _, _, _ = models[0].get_scales(pre_quant=pre_quant)
39 |
40 | model = models[model_index]
41 | c = torch.randint(0,
42 | model.first_stage_model.quantize.re_embed,
43 | x_clean.shape,
44 | device=x_clean.device)
45 |
46 | mask = torch.nn.functional.interpolate(mask, size=(16, 16),
47 | mode="nearest")
48 | mask = mask.reshape(c.shape).to(device=c.device)
49 | orig_dtype = c.dtype
50 | c = (1 - mask) * c + mask * x_clean
51 | c = c.to(dtype=orig_dtype)
52 |
53 | guide = torch.nn.functional.one_hot(c, num_classes=model.first_stage_model.quantize.re_embed).to(
54 | dtype=torch.float32)
55 | guide = torch.log(guide)
56 | guide[mask < 0.5] = 0
57 |
58 | # start sampling
59 | c_scale_indices = c
60 | scale = model_index
61 | current_scale = (scale * torch.ones(batch_size, 1)).to(device).long()
62 | steppys = st.empty()
63 | cb = lambda x: steppys.write(f"{x}/{h * w}")
64 | scaleinfo = st.empty()
65 | n_scales = len(models)
66 | for scale_n, model in enumerate(models[:model_index + 1][::-1]):
67 | temperature = temperature_schedule[scale]
68 | top_k = top_k_schedule[scale]
69 | n_chains = chain_schedule[scale]
70 | with on_gpu(model):
71 | with ema_scope(model, active=use_ema):
72 | for chain_idx in range(n_chains):
73 | scaleinfo.write(
74 | f"sampling chain {chain_idx + 1}/{n_chains} for scale {n_scales - scale_n}/{n_scales}, "
75 | f"temp = {temperature:.2f}, top-k = {top_k}")
76 |
77 | chain_weight = 1 - chain_idx / n_chains
78 | if chain_idx > 0:
79 | # already reversed, run forward again
80 | origdtype = c_scale_indices.dtype
81 | randindices = torch.randint(
82 | 0,
83 | model.first_stage_model.quantize.re_embed,
84 | c_scale_indices.shape,
85 | device=c_scale_indices.device)
86 | redraw_prob = chain_weight * model.temperature_range[model.single_scale]
87 | redraw = torch.bernoulli(
88 | redraw_prob * torch.ones(c_scale_indices.shape)).to(
89 | device=c_scale_indices.device)
90 | c_scale_indices = (1 - redraw) * c_scale_indices + redraw * randindices
91 | c_scale_indices = c_scale_indices.to(dtype=origdtype)
92 | c_scale_indices = model.sample_single_scale(c_scale_indices,
93 | current_scale + 1,
94 | temp_x=None,
95 | steps=steps[scale],
96 | temperature=temperature,
97 | top_k=top_k,
98 | guide=guide,
99 | callback=cb
100 | )
101 | scale -= 1
102 | current_scale = (scale * torch.ones(batch_size, 1)).to(device).long()
103 |
104 | qzshape = [batch_size, dim_z, h, w]
105 | with on_gpu(model):
106 | sample = model.decode_to_img(c_scale_indices, qzshape)
107 |
108 | log = dict()
109 | log["samples"] = sample
110 | log["inputs"] = unmasked_input
111 | log["masked_inputs"] = masked_inputs
112 | return log
113 |
114 |
115 | def generate_mask(masking_option, shape):
116 | bs, h, w, c = shape
117 | mask = np.array(Image.new('L', (h, w))).astype(np.bool)
118 |
119 | if masking_option == 'upper-half completion':
120 | mask[h // 2:] = np.logical_not(mask[h // 2:])
121 |
122 | elif masking_option == 'window-inpainting':
123 | window_size = st.number_input(f'Select size of quadratic window for {masking_option} '
124 | f'(note: divided by 16 in latent space)', min_value=10,
125 | max_value=h // 2, value=h // 4)
126 | mask = np.logical_not(mask)
127 | mask[
128 | (h - window_size) // 2:(h + window_size) // 2,
129 | (w - window_size) // 2:(w + window_size) // 2] = np.logical_not(
130 | mask[(h - window_size) // 2:(h + window_size) // 2,
131 | (w - window_size) // 2:(w + window_size) // 2])
132 |
133 | else:
134 | window_size = st.number_input(f'Select size of quadratic window for {masking_option} '
135 | f'(note: divided by 16 in latent space)', min_value=h // 2,
136 | max_value=h - 20, value=h // 2)
137 | mask[
138 | (h - window_size) // 2:(h + window_size) // 2,
139 | (w - window_size) // 2:(w + window_size) // 2] = np.logical_not(
140 | mask[(h - window_size) // 2:(h + window_size) // 2,
141 | (w - window_size) // 2:(w + window_size) // 2])
142 | st.warning('With outpainting enabled, you might have to increase the length of the chains')
143 |
144 | display_mask = mask
145 | # only for displaying reasons incase of inpainting and upper half completion
146 | for p in [0, h - 1]:
147 | display_mask[p] = False
148 | display_mask[:, p] = False
149 | st.image((255 * mask.astype(np.uint8)), f'Selected mask for {masking_option}')
150 |
151 | mask = torch.from_numpy(mask.astype(np.float32)).float()
152 | mask = repeat(mask, 'h w -> b 1 h w', b=batch_size)
153 | return mask
154 |
155 |
156 | @torch.no_grad()
157 | def run(models, dset, batch_size, temperature, top_k, chain_schedule, num_runs):
158 | img_spatial = models[0].first_stage_model.encoder.resolution
159 | img_shape = [batch_size, img_spatial, img_spatial, 3]
160 |
161 | masking_option = st.selectbox('Select masking option',
162 | ['upper-half completion', 'window-inpainting', 'window-outpainting'],
163 | index=0)
164 |
165 | mask = generate_mask(masking_option, img_shape)
166 |
167 | if st.button('Sample with chain'):
168 |
169 | for n in trange(num_runs, desc="Data"):
170 | indices = np.random.choice(len(dset), batch_size, replace=False)
171 | example = default_collate([dset[i] for i in indices])
172 | logs = sample_unconditional(models, batch_size=batch_size,
173 | temperature_schedule=temperature, top_k_schedule=top_k,
174 | example=example, mask=mask, chain_schedule=chain_schedule)
175 |
176 | log_to_st(logs, n)
177 |
178 |
179 | def log_to_st(log, n):
180 | keys = ['inputs', 'masked_inputs', 'samples']
181 | bs = log[keys[0]].shape[0]
182 |
183 | flatgrid = torch.cat([torch.clamp(log[k].detach().cpu(), -1., 1.) for k in keys], dim=0)
184 | grid = make_grid(flatgrid, nrow=bs, normalize=True).permute(1, 2, 0).numpy()
185 |
186 | st.image(grid, f'Masked samples #{n + 1} (top: original, mid: masked input, bottom: sample)')
187 |
188 |
189 | @contextmanager
190 | def ema_scope(model, active=False, context=None):
191 | if active:
192 | model.transformer_ema.store(model.transformer.parameters())
193 | model.transformer_ema.copy_to(model.transformer)
194 | if context is not None:
195 | print(f"{context}: Switched to EMA weights")
196 | try:
197 | yield None
198 | finally:
199 | if active:
200 | model.transformer_ema.restore(model.transformer.parameters())
201 | if context is not None:
202 | print(f"{context}: Restored training weights")
203 |
204 |
205 | @contextmanager
206 | def on_gpu(model, context=None):
207 | model = model.cuda()
208 | if context is not None:
209 | print(f"{context}: Moved model to GPU")
210 | try:
211 | yield None
212 | finally:
213 | model = model.cpu()
214 | torch.cuda.empty_cache()
215 | if context is not None:
216 | print(f"{context}: Moved model to CPU")
217 |
218 |
219 | def load_model_from_config(config, sd, gpu=True, eval_mode=True):
220 | print("config:")
221 | print(OmegaConf.to_yaml(config))
222 | model = instantiate_from_config(config["model"])
223 | if sd is not None:
224 | m, u = model.load_state_dict(sd, strict=False)
225 | if gpu:
226 | model.cuda()
227 | if eval_mode:
228 | model.eval()
229 | return {"model": model}
230 |
231 |
232 | @st.cache(allow_output_mutation=True)
233 | def get_data(config):
234 | # get data
235 | try:
236 | if config.data.params.train.target == "braket.data.faceshq.FFHQTrain":
237 | config.data.params.train.params.random_flip = False
238 | print("Disabled random flip for FFHQ train")
239 | except Exception:
240 | pass
241 | data = instantiate_from_config(config.data)
242 | data.prepare_data()
243 | data.setup()
244 | return data
245 |
246 |
247 | def get_config(path):
248 | config = OmegaConf.load(path)
249 | return config
250 |
251 |
252 | @st.cache(allow_output_mutation=True)
253 | def load_models(paths, gpu=False, eval_mode=True):
254 | assert not gpu, 'moving them later'
255 | models = list()
256 | configs = list()
257 | global_steps = list()
258 | for ckpt_path, config_path in zip(paths["checkpoints"], paths["configs"]):
259 | print(f"loading config from {config_path} and model from {ckpt_path}")
260 | config = get_config(config_path)
261 | pl_sd = torch.load(ckpt_path, map_location="cpu")
262 | global_step = pl_sd["global_step"]
263 | model = load_model_from_config(config, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
264 | models.append(model)
265 | configs.append(config)
266 | global_steps.append(global_step)
267 | print(f"loaded model from global step {global_step}")
268 | if "repeat" in paths:
269 | n_models = list()
270 | n_configs = list()
271 | n_global_steps = list()
272 | for i, n in enumerate(paths["repeat"]):
273 | print(f"Repeating model {i} {n}x.")
274 | n_models += n * [models[i]]
275 | n_configs += n * [configs[i]]
276 | n_global_steps += n * [global_steps[i]]
277 | models = n_models
278 | configs = n_configs
279 | global_steps = n_global_steps
280 | return models, configs, global_steps
281 |
282 |
283 | if __name__ == "__main__":
284 | sys.path.append(os.getcwd())
285 | if not st._is_running_with_streamlit:
286 | print("Not running with streamlit. Redefining st functions...")
287 | st.info = print
288 | st.write = print
289 |
290 | yaml_path = sys.argv[1]
291 | paths = OmegaConf.load(yaml_path)
292 | paths = OmegaConf.to_container(paths)
293 |
294 | gpu = True
295 | eval_mode = True
296 |
297 | models, configs, global_steps = load_models(paths, gpu=False, eval_mode=eval_mode)
298 | if models[0].conditioner is not None:
299 | raise NotImplementedError('Currently only available for unconditional models.')
300 | device = torch.device("cuda") if gpu else torch.device("cpu")
301 | dsets = get_data(configs[0])
302 |
303 | split = "validation"
304 | dset = dsets.datasets[split]
305 | print(f"Dataset size: {len(dset)}")
306 |
307 | codebook_size = models[0].first_stage_model.quantize.re_embed
308 |
309 | st.sidebar.write('Sampling options')
310 | n_runs = st.sidebar.number_input('Number of runs', 1, 100, 1)
311 | batch_size = st.sidebar.number_input('Batch size', 1, 20, 4)
312 |
313 | top_k = get_top_k_schedule(len(models), codebook_size=codebook_size)
314 | temperature = get_temperature_schedule(len(models))
315 |
316 | chain_schedule = []
317 | st.write('Define chain schedule')
318 | st.info(
319 | f'The n-th entry in the chain schedule defines the number of sucessive runs, '
320 | f'the n-th AR submodel should perfom before passing the output to the next submodel.')
321 | for n in range(len(models)):
322 | if models[n].redraw_prob != 'geometric':
323 | if n == len(models) - 1:
324 | def_chain_len = 1
325 | else:
326 | def_chain_len = 5
327 | else:
328 | def_chain_len = 3
329 | chain_n = st.number_input(f"Chain length for scale #{n + 1}", min_value=1, max_value=100, value=def_chain_len)
330 | chain_schedule.append(chain_n)
331 |
332 | chain_schedule = chain_schedule
333 |
334 | run(models, dset, batch_size, temperature=temperature, top_k=top_k, chain_schedule=chain_schedule, num_runs=n_runs)
335 |
--------------------------------------------------------------------------------
/scripts/sample_imagebart.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import hashlib
3 | import torch
4 | import numpy as np
5 | from omegaconf import OmegaConf
6 | import streamlit as st
7 | from PIL import Image
8 | from main import instantiate_from_config
9 | from torchvision.utils import make_grid
10 | from contextlib import contextmanager
11 |
12 | rescale = lambda x: (x + 1.) / 2.
13 |
14 |
15 | def bchw_to_st(x):
16 | return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
17 |
18 | def chw_to_st(x):
19 | return rescale(x.detach().cpu().numpy().transpose(1,2,0))
20 |
21 | def chw_to_pillow(x):
22 | return Image.fromarray(chw_to_np(x))
23 |
24 | def chw_to_np(x):
25 | return (255 * rescale(x.detach().cpu().numpy().transpose(1, 2, 0))).clip(0, 255).astype(np.uint8)
26 |
27 |
28 | def computeMD5hash(string):
29 | m = hashlib.md5()
30 | m.update(string.encode('utf-8'))
31 | return m.hexdigest()
32 |
33 |
34 | class L1(torch.nn.Module):
35 | def __init__(self):
36 | super().__init__()
37 |
38 | def forward(self, x, y):
39 | return torch.abs(x-y).sum(dim=[1,2,3]).mean()
40 |
41 |
42 |
43 | @torch.no_grad()
44 | def custom_log_images(model, batch, temperature):
45 | log = dict()
46 | x = model.get_input(batch, model.image_key)
47 | x = x.to(model.device)
48 | # encode
49 | h = model.encoder(x)
50 | h = model.quant_conv(h)
51 | quant, _, _ = model.quantize(h, temp=temperature, rescale_logits=True)
52 | # decode
53 | x_rec = model.decode(quant)
54 | log["inputs"] = x
55 | log["reconstructions"] = x_rec
56 | return log
57 |
58 |
59 | def grid2img(grid):
60 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
61 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
62 | grid = grid.numpy()
63 | grid = (grid * 255).astype(np.uint8)
64 | print(f"grid.max/grid.min {grid.max()}/{grid.min()}")
65 | grid = Image.fromarray(grid)
66 | return grid
67 |
68 |
69 | def get_top_k_schedule(n_steps, codebook_size):
70 | tk_schedule = st.radio("top-k scheduling", ["constant", "linear", "user"], 2)
71 | if tk_schedule == "constant":
72 | tk = st.number_input("Constant Top-K Value", value=codebook_size)
73 | top_k_schedule = (np.ones(n_steps) * tk).astype(int)
74 | elif tk_schedule == "linear":
75 | tk_start = st.number_input("Start Top-K Value", value=codebook_size)
76 | tk_end = st.number_input("End Top-K Value", value=codebook_size//4)
77 | top_k_schedule = np.linspace(tk_start, tk_end, n_steps).astype(int)
78 | elif tk_schedule == "user":
79 | default = f"{codebook_size}," * n_steps
80 | tk_list = st.text_input(f"Top-K Values (comma separated, must be {n_steps}, counted in sampling order, i.e. from last to first scale)", f"{default[:-1]}")
81 | tks = tk_list.split(",")
82 | top_k_schedule = list()
83 | for tk in tks:
84 | top_k_schedule.append(int(tk))
85 | assert len(top_k_schedule) == n_steps
86 | else:
87 | return None
88 | return top_k_schedule
89 |
90 | def get_temperature_schedule(n_steps):
91 | type = st.radio("temperature scheduling", ["constant", "linear", "user"], 2)
92 | if type == "constant":
93 | tk = st.number_input("Constant Temperature Value", value=1.0)
94 | schedule = (np.ones(n_steps) * tk)
95 | elif type == "linear":
96 | tk_start = st.number_input("Start Temperature Value", value=1.)
97 | tk_end = st.number_input("End Top-Temperature Value", value=0.1)
98 | schedule = np.linspace(tk_start, tk_end, n_steps)
99 | elif type == "user":
100 | default = f"{1.0}," * n_steps
101 | tk_list = st.text_input(f"Temperature Values (comma separated, must be {n_steps}, counted in sampling order, i.e. from last to first scale)", f"{default[:-1]}")
102 | tks = tk_list.split(",")
103 | schedule = list()
104 | for tk in tks:
105 | schedule.append(float(tk))
106 | assert len(schedule) == n_steps
107 | else:
108 | return None
109 | return schedule
110 |
111 |
112 | @contextmanager
113 | def ema_scope(model, active=False, context=None):
114 | if active:
115 | model.transformer_ema.store(model.transformer.parameters())
116 | model.transformer_ema.copy_to(model.transformer)
117 | if context is not None:
118 | print(f"{context}: Switched to EMA weights")
119 | try:
120 | yield None
121 | finally:
122 | if active:
123 | model.transformer_ema.restore(model.transformer.parameters())
124 | if context is not None:
125 | print(f"{context}: Restored training weights")
126 |
127 |
128 | @contextmanager
129 | def on_gpu(model, context=None):
130 | model = model.cuda()
131 | if context is not None:
132 | print(f"{context}: Moved model to GPU")
133 | try:
134 | yield None
135 | finally:
136 | model = model.cpu()
137 | torch.cuda.empty_cache()
138 | if context is not None:
139 | print(f"{context}: Moved model to CPU")
140 |
141 |
142 | @torch.no_grad()
143 | def run(models, user_conditioning, batch_size, device=torch.device("cuda"),conditional=False):
144 | assert type(models) == list
145 |
146 | n_scales = len(models)
147 | codebook_size = models[0].first_stage_model.quantize.re_embed
148 |
149 | cond = None
150 | start_index = len(models) - 1
151 | model = models[start_index]
152 |
153 |
154 | n_downs= model.first_stage_model.encoder.num_resolutions - 1
155 | h = model.first_stage_model.encoder.resolution // (2**n_downs)
156 | w = model.first_stage_model.encoder.resolution // (2**n_downs)
157 | dim_z = model.first_stage_model.embed_dim
158 |
159 |
160 | index_shape = [batch_size, h * w]
161 | qzshape = [batch_size, dim_z, h, w]
162 |
163 | st.info(f'Latent shape is {qzshape}')
164 | if user_conditioning is not None:
165 | exmpl = {model.conditioner.key: user_conditioning}
166 | cond = model.get_conditioning(exmpl).to(device)
167 | st.sidebar.write(f"cond.shape: {cond.shape}")
168 |
169 |
170 | top_k_schedule = get_top_k_schedule(n_scales, codebook_size=codebook_size)[::-1]
171 | temperature_schedule = get_temperature_schedule(n_scales)[::-1]
172 | st.text(f"top-k schedule: {top_k_schedule}")
173 | st.text(f"temperature schedule: {temperature_schedule}")
174 |
175 | n_batches = st.number_input("number runs", value=1, min_value=1, max_value=1000)
176 | steps = n_scales * [h*w]
177 |
178 |
179 | if st.button("Sample", False):
180 | grid_ph = st.empty()
181 | final_samples = list()
182 | for n in range(n_batches):
183 |
184 | scaleinfo = st.empty()
185 | scale = start_index
186 | c_scale_indices = torch.randint(0,
187 | model.first_stage_model.quantize.re_embed,
188 | index_shape,
189 | device=model.device)
190 | current_scale = (scale * torch.ones(batch_size, 1)).to(c_scale_indices).to(device)
191 |
192 | steppys = st.empty()
193 | cb = lambda x: steppys.write(f"{x}/{h*w}")
194 | for model_count, model in enumerate(models[::-1]):
195 | with on_gpu(model, "Sampling"):
196 | temperature = temperature_schedule[scale]
197 | top_k = top_k_schedule[scale]
198 | scaleinfo.write(f"sampling scale {scale+1}/{n_scales}, temp = {temperature:.2f}, top-k = {top_k}")
199 |
200 | with ema_scope(model, active=True, context="Sampling"):
201 |
202 | assert (current_scale + 1)[0].item() == model.single_scale, \
203 | f"{(current_scale + 1)[0].item()} =/= {model.single_scale} :("
204 | c_scale_indices = model.sample_single_scale(c_scale_indices.to(device),
205 | (current_scale+1).to(device),
206 | temp_x=None,
207 | steps=steps[scale],
208 | temperature=temperature,
209 | top_k=top_k,
210 | cond=cond.to(model.device) if cond is not None else None,
211 | callback=cb
212 | )
213 |
214 | if model_count == len(models) -1:
215 | final_samples.append(model.decode_to_img(c_scale_indices,
216 | [batch_size, qzshape[1], qzshape[2], qzshape[3]]))
217 | scale -= 1
218 | current_scale = (scale * torch.ones(batch_size, 1)).to(device).long()
219 |
220 | intermediate_grid = make_grid(final_samples[-1],nrow=batch_size,padding=0)
221 | st.image(chw_to_st(intermediate_grid),clamp=True,output_format='PNG')
222 |
223 | final_samples = torch.cat(final_samples, 0)
224 | grid = make_grid(final_samples, nrow=batch_size, padding=0)
225 | grid_ph.image(chw_to_st(grid), clamp=True, output_format="PNG")
226 |
227 | @torch.no_grad()
228 | def render_as_grid(scale_samples, batch_size, stack=True):
229 | # make a grid
230 | if stack:
231 | scale_samples = torch.stack(scale_samples, dim=0)
232 | assert batch_size == scale_samples.shape[1]
233 | grids = []
234 | for i in range(scale_samples.shape[1]):
235 | grid = make_grid(scale_samples[:, i, ...], nrow=scale_samples.shape[0])
236 | grids.append(grid)
237 |
238 | for k in range(len(grids)):
239 | st.image(chw_to_st(grids[k]), clamp=True, output_format="PNG", use_column_width=grids[k].shape[2] < 500)
240 |
241 |
242 | def load_model_from_config(config, sd, gpu=True, eval_mode=True):
243 | print("config:")
244 | print(OmegaConf.to_yaml(config))
245 | model = instantiate_from_config(config["model"])
246 | if sd is not None:
247 | m, u = model.load_state_dict(sd, strict=False)
248 | if gpu:
249 | model.cuda()
250 | if eval_mode:
251 | model.eval()
252 | return {"model": model}
253 |
254 | @st.cache(allow_output_mutation=True)
255 | def get_data(config):
256 | # get data
257 | data = instantiate_from_config(config.data)
258 | data.prepare_data()
259 | data.setup()
260 | return data
261 |
262 |
263 | def get_config(path):
264 | config = OmegaConf.load(path)
265 | return config
266 |
267 |
268 | @st.cache(allow_output_mutation=True)
269 | def load_models(paths, gpu=False, eval_mode=True):
270 | assert not gpu, 'moving them later'
271 | models = list()
272 | configs = list()
273 | global_steps = list()
274 |
275 | for ckpt_path, config_path in zip(paths["checkpoints"], paths["configs"]):
276 | print(f"loading config from {config_path} and model from {ckpt_path}")
277 | config = get_config(config_path)
278 | pl_sd = torch.load(ckpt_path, map_location="cpu")
279 | global_step = pl_sd["global_step"]
280 | model = load_model_from_config(config, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
281 | models.append(model)
282 | configs.append(config)
283 | global_steps.append(global_step)
284 |
285 | print(f"loaded model from global step {global_step}")
286 |
287 | return models, configs, global_steps
288 |
289 |
290 | if __name__ == "__main__":
291 |
292 | yaml_path = sys.argv[1]
293 | log_path = yaml_path.split(os.sep)[-1][:-5]
294 | paths = OmegaConf.load(yaml_path)
295 | print(OmegaConf.to_yaml(paths))
296 | paths = OmegaConf.to_container(paths)
297 |
298 | log_path =paths["metrics"]["savepath"] if 'metrics' in paths and 'savepath' in paths['metrics'] else os.path.join("logs", log_path)
299 | print(f"loading from .yaml at {yaml_path}")
300 |
301 |
302 | print(f'logging to {log_path}')
303 |
304 | st.sidebar.text("Options")
305 | gpu = st.sidebar.checkbox("GPU", value=True)
306 |
307 | models, configs, global_steps = load_models(paths, gpu=False)
308 | device = torch.device("cuda") if gpu else torch.device("cpu")
309 |
310 | # dsets = get_data(configs[0])
311 | step_info = ""
312 | st.write(step_info[:-2])
313 |
314 | batch_size = st.number_input("Batch size", min_value=1, value=4)
315 | conditional = models[0].conditioner is not None
316 |
317 |
318 | user_conditioning = None
319 | if conditional:
320 | st.info("Detected a conditional model.")
321 | user_inputs = []
322 | conditioner_key = models[0].conditioner.key
323 |
324 | if conditioner_key == "caption":
325 | for n in range(batch_size):
326 | user_input = st.text_input(f"user caption {n}", value=f"Example caption {n}")
327 | user_inputs.append(user_input)
328 |
329 | #example["caption"] = [user_inputs]
330 | user_conditioning = user_inputs
331 |
332 | st.write(f"Selected text-prompts are {user_conditioning}")
333 | elif conditioner_key == "class_label":
334 |
335 | cfd = os.path.dirname(os.path.abspath(__file__))
336 | integer2human = OmegaConf.load(os.path.join(cfd,'../data/imagenet_ids2labels.yaml'))
337 |
338 | format_fn = lambda x: integer2human[x]
339 | for n in range(batch_size):
340 | user_input = st.selectbox(f"user class label {n}", index=144,
341 | options=list(integer2human.keys()),
342 | format_func=format_fn)
343 | user_inputs.append(int(user_input))
344 |
345 | user_conditioning = torch.tensor(user_inputs)
346 |
347 | human_labels = [integer2human[str(l)] for l in user_inputs]
348 | st.write(f"Selected class labels are {human_labels}")
349 | else:
350 | raise NotImplementedError(f"Model with conditoner key {conditioner_key} not yet implemented.")
351 |
352 | run(models, user_conditioning, batch_size, device=device,conditional=conditional)
353 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='imagebart',
5 | version='0.0.1',
6 | description='autoregressive image modification via multinomial diffusion',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
14 |
--------------------------------------------------------------------------------