├── .gitignore
├── LICENSE
├── README.md
├── configs
├── ldm
│ ├── coco_sg2im_ldm_Layout2I_vqgan_f8.yaml
│ └── coco_stuff_ldm_T2I_vqgan_f8.yaml
└── vqgan
│ └── coco_vqgan_f8.yaml
├── environment.yaml
├── ldm
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── imagenet.py
│ └── lsun.py
├── lr_scheduler.py
├── models
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ └── ddpm.py
├── modules
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ └── utils_image.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── x_transformer.py
└── util.py
├── main.py
├── setup.py
├── taming
├── data
│ ├── annotated_objects_coco.py
│ ├── annotated_objects_dataset.py
│ ├── base.py
│ ├── coco.py
│ ├── conditional_builder
│ │ ├── objects_bbox.py
│ │ ├── objects_center_points.py
│ │ └── utils.py
│ ├── custom.py
│ ├── helper_types.py
│ ├── image_transforms.py
│ ├── open_images_helper.py
│ ├── sflckr.py
│ └── utils.py
├── lr_scheduler.py
├── models
│ ├── cond_transformer.py
│ ├── dummy_cond_stage.py
│ └── vqgan.py
├── modules
│ ├── diffusionmodules
│ │ └── model.py
│ ├── discriminator
│ │ └── model.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── lpips.py
│ │ ├── segmentation.py
│ │ ├── soft_cross_entropy.py
│ │ └── vqperceptual.py
│ ├── misc
│ │ ├── coord.py
│ │ └── pos_embed.py
│ ├── transformer
│ │ ├── mingpt.py
│ │ └── permuter.py
│ ├── util.py
│ └── vqvae
│ │ ├── mapping.py
│ │ └── quantize.py
└── util.py
└── tools
├── download_datasets.sh
├── download_models.sh
├── ldm
├── train_ldm_coco_Layout2I.sh
└── train_ldm_coco_T2I.sh
└── vqgan
└── train_vqgan_coco.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # amlt
2 | .amltconfig
3 |
4 | # original
5 | *.swp
6 | *.pt
7 | *.pth
8 | *.ckpt
9 | *.txt
10 | *tfevents*
11 | *.json
12 | *.png
13 | **/__pycache__/**
14 | .dumbo.json
15 | checkpoints/
16 | checkpoint/
17 | .idea/*
18 | **/.ipynb_checkpoints/**
19 | run.sh
20 | output/
21 | output_dir/
22 | output_aml/
23 | weights/
24 | logs/*
25 | exp/*
26 | src/*
27 |
28 | ## Deep speed
29 | *.pyc
30 | *.idea/
31 | *~
32 | *.swp
33 | *.log
34 | *deepspeed/git_version_info_installed.py
35 |
36 | # Build + installation data
37 | *build/
38 | *dist/
39 | *.so
40 | *deepspeed.egg-info/
41 | *build.txt
42 |
43 | # Website
44 | *docs/_site/
45 | *docs/build
46 | *docs/code-docs/source/_build
47 | *docs/code-docs/_build
48 | *docs/code-docs/build
49 | *.sass-cache/
50 | *.jekyll-cache/
51 | *.jekyll-metadata
52 |
53 | # Testing data
54 | *tests/unit/saved_checkpoint/
55 |
56 | # Dev/IDE data
57 | *.vscode
58 | *.theia
59 |
60 | *cache*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 ChirsFan0312
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Latent Diffusion Model (LDM) for Layout-to-image generation
2 |
3 | ---
4 | This is the non-official repository of [LDM](https://arxiv.org/abs/2112.10752) for layout-to-image generation. Currently, the config and code in official LDM repo is incompleted. Thus, the repo aims to reproduce LDM on Layout-to-image generation task. If you find it useful, please cite their original paper [LDM](https://arxiv.org/abs/2112.10752).
5 |
6 | ---
7 | ## Machine environment
8 | - Ubuntu version: 18.04.5 LTS
9 | - CUDA version: 11.6
10 | - Testing GPU: Nvidia Tesla V100
11 | ---
12 |
13 | ## Requirements
14 | A [conda](https://conda.io/) environment named `ldm_layout` can be created and activated with:
15 |
16 | ```bash
17 | conda env create -f environment.yaml
18 | conda activate ldm_layout
19 | ```
20 | ---
21 |
22 | ## Datasets setup
23 | We provide two approaches to set up the datasets:
24 | ### Auto-download
25 | To automatically download datasets and save it into the default path (`../`), please use following script:
26 | ```bash
27 | bash tools/download_datasets.sh
28 | ```
29 | ### Manual setup
30 |
31 | #### Text-to-image generation
32 | - We use COCO 2014 splits for text-to-image task, which can be downloaded from [official COCO website](https://cocodataset.org/#download).
33 | - Please create a folder name `2014` and collect the downloaded data and annotations as follows.
34 |
35 | COCO 2014 file structure
36 |
37 | ```
38 | >2014
39 | ├── annotations
40 | │ └── captions_val2014.json
41 | │ └── ...
42 | └── val2014
43 | └── COCO_val2014_000000000073.jpg
44 | └── ...
45 | ```
46 |
47 |
48 |
49 |
50 | #### Layout-to-image generation
51 | - We use COCO 2017 splits to test Frido on layout-to-image task, which can be downloaded from [official COCO website](https://cocodataset.org/#download).
52 | - Please create a folder name `2017` and collect the downloaded data and annotations as follows.
53 |
54 | COCO 2017 file structure
55 |
56 | ```
57 | >2017
58 | ├── annotations
59 | │ └── captions_val2017.json
60 | │ └── ...
61 | └── val2017
62 | └── 000000000872.jpg
63 | └── ...
64 | ```
65 |
66 |
67 |
68 |
69 | #### File structure for dataset and code
70 | Please make sure that the file structure is the same as the following. Or, you might modify the config file to match the corresponding paths.
71 |
72 | File structure
73 |
74 | ```
75 | >datasets
76 | ├── coco
77 | │ └── 2014
78 | │ └── annotations
79 | │ └── val2014
80 | │ └── ...
81 | │ └── 2017
82 | │ └── annotations
83 | │ └── val2017
84 | │ └── ...
85 | >ldm_layout
86 | └── configs
87 | │ └── ldm
88 | │ └── ...
89 | └── exp
90 | │ └── ...
91 | └── ldm
92 | └── taming
93 | └── scripts
94 | └── tools
95 | └── ...
96 | ```
97 |
98 |
99 |
100 | ---
101 |
102 |
103 | ## VQGAN models setup
104 | We provide script to download VQGAN-f8 in [LDM github](https://github.com/CompVis/latent-diffusion):
105 |
106 | To automatically download VQGAN-f8 and save it into the default path (`exp/`), please use following script:
107 | ```bash
108 | bash tools/download_models.sh
109 | ```
110 |
111 | ## Train LDM for layout-to-image generation
112 | We now provide scripts for training LDM on text-to-image and layout-to-image.
113 |
114 | Once the datasets are properly set up, one may train LDM by the following commands.
115 | ### Text-to-image
116 | ```bash
117 | bash tools/ldm/train_ldm_coco_T2I.sh
118 | ```
119 | - Default output folder will be `exp/ldm/T2I`
120 | ### Layout-to-image
121 |
122 | ```bash
123 | bash tools/ldm/train_ldm_coco_Layout2I.sh
124 | ```
125 | - Default output folder will be `exp/ldm/Layout2I`
126 |
127 | ### Multi-GPU testing
128 |
129 | Change "--gpus" to identify the number of GPUs for training.
130 |
131 | For example, using 4 gpus
132 | ```bash
133 |
134 | python main.py --base configs/ldm/coco_sg2im_ldm_Layout2I_vqgan_f8.yaml \
135 | -t True --gpus 0,1,2,3 -log_dir ./exp/ldm/Layout2I \
136 | -n coco_sg2im_ldm_Layout2I_vqgan_f8 --scale_lr False -tb True
137 | ```
138 |
139 | ---
140 |
141 | ## Inference
142 |
143 | Change "-t" to identify training or testing phase.
144 | (Note that multi-gpu testing is supported.)
145 |
146 | For example, using 4 gpus for testing
147 | ```bash
148 |
149 | python main.py --base configs/ldm/coco_sg2im_ldm_Layout2I_vqgan_f8.yaml \
150 | -t False --gpus 0,1,2,3 -log_dir ./exp/ldm/Layout2I \
151 | -n coco_sg2im_ldm_Layout2I_vqgan_f8 --scale_lr False -tb True
152 | ```
153 |
154 | ## Acknowledgement
155 | We build LDM_layout codebase heavily on the codebase of [Latent Diffusion Model (LDM)](https://github.com/CompVis/latent-diffusion) and [VQGAN](https://github.com/CompVis/taming-transformers). We sincerely thank the authors for open-sourcing!
156 |
157 | ## Citation
158 | If you find this code useful for your research, please consider citing:
159 | ```bibtex
160 | @misc{rombach2021highresolution,
161 | title={High-Resolution Image Synthesis with Latent Diffusion Models},
162 | author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
163 | year={2021},
164 | eprint={2112.10752},
165 | archivePrefix={arXiv},
166 | primaryClass={cs.CV}
167 | }
168 |
169 | @misc{https://doi.org/10.48550/arxiv.2204.11824,
170 | doi = {10.48550/ARXIV.2204.11824},
171 | url = {https://arxiv.org/abs/2204.11824},
172 | author = {Blattmann, Andreas and Rombach, Robin and Oktay, Kaan and Ommer, Björn},
173 | keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
174 | title = {Retrieval-Augmented Diffusion Models},
175 | publisher = {arXiv},
176 | year = {2022},
177 | copyright = {arXiv.org perpetual, non-exclusive license}
178 | }
179 |
180 | ```
181 |
--------------------------------------------------------------------------------
/configs/ldm/coco_sg2im_ldm_Layout2I_vqgan_f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.e-6 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | first_stage_key: image
6 | cond_stage_key: objects_bbox
7 | linear_start: 0.0015
8 | linear_end: 0.0205
9 | num_timesteps_cond: 1
10 | log_every_t: 20
11 | timesteps: 1000
12 | loss_type: l1
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: true
16 | conditioning_key: crossattn
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 32
23 | in_channels: 4
24 | out_channels: 4
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 4
35 | num_head_channels: 32
36 | use_spatial_transformer: true
37 | transformer_depth: 2
38 | context_dim: 512
39 |
40 | first_stage_config:
41 | target: taming.models.vqgan.VQModelInterface
42 | params:
43 | ckpt_path: exp/vqgan/vq-f8/model.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/
44 | embed_dim: 4
45 | n_embed: 16384
46 | ddconfig:
47 | double_z: False
48 | z_channels: 4
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 2
57 | - 4
58 | num_res_blocks: 2
59 | attn_resolutions:
60 | - 32
61 | dropout: 0.0
62 | vitconfig:
63 | embed_size: 256
64 | lossconfig:
65 | target: taming.modules.losses.DummyLoss
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 512
71 | n_layer: 16
72 | vocab_size: 16384
73 | max_seq_len: 92
74 | use_tokenizer: False
75 |
76 | plot_sample: False
77 | plot_inpaint: False
78 | plot_denoise_rows: False
79 | plot_progressive_rows: False
80 | plot_diffusion_rows: False
81 | plot_quantize_denoised: True
82 |
83 | data:
84 | target: main.DataModuleFromConfig
85 | params:
86 | batch_size: 4
87 | train:
88 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
89 | params:
90 | data_path: ../datasets/coco/2017 # substitute with path to full dataset
91 | img_id_file: ../datasets/coco/2017/annotations/deprecated-challenge2017/train-ids.txt
92 | split: train
93 | keys: [image, objects_bbox, file_name, annotations]
94 | no_tokens: 1024
95 | target_image_size: 256
96 | min_object_area: 0.02
97 | min_objects_per_image: 3
98 | max_objects_per_image: 8
99 | crop_method: center
100 | random_flip: True
101 | use_group_parameter: true
102 | encode_crop: true
103 | validation:
104 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
105 | params:
106 | data_path: ../datasets/coco/2017 # substitute with path to full dataset
107 | img_id_file: ../datasets/coco/2017/annotations/deprecated-challenge2017/val-ids.txt
108 | split: validation
109 | keys: [image, objects_bbox, file_name, annotations]
110 | no_tokens: 1024
111 | target_image_size: 256
112 | min_object_area: 0.02
113 | min_objects_per_image: 3
114 | max_objects_per_image: 8
115 | crop_method: center
116 | random_flip: false
117 | use_group_parameter: true
118 | encode_crop: true
119 | test:
120 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
121 | params:
122 | data_path: ../datasets/coco/2017 # substitute with path to full dataset
123 | img_id_file: ../datasets/coco/2017/annotations/deprecated-challenge2017/val-ids.txt
124 | split: validation
125 | keys: [image, objects_bbox, file_name, annotations]
126 | no_tokens: 1024
127 | target_image_size: 256
128 | min_object_area: 0.02
129 | min_objects_per_image: 3
130 | max_objects_per_image: 8
131 | crop_method: center
132 | random_flip: false
133 | use_group_parameter: true
134 | encode_crop: true
135 |
136 | lightning:
137 | callbacks:
138 | image_logger:
139 | target: main.ImageLogger
140 | params:
141 | batch_frequency: 1000
142 | max_images: 99
143 | increase_log_steps: False
144 |
145 | trainer:
146 | benchmark: True
147 | max_epochs: 300
148 |
149 |
--------------------------------------------------------------------------------
/configs/ldm/coco_stuff_ldm_T2I_vqgan_f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.e-6 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | first_stage_key: image
6 | cond_stage_key: caption
7 | linear_start: 0.0015
8 | linear_end: 0.0155
9 | num_timesteps_cond: 1
10 | log_every_t: 100
11 | timesteps: 1000
12 | loss_type: l1
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: true
16 | conditioning_key: crossattn
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 32
23 | in_channels: 4
24 | out_channels: 4
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_head_channels: 32
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 640
40 |
41 | first_stage_config:
42 | target: taming.models.vqgan.VQModelInterface
43 | params:
44 | ckpt_path: exp/vqgan/vq-f8/model.ckpt
45 | embed_dim: 4
46 | n_embed: 16384
47 | ddconfig:
48 | double_z: False
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 2
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions:
61 | - 32
62 | dropout: 0.0
63 | vitconfig:
64 | embed_size: 256
65 | lossconfig:
66 | target: taming.modules.losses.DummyLoss
67 |
68 | cond_stage_config:
69 | target: ldm.modules.encoders.modules.BERTEmbedder
70 | params:
71 | n_embed: 640
72 | n_layer: 32
73 |
74 | plot_sample: False
75 | plot_inpaint: False
76 | plot_denoise_rows: False
77 | plot_progressive_rows: False
78 | plot_diffusion_rows: False
79 | plot_quantize_denoised: True
80 |
81 | data:
82 | target: main.DataModuleFromConfig
83 | params:
84 | batch_size: 4
85 | train:
86 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
87 | params:
88 | data_path: ../datasets/coco/2014 # substitute with path to full dataset
89 | caption_ann_path: ../datasets/coco/2014/annotations/captions_train2014.json
90 | use_stuff: False
91 | split: train
92 | keys: [image, caption, file_name, annotations]
93 | no_tokens: 1024
94 | target_image_size: 256
95 | min_object_area: 0.00001
96 | min_objects_per_image: 2
97 | max_objects_per_image: 30
98 | crop_method: random-1d
99 | random_flip: true
100 | use_group_parameter: true
101 | encode_crop: False
102 | validation:
103 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
104 | params:
105 | data_path: ../datasets/coco/2014 # substitute with path to full dataset
106 | caption_ann_path: ../datasets/coco/2014/annotations/captions_val2014.json
107 | use_stuff: False
108 | split: validation
109 | keys: [image, caption, file_name, annotations]
110 | no_tokens: 1024
111 | target_image_size: 256
112 | min_object_area: 0.00001
113 | min_objects_per_image: 2
114 | max_objects_per_image: 30
115 | crop_method: center
116 | random_flip: false
117 | use_group_parameter: true
118 | encode_crop: False
119 | test:
120 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
121 | params:
122 | data_path: ../datasets/coco/2014 # substitute with path to full dataset
123 | caption_ann_path: ../datasets/coco/2014/annotations/captions_val2014.json
124 | use_stuff: False
125 | split: validation
126 | keys: [image, objects, caption, file_name, annotations]
127 | no_tokens: 1024
128 | target_image_size: 256
129 | min_object_area: 0.00001
130 | min_objects_per_image: 2
131 | max_objects_per_image: 30
132 | crop_method: center
133 | random_flip: false
134 | use_group_parameter: true
135 | encode_crop: false
136 |
137 | lightning:
138 | callbacks:
139 | image_logger:
140 | target: main.ImageLogger
141 | params:
142 | batch_frequency: 1000
143 | max_images: 99
144 | increase_log_steps: False
145 |
146 | trainer:
147 | benchmark: True
148 | max_epochs: 300
149 |
150 |
--------------------------------------------------------------------------------
/configs/vqgan/coco_vqgan_f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 16384
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions: [32]
22 | dropout: 0.0
23 | lossconfig:
24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
25 | params:
26 | disc_conditional: false
27 | disc_in_channels: 3
28 | disc_num_layers: 2
29 | disc_start: 1
30 | disc_weight: 0.6
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 4
37 | num_workers: 24
38 | train:
39 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
40 | params:
41 | data_path: ../datasets/coco/2017 # substitute with path to full dataset
42 | split: train
43 | keys: [image, objects_bbox, file_name, annotations]
44 | no_tokens: 1024
45 | target_image_size: 256
46 | min_object_area: 0.00001
47 | min_objects_per_image: 2
48 | max_objects_per_image: 30
49 | crop_method: random-1d
50 | random_flip: true
51 | use_group_parameter: true
52 | encode_crop: true
53 | validation:
54 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
55 | params:
56 | data_path: ../datasets/coco/2017 # substitute with path to full dataset
57 | split: validation
58 | keys: [image, objects_bbox, file_name, annotations]
59 | no_tokens: 1024
60 | target_image_size: 256
61 | min_object_area: 0.00001
62 | min_objects_per_image: 2
63 | max_objects_per_image: 30
64 | crop_method: center
65 | random_flip: false
66 | use_group_parameter: true
67 | encode_crop: true
68 | test:
69 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
70 | params:
71 | data_path: ../datasets/coco/2017 # substitute with path to full dataset
72 | split: validation
73 | keys: [image, objects_bbox, file_name, annotations]
74 | no_tokens: 1024
75 | target_image_size: 256
76 | min_object_area: 0.0000001
77 | min_objects_per_image: 2
78 | max_objects_per_image: 30
79 | crop_method: center
80 | random_flip: false
81 | use_group_parameter: true
82 | encode_crop: true
83 |
84 | lightning:
85 | callbacks:
86 | image_logger:
87 | target: main.ImageLogger
88 | params:
89 | batch_frequency: 1000
90 | max_images: 99
91 | increase_log_steps: False
92 |
93 | lightning:
94 | trainer:
95 | max_epochs: 50
96 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: ldm_layout
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=10.2
9 | - pytorch=1.7.0
10 | - torchvision=0.8.1
11 | - numpy=1.19.2
12 | - pip:
13 | - opencv-python==4.1.2.30
14 | - albumentations==0.4.3
15 | - einops==0.3.0
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - matplotlib==3.5.1
19 | - matplotlib-inline==0.1.3
20 | - more-itertools==8.12.0
21 | - numpy==1.19.2
22 | - omegaconf==2.0.0
23 | - opencv-python-headless==4.1.2.30
24 | - pandas==1.4.1
25 | - Pillow==9.0.1
26 | - pudb==2019.2
27 | - Pygments==2.11.2
28 | - Pympler==1.0.1
29 | - python-dateutil==2.8.2
30 | - torch-fidelity==0.3.0
31 | - pytorch-lightning==1.0.8
32 | - PyYAML==6.0
33 | - requests==2.27.1
34 | - requests-oauthlib==1.3.1
35 | - scikit-image==0.19.2
36 | - scipy==1.8.0
37 | - setuptools==58.0.4
38 | - six==1.16.0
39 | - streamlit==1.7.0
40 | - terminado==0.13.3
41 | - test-tube==0.7.5
42 | - timm==0.4.5
43 | - tokenizers==0.10.3
44 | - tornado==6.1
45 | - tqdm==4.63.0
46 | - transformers==4.3.1
47 | - typing-extensions==3.10.0.2
48 | - urllib3==1.26.9
49 | - -e .
--------------------------------------------------------------------------------
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidhalladay/ldm_layout/49d664158db0c9d51aa057494dd72b8669fe586c/ldm/data/__init__.py
--------------------------------------------------------------------------------
/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidhalladay/ldm_layout/49d664158db0c9d51aa057494dd72b8669fe586c/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of diffusion model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------
/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from ldm.modules.diffusionmodules.util import checkpoint
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 |
14 | def uniq(arr):
15 | return{el: True for el in arr}.keys()
16 |
17 |
18 | def default(val, d):
19 | if exists(val):
20 | return val
21 | return d() if isfunction(d) else d
22 |
23 |
24 | def max_neg_value(t):
25 | return -torch.finfo(t.dtype).max
26 |
27 |
28 | def init_(tensor):
29 | dim = tensor.shape[-1]
30 | std = 1 / math.sqrt(dim)
31 | tensor.uniform_(-std, std)
32 | return tensor
33 |
34 |
35 | # feedforward
36 | class GEGLU(nn.Module):
37 | def __init__(self, dim_in, dim_out):
38 | super().__init__()
39 | self.proj = nn.Linear(dim_in, dim_out * 2)
40 |
41 | def forward(self, x):
42 | x, gate = self.proj(x).chunk(2, dim=-1)
43 | return x * F.gelu(gate)
44 |
45 |
46 | class FeedForward(nn.Module):
47 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
48 | super().__init__()
49 | inner_dim = int(dim * mult)
50 | dim_out = default(dim_out, dim)
51 | project_in = nn.Sequential(
52 | nn.Linear(dim, inner_dim),
53 | nn.GELU()
54 | ) if not glu else GEGLU(dim, inner_dim)
55 |
56 | self.net = nn.Sequential(
57 | project_in,
58 | nn.Dropout(dropout),
59 | nn.Linear(inner_dim, dim_out)
60 | )
61 |
62 | def forward(self, x):
63 | return self.net(x)
64 |
65 |
66 | def zero_module(module):
67 | """
68 | Zero out the parameters of a module and return it.
69 | """
70 | for p in module.parameters():
71 | p.detach().zero_()
72 | return module
73 |
74 |
75 | def Normalize(in_channels):
76 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
77 |
78 |
79 | class LinearAttention(nn.Module):
80 | def __init__(self, dim, heads=4, dim_head=32):
81 | super().__init__()
82 | self.heads = heads
83 | hidden_dim = dim_head * heads
84 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
85 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
86 |
87 | def forward(self, x):
88 | b, c, h, w = x.shape
89 | qkv = self.to_qkv(x)
90 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
91 | k = k.softmax(dim=-1)
92 | context = torch.einsum('bhdn,bhen->bhde', k, v)
93 | out = torch.einsum('bhde,bhdn->bhen', context, q)
94 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
95 | return self.to_out(out)
96 |
97 |
98 | class SpatialSelfAttention(nn.Module):
99 | def __init__(self, in_channels):
100 | super().__init__()
101 | self.in_channels = in_channels
102 |
103 | self.norm = Normalize(in_channels)
104 | self.q = torch.nn.Conv2d(in_channels,
105 | in_channels,
106 | kernel_size=1,
107 | stride=1,
108 | padding=0)
109 | self.k = torch.nn.Conv2d(in_channels,
110 | in_channels,
111 | kernel_size=1,
112 | stride=1,
113 | padding=0)
114 | self.v = torch.nn.Conv2d(in_channels,
115 | in_channels,
116 | kernel_size=1,
117 | stride=1,
118 | padding=0)
119 | self.proj_out = torch.nn.Conv2d(in_channels,
120 | in_channels,
121 | kernel_size=1,
122 | stride=1,
123 | padding=0)
124 |
125 | def forward(self, x):
126 | h_ = x
127 | h_ = self.norm(h_)
128 | q = self.q(h_)
129 | k = self.k(h_)
130 | v = self.v(h_)
131 |
132 | # compute attention
133 | b,c,h,w = q.shape
134 | q = rearrange(q, 'b c h w -> b (h w) c')
135 | k = rearrange(k, 'b c h w -> b c (h w)')
136 | w_ = torch.einsum('bij,bjk->bik', q, k)
137 |
138 | w_ = w_ * (int(c)**(-0.5))
139 | w_ = torch.nn.functional.softmax(w_, dim=2)
140 |
141 | # attend to values
142 | v = rearrange(v, 'b c h w -> b c (h w)')
143 | w_ = rearrange(w_, 'b i j -> b j i')
144 | h_ = torch.einsum('bij,bjk->bik', v, w_)
145 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
146 | h_ = self.proj_out(h_)
147 |
148 | return x+h_
149 |
150 |
151 | class CrossAttention(nn.Module):
152 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
153 | super().__init__()
154 | inner_dim = dim_head * heads
155 | context_dim = default(context_dim, query_dim)
156 |
157 | self.scale = dim_head ** -0.5
158 | self.heads = heads
159 |
160 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
161 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
162 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
163 |
164 | self.to_out = nn.Sequential(
165 | nn.Linear(inner_dim, query_dim),
166 | nn.Dropout(dropout)
167 | )
168 |
169 | def forward(self, x, context=None, mask=None):
170 | h = self.heads
171 |
172 | q = self.to_q(x)
173 | context = default(context, x)
174 | k = self.to_k(context)
175 | v = self.to_v(context)
176 |
177 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
178 |
179 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
180 |
181 | if exists(mask):
182 | mask = rearrange(mask, 'b ... -> b (...)')
183 | max_neg_value = -torch.finfo(sim.dtype).max
184 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
185 | sim.masked_fill_(~mask, max_neg_value)
186 |
187 | # attention, what we cannot get enough of
188 | attn = sim.softmax(dim=-1)
189 |
190 | out = einsum('b i j, b j d -> b i d', attn, v)
191 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
192 | return self.to_out(out)
193 |
194 |
195 | class BasicTransformerBlock(nn.Module):
196 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, use_mscond=False):
197 | super().__init__()
198 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
199 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
200 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
201 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
202 | self.norm1 = nn.LayerNorm(dim)
203 | self.norm2 = nn.LayerNorm(dim)
204 | self.norm3 = nn.LayerNorm(dim)
205 | self.checkpoint = checkpoint
206 |
207 | if use_mscond:
208 | self.attn_prev = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=0.5)
209 | self.norm_prev = nn.LayerNorm(dim)
210 | self.attn_cross = CrossAttention(query_dim=dim, context_dim=dim,
211 | heads=n_heads, dim_head=d_head, dropout=0.2)
212 | self.norm_cross = nn.LayerNorm(dim)
213 |
214 | def forward(self, x, context=None, x_prev_stage=None):
215 |
216 | if x_prev_stage is None:
217 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
218 | else:
219 | return checkpoint(self._forward_w_prev, (x, context, x_prev_stage), self.parameters(), self.checkpoint)
220 |
221 | def _forward(self, x, context=None):
222 | x_length = x.shape[1]
223 | x_attn = self.attn1(self.norm1(x)) + x
224 | x_attn = self.attn2(self.norm2(x_attn), context=context) + x_attn
225 | x_attn = self.ff(self.norm3(x_attn)) + x_attn
226 | return x_attn
227 |
228 | def _forward_w_prev(self, x, context=None, x_prev_stage=None):
229 |
230 | x_length = x.shape[1]
231 | x_attn = self.attn1(self.norm1(x)) + x
232 |
233 | x_prev_stage = self.attn_prev(self.norm_prev(x_prev_stage)) + x_prev_stage
234 | x_attn = self.attn_cross(self.norm_cross(x_attn), context=x_prev_stage) + x_attn
235 |
236 | x_attn = self.attn2(self.norm2(x_attn), context=context) + x_attn
237 | x_attn = self.ff(self.norm3(x_attn)) + x_attn
238 |
239 | return x_attn
240 |
241 |
242 | class SpatialTransformer(nn.Module):
243 | """
244 | Transformer block for image-like data.
245 | First, project the input (aka embedding)
246 | and reshape to b, t, d.
247 | Then apply standard transformer action.
248 | Finally, reshape to image
249 | """
250 | def __init__(self, in_channels, channels_cond, n_heads, d_head,
251 | depth=1, dropout=0., context_dim=None, use_pos_embed=-1, use_mscond=False, mscond_dim=None):
252 | super().__init__()
253 | self.in_channels = in_channels
254 | self.use_pos_embed = use_pos_embed
255 | self.use_mscond = use_mscond
256 |
257 | inner_dim = n_heads * d_head
258 | self.norm = Normalize(in_channels)
259 |
260 | if use_pos_embed > 0:
261 | self.pos_embed = nn.Embedding(use_pos_embed, in_channels)
262 |
263 | self.proj_in = nn.Conv2d(in_channels,
264 | inner_dim,
265 | kernel_size=1,
266 | stride=1,
267 | padding=0)
268 |
269 | self.transformer_blocks = nn.ModuleList(
270 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_mscond=use_mscond)
271 | for d in range(depth)]
272 | )
273 |
274 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
275 | in_channels,
276 | kernel_size=1,
277 | stride=1,
278 | padding=0))
279 |
280 | if self.use_mscond:
281 | self.cond_proj_in = nn.Conv2d(mscond_dim,
282 | inner_dim,
283 | kernel_size=1,
284 | stride=1,
285 | padding=0)
286 |
287 | def forward(self, x, context=None, feat_cond=None):
288 | # note: if no context is given, cross-attention defaults to self-attention
289 | b, c, h, w = x.shape
290 | x_in = x
291 | x = self.norm(x)
292 |
293 | if feat_cond is not None and self.use_mscond:
294 | feat_cond = F.interpolate(feat_cond, size=x.size()[2:], mode='nearest')
295 | feat_cond = self.cond_proj_in(feat_cond)
296 | feat_cond = rearrange(feat_cond, 'b c h w -> b (h w) c')
297 |
298 | x = self.proj_in(x)
299 | x = rearrange(x, 'b c h w -> b (h w) c')
300 |
301 | if self.use_pos_embed > 0:
302 | pos_x = torch.arange(w)
303 | pos_y = torch.arange(h)
304 | grid_x, grid_y = torch.meshgrid(pos_x, pos_y)
305 | grid_x = grid_x.reshape(1, -1).repeat(b, 1).cuda()
306 | grid_y = grid_y.reshape(1, -1).repeat(b, 1).cuda()
307 | emb_pos_x = self.pos_embed(grid_x)
308 | emb_pos_y = self.pos_embed(grid_y)
309 | emb_pos = (emb_pos_x + emb_pos_y) / 2.
310 | x = x + emb_pos
311 |
312 | if feat_cond is not None and self.use_mscond:
313 | for block in self.transformer_blocks:
314 | x = block(x, context=context, x_prev_stage=feat_cond)
315 | else:
316 | for block in self.transformer_blocks:
317 | x = block(x, context=context)
318 |
319 |
320 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
321 | x = self.proj_out(x)
322 | return x + x_in
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidhalladay/ldm_layout/49d664158db0c9d51aa057494dd72b8669fe586c/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidhalladay/ldm_layout/49d664158db0c9d51aa057494dd72b8669fe586c/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay, torch.true_divide((1 + self.num_updates), (10 + self.num_updates)))
31 | # decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
32 |
33 | one_minus_decay = 1.0 - decay
34 |
35 | with torch.no_grad():
36 | m_param = dict(model.named_parameters())
37 | shadow_params = dict(self.named_buffers())
38 |
39 | for key in m_param:
40 | if m_param[key].requires_grad:
41 | sname = self.m_name2s_name[key]
42 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
43 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
44 | else:
45 | assert not key in self.m_name2s_name
46 |
47 | def copy_to(self, model):
48 | m_param = dict(model.named_parameters())
49 | shadow_params = dict(self.named_buffers())
50 | for key in m_param:
51 | if m_param[key].requires_grad:
52 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
53 | else:
54 | assert not key in self.m_name2s_name
55 |
56 | def store(self, parameters):
57 | """
58 | Save the current parameters for restoring later.
59 | Args:
60 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
61 | temporarily stored.
62 | """
63 | self.collected_params = [param.clone() for param in parameters]
64 |
65 | def restore(self, parameters):
66 | """
67 | Restore the parameters stored with the `store` method.
68 | Useful to validate the model with EMA parameters without affecting the
69 | original optimization process. Store the parameters before the
70 | `copy_to` method. After validation (or model saving), use this to
71 | restore the former parameters.
72 | Args:
73 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
74 | updated with the stored parameters.
75 | """
76 | for c_param, param in zip(self.collected_params, parameters):
77 | param.data.copy_(c_param.data)
78 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidhalladay/ldm_layout/49d664158db0c9d51aa057494dd72b8669fe586c/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | from transformers import CLIPTokenizer, CLIPTextModel
7 | import kornia
8 |
9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, multilabel=False, padding_idx=1023, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.multilabel = multilabel
26 | self.embedding = nn.Embedding(n_classes, embed_dim)
27 |
28 | def forward(self, batch, key=None):
29 | if key is None:
30 | key = self.key
31 | # this is for use in crossattn
32 | if self.multilabel:
33 | c = batch[key].cuda()
34 | c = self.embedding(c)
35 | c = c.max(-2)[0]
36 | else:
37 | c = batch[key][:, None].cuda()
38 | c = self.embedding(c)
39 | return c
40 |
41 |
42 | class TransformerEmbedder(AbstractEncoder):
43 | """Some transformer encoder layers"""
44 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
45 | super().__init__()
46 | self.device = device
47 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
48 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
49 |
50 | def forward(self, tokens):
51 | tokens = tokens.to(self.device) # meh
52 | z = self.transformer(tokens, return_embeddings=True)
53 | return z
54 |
55 | def encode(self, x):
56 | return self(x)
57 |
58 |
59 | class BERTTokenizer(AbstractEncoder):
60 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
61 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
62 | super().__init__()
63 | from transformers import BertTokenizerFast # TODO: add to reuquirements
64 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
65 | self.device = device
66 | self.vq_interface = vq_interface
67 | self.max_length = max_length
68 |
69 | def forward(self, text):
70 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
71 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
72 | tokens = batch_encoding["input_ids"].to(self.device)
73 | return tokens
74 |
75 | @torch.no_grad()
76 | def encode(self, text):
77 | tokens = self(text)
78 | if not self.vq_interface:
79 | return tokens
80 | return None, None, [None, None, tokens]
81 |
82 | def decode(self, text):
83 | return text
84 |
85 | class BERTEmbedder(AbstractEncoder):
86 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
87 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
88 | device="cuda",use_tokenizer=True, embedding_dropout=0.0, cond_key=''):
89 | super().__init__()
90 | self.use_tknz_fn = use_tokenizer
91 | self.cond_key = cond_key
92 | if self.use_tknz_fn:
93 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
94 | self.device = device
95 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
96 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
97 | emb_dropout=embedding_dropout)
98 |
99 | def forward(self, text, return_token=False):
100 | if self.use_tknz_fn:
101 | tokens = self.tknz_fn(text).to(self.device)
102 | else:
103 | if self.cond_key != '':
104 | text = text[self.cond_key].cuda()
105 | tokens = text.long()
106 |
107 | z = self.transformer(tokens, return_embeddings=True)
108 | if return_token:
109 | return z, tokens
110 | return z
111 |
112 | def encode(self, text):
113 | # output of length 77
114 | return self(text)
115 |
116 | class BERTEmbedderVQTInterface(BERTTokenizer):
117 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
118 | super().__init__(device=device, vq_interface=vq_interface, max_length=max_length)
119 |
120 | def encode(self, c):
121 | tokens = self(c)
122 | return c, None, [None,None,tokens]
123 |
124 | def decode(self, c):
125 | return c
126 |
127 | class SpatialRescaler(nn.Module):
128 | def __init__(self,
129 | n_stages=1,
130 | method='bilinear',
131 | multiplier=0.5,
132 | in_channels=3,
133 | out_channels=None,
134 | bias=False):
135 | super().__init__()
136 | self.n_stages = n_stages
137 | assert self.n_stages >= 0
138 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
139 | self.multiplier = multiplier
140 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
141 | self.remap_output = out_channels is not None
142 | if self.remap_output:
143 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
144 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
145 |
146 | def forward(self,x):
147 | for stage in range(self.n_stages):
148 | x = self.interpolator(x, scale_factor=self.multiplier)
149 |
150 |
151 | if self.remap_output:
152 | x = self.channel_mapper(x)
153 | return x
154 |
155 | def encode(self, x):
156 | return self(x)
157 |
158 |
159 | class FrozenCLIPEmbedder(AbstractEncoder):
160 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
161 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
162 | super().__init__()
163 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
164 | self.transformer = CLIPTextModel.from_pretrained(version)
165 | self.device = device
166 | self.max_length = max_length
167 | self.freeze()
168 |
169 | def freeze(self):
170 | self.transformer = self.transformer.eval()
171 | for param in self.parameters():
172 | param.requires_grad = False
173 |
174 | def forward(self, text):
175 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
176 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
177 | tokens = batch_encoding["input_ids"].to(self.device)
178 | outputs = self.transformer(input_ids=tokens)
179 |
180 | z = outputs.last_hidden_state
181 | return z
182 |
183 | def encode(self, text):
184 | return self(text)
185 |
186 |
187 | class FrozenCLIPTextEmbedder(nn.Module):
188 | """
189 | Uses the CLIP transformer encoder for text.
190 | """
191 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
192 | super().__init__()
193 | self.model, _ = clip.load(version, jit=False, device="cpu")
194 | self.device = device
195 | self.max_length = max_length
196 | self.n_repeat = n_repeat
197 | self.normalize = normalize
198 | self.use_tknz_fn = True
199 |
200 | def freeze(self):
201 | self.model = self.model.eval()
202 | for param in self.parameters():
203 | param.requires_grad = False
204 |
205 | def forward(self, text):
206 | tokens = clip.tokenize(text).to(self.device)
207 | z = self.model.encode_text(tokens)
208 | if self.normalize:
209 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
210 | return z
211 |
212 | def encode(self, text):
213 | z = self(text)
214 | if z.ndim==2:
215 | z = z[:, None, :]
216 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
217 | return z
218 |
219 |
220 | class FrozenClipImageEmbedder(nn.Module):
221 | """
222 | Uses the CLIP image encoder.
223 | """
224 | def __init__(
225 | self,
226 | model,
227 | jit=False,
228 | device='cuda' if torch.cuda.is_available() else 'cpu',
229 | antialias=False,
230 | ):
231 | super().__init__()
232 | self.model, _ = clip.load(name=model, device=device, jit=jit)
233 |
234 | self.antialias = antialias
235 |
236 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
237 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
238 |
239 | def preprocess(self, x):
240 | # normalize to [0,1]
241 | x = kornia.geometry.resize(x, (224, 224),
242 | interpolation='bicubic',align_corners=True,
243 | antialias=self.antialias)
244 | x = (x + 1.) / 2.
245 | # renormalize according to clip
246 | x = kornia.enhance.normalize(x, self.mean, self.std)
247 | return x
248 |
249 | def forward(self, x):
250 | # x is assumed to be in range [-1,1]
251 | return self.model.encode_image(self.preprocess(x))
252 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if codebook_loss is None:
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 |
6 | from inspect import isfunction
7 | from PIL import Image, ImageDraw, ImageFont
8 |
9 |
10 | def log_txt_as_img(wh, xc, size=10):
11 | # wh a tuple of (width, height)
12 | # xc a list of captions to plot
13 | b = len(xc)
14 | txts = list()
15 | for bi in range(b):
16 | txt = Image.new("RGB", wh, color="white")
17 | draw = ImageDraw.Draw(txt)
18 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
19 | nc = int(40 * (wh[0] / 256))
20 | if type(xc[bi]) is list:
21 | xc[bi] = '{}'.format(xc[bi])[1:-1]
22 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
23 |
24 | try:
25 | draw.text((0, 0), lines, fill="black", font=font)
26 | except UnicodeEncodeError:
27 | print("Cant encode string for logging. Skipping.")
28 |
29 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
30 | txts.append(txt)
31 | txts = np.stack(txts)
32 | txts = torch.tensor(txts)
33 | return txts
34 |
35 |
36 | def ismap(x):
37 | if not isinstance(x, torch.Tensor):
38 | return False
39 | return (len(x.shape) == 4) and (x.shape[1] > 3)
40 |
41 |
42 | def isimage(x):
43 | if not isinstance(x,torch.Tensor):
44 | return False
45 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
46 |
47 |
48 | def exists(x):
49 | return x is not None
50 |
51 |
52 | def default(val, d):
53 | if exists(val):
54 | return val
55 | return d() if isfunction(d) else d
56 |
57 |
58 | def mean_flat(tensor):
59 | """
60 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
61 | Take the mean over all non-batch dimensions.
62 | """
63 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
64 |
65 |
66 | def count_params(model, verbose=False):
67 | total_params = sum(p.numel() for p in model.parameters())
68 | if verbose:
69 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
70 | return total_params
71 |
72 |
73 | def instantiate_from_config(config):
74 | if not "target" in config:
75 | if config == '__is_first_stage__':
76 | return None
77 | elif config == "__is_unconditional__":
78 | return None
79 | raise KeyError("Expected key `target` to instantiate.")
80 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
81 |
82 | # TODO: clean this
83 | def instantiate_from_config_main(config, *args, **kwargs):
84 | if not "target" in config:
85 | raise KeyError("Expected key `target` to instantiate.")
86 | return get_obj_from_str(config["target"])(*args, **config.get("params", dict()), **kwargs)
87 |
88 |
89 | def get_obj_from_str(string, reload=False):
90 | module, cls = string.rsplit(".", 1)
91 | if reload:
92 | module_imp = importlib.import_module(module)
93 | importlib.reload(module_imp)
94 | return getattr(importlib.import_module(module, package=None), cls)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='ldm_layout',
5 | version='1.0.0',
6 | packages=find_packages(),
7 | install_requires=[
8 | 'torch',
9 | 'numpy',
10 | ],
11 | )
--------------------------------------------------------------------------------
/taming/data/annotated_objects_dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional, List, Callable, Dict, Any, Union
3 | import warnings
4 |
5 | import PIL.Image as pil_image
6 | from torch import Tensor
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 |
10 | from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder, ObjectsConditionalBuilder, CaptionsConditionalBuilder
11 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
12 | from taming.data.conditional_builder.utils import load_object_from_string
13 | from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
14 | from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
15 | Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
16 |
17 |
18 | class AnnotatedObjectsDataset(Dataset):
19 | def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
20 | min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
21 | crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
22 | encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
23 | no_object_classes: Optional[int] = None):
24 |
25 | self.data_path = data_path
26 | self.split = split
27 | self.keys = keys
28 | self.target_image_size = target_image_size
29 | self.min_object_area = min_object_area
30 | self.min_objects_per_image = min_objects_per_image
31 | self.max_objects_per_image = max_objects_per_image
32 | self.crop_method = crop_method
33 | self.random_flip = random_flip
34 | self.no_tokens = no_tokens
35 | self.use_group_parameter = use_group_parameter
36 | self.encode_crop = encode_crop
37 |
38 | self.annotations = None
39 | self.image_descriptions = None
40 | self.categories = None
41 | self.category_ids = None
42 | self.category_number = None
43 | self.image_ids = None
44 | self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
45 | self.paths = self.build_paths(self.data_path)
46 | self._conditional_builders = None
47 | self.category_allow_list = None
48 | if category_allow_list_target:
49 | allow_list = load_object_from_string(category_allow_list_target)
50 | self.category_allow_list = {name for name, _ in allow_list}
51 | self.category_mapping = {}
52 | if category_mapping_target:
53 | self.category_mapping = load_object_from_string(category_mapping_target)
54 | self.no_object_classes = no_object_classes
55 |
56 | def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
57 | top_level = Path(top_level)
58 | sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
59 | for path in sub_paths.values():
60 | if not path.exists():
61 | raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
62 | return sub_paths
63 |
64 | @staticmethod
65 | def load_image_from_disk(path: Path) -> Image:
66 | return pil_image.open(path).convert('RGB')
67 |
68 | @staticmethod
69 | def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
70 | transform_functions = []
71 | if crop_method == 'none':
72 | transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
73 | # transform_functions.extend([
74 | # transforms.Resize((target_image_size, target_image_size)),
75 | # CenterCropReturnCoordinates(target_image_size)
76 | # ])
77 | elif crop_method == 'center':
78 | transform_functions.extend([
79 | transforms.Resize(target_image_size),
80 | CenterCropReturnCoordinates(target_image_size)
81 | ])
82 | elif crop_method == 'random-1d':
83 | transform_functions.extend([
84 | transforms.Resize(target_image_size),
85 | RandomCrop1dReturnCoordinates(target_image_size)
86 | ])
87 | elif crop_method == 'random-2d':
88 | transform_functions.extend([
89 | Random2dCropReturnCoordinates(target_image_size),
90 | transforms.Resize(target_image_size)
91 | ])
92 | elif crop_method is None:
93 | return None
94 | else:
95 | raise ValueError(f'Received invalid crop method [{crop_method}].')
96 | if random_flip:
97 | transform_functions.append(RandomHorizontalFlipReturn())
98 | transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
99 | return transform_functions
100 |
101 | def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
102 | crop_bbox = None
103 | flipped = None
104 | for t in self.transform_functions:
105 | if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
106 | crop_bbox, x = t(x)
107 | elif isinstance(t, RandomHorizontalFlipReturn):
108 | flipped, x = t(x)
109 | else:
110 | x = t(x)
111 | return crop_bbox, flipped, x
112 |
113 | @property
114 | def no_classes(self) -> int:
115 | return self.no_object_classes if self.no_object_classes else len(self.categories)
116 |
117 | @property
118 | def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
119 | # cannot set this up in init because no_classes is only known after loading data in init of superclass
120 | if self._conditional_builders is None:
121 | self._conditional_builders = {
122 | 'objects_center_points': ObjectsCenterPointsConditionalBuilder(
123 | self.no_classes,
124 | self.max_objects_per_image,
125 | self.no_tokens,
126 | self.encode_crop,
127 | self.use_group_parameter,
128 | getattr(self, 'use_additional_parameters', False)
129 | ),
130 | 'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
131 | self.no_classes,
132 | self.max_objects_per_image,
133 | self.no_tokens,
134 | self.encode_crop,
135 | self.use_group_parameter,
136 | getattr(self, 'use_additional_parameters', False)
137 | ),
138 | 'objects': ObjectsConditionalBuilder(
139 | self.no_classes,
140 | self.max_objects_per_image,
141 | self.no_tokens,
142 | self.encode_crop,
143 | self.use_group_parameter,
144 | getattr(self, 'use_additional_parameters', False)
145 | ),
146 | # 'captions': CaptionsConditionalBuilder(
147 | # self.no_classes,
148 | # self.max_objects_per_image,
149 | # self.no_tokens,
150 | # self.encode_crop,
151 | # self.use_group_parameter,
152 | # getattr(self, 'use_additional_parameters', False)
153 | # ),
154 | }
155 | return self._conditional_builders
156 |
157 | def filter_categories(self) -> None:
158 | if self.category_allow_list:
159 | self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
160 | if self.category_mapping:
161 | self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
162 | try:
163 | print('Filting appending categories')
164 | if self.category_allow_list:
165 | self.categories_append = {id_: cat for id_, cat in self.categories_append.items() if cat.name in self.category_allow_list}
166 | if self.category_mapping:
167 | self.categories_append = {id_: cat for id_, cat in self.categories_append.items() if cat.id not in self.category_mapping}
168 | except:
169 | pass
170 |
171 | def setup_category_id_and_number(self) -> None:
172 | self.category_ids = list(self.categories.keys())
173 | self.category_ids.sort()
174 | if '/m/01s55n' in self.category_ids:
175 | self.category_ids.remove('/m/01s55n')
176 | self.category_ids.append('/m/01s55n')
177 | try:
178 | print('Adding appending categories into main one.')
179 | self.category_ids_append = list(self.categories_append.keys())
180 | self.category_ids_append.sort()
181 | self.category_ids += self.category_ids_append
182 | self.categories = {**self.categories, **self.categories_append}
183 | except:
184 | pass
185 | self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
186 | if self.category_allow_list is not None and self.category_mapping is None \
187 | and len(self.category_ids) != len(self.category_allow_list):
188 | warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
189 | 'Make sure all names in category_allow_list exist.')
190 |
191 | def clean_up_annotations_and_image_descriptions(self) -> None:
192 | image_id_set = set(self.image_ids)
193 | self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
194 | self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
195 |
196 | @staticmethod
197 | def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
198 | min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
199 | filtered = {}
200 | for image_id, annotations in all_annotations.items():
201 | annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
202 | if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
203 | filtered[image_id] = annotations_with_min_area
204 | return filtered
205 |
206 | def __len__(self):
207 | return len(self.image_ids)
208 |
209 | def __getitem__(self, n: int) -> Dict[str, Any]:
210 | image_id = self.get_image_id(n)
211 | sample = self.get_image_description(image_id)
212 | sample['annotations'] = self.get_annotation(image_id)
213 |
214 | if 'image' in self.keys:
215 | sample['image_path'] = str(self.get_image_path(image_id))
216 | sample['image'] = self.load_image_from_disk(sample['image_path'])
217 | sample['image'] = convert_pil_to_tensor(sample['image'])
218 | sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
219 | sample['image'] = sample['image'].permute(1, 2, 0)
220 |
221 | for conditional, builder in self.conditional_builders.items():
222 | if conditional in self.keys:
223 | sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
224 |
225 | if self.keys:
226 | # only return specified keys
227 | sample = {key: sample[key] for key in self.keys}
228 | return sample
229 |
230 | def get_image_id(self, no: int) -> str:
231 | return self.image_ids[no]
232 |
233 | def get_annotation(self, image_id: str) -> str:
234 | return self.annotations[image_id]
235 |
236 | def get_textual_label_for_category_id(self, category_id: str) -> str:
237 | return self.categories[category_id].name
238 |
239 | def get_textual_label_for_category_no(self, category_no: int) -> str:
240 | return self.categories[self.get_category_id(category_no)].name
241 |
242 | def get_category_number(self, category_id: str) -> int:
243 | return self.category_number[category_id]
244 |
245 | def get_category_id(self, category_no: int) -> str:
246 | return self.category_ids[category_no]
247 |
248 | def get_image_description(self, image_id: str) -> Dict[str, Any]:
249 | raise NotImplementedError()
250 |
251 | def get_path_structure(self):
252 | raise NotImplementedError
253 |
254 | def get_image_path(self, image_id: str) -> Path:
255 | raise NotImplementedError
256 |
--------------------------------------------------------------------------------
/taming/data/base.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import numpy as np
3 | import albumentations
4 | from PIL import Image
5 | from torch.utils.data import Dataset, ConcatDataset
6 |
7 |
8 | class ConcatDatasetWithIndex(ConcatDataset):
9 | """Modified from original pytorch code to return dataset idx"""
10 | def __getitem__(self, idx):
11 | if idx < 0:
12 | if -idx > len(self):
13 | raise ValueError("absolute value of index should not exceed dataset length")
14 | idx = len(self) + idx
15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
16 | if dataset_idx == 0:
17 | sample_idx = idx
18 | else:
19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
20 | return self.datasets[dataset_idx][sample_idx], dataset_idx
21 |
22 |
23 | class ImagePaths(Dataset):
24 | def __init__(self, paths, size=None, random_crop=False, labels=None):
25 | self.size = size
26 | self.random_crop = random_crop
27 |
28 | self.labels = dict() if labels is None else labels
29 | self.labels["file_path_"] = paths
30 | self._length = len(paths)
31 |
32 | if self.size is not None and self.size > 0:
33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
34 | if not self.random_crop:
35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
36 | else:
37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
39 | else:
40 | self.preprocessor = lambda **kwargs: kwargs
41 |
42 | def __len__(self):
43 | return self._length
44 |
45 | def preprocess_image(self, image_path):
46 | image = Image.open(image_path)
47 | if not image.mode == "RGB":
48 | image = image.convert("RGB")
49 | image = np.array(image).astype(np.uint8)
50 | image = self.preprocessor(image=image)["image"]
51 | image = (image/127.5 - 1.0).astype(np.float32)
52 | return image
53 |
54 | def __getitem__(self, i):
55 | example = dict()
56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i])
57 | for k in self.labels:
58 | example[k] = self.labels[k][i]
59 | return example
60 |
61 |
62 | class NumpyPaths(ImagePaths):
63 | def preprocess_image(self, image_path):
64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
65 | image = np.transpose(image, (1,2,0))
66 | image = Image.fromarray(image, mode="RGB")
67 | image = np.array(image).astype(np.uint8)
68 | image = self.preprocessor(image=image)["image"]
69 | image = (image/127.5 - 1.0).astype(np.float32)
70 | return image
71 |
--------------------------------------------------------------------------------
/taming/data/coco.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import albumentations
4 | import numpy as np
5 | from PIL import Image
6 | from tqdm import tqdm
7 | from torch.utils.data import Dataset
8 |
9 | from taming.data.sflckr import SegmentationBase # for examples included in repo
10 |
11 |
12 | class Examples(SegmentationBase):
13 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
14 | super().__init__(data_csv="data/coco_examples.txt",
15 | data_root="data/coco_images",
16 | segmentation_root="data/coco_segmentations",
17 | size=size, random_crop=random_crop,
18 | interpolation=interpolation,
19 | n_labels=183, shift_segmentation=True)
20 |
21 |
22 | class CocoBase(Dataset):
23 | """needed for (image, caption, segmentation) pairs"""
24 | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
25 | crop_size=None, force_no_crop=False, given_files=None):
26 | self.split = self.get_split()
27 | self.size = size
28 | if crop_size is None:
29 | self.crop_size = size
30 | else:
31 | self.crop_size = crop_size
32 |
33 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot
34 | self.stuffthing = use_stuffthing # include thing in segmentation
35 | if self.onehot and not self.stuffthing:
36 | raise NotImplemented("One hot mode is only supported for the "
37 | "stuffthings version because labels are stored "
38 | "a bit different.")
39 |
40 | data_json = datajson
41 | with open(data_json) as json_file:
42 | self.json_data = json.load(json_file)
43 | self.img_id_to_captions = dict()
44 | self.img_id_to_filepath = dict()
45 | self.img_id_to_segmentation_filepath = dict()
46 |
47 | assert data_json.split("/")[-1] in ["captions_train2017.json",
48 | "captions_val2017.json"]
49 | if self.stuffthing:
50 | self.segmentation_prefix = (
51 | "data/cocostuffthings/val2017" if
52 | data_json.endswith("captions_val2017.json") else
53 | "data/cocostuffthings/train2017")
54 | else:
55 | self.segmentation_prefix = (
56 | "data/coco/annotations/stuff_val2017_pixelmaps" if
57 | data_json.endswith("captions_val2017.json") else
58 | "data/coco/annotations/stuff_train2017_pixelmaps")
59 |
60 | imagedirs = self.json_data["images"]
61 | self.labels = {"image_ids": list()}
62 | for imgdir in tqdm(imagedirs, desc="ImgToPath"):
63 | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
64 | self.img_id_to_captions[imgdir["id"]] = list()
65 | pngfilename = imgdir["file_name"].replace("jpg", "png")
66 | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
67 | self.segmentation_prefix, pngfilename)
68 | if given_files is not None:
69 | if pngfilename in given_files:
70 | self.labels["image_ids"].append(imgdir["id"])
71 | else:
72 | self.labels["image_ids"].append(imgdir["id"])
73 |
74 | capdirs = self.json_data["annotations"]
75 | for capdir in tqdm(capdirs, desc="ImgToCaptions"):
76 | # there are in average 5 captions per image
77 | self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
78 |
79 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
80 | if self.split=="validation":
81 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
82 | else:
83 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
84 | self.preprocessor = albumentations.Compose(
85 | [self.rescaler, self.cropper],
86 | additional_targets={"segmentation": "image"})
87 | if force_no_crop:
88 | self.rescaler = albumentations.Resize(height=self.size, width=self.size)
89 | self.preprocessor = albumentations.Compose(
90 | [self.rescaler],
91 | additional_targets={"segmentation": "image"})
92 |
93 | def __len__(self):
94 | return len(self.labels["image_ids"])
95 |
96 | def preprocess_image(self, image_path, segmentation_path):
97 | image = Image.open(image_path)
98 | if not image.mode == "RGB":
99 | image = image.convert("RGB")
100 | image = np.array(image).astype(np.uint8)
101 |
102 | segmentation = Image.open(segmentation_path)
103 | if not self.onehot and not segmentation.mode == "RGB":
104 | segmentation = segmentation.convert("RGB")
105 | segmentation = np.array(segmentation).astype(np.uint8)
106 | if self.onehot:
107 | assert self.stuffthing
108 | # stored in caffe format: unlabeled==255. stuff and thing from
109 | # 0-181. to be compatible with the labels in
110 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt
111 | # we shift stuffthing one to the right and put unlabeled in zero
112 | # as long as segmentation is uint8 shifting to right handles the
113 | # latter too
114 | assert segmentation.dtype == np.uint8
115 | segmentation = segmentation + 1
116 |
117 | processed = self.preprocessor(image=image, segmentation=segmentation)
118 | image, segmentation = processed["image"], processed["segmentation"]
119 | image = (image / 127.5 - 1.0).astype(np.float32)
120 |
121 | if self.onehot:
122 | assert segmentation.dtype == np.uint8
123 | # make it one hot
124 | n_labels = 183
125 | flatseg = np.ravel(segmentation)
126 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
127 | onehot[np.arange(flatseg.size), flatseg] = True
128 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
129 | segmentation = onehot
130 | else:
131 | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
132 | return image, segmentation
133 |
134 | def __getitem__(self, i):
135 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
136 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
137 | image, segmentation = self.preprocess_image(img_path, seg_path)
138 | captions = self.img_id_to_captions[self.labels["image_ids"][i]]
139 | # randomly draw one of all available captions per image
140 | caption = captions[np.random.randint(0, len(captions))]
141 | example = {"image": image,
142 | "caption": [str(caption[0])],
143 | "segmentation": segmentation,
144 | "img_path": img_path,
145 | "seg_path": seg_path,
146 | "filename_": img_path.split(os.sep)[-1]
147 | }
148 | return example
149 |
150 |
151 | class CocoImagesAndCaptionsTrain(CocoBase):
152 | """returns a pair of (image, caption)"""
153 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
154 | super().__init__(size=size,
155 | dataroot="data/coco/train2017",
156 | datajson="data/coco/annotations/captions_train2017.json",
157 | onehot_segmentation=onehot_segmentation,
158 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
159 |
160 | def get_split(self):
161 | return "train"
162 |
163 |
164 | class CocoImagesAndCaptionsValidation(CocoBase):
165 | """returns a pair of (image, caption)"""
166 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
167 | given_files=None):
168 | super().__init__(size=size,
169 | dataroot="data/coco/val2017",
170 | datajson="data/coco/annotations/captions_val2017.json",
171 | onehot_segmentation=onehot_segmentation,
172 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
173 | given_files=given_files)
174 |
175 | def get_split(self):
176 | return "validation"
177 |
--------------------------------------------------------------------------------
/taming/data/conditional_builder/objects_bbox.py:
--------------------------------------------------------------------------------
1 | from itertools import cycle
2 | from typing import List, Tuple, Callable, Optional
3 |
4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
5 | from more_itertools.recipes import grouper
6 | from taming.data.image_transforms import convert_pil_to_tensor
7 | from torch import LongTensor, Tensor
8 |
9 | from taming.data.helper_types import BoundingBox, Annotation
10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
12 | pad_list, get_plot_font_size, absolute_bbox
13 |
14 |
15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
16 | @property
17 | def object_descriptor_length(self) -> int:
18 | return 3
19 |
20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
21 | object_triples = [
22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
23 | for ann in annotations
24 | ]
25 | empty_triple = (self.none, self.none, self.none)
26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
27 | return object_triples
28 |
29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
30 | conditional_list = conditional.tolist()
31 | crop_coordinates = None
32 | if self.encode_crop:
33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
34 | conditional_list = conditional_list[:-2]
35 | object_triples = grouper(conditional_list, 3)
36 | assert conditional.shape[0] == self.embedding_dim
37 | return [
38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
39 | for object_triple in object_triples if object_triple[0] != self.none
40 | ], crop_coordinates
41 |
42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
44 | plot = pil_image.new('RGB', figure_size, WHITE)
45 | draw = pil_img_draw.Draw(plot)
46 | # font = ImageFont.truetype(
47 | # "arial.ttf",
48 | # size=get_plot_font_size(font_size, figure_size)
49 | # )
50 | width, height = plot.size
51 | description, crop_coordinates = self.inverse_build(conditional)
52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
53 | annotation = self.representation_to_annotation(representation)
54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
55 | bbox = absolute_bbox(bbox, width, height)
56 | draw.rectangle(bbox, outline=color, width=line_width)
57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK) #, font=font)
58 | if crop_coordinates is not None:
59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
60 | return convert_pil_to_tensor(plot) / 127.5 - 1.
61 |
62 |
63 | class ObjectsConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
64 | @property
65 | def object_descriptor_length(self) -> int:
66 | return 1
67 |
68 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
69 | object_triples = [
70 | (self.object_representation(ann),)
71 | for ann in annotations
72 | ]
73 |
74 | empty_triple = (self.none,)
75 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
76 | return object_triples
77 |
78 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
79 | conditional_list = conditional.tolist()
80 | crop_coordinates = None
81 | # if self.encode_crop:
82 | # crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
83 | # conditional_list = conditional_list[:-2]
84 | object_triples = grouper(conditional_list, 1)
85 | assert conditional.shape[0] == self.embedding_dim
86 | return [
87 | (object_triple[0])
88 | for object_triple in object_triples if object_triple[0] != self.none
89 | ], crop_coordinates
90 |
91 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
92 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
93 |
94 | return 0
95 | # plot = pil_image.new('RGB', figure_size, WHITE)
96 | # draw = pil_img_draw.Draw(plot)
97 | # font = ImageFont.truetype(
98 | # "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
99 | # size=get_plot_font_size(font_size, figure_size)
100 | # )
101 | # width, height = plot.size
102 | # description, crop_coordinates = self.inverse_build(conditional)
103 | # for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
104 | # annotation = self.representation_to_annotation(representation)
105 | # class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
106 | # bbox = absolute_bbox(bbox, width, height)
107 | # draw.rectangle(bbox, outline=color, width=line_width)
108 | # draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
109 | # if crop_coordinates is not None:
110 | # draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
111 | # return convert_pil_to_tensor(plot) / 127.5 - 1.
112 |
113 | class CaptionsConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
114 | @property
115 | def object_descriptor_length(self) -> int:
116 | return 1
117 |
118 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
119 | object_triples = [
120 | (self.object_representation(ann),)
121 | for ann in annotations
122 | ]
123 | empty_triple = (self.none,)
124 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
125 | return object_triples
126 |
127 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
128 | conditional_list = conditional.tolist()
129 | crop_coordinates = None
130 | # if self.encode_crop:
131 | # crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
132 | # conditional_list = conditional_list[:-2]
133 | object_triples = grouper(conditional_list, 1)
134 | assert conditional.shape[0] == self.embedding_dim
135 | return [
136 | (object_triple[0])
137 | for object_triple in object_triples if object_triple[0] != self.none
138 | ], crop_coordinates
139 |
140 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
141 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
142 |
143 | return 0
144 | # plot = pil_image.new('RGB', figure_size, WHITE)
145 | # draw = pil_img_draw.Draw(plot)
146 | # font = ImageFont.truetype(
147 | # "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
148 | # size=get_plot_font_size(font_size, figure_size)
149 | # )
150 | # width, height = plot.size
151 | # description, crop_coordinates = self.inverse_build(conditional)
152 | # for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
153 | # annotation = self.representation_to_annotation(representation)
154 | # class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
155 | # bbox = absolute_bbox(bbox, width, height)
156 | # draw.rectangle(bbox, outline=color, width=line_width)
157 | # draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
158 | # if crop_coordinates is not None:
159 | # draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
160 | # return convert_pil_to_tensor(plot) / 127.5 - 1.
--------------------------------------------------------------------------------
/taming/data/conditional_builder/objects_center_points.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import warnings
4 | from itertools import cycle
5 | from typing import List, Optional, Tuple, Callable
6 |
7 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
8 | from more_itertools.recipes import grouper
9 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
10 | additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
11 | absolute_bbox, rescale_annotations
12 | from taming.data.helper_types import BoundingBox, Annotation
13 | from taming.data.image_transforms import convert_pil_to_tensor
14 | from torch import LongTensor, Tensor
15 |
16 |
17 | class ObjectsCenterPointsConditionalBuilder:
18 | def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
19 | use_group_parameter: bool, use_additional_parameters: bool):
20 | self.no_object_classes = no_object_classes
21 | self.no_max_objects = no_max_objects
22 | self.no_tokens = no_tokens
23 | self.encode_crop = encode_crop
24 | self.no_sections = int(math.sqrt(self.no_tokens))
25 | self.use_group_parameter = use_group_parameter
26 | self.use_additional_parameters = use_additional_parameters
27 |
28 | @property
29 | def none(self) -> int:
30 | return self.no_tokens - 1
31 |
32 | @property
33 | def object_descriptor_length(self) -> int:
34 | return 2
35 |
36 | @property
37 | def embedding_dim(self) -> int:
38 | extra_length = 2 if self.encode_crop else 0
39 | return self.no_max_objects * self.object_descriptor_length + extra_length
40 |
41 | def tokenize_coordinates(self, x: float, y: float) -> int:
42 | """
43 | Express 2d coordinates with one number.
44 | Example: assume self.no_tokens = 16, then no_sections = 4:
45 | 0 0 0 0
46 | 0 0 # 0
47 | 0 0 0 0
48 | 0 0 0 x
49 | Then the # position corresponds to token 6, the x position to token 15.
50 | @param x: float in [0, 1]
51 | @param y: float in [0, 1]
52 | @return: discrete tokenized coordinate
53 | """
54 | x_discrete = int(round(x * (self.no_sections - 1)))
55 | y_discrete = int(round(y * (self.no_sections - 1)))
56 | return y_discrete * self.no_sections + x_discrete
57 |
58 | def coordinates_from_token(self, token: int) -> (float, float):
59 | x = token % self.no_sections
60 | y = token // self.no_sections
61 | return x / (self.no_sections - 1), y / (self.no_sections - 1)
62 |
63 | def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
64 | x0, y0 = self.coordinates_from_token(token1)
65 | x1, y1 = self.coordinates_from_token(token2)
66 | return x0, y0, x1 - x0, y1 - y0
67 |
68 | def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
69 | return self.tokenize_coordinates(bbox[0], bbox[1]), \
70 | self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
71 |
72 | def inverse_build(self, conditional: LongTensor) \
73 | -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
74 | conditional_list = conditional.tolist()
75 | crop_coordinates = None
76 | if self.encode_crop:
77 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
78 | conditional_list = conditional_list[:-2]
79 | table_of_content = grouper(conditional_list, self.object_descriptor_length)
80 | assert conditional.shape[0] == self.embedding_dim
81 | return [
82 | (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
83 | for object_tuple in table_of_content if object_tuple[0] != self.none
84 | ], crop_coordinates
85 |
86 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
87 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
88 | plot = pil_image.new('RGB', figure_size, WHITE)
89 | draw = pil_img_draw.Draw(plot)
90 | circle_size = get_circle_size(figure_size)
91 | font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
92 | size=get_plot_font_size(font_size, figure_size))
93 | width, height = plot.size
94 | description, crop_coordinates = self.inverse_build(conditional)
95 | for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
96 | x_abs, y_abs = x * width, y * height
97 | ann = self.representation_to_annotation(representation)
98 | label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
99 | ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
100 | draw.ellipse(ellipse_bbox, fill=color, width=0)
101 | draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
102 | if crop_coordinates is not None:
103 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
104 | return convert_pil_to_tensor(plot) / 127.5 - 1.
105 |
106 | def object_representation(self, annotation: Annotation) -> int:
107 | modifier = 0
108 | if self.use_group_parameter:
109 | modifier |= 1 * (annotation.is_group_of is True)
110 | if self.use_additional_parameters:
111 | modifier |= 2 * (annotation.is_occluded is True)
112 | modifier |= 4 * (annotation.is_depiction is True)
113 | modifier |= 8 * (annotation.is_inside is True)
114 | return annotation.category_no + self.no_object_classes * modifier
115 |
116 | def representation_to_annotation(self, representation: int) -> Annotation:
117 | category_no = representation % self.no_object_classes
118 | modifier = representation // self.no_object_classes
119 | # noinspection PyTypeChecker
120 | return Annotation(
121 | area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
122 | category_no=category_no,
123 | is_group_of=bool((modifier & 1) * self.use_group_parameter),
124 | is_occluded=bool((modifier & 2) * self.use_additional_parameters),
125 | is_depiction=bool((modifier & 4) * self.use_additional_parameters),
126 | is_inside=bool((modifier & 8) * self.use_additional_parameters)
127 | )
128 |
129 | def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
130 | return list(self.token_pair_from_bbox(crop_coordinates))
131 |
132 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
133 | object_tuples = [
134 | (self.object_representation(a),
135 | self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
136 | for a in annotations
137 | ]
138 | empty_tuple = (self.none, self.none)
139 | object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
140 | return object_tuples
141 |
142 | def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
143 | -> LongTensor:
144 | if len(annotations) == 0:
145 | warnings.warn('Did not receive any annotations.')
146 | if len(annotations) > self.no_max_objects:
147 | warnings.warn('Received more annotations than allowed.')
148 | annotations = annotations[:self.no_max_objects]
149 |
150 | if not crop_coordinates:
151 | crop_coordinates = FULL_CROP
152 |
153 | random.shuffle(annotations)
154 | annotations = filter_annotations(annotations, crop_coordinates)
155 | if self.encode_crop:
156 | annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
157 | if horizontal_flip:
158 | crop_coordinates = horizontally_flip_bbox(crop_coordinates)
159 | extra = self._crop_encoder(crop_coordinates)
160 | else:
161 | annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
162 | extra = []
163 |
164 | object_tuples = self._make_object_descriptors(annotations)
165 |
166 | # flatten
167 | flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
168 | assert len(flattened) == self.embedding_dim
169 | assert all(0 <= value < self.no_tokens for value in flattened)
170 | return LongTensor(flattened)
171 |
--------------------------------------------------------------------------------
/taming/data/conditional_builder/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from typing import List, Any, Tuple, Optional
3 |
4 | from taming.data.helper_types import BoundingBox, Annotation
5 |
6 | # source: seaborn, color palette tab10
7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
9 | BLACK = (0, 0, 0)
10 | GRAY_75 = (63, 63, 63)
11 | GRAY_50 = (127, 127, 127)
12 | GRAY_25 = (191, 191, 191)
13 | WHITE = (255, 255, 255)
14 | FULL_CROP = (0., 0., 1., 1.)
15 |
16 |
17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
18 | """
19 | Give intersection area of two rectangles.
20 | @param rectangle1: (x0, y0, w, h) of first rectangle
21 | @param rectangle2: (x0, y0, w, h) of second rectangle
22 | """
23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
27 | return x_overlap * y_overlap
28 |
29 |
30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
32 |
33 |
34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
35 | bbox = relative_bbox
36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
38 |
39 |
40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
42 |
43 |
44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
45 | List[Annotation]:
46 | def clamp(x: float):
47 | return max(min(x, 1.), 0.)
48 |
49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0)
53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0)
54 | if flip:
55 | x0 = 1 - (x0 + w)
56 | return x0, y0, w, h
57 |
58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
59 |
60 |
61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
63 |
64 |
65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
66 | sl = slice(1) if short else slice(None)
67 | string = ''
68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
69 | return string
70 | if annotation.is_group_of:
71 | string += 'group'[sl] + ','
72 | if annotation.is_occluded:
73 | string += 'occluded'[sl] + ','
74 | if annotation.is_depiction:
75 | string += 'depiction'[sl] + ','
76 | if annotation.is_inside:
77 | string += 'inside'[sl]
78 | return '(' + string.strip(",") + ')'
79 |
80 |
81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
82 | if font_size is None:
83 | font_size = 10
84 | if max(figure_size) >= 256:
85 | font_size = 12
86 | if max(figure_size) >= 512:
87 | font_size = 15
88 | return font_size
89 |
90 |
91 | def get_circle_size(figure_size: Tuple[int, int]) -> int:
92 | circle_size = 2
93 | if max(figure_size) >= 256:
94 | circle_size = 3
95 | if max(figure_size) >= 512:
96 | circle_size = 4
97 | return circle_size
98 |
99 |
100 | def load_object_from_string(object_string: str) -> Any:
101 | """
102 | Source: https://stackoverflow.com/a/10773699
103 | """
104 | module_name, class_name = object_string.rsplit(".", 1)
105 | return getattr(importlib.import_module(module_name), class_name)
106 |
--------------------------------------------------------------------------------
/taming/data/custom.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import albumentations
4 | from torch.utils.data import Dataset
5 |
6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7 |
8 |
9 | class CustomBase(Dataset):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__()
12 | self.data = None
13 |
14 | def __len__(self):
15 | return len(self.data)
16 |
17 | def __getitem__(self, i):
18 | example = self.data[i]
19 | return example
20 |
21 |
22 |
23 | class CustomTrain(CustomBase):
24 | def __init__(self, size, training_images_list_file):
25 | super().__init__()
26 | with open(training_images_list_file, "r") as f:
27 | paths = f.read().splitlines()
28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
29 |
30 |
31 | class CustomTest(CustomBase):
32 | def __init__(self, size, test_images_list_file):
33 | super().__init__()
34 | with open(test_images_list_file, "r") as f:
35 | paths = f.read().splitlines()
36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
37 |
38 |
39 |
--------------------------------------------------------------------------------
/taming/data/helper_types.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple, Optional, NamedTuple, Union
2 | from PIL.Image import Image as pil_image
3 | from torch import Tensor
4 |
5 | try:
6 | from typing import Literal
7 | except ImportError:
8 | from typing_extensions import Literal
9 |
10 | Image = Union[Tensor, pil_image]
11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d']
13 | SplitType = Literal['train', 'validation', 'test']
14 |
15 |
16 | class ImageDescription(NamedTuple):
17 | id: int
18 | file_name: str
19 | original_size: Tuple[int, int] # w, h
20 | url: Optional[str] = None
21 | license: Optional[int] = None
22 | coco_url: Optional[str] = None
23 | date_captured: Optional[str] = None
24 | flickr_url: Optional[str] = None
25 | flickr_id: Optional[str] = None
26 | coco_id: Optional[str] = None
27 |
28 |
29 | class Category(NamedTuple):
30 | id: str
31 | super_category: Optional[str]
32 | name: str
33 |
34 |
35 | class Annotation(NamedTuple):
36 | area: float
37 | image_id: str
38 | bbox: BoundingBox
39 | category_no: int
40 | category_id: str
41 | id: Optional[int] = None
42 | source: Optional[str] = None
43 | confidence: Optional[float] = None
44 | is_group_of: Optional[bool] = None
45 | is_truncated: Optional[bool] = None
46 | is_occluded: Optional[bool] = None
47 | is_depiction: Optional[bool] = None
48 | is_inside: Optional[bool] = None
49 | segmentation: Optional[Dict] = None
50 |
--------------------------------------------------------------------------------
/taming/data/image_transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import warnings
3 | from typing import Union
4 |
5 | import torch
6 | from torch import Tensor
7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
8 | from torchvision.transforms.functional import _get_image_size as get_image_size
9 |
10 | from taming.data.helper_types import BoundingBox, Image
11 |
12 | pil_to_tensor = PILToTensor()
13 |
14 |
15 | def convert_pil_to_tensor(image: Image) -> Tensor:
16 | with warnings.catch_warnings():
17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
18 | warnings.simplefilter("ignore")
19 | return pil_to_tensor(image)
20 |
21 |
22 | class RandomCrop1dReturnCoordinates(RandomCrop):
23 | def forward(self, img: Image) -> (BoundingBox, Image):
24 | """
25 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
26 | Args:
27 | img (PIL Image or Tensor): Image to be cropped.
28 |
29 | Returns:
30 | Bounding box: x0, y0, w, h
31 | PIL Image or Tensor: Cropped image.
32 |
33 | Based on:
34 | torchvision.transforms.RandomCrop, torchvision 1.7.0
35 | """
36 | if self.padding is not None:
37 | img = F.pad(img, self.padding, self.fill, self.padding_mode)
38 |
39 | width, height = get_image_size(img)
40 | # pad the width if needed
41 | if self.pad_if_needed and width < self.size[1]:
42 | padding = [self.size[1] - width, 0]
43 | img = F.pad(img, padding, self.fill, self.padding_mode)
44 | # pad the height if needed
45 | if self.pad_if_needed and height < self.size[0]:
46 | padding = [0, self.size[0] - height]
47 | img = F.pad(img, padding, self.fill, self.padding_mode)
48 |
49 | i, j, h, w = self.get_params(img, self.size)
50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
51 | return bbox, F.crop(img, i, j, h, w)
52 |
53 |
54 | class Random2dCropReturnCoordinates(torch.nn.Module):
55 | """
56 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
57 | Args:
58 | img (PIL Image or Tensor): Image to be cropped.
59 |
60 | Returns:
61 | Bounding box: x0, y0, w, h
62 | PIL Image or Tensor: Cropped image.
63 |
64 | Based on:
65 | torchvision.transforms.RandomCrop, torchvision 1.7.0
66 | """
67 |
68 | def __init__(self, min_size: int):
69 | super().__init__()
70 | self.min_size = min_size
71 |
72 | def forward(self, img: Image) -> (BoundingBox, Image):
73 | width, height = get_image_size(img)
74 | max_size = min(width, height)
75 | if max_size <= self.min_size:
76 | size = max_size
77 | else:
78 | size = random.randint(self.min_size, max_size)
79 | top = random.randint(0, height - size)
80 | left = random.randint(0, width - size)
81 | bbox = left / width, top / height, size / width, size / height
82 | return bbox, F.crop(img, top, left, size, size)
83 |
84 |
85 | class CenterCropReturnCoordinates(CenterCrop):
86 | @staticmethod
87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
88 | if width > height:
89 | w = height / width
90 | h = 1.0
91 | x0 = 0.5 - w / 2
92 | y0 = 0.
93 | else:
94 | w = 1.0
95 | h = width / height
96 | x0 = 0.
97 | y0 = 0.5 - h / 2
98 | return x0, y0, w, h
99 |
100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
101 | """
102 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
103 | Args:
104 | img (PIL Image or Tensor): Image to be cropped.
105 |
106 | Returns:
107 | Bounding box: x0, y0, w, h
108 | PIL Image or Tensor: Cropped image.
109 | Based on:
110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
111 | """
112 | width, height = get_image_size(img)
113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
114 |
115 |
116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip):
117 | def forward(self, img: Image) -> (bool, Image):
118 | """
119 | Additionally to flipping, returns a boolean whether it was flipped or not.
120 | Args:
121 | img (PIL Image or Tensor): Image to be flipped.
122 |
123 | Returns:
124 | flipped: whether the image was flipped or not
125 | PIL Image or Tensor: Randomly flipped image.
126 |
127 | Based on:
128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
129 | """
130 | if torch.rand(1) < self.p:
131 | return True, F.hflip(img)
132 | return False, img
133 |
--------------------------------------------------------------------------------
/taming/data/open_images_helper.py:
--------------------------------------------------------------------------------
1 | open_images_unify_categories_for_coco = {
2 | '/m/03bt1vf': '/m/01g317',
3 | '/m/04yx4': '/m/01g317',
4 | '/m/05r655': '/m/01g317',
5 | '/m/01bl7v': '/m/01g317',
6 | '/m/0cnyhnx': '/m/01xq0k1',
7 | '/m/01226z': '/m/018xm',
8 | '/m/05ctyq': '/m/018xm',
9 | '/m/058qzx': '/m/04ctx',
10 | '/m/06pcq': '/m/0l515',
11 | '/m/03m3pdh': '/m/02crq1',
12 | '/m/046dlr': '/m/01x3z',
13 | '/m/0h8mzrc': '/m/01x3z',
14 | }
15 |
16 |
17 | top_300_classes_plus_coco_compatibility = [
18 | ('Man', 1060962),
19 | ('Clothing', 986610),
20 | ('Tree', 748162),
21 | ('Woman', 611896),
22 | ('Person', 610294),
23 | ('Human face', 442948),
24 | ('Girl', 175399),
25 | ('Building', 162147),
26 | ('Car', 159135),
27 | ('Plant', 155704),
28 | ('Human body', 137073),
29 | ('Flower', 133128),
30 | ('Window', 127485),
31 | ('Human arm', 118380),
32 | ('House', 114365),
33 | ('Wheel', 111684),
34 | ('Suit', 99054),
35 | ('Human hair', 98089),
36 | ('Human head', 92763),
37 | ('Chair', 88624),
38 | ('Boy', 79849),
39 | ('Table', 73699),
40 | ('Jeans', 57200),
41 | ('Tire', 55725),
42 | ('Skyscraper', 53321),
43 | ('Food', 52400),
44 | ('Footwear', 50335),
45 | ('Dress', 50236),
46 | ('Human leg', 47124),
47 | ('Toy', 46636),
48 | ('Tower', 45605),
49 | ('Boat', 43486),
50 | ('Land vehicle', 40541),
51 | ('Bicycle wheel', 34646),
52 | ('Palm tree', 33729),
53 | ('Fashion accessory', 32914),
54 | ('Glasses', 31940),
55 | ('Bicycle', 31409),
56 | ('Furniture', 30656),
57 | ('Sculpture', 29643),
58 | ('Bottle', 27558),
59 | ('Dog', 26980),
60 | ('Snack', 26796),
61 | ('Human hand', 26664),
62 | ('Bird', 25791),
63 | ('Book', 25415),
64 | ('Guitar', 24386),
65 | ('Jacket', 23998),
66 | ('Poster', 22192),
67 | ('Dessert', 21284),
68 | ('Baked goods', 20657),
69 | ('Drink', 19754),
70 | ('Flag', 18588),
71 | ('Houseplant', 18205),
72 | ('Tableware', 17613),
73 | ('Airplane', 17218),
74 | ('Door', 17195),
75 | ('Sports uniform', 17068),
76 | ('Shelf', 16865),
77 | ('Drum', 16612),
78 | ('Vehicle', 16542),
79 | ('Microphone', 15269),
80 | ('Street light', 14957),
81 | ('Cat', 14879),
82 | ('Fruit', 13684),
83 | ('Fast food', 13536),
84 | ('Animal', 12932),
85 | ('Vegetable', 12534),
86 | ('Train', 12358),
87 | ('Horse', 11948),
88 | ('Flowerpot', 11728),
89 | ('Motorcycle', 11621),
90 | ('Fish', 11517),
91 | ('Desk', 11405),
92 | ('Helmet', 10996),
93 | ('Truck', 10915),
94 | ('Bus', 10695),
95 | ('Hat', 10532),
96 | ('Auto part', 10488),
97 | ('Musical instrument', 10303),
98 | ('Sunglasses', 10207),
99 | ('Picture frame', 10096),
100 | ('Sports equipment', 10015),
101 | ('Shorts', 9999),
102 | ('Wine glass', 9632),
103 | ('Duck', 9242),
104 | ('Wine', 9032),
105 | ('Rose', 8781),
106 | ('Tie', 8693),
107 | ('Butterfly', 8436),
108 | ('Beer', 7978),
109 | ('Cabinetry', 7956),
110 | ('Laptop', 7907),
111 | ('Insect', 7497),
112 | ('Goggles', 7363),
113 | ('Shirt', 7098),
114 | ('Dairy Product', 7021),
115 | ('Marine invertebrates', 7014),
116 | ('Cattle', 7006),
117 | ('Trousers', 6903),
118 | ('Van', 6843),
119 | ('Billboard', 6777),
120 | ('Balloon', 6367),
121 | ('Human nose', 6103),
122 | ('Tent', 6073),
123 | ('Camera', 6014),
124 | ('Doll', 6002),
125 | ('Coat', 5951),
126 | ('Mobile phone', 5758),
127 | ('Swimwear', 5729),
128 | ('Strawberry', 5691),
129 | ('Stairs', 5643),
130 | ('Goose', 5599),
131 | ('Umbrella', 5536),
132 | ('Cake', 5508),
133 | ('Sun hat', 5475),
134 | ('Bench', 5310),
135 | ('Bookcase', 5163),
136 | ('Bee', 5140),
137 | ('Computer monitor', 5078),
138 | ('Hiking equipment', 4983),
139 | ('Office building', 4981),
140 | ('Coffee cup', 4748),
141 | ('Curtain', 4685),
142 | ('Plate', 4651),
143 | ('Box', 4621),
144 | ('Tomato', 4595),
145 | ('Coffee table', 4529),
146 | ('Office supplies', 4473),
147 | ('Maple', 4416),
148 | ('Muffin', 4365),
149 | ('Cocktail', 4234),
150 | ('Castle', 4197),
151 | ('Couch', 4134),
152 | ('Pumpkin', 3983),
153 | ('Computer keyboard', 3960),
154 | ('Human mouth', 3926),
155 | ('Christmas tree', 3893),
156 | ('Mushroom', 3883),
157 | ('Swimming pool', 3809),
158 | ('Pastry', 3799),
159 | ('Lavender (Plant)', 3769),
160 | ('Football helmet', 3732),
161 | ('Bread', 3648),
162 | ('Traffic sign', 3628),
163 | ('Common sunflower', 3597),
164 | ('Television', 3550),
165 | ('Bed', 3525),
166 | ('Cookie', 3485),
167 | ('Fountain', 3484),
168 | ('Paddle', 3447),
169 | ('Bicycle helmet', 3429),
170 | ('Porch', 3420),
171 | ('Deer', 3387),
172 | ('Fedora', 3339),
173 | ('Canoe', 3338),
174 | ('Carnivore', 3266),
175 | ('Bowl', 3202),
176 | ('Human eye', 3166),
177 | ('Ball', 3118),
178 | ('Pillow', 3077),
179 | ('Salad', 3061),
180 | ('Beetle', 3060),
181 | ('Orange', 3050),
182 | ('Drawer', 2958),
183 | ('Platter', 2937),
184 | ('Elephant', 2921),
185 | ('Seafood', 2921),
186 | ('Monkey', 2915),
187 | ('Countertop', 2879),
188 | ('Watercraft', 2831),
189 | ('Helicopter', 2805),
190 | ('Kitchen appliance', 2797),
191 | ('Personal flotation device', 2781),
192 | ('Swan', 2739),
193 | ('Lamp', 2711),
194 | ('Boot', 2695),
195 | ('Bronze sculpture', 2693),
196 | ('Chicken', 2677),
197 | ('Taxi', 2643),
198 | ('Juice', 2615),
199 | ('Cowboy hat', 2604),
200 | ('Apple', 2600),
201 | ('Tin can', 2590),
202 | ('Necklace', 2564),
203 | ('Ice cream', 2560),
204 | ('Human beard', 2539),
205 | ('Coin', 2536),
206 | ('Candle', 2515),
207 | ('Cart', 2512),
208 | ('High heels', 2441),
209 | ('Weapon', 2433),
210 | ('Handbag', 2406),
211 | ('Penguin', 2396),
212 | ('Rifle', 2352),
213 | ('Violin', 2336),
214 | ('Skull', 2304),
215 | ('Lantern', 2285),
216 | ('Scarf', 2269),
217 | ('Saucer', 2225),
218 | ('Sheep', 2215),
219 | ('Vase', 2189),
220 | ('Lily', 2180),
221 | ('Mug', 2154),
222 | ('Parrot', 2140),
223 | ('Human ear', 2137),
224 | ('Sandal', 2115),
225 | ('Lizard', 2100),
226 | ('Kitchen & dining room table', 2063),
227 | ('Spider', 1977),
228 | ('Coffee', 1974),
229 | ('Goat', 1926),
230 | ('Squirrel', 1922),
231 | ('Cello', 1913),
232 | ('Sushi', 1881),
233 | ('Tortoise', 1876),
234 | ('Pizza', 1870),
235 | ('Studio couch', 1864),
236 | ('Barrel', 1862),
237 | ('Cosmetics', 1841),
238 | ('Moths and butterflies', 1841),
239 | ('Convenience store', 1817),
240 | ('Watch', 1792),
241 | ('Home appliance', 1786),
242 | ('Harbor seal', 1780),
243 | ('Luggage and bags', 1756),
244 | ('Vehicle registration plate', 1754),
245 | ('Shrimp', 1751),
246 | ('Jellyfish', 1730),
247 | ('French fries', 1723),
248 | ('Egg (Food)', 1698),
249 | ('Football', 1697),
250 | ('Musical keyboard', 1683),
251 | ('Falcon', 1674),
252 | ('Candy', 1660),
253 | ('Medical equipment', 1654),
254 | ('Eagle', 1651),
255 | ('Dinosaur', 1634),
256 | ('Surfboard', 1630),
257 | ('Tank', 1628),
258 | ('Grape', 1624),
259 | ('Lion', 1624),
260 | ('Owl', 1622),
261 | ('Ski', 1613),
262 | ('Waste container', 1606),
263 | ('Frog', 1591),
264 | ('Sparrow', 1585),
265 | ('Rabbit', 1581),
266 | ('Pen', 1546),
267 | ('Sea lion', 1537),
268 | ('Spoon', 1521),
269 | ('Sink', 1512),
270 | ('Teddy bear', 1507),
271 | ('Bull', 1495),
272 | ('Sofa bed', 1490),
273 | ('Dragonfly', 1479),
274 | ('Brassiere', 1478),
275 | ('Chest of drawers', 1472),
276 | ('Aircraft', 1466),
277 | ('Human foot', 1463),
278 | ('Pig', 1455),
279 | ('Fork', 1454),
280 | ('Antelope', 1438),
281 | ('Tripod', 1427),
282 | ('Tool', 1424),
283 | ('Cheese', 1422),
284 | ('Lemon', 1397),
285 | ('Hamburger', 1393),
286 | ('Dolphin', 1390),
287 | ('Mirror', 1390),
288 | ('Marine mammal', 1387),
289 | ('Giraffe', 1385),
290 | ('Snake', 1368),
291 | ('Gondola', 1364),
292 | ('Wheelchair', 1360),
293 | ('Piano', 1358),
294 | ('Cupboard', 1348),
295 | ('Banana', 1345),
296 | ('Trumpet', 1335),
297 | ('Lighthouse', 1333),
298 | ('Invertebrate', 1317),
299 | ('Carrot', 1268),
300 | ('Sock', 1260),
301 | ('Tiger', 1241),
302 | ('Camel', 1224),
303 | ('Parachute', 1224),
304 | ('Bathroom accessory', 1223),
305 | ('Earrings', 1221),
306 | ('Headphones', 1218),
307 | ('Skirt', 1198),
308 | ('Skateboard', 1190),
309 | ('Sandwich', 1148),
310 | ('Saxophone', 1141),
311 | ('Goldfish', 1136),
312 | ('Stool', 1104),
313 | ('Traffic light', 1097),
314 | ('Shellfish', 1081),
315 | ('Backpack', 1079),
316 | ('Sea turtle', 1078),
317 | ('Cucumber', 1075),
318 | ('Tea', 1051),
319 | ('Toilet', 1047),
320 | ('Roller skates', 1040),
321 | ('Mule', 1039),
322 | ('Bust', 1031),
323 | ('Broccoli', 1030),
324 | ('Crab', 1020),
325 | ('Oyster', 1019),
326 | ('Cannon', 1012),
327 | ('Zebra', 1012),
328 | ('French horn', 1008),
329 | ('Grapefruit', 998),
330 | ('Whiteboard', 997),
331 | ('Zucchini', 997),
332 | ('Crocodile', 992),
333 |
334 | ('Clock', 960),
335 | ('Wall clock', 958),
336 |
337 | ('Doughnut', 869),
338 | ('Snail', 868),
339 |
340 | ('Baseball glove', 859),
341 |
342 | ('Panda', 830),
343 | ('Tennis racket', 830),
344 |
345 | ('Pear', 652),
346 |
347 | ('Bagel', 617),
348 | ('Oven', 616),
349 | ('Ladybug', 615),
350 | ('Shark', 615),
351 | ('Polar bear', 614),
352 | ('Ostrich', 609),
353 |
354 | ('Hot dog', 473),
355 | ('Microwave oven', 467),
356 | ('Fire hydrant', 20),
357 | ('Stop sign', 20),
358 | ('Parking meter', 20),
359 | ('Bear', 20),
360 | ('Flying disc', 20),
361 | ('Snowboard', 20),
362 | ('Tennis ball', 20),
363 | ('Kite', 20),
364 | ('Baseball bat', 20),
365 | ('Kitchen knife', 20),
366 | ('Knife', 20),
367 | ('Submarine sandwich', 20),
368 | ('Computer mouse', 20),
369 | ('Remote control', 20),
370 | ('Toaster', 20),
371 | ('Sink', 20),
372 | ('Refrigerator', 20),
373 | ('Alarm clock', 20),
374 | ('Wall clock', 20),
375 | ('Scissors', 20),
376 | ('Hair dryer', 20),
377 | ('Toothbrush', 20),
378 | ('Suitcase', 20)
379 | ]
380 |
--------------------------------------------------------------------------------
/taming/data/sflckr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import albumentations
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class SegmentationBase(Dataset):
10 | def __init__(self,
11 | data_csv, data_root, segmentation_root,
12 | size=None, random_crop=False, interpolation="bicubic",
13 | n_labels=182, shift_segmentation=False,
14 | ):
15 | self.n_labels = n_labels
16 | self.shift_segmentation = shift_segmentation
17 | self.data_csv = data_csv
18 | self.data_root = data_root
19 | self.segmentation_root = segmentation_root
20 | with open(self.data_csv, "r") as f:
21 | self.image_paths = f.read().splitlines()
22 | self._length = len(self.image_paths)
23 | self.labels = {
24 | "relative_file_path_": [l for l in self.image_paths],
25 | "file_path_": [os.path.join(self.data_root, l)
26 | for l in self.image_paths],
27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
28 | for l in self.image_paths]
29 | }
30 |
31 | size = None if size is not None and size<=0 else size
32 | self.size = size
33 | if self.size is not None:
34 | self.interpolation = interpolation
35 | self.interpolation = {
36 | "nearest": cv2.INTER_NEAREST,
37 | "bilinear": cv2.INTER_LINEAR,
38 | "bicubic": cv2.INTER_CUBIC,
39 | "area": cv2.INTER_AREA,
40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
42 | interpolation=self.interpolation)
43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
44 | interpolation=cv2.INTER_NEAREST)
45 | self.center_crop = not random_crop
46 | if self.center_crop:
47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
48 | else:
49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
50 | self.preprocessor = self.cropper
51 |
52 | def __len__(self):
53 | return self._length
54 |
55 | def __getitem__(self, i):
56 | example = dict((k, self.labels[k][i]) for k in self.labels)
57 | image = Image.open(example["file_path_"])
58 | if not image.mode == "RGB":
59 | image = image.convert("RGB")
60 | image = np.array(image).astype(np.uint8)
61 | if self.size is not None:
62 | image = self.image_rescaler(image=image)["image"]
63 | segmentation = Image.open(example["segmentation_path_"])
64 | assert segmentation.mode == "L", segmentation.mode
65 | segmentation = np.array(segmentation).astype(np.uint8)
66 | if self.shift_segmentation:
67 | # used to support segmentations containing unlabeled==255 label
68 | segmentation = segmentation+1
69 | if self.size is not None:
70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"]
71 | if self.size is not None:
72 | processed = self.preprocessor(image=image,
73 | mask=segmentation
74 | )
75 | else:
76 | processed = {"image": image,
77 | "mask": segmentation
78 | }
79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
80 | segmentation = processed["mask"]
81 | onehot = np.eye(self.n_labels)[segmentation]
82 | example["segmentation"] = onehot
83 | return example
84 |
85 |
86 | class Examples(SegmentationBase):
87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
88 | super().__init__(data_csv="data/sflckr_examples.txt",
89 | data_root="data/sflckr_images",
90 | segmentation_root="data/sflckr_segmentations",
91 | size=size, random_crop=random_crop, interpolation=interpolation)
92 |
--------------------------------------------------------------------------------
/taming/data/utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import tarfile
4 | import urllib
5 | import zipfile
6 | from pathlib import Path
7 |
8 | import numpy as np
9 | import torch
10 | from taming.data.helper_types import Annotation
11 | from torch._six import string_classes
12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
13 | from tqdm import tqdm
14 |
15 |
16 | def unpack(path):
17 | if path.endswith("tar.gz"):
18 | with tarfile.open(path, "r:gz") as tar:
19 | tar.extractall(path=os.path.split(path)[0])
20 | elif path.endswith("tar"):
21 | with tarfile.open(path, "r:") as tar:
22 | tar.extractall(path=os.path.split(path)[0])
23 | elif path.endswith("zip"):
24 | with zipfile.ZipFile(path, "r") as f:
25 | f.extractall(path=os.path.split(path)[0])
26 | else:
27 | raise NotImplementedError(
28 | "Unknown file extension: {}".format(os.path.splitext(path)[1])
29 | )
30 |
31 |
32 | def reporthook(bar):
33 | """tqdm progress bar for downloads."""
34 |
35 | def hook(b=1, bsize=1, tsize=None):
36 | if tsize is not None:
37 | bar.total = tsize
38 | bar.update(b * bsize - bar.n)
39 |
40 | return hook
41 |
42 |
43 | def get_root(name):
44 | base = "data/"
45 | root = os.path.join(base, name)
46 | os.makedirs(root, exist_ok=True)
47 | return root
48 |
49 |
50 | def is_prepared(root):
51 | return Path(root).joinpath(".ready").exists()
52 |
53 |
54 | def mark_prepared(root):
55 | Path(root).joinpath(".ready").touch()
56 |
57 |
58 | def prompt_download(file_, source, target_dir, content_dir=None):
59 | targetpath = os.path.join(target_dir, file_)
60 | while not os.path.exists(targetpath):
61 | if content_dir is not None and os.path.exists(
62 | os.path.join(target_dir, content_dir)
63 | ):
64 | break
65 | print(
66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
67 | )
68 | if content_dir is not None:
69 | print(
70 | "Or place its content into '{}'.".format(
71 | os.path.join(target_dir, content_dir)
72 | )
73 | )
74 | input("Press Enter when done...")
75 | return targetpath
76 |
77 |
78 | def download_url(file_, url, target_dir):
79 | targetpath = os.path.join(target_dir, file_)
80 | os.makedirs(target_dir, exist_ok=True)
81 | with tqdm(
82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
83 | ) as bar:
84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
85 | return targetpath
86 |
87 |
88 | def download_urls(urls, target_dir):
89 | paths = dict()
90 | for fname, url in urls.items():
91 | outpath = download_url(fname, url, target_dir)
92 | paths[fname] = outpath
93 | return paths
94 |
95 |
96 | def quadratic_crop(x, bbox, alpha=1.0):
97 | """bbox is xmin, ymin, xmax, ymax"""
98 | im_h, im_w = x.shape[:2]
99 | bbox = np.array(bbox, dtype=np.float32)
100 | bbox = np.clip(bbox, 0, max(im_h, im_w))
101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
102 | w = bbox[2] - bbox[0]
103 | h = bbox[3] - bbox[1]
104 | l = int(alpha * max(w, h))
105 | l = max(l, 2)
106 |
107 | required_padding = -1 * min(
108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
109 | )
110 | required_padding = int(np.ceil(required_padding))
111 | if required_padding > 0:
112 | padding = [
113 | [required_padding, required_padding],
114 | [required_padding, required_padding],
115 | ]
116 | padding += [[0, 0]] * (len(x.shape) - 2)
117 | x = np.pad(x, padding, "reflect")
118 | center = center[0] + required_padding, center[1] + required_padding
119 | xmin = int(center[0] - l / 2)
120 | ymin = int(center[1] - l / 2)
121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
122 |
123 |
124 | def custom_collate(batch):
125 | r"""source: pytorch 1.9.0, only one modification to original code """
126 |
127 | elem = batch[0]
128 | elem_type = type(elem)
129 | if isinstance(elem, torch.Tensor):
130 | out = None
131 | if torch.utils.data.get_worker_info() is not None:
132 | # If we're in a background process, concatenate directly into a
133 | # shared memory tensor to avoid an extra copy
134 | numel = sum([x.numel() for x in batch])
135 | storage = elem.storage()._new_shared(numel)
136 | out = elem.new(storage)
137 | return torch.stack(batch, 0, out=out)
138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
139 | and elem_type.__name__ != 'string_':
140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
141 | # array of string classes and object
142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype))
144 |
145 | return custom_collate([torch.as_tensor(b) for b in batch])
146 | elif elem.shape == (): # scalars
147 | return torch.as_tensor(batch)
148 | elif isinstance(elem, float):
149 | return torch.tensor(batch, dtype=torch.float64)
150 | elif isinstance(elem, int):
151 | return torch.tensor(batch)
152 | elif isinstance(elem, string_classes):
153 | return batch
154 | elif isinstance(elem, collections.abc.Mapping):
155 | return {key: custom_collate([d[key] for d in batch]) for key in elem}
156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
159 | return batch # added
160 | elif isinstance(elem, collections.abc.Sequence):
161 | # check to make sure that the elements in batch have consistent size
162 | it = iter(batch)
163 | elem_size = len(next(it))
164 | if not all(len(elem) == elem_size for elem in it):
165 | raise RuntimeError('each element in list of batch should be of equal size')
166 | transposed = zip(*batch)
167 | return [custom_collate(samples) for samples in transposed]
168 |
169 | raise TypeError(default_collate_err_msg_format.format(elem_type))
170 |
--------------------------------------------------------------------------------
/taming/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n):
33 | return self.schedule(n)
34 |
35 |
--------------------------------------------------------------------------------
/taming/models/dummy_cond_stage.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 |
4 | class DummyCondStage:
5 | def __init__(self, conditional_key):
6 | self.conditional_key = conditional_key
7 | self.train = None
8 |
9 | def eval(self):
10 | return self
11 |
12 | @staticmethod
13 | def encode(c: Tensor):
14 | return c, None, (None, None, c)
15 |
16 | @staticmethod
17 | def decode(c: Tensor):
18 | return c
19 |
20 | @staticmethod
21 | def to_rgb(c: Tensor):
22 | return c
23 |
--------------------------------------------------------------------------------
/taming/modules/discriminator/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 |
4 |
5 | from taming.modules.util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find('Conv') != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find('BatchNorm') != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class NLayerDiscriminator(nn.Module):
18 | """Defines a PatchGAN discriminator as in Pix2Pix
19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20 | """
21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22 | """Construct a PatchGAN discriminator
23 | Parameters:
24 | input_nc (int) -- the number of channels in input images
25 | ndf (int) -- the number of filters in the last conv layer
26 | n_layers (int) -- the number of conv layers in the discriminator
27 | norm_layer -- normalization layer
28 | """
29 | super(NLayerDiscriminator, self).__init__()
30 | if not use_actnorm:
31 | norm_layer = nn.BatchNorm2d
32 | else:
33 | norm_layer = ActNorm
34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35 | use_bias = norm_layer.func != nn.BatchNorm2d
36 | else:
37 | use_bias = norm_layer != nn.BatchNorm2d
38 |
39 | kw = 4
40 | padw = 1
41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42 | nf_mult = 1
43 | nf_mult_prev = 1
44 | for n in range(1, n_layers): # gradually increase the number of filters
45 | nf_mult_prev = nf_mult
46 | nf_mult = min(2 ** n, 8)
47 | sequence += [
48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49 | norm_layer(ndf * nf_mult),
50 | nn.LeakyReLU(0.2, True)
51 | ]
52 |
53 | nf_mult_prev = nf_mult
54 | nf_mult = min(2 ** n_layers, 8)
55 | sequence += [
56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57 | norm_layer(ndf * nf_mult),
58 | nn.LeakyReLU(0.2, True)
59 | ]
60 |
61 | sequence += [
62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63 | self.main = nn.Sequential(*sequence)
64 |
65 | def forward(self, input):
66 | """Standard forward."""
67 | return self.main(input)
68 |
--------------------------------------------------------------------------------
/taming/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from taming.modules.losses.vqperceptual import DummyLoss
2 |
3 |
--------------------------------------------------------------------------------
/taming/modules/losses/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from collections import namedtuple
7 |
8 | from taming.util import get_ckpt_path
9 |
10 |
11 | class LPIPS(nn.Module):
12 | # Learned perceptual metric
13 | def __init__(self, use_dropout=True):
14 | super().__init__()
15 | self.scaling_layer = ScalingLayer()
16 | self.chns = [64, 128, 256, 512, 512] # vg16 features
17 | self.net = vgg16(pretrained=True, requires_grad=False)
18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23 | self.load_from_pretrained()
24 | for param in self.parameters():
25 | param.requires_grad = False
26 |
27 | def load_from_pretrained(self, name="vgg_lpips"):
28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
31 |
32 | @classmethod
33 | def from_pretrained(cls, name="vgg_lpips"):
34 | if name != "vgg_lpips":
35 | raise NotImplementedError
36 | model = cls()
37 | ckpt = get_ckpt_path(name)
38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39 | return model
40 |
41 | def forward(self, input, target):
42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
44 | feats0, feats1, diffs = {}, {}, {}
45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46 | for kk in range(len(self.chns)):
47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49 |
50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51 | val = res[0]
52 | for l in range(1, len(self.chns)):
53 | val += res[l]
54 | return val
55 |
56 |
57 | class ScalingLayer(nn.Module):
58 | def __init__(self):
59 | super(ScalingLayer, self).__init__()
60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62 |
63 | def forward(self, inp):
64 | return (inp - self.shift) / self.scale
65 |
66 |
67 | class NetLinLayer(nn.Module):
68 | """ A single linear layer which does a 1x1 conv """
69 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
70 | super(NetLinLayer, self).__init__()
71 | layers = [nn.Dropout(), ] if (use_dropout) else []
72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73 | self.model = nn.Sequential(*layers)
74 |
75 |
76 | class vgg16(torch.nn.Module):
77 | def __init__(self, requires_grad=False, pretrained=True):
78 | super(vgg16, self).__init__()
79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80 | self.slice1 = torch.nn.Sequential()
81 | self.slice2 = torch.nn.Sequential()
82 | self.slice3 = torch.nn.Sequential()
83 | self.slice4 = torch.nn.Sequential()
84 | self.slice5 = torch.nn.Sequential()
85 | self.N_slices = 5
86 | for x in range(4):
87 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
88 | for x in range(4, 9):
89 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
90 | for x in range(9, 16):
91 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
92 | for x in range(16, 23):
93 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
94 | for x in range(23, 30):
95 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
96 | if not requires_grad:
97 | for param in self.parameters():
98 | param.requires_grad = False
99 |
100 | def forward(self, X):
101 | h = self.slice1(X)
102 | h_relu1_2 = h
103 | h = self.slice2(h)
104 | h_relu2_2 = h
105 | h = self.slice3(h)
106 | h_relu3_3 = h
107 | h = self.slice4(h)
108 | h_relu4_3 = h
109 | h = self.slice5(h)
110 | h_relu5_3 = h
111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113 | return out
114 |
115 |
116 | def normalize_tensor(x,eps=1e-10):
117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118 | return x/(norm_factor+eps)
119 |
120 |
121 | def spatial_average(x, keepdim=True):
122 | return x.mean([2,3],keepdim=keepdim)
123 |
124 |
--------------------------------------------------------------------------------
/taming/modules/losses/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class BCELoss(nn.Module):
6 | def forward(self, prediction, target):
7 | loss = F.binary_cross_entropy_with_logits(prediction,target)
8 | return loss, {}
9 |
10 |
11 | class BCELossWithQuant(nn.Module):
12 | def __init__(self, codebook_weight=1.):
13 | super().__init__()
14 | self.codebook_weight = codebook_weight
15 |
16 | def forward(self, qloss, target, prediction, split):
17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18 | loss = bce_loss + self.codebook_weight*qloss
19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20 | "{}/bce_loss".format(split): bce_loss.detach().mean(),
21 | "{}/quant_loss".format(split): qloss.detach().mean()
22 | }
23 |
--------------------------------------------------------------------------------
/taming/modules/losses/soft_cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def SoftCrossEntropy(inputs, target, reduction='sum'):
5 | log_likelihood = -F.log_softmax(inputs, dim=1)
6 | batch = inputs.shape[0]
7 | if reduction == 'average':
8 | loss = torch.sum(torch.mul(log_likelihood, target)) / batch
9 | else:
10 | loss = torch.sum(torch.mul(log_likelihood, target))
11 | return loss
--------------------------------------------------------------------------------
/taming/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import cv2
5 | import torchvision
6 |
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
9 |
10 |
11 | class DummyLoss(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 |
16 | def adopt_weight(weight, global_step, threshold=0, value=0.):
17 | if global_step < threshold:
18 | weight = value
19 | return weight
20 |
21 |
22 | def hinge_d_loss(logits_real, logits_fake):
23 | loss_real = torch.mean(F.relu(1. - logits_real))
24 | loss_fake = torch.mean(F.relu(1. + logits_fake))
25 | d_loss = 0.5 * (loss_real + loss_fake)
26 | return d_loss
27 |
28 |
29 | def vanilla_d_loss(logits_real, logits_fake):
30 | d_loss = 0.5 * (
31 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
32 | torch.mean(torch.nn.functional.softplus(logits_fake)))
33 | return d_loss
34 |
35 |
36 | class VQLPIPSWithDiscriminator(nn.Module):
37 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
38 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
39 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
40 | disc_ndf=64, disc_loss="hinge", aux_downscale=4.):
41 | super().__init__()
42 | assert disc_loss in ["hinge", "vanilla"]
43 | self.codebook_weight = codebook_weight
44 | self.pixel_weight = pixelloss_weight
45 | self.perceptual_loss = LPIPS().eval()
46 | self.perceptual_weight = perceptual_weight
47 | self.aux_downscale = aux_downscale
48 |
49 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
50 | n_layers=disc_num_layers,
51 | use_actnorm=use_actnorm,
52 | ndf=disc_ndf
53 | ).apply(weights_init)
54 | self.discriminator_iter_start = disc_start
55 | if disc_loss == "hinge":
56 | self.disc_loss = hinge_d_loss
57 | elif disc_loss == "vanilla":
58 | self.disc_loss = vanilla_d_loss
59 | else:
60 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
61 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
62 | self.disc_factor = disc_factor
63 | self.discriminator_weight = disc_weight
64 | self.disc_conditional = disc_conditional
65 |
66 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
67 | if last_layer is not None:
68 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
69 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
70 | else:
71 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
72 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
73 |
74 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
75 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
76 | d_weight = d_weight * self.discriminator_weight
77 | return d_weight
78 |
79 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
80 | global_step, last_layer=None, cond=None, split="train", xrec_aux=None):
81 |
82 | aux_downscale = self.aux_downscale
83 |
84 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
85 | if self.perceptual_weight > 0:
86 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
87 | rec_loss = rec_loss + self.perceptual_weight * p_loss
88 | else:
89 | p_loss = torch.tensor([0.0])
90 |
91 | if xrec_aux is not None:
92 | # print(aux_downscale)
93 | inputs_aux = F.interpolate(inputs, scale_factor=1./aux_downscale)
94 | inputs_aux = F.interpolate(inputs_aux, scale_factor=aux_downscale, mode='bilinear')
95 | # inputs_cv = torchvision.utils.make_grid(inputs)
96 | # inputs_aux_cv = torchvision.utils.make_grid(inputs_aux)
97 | # torchvision.utils.save_image(inputs_cv, "input.png")
98 | # torchvision.utils.save_image(inputs_aux_cv, "input_aux.png")
99 | rec_aux_loss = torch.abs(inputs_aux.contiguous() - xrec_aux.contiguous())
100 | rec_loss = rec_loss + 0.5 * rec_aux_loss
101 | else:
102 | rec_aux_loss = torch.tensor([0.0])
103 |
104 | nll_loss = rec_loss
105 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
106 | nll_loss = torch.mean(nll_loss)
107 |
108 | # now the GAN part
109 | if optimizer_idx == 0:
110 | # generator update
111 | if cond is None:
112 | assert not self.disc_conditional
113 | logits_fake = self.discriminator(reconstructions.contiguous())
114 | else:
115 | assert self.disc_conditional
116 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
117 | g_loss = -torch.mean(logits_fake)
118 |
119 | try:
120 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
121 | except RuntimeError:
122 | assert not self.training
123 | d_weight = torch.tensor(0.0)
124 |
125 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
126 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
127 |
128 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
129 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
130 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
131 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
132 | "{}/p_loss".format(split): p_loss.detach().mean(),
133 | "{}/rec_aux_loss".format(split): rec_aux_loss.detach().mean(),
134 | "{}/d_weight".format(split): d_weight.detach(),
135 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
136 | "{}/g_loss".format(split): g_loss.detach().mean(),
137 | }
138 | return loss, log
139 |
140 | if optimizer_idx == 1:
141 | # second pass for discriminator update
142 | if cond is None:
143 | logits_real = self.discriminator(inputs.contiguous().detach())
144 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
145 | else:
146 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
147 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
148 |
149 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
150 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
151 |
152 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
153 | "{}/logits_real".format(split): logits_real.detach().mean(),
154 | "{}/logits_fake".format(split): logits_fake.detach().mean()
155 | }
156 | return d_loss, log
157 |
--------------------------------------------------------------------------------
/taming/modules/misc/coord.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class CoordStage(object):
4 | def __init__(self, n_embed, down_factor):
5 | self.n_embed = n_embed
6 | self.down_factor = down_factor
7 |
8 | def eval(self):
9 | return self
10 |
11 | def encode(self, c):
12 | """fake vqmodel interface"""
13 | assert 0.0 <= c.min() and c.max() <= 1.0
14 | b,ch,h,w = c.shape
15 | assert ch == 1
16 |
17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18 | mode="area")
19 | c = c.clamp(0.0, 1.0)
20 | c = self.n_embed*c
21 | c_quant = c.round()
22 | c_ind = c_quant.to(dtype=torch.long)
23 |
24 | info = None, None, c_ind
25 | return c_quant, None, info
26 |
27 | def decode(self, c):
28 | c = c/self.n_embed
29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30 | mode="nearest")
31 | return c
32 |
--------------------------------------------------------------------------------
/taming/modules/misc/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import numpy as np
11 |
12 | import torch
13 |
14 | # --------------------------------------------------------
15 | # 2D sine-cosine position embedding
16 | # References:
17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18 | # MoCo v3: https://github.com/facebookresearch/moco-v3
19 | # --------------------------------------------------------
20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21 | """
22 | grid_size: int of the grid height and width
23 | return:
24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25 | """
26 | grid_h = np.arange(grid_size, dtype=np.float32)
27 | grid_w = np.arange(grid_size, dtype=np.float32)
28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
29 | grid = np.stack(grid, axis=0)
30 |
31 | grid = grid.reshape([2, 1, grid_size, grid_size])
32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33 | if cls_token:
34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35 | return pos_embed
36 |
37 |
38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39 | assert embed_dim % 2 == 0
40 |
41 | # use half of dimensions to encode grid_h
42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44 |
45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46 | return emb
47 |
48 |
49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50 | """
51 | embed_dim: output dimension for each position
52 | pos: a list of positions to be encoded: size (M,)
53 | out: (M, D)
54 | """
55 | assert embed_dim % 2 == 0
56 | omega = np.arange(embed_dim // 2, dtype=np.float)
57 | omega /= embed_dim / 2.
58 | omega = 1. / 10000**omega # (D/2,)
59 |
60 | pos = pos.reshape(-1) # (M,)
61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62 |
63 | emb_sin = np.sin(out) # (M, D/2)
64 | emb_cos = np.cos(out) # (M, D/2)
65 |
66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67 | return emb
68 |
69 |
70 | # --------------------------------------------------------
71 | # Interpolate position embeddings for high-resolution
72 | # References:
73 | # DeiT: https://github.com/facebookresearch/deit
74 | # --------------------------------------------------------
75 | def interpolate_pos_embed(model, checkpoint_model):
76 | if 'pos_embed' in checkpoint_model:
77 | pos_embed_checkpoint = checkpoint_model['pos_embed']
78 | embedding_size = pos_embed_checkpoint.shape[-1]
79 | num_patches = model.patch_embed.num_patches
80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81 | # height (== width) for the checkpoint position embedding
82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83 | # height (== width) for the new position embedding
84 | new_size = int(num_patches ** 0.5)
85 | # class_token and dist_token are kept unchanged
86 | if orig_size != new_size:
87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89 | # only the position tokens are interpolated
90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92 | pos_tokens = torch.nn.functional.interpolate(
93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96 | checkpoint_model['pos_embed'] = new_pos_embed
97 |
--------------------------------------------------------------------------------
/taming/modules/transformer/permuter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class AbstractPermuter(nn.Module):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__()
9 | def forward(self, x, reverse=False):
10 | raise NotImplementedError
11 |
12 |
13 | class Identity(AbstractPermuter):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def forward(self, x, reverse=False):
18 | return x
19 |
20 |
21 | class Subsample(AbstractPermuter):
22 | def __init__(self, H, W):
23 | super().__init__()
24 | C = 1
25 | indices = np.arange(H*W).reshape(C,H,W)
26 | while min(H, W) > 1:
27 | indices = indices.reshape(C,H//2,2,W//2,2)
28 | indices = indices.transpose(0,2,4,1,3)
29 | indices = indices.reshape(C*4,H//2, W//2)
30 | H = H//2
31 | W = W//2
32 | C = C*4
33 | assert H == W == 1
34 | idx = torch.tensor(indices.ravel())
35 | self.register_buffer('forward_shuffle_idx',
36 | nn.Parameter(idx, requires_grad=False))
37 | self.register_buffer('backward_shuffle_idx',
38 | nn.Parameter(torch.argsort(idx), requires_grad=False))
39 |
40 | def forward(self, x, reverse=False):
41 | if not reverse:
42 | return x[:, self.forward_shuffle_idx]
43 | else:
44 | return x[:, self.backward_shuffle_idx]
45 |
46 |
47 | def mortonify(i, j):
48 | """(i,j) index to linear morton code"""
49 | i = np.uint64(i)
50 | j = np.uint64(j)
51 |
52 | z = np.uint(0)
53 |
54 | for pos in range(32):
55 | z = (z |
56 | ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
57 | ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
58 | )
59 | return z
60 |
61 |
62 | class ZCurve(AbstractPermuter):
63 | def __init__(self, H, W):
64 | super().__init__()
65 | reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
66 | idx = np.argsort(reverseidx)
67 | idx = torch.tensor(idx)
68 | reverseidx = torch.tensor(reverseidx)
69 | self.register_buffer('forward_shuffle_idx',
70 | idx)
71 | self.register_buffer('backward_shuffle_idx',
72 | reverseidx)
73 |
74 | def forward(self, x, reverse=False):
75 | if not reverse:
76 | return x[:, self.forward_shuffle_idx]
77 | else:
78 | return x[:, self.backward_shuffle_idx]
79 |
80 |
81 | class SpiralOut(AbstractPermuter):
82 | def __init__(self, H, W):
83 | super().__init__()
84 | assert H == W
85 | size = W
86 | indices = np.arange(size*size).reshape(size,size)
87 |
88 | i0 = size//2
89 | j0 = size//2-1
90 |
91 | i = i0
92 | j = j0
93 |
94 | idx = [indices[i0, j0]]
95 | step_mult = 0
96 | for c in range(1, size//2+1):
97 | step_mult += 1
98 | # steps left
99 | for k in range(step_mult):
100 | i = i - 1
101 | j = j
102 | idx.append(indices[i, j])
103 |
104 | # step down
105 | for k in range(step_mult):
106 | i = i
107 | j = j + 1
108 | idx.append(indices[i, j])
109 |
110 | step_mult += 1
111 | if c < size//2:
112 | # step right
113 | for k in range(step_mult):
114 | i = i + 1
115 | j = j
116 | idx.append(indices[i, j])
117 |
118 | # step up
119 | for k in range(step_mult):
120 | i = i
121 | j = j - 1
122 | idx.append(indices[i, j])
123 | else:
124 | # end reached
125 | for k in range(step_mult-1):
126 | i = i + 1
127 | idx.append(indices[i, j])
128 |
129 | assert len(idx) == size*size
130 | idx = torch.tensor(idx)
131 | self.register_buffer('forward_shuffle_idx', idx)
132 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
133 |
134 | def forward(self, x, reverse=False):
135 | if not reverse:
136 | return x[:, self.forward_shuffle_idx]
137 | else:
138 | return x[:, self.backward_shuffle_idx]
139 |
140 |
141 | class SpiralIn(AbstractPermuter):
142 | def __init__(self, H, W):
143 | super().__init__()
144 | assert H == W
145 | size = W
146 | indices = np.arange(size*size).reshape(size,size)
147 |
148 | i0 = size//2
149 | j0 = size//2-1
150 |
151 | i = i0
152 | j = j0
153 |
154 | idx = [indices[i0, j0]]
155 | step_mult = 0
156 | for c in range(1, size//2+1):
157 | step_mult += 1
158 | # steps left
159 | for k in range(step_mult):
160 | i = i - 1
161 | j = j
162 | idx.append(indices[i, j])
163 |
164 | # step down
165 | for k in range(step_mult):
166 | i = i
167 | j = j + 1
168 | idx.append(indices[i, j])
169 |
170 | step_mult += 1
171 | if c < size//2:
172 | # step right
173 | for k in range(step_mult):
174 | i = i + 1
175 | j = j
176 | idx.append(indices[i, j])
177 |
178 | # step up
179 | for k in range(step_mult):
180 | i = i
181 | j = j - 1
182 | idx.append(indices[i, j])
183 | else:
184 | # end reached
185 | for k in range(step_mult-1):
186 | i = i + 1
187 | idx.append(indices[i, j])
188 |
189 | assert len(idx) == size*size
190 | idx = idx[::-1]
191 | idx = torch.tensor(idx)
192 | self.register_buffer('forward_shuffle_idx', idx)
193 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
194 |
195 | def forward(self, x, reverse=False):
196 | if not reverse:
197 | return x[:, self.forward_shuffle_idx]
198 | else:
199 | return x[:, self.backward_shuffle_idx]
200 |
201 |
202 | class Random(nn.Module):
203 | def __init__(self, H, W):
204 | super().__init__()
205 | indices = np.random.RandomState(1).permutation(H*W)
206 | idx = torch.tensor(indices.ravel())
207 | self.register_buffer('forward_shuffle_idx', idx)
208 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
209 |
210 | def forward(self, x, reverse=False):
211 | if not reverse:
212 | return x[:, self.forward_shuffle_idx]
213 | else:
214 | return x[:, self.backward_shuffle_idx]
215 |
216 |
217 | class AlternateParsing(AbstractPermuter):
218 | def __init__(self, H, W):
219 | super().__init__()
220 | indices = np.arange(W*H).reshape(H,W)
221 | for i in range(1, H, 2):
222 | indices[i, :] = indices[i, ::-1]
223 | idx = indices.flatten()
224 | assert len(idx) == H*W
225 | idx = torch.tensor(idx)
226 | self.register_buffer('forward_shuffle_idx', idx)
227 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
228 |
229 | def forward(self, x, reverse=False):
230 | if not reverse:
231 | return x[:, self.forward_shuffle_idx]
232 | else:
233 | return x[:, self.backward_shuffle_idx]
234 |
235 |
236 | if __name__ == "__main__":
237 | p0 = AlternateParsing(16, 16)
238 | print(p0.forward_shuffle_idx)
239 | print(p0.backward_shuffle_idx)
240 |
241 | x = torch.randint(0, 768, size=(11, 256))
242 | y = p0(x)
243 | xre = p0(y, reverse=True)
244 | assert torch.equal(x, xre)
245 |
246 | p1 = SpiralOut(2, 2)
247 | print(p1.forward_shuffle_idx)
248 | print(p1.backward_shuffle_idx)
249 |
--------------------------------------------------------------------------------
/taming/modules/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def count_params(model):
6 | total_params = sum(p.numel() for p in model.parameters())
7 | return total_params
8 |
9 |
10 | class ActNorm(nn.Module):
11 | def __init__(self, num_features, logdet=False, affine=True,
12 | allow_reverse_init=False):
13 | assert affine
14 | super().__init__()
15 | self.logdet = logdet
16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18 | self.allow_reverse_init = allow_reverse_init
19 |
20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21 |
22 | def initialize(self, input):
23 | with torch.no_grad():
24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25 | mean = (
26 | flatten.mean(1)
27 | .unsqueeze(1)
28 | .unsqueeze(2)
29 | .unsqueeze(3)
30 | .permute(1, 0, 2, 3)
31 | )
32 | std = (
33 | flatten.std(1)
34 | .unsqueeze(1)
35 | .unsqueeze(2)
36 | .unsqueeze(3)
37 | .permute(1, 0, 2, 3)
38 | )
39 |
40 | self.loc.data.copy_(-mean)
41 | self.scale.data.copy_(1 / (std + 1e-6))
42 |
43 | def forward(self, input, reverse=False):
44 | if reverse:
45 | return self.reverse(input)
46 | if len(input.shape) == 2:
47 | input = input[:,:,None,None]
48 | squeeze = True
49 | else:
50 | squeeze = False
51 |
52 | _, _, height, width = input.shape
53 |
54 | if self.training and self.initialized.item() == 0:
55 | self.initialize(input)
56 | self.initialized.fill_(1)
57 |
58 | h = self.scale * (input + self.loc)
59 |
60 | if squeeze:
61 | h = h.squeeze(-1).squeeze(-1)
62 |
63 | if self.logdet:
64 | log_abs = torch.log(torch.abs(self.scale))
65 | logdet = height*width*torch.sum(log_abs)
66 | logdet = logdet * torch.ones(input.shape[0]).to(input)
67 | return h, logdet
68 |
69 | return h
70 |
71 | def reverse(self, output):
72 | if self.training and self.initialized.item() == 0:
73 | if not self.allow_reverse_init:
74 | raise RuntimeError(
75 | "Initializing ActNorm in reverse direction is "
76 | "disabled by default. Use allow_reverse_init=True to enable."
77 | )
78 | else:
79 | self.initialize(output)
80 | self.initialized.fill_(1)
81 |
82 | if len(output.shape) == 2:
83 | output = output[:,:,None,None]
84 | squeeze = True
85 | else:
86 | squeeze = False
87 |
88 | h = output / self.scale - self.loc
89 |
90 | if squeeze:
91 | h = h.squeeze(-1).squeeze(-1)
92 | return h
93 |
94 |
95 | class AbstractEncoder(nn.Module):
96 | def __init__(self):
97 | super().__init__()
98 |
99 | def encode(self, *args, **kwargs):
100 | raise NotImplementedError
101 |
102 |
103 | class Labelator(AbstractEncoder):
104 | """Net2Net Interface for Class-Conditional Model"""
105 | def __init__(self, n_classes, quantize_interface=True):
106 | super().__init__()
107 | self.n_classes = n_classes
108 | self.quantize_interface = quantize_interface
109 |
110 | def encode(self, c):
111 | c = c[:,None]
112 | if self.quantize_interface:
113 | return c, None, [None, None, c.long()]
114 | return c
115 |
116 |
117 | class SOSProvider(AbstractEncoder):
118 | # for unconditional training
119 | def __init__(self, sos_token, quantize_interface=True):
120 | super().__init__()
121 | self.sos_token = sos_token
122 | self.quantize_interface = quantize_interface
123 |
124 | def encode(self, x):
125 | # get batch size from data and replicate sos_token
126 | c = torch.ones(x.shape[0], 1)*self.sos_token
127 | c = c.long().to(x.device)
128 | if self.quantize_interface:
129 | return c, None, [None, None, c]
130 | return c
131 |
--------------------------------------------------------------------------------
/taming/modules/vqvae/mapping.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import copy
9 | import datetime
10 | import os
11 | import random
12 | import time
13 | import timeit
14 | import warnings
15 | from collections import OrderedDict
16 |
17 | import numpy as np
18 | import torch
19 | import torch.nn as nn
20 | from torch.nn.functional import interpolate
21 | from torch.nn.modules.sparse import Embedding
22 |
23 | class PixelNormLayer(nn.Module):
24 | def __init__(self, epsilon=1e-8):
25 | super().__init__()
26 | self.epsilon = epsilon
27 |
28 | def forward(self, x):
29 | return x * torch.rsqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
30 |
31 |
32 | class EqualizedLinear(nn.Module):
33 | """Linear layer with equalized learning rate and custom learning rate multiplier."""
34 |
35 | def __init__(self, input_size, output_size, gain=2 ** 0.5, use_wscale=False, lrmul=1, bias=True):
36 | super().__init__()
37 | he_std = gain * input_size ** (-0.5) # He init
38 | # Equalized learning rate and custom learning rate multiplier.
39 | if use_wscale:
40 | init_std = 1.0 / lrmul
41 | self.w_mul = he_std * lrmul
42 | else:
43 | init_std = he_std / lrmul
44 | self.w_mul = lrmul
45 | self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)
46 | if bias:
47 | self.bias = torch.nn.Parameter(torch.zeros(output_size))
48 | self.b_mul = lrmul
49 | else:
50 | self.bias = None
51 |
52 | def forward(self, x):
53 | bias = self.bias
54 | if bias is not None:
55 | bias = bias * self.b_mul
56 | return F.linear(x, self.weight * self.w_mul, bias)
57 |
58 | class GMapping(nn.Module):
59 |
60 | def __init__(self, latent_size=512, dlatent_size=512, dlatent_broadcast=None,
61 | mapping_layers=8, mapping_fmaps=512, mapping_lrmul=0.01, mapping_nonlinearity='lrelu',
62 | use_wscale=True, normalize_latents=False, **kwargs):
63 | super().__init__()
64 |
65 | self.latent_size = latent_size
66 | self.mapping_fmaps = mapping_fmaps
67 | self.dlatent_size = dlatent_size
68 | self.dlatent_broadcast = dlatent_broadcast
69 |
70 | # Activation function.
71 | act, gain = {'relu': (torch.relu, np.sqrt(2)),
72 | 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[mapping_nonlinearity]
73 |
74 | # Embed labels and concatenate them with latents.
75 | # TODO
76 |
77 | layers = []
78 | # Normalize latents.
79 | if normalize_latents:
80 | layers.append(('pixel_norm', PixelNormLayer()))
81 |
82 | # Mapping layers. (apply_bias?)
83 | layers.append(('dense0', EqualizedLinear(self.latent_size, self.mapping_fmaps,
84 | gain=gain, lrmul=mapping_lrmul, use_wscale=use_wscale)))
85 | layers.append(('dense0_act', act))
86 | for layer_idx in range(1, mapping_layers):
87 | fmaps_in = self.mapping_fmaps
88 | fmaps_out = self.dlatent_size if layer_idx == mapping_layers - 1 else self.mapping_fmaps
89 | layers.append(
90 | ('dense{:d}'.format(layer_idx),
91 | EqualizedLinear(fmaps_in, fmaps_out, gain=gain, lrmul=mapping_lrmul, use_wscale=use_wscale)))
92 | layers.append(('dense{:d}_act'.format(layer_idx), act))
93 |
94 | # Output.
95 | self.map = nn.Sequential(OrderedDict(layers))
96 |
97 | def forward(self, x):
98 | # First input: Latent vectors (Z) [mini_batch, latent_size].
99 | x = self.map(x)
100 |
101 | # Broadcast -> batch_size * dlatent_broadcast * dlatent_size
102 | if self.dlatent_broadcast is not None:
103 | x = x.unsqueeze(1).expand(-1, self.dlatent_broadcast, -1)
104 | return x
105 |
106 | if __name__=='__main__':
107 |
108 | m = GMapping(latent_size=256, dlatent_size=256)
109 |
110 | x = torch.randn(10, 16, 16, 256)
111 | o=m(x)
112 | print(o.shape)
113 |
--------------------------------------------------------------------------------
/taming/util.py:
--------------------------------------------------------------------------------
1 | import os, hashlib
2 | import requests
3 | from tqdm import tqdm
4 |
5 | URL_MAP = {
6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7 | }
8 |
9 | CKPT_MAP = {
10 | "vgg_lpips": "vgg.pth"
11 | }
12 |
13 | MD5_MAP = {
14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15 | }
16 |
17 |
18 |
19 | def download(url, local_path, chunk_size=1024):
20 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
21 | with requests.get(url, stream=True) as r:
22 | total_size = int(r.headers.get("content-length", 0))
23 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
24 | with open(local_path, "wb") as f:
25 | for data in r.iter_content(chunk_size=chunk_size):
26 | if data:
27 | f.write(data)
28 | pbar.update(chunk_size)
29 |
30 |
31 | def md5_hash(path):
32 | with open(path, "rb") as f:
33 | content = f.read()
34 | return hashlib.md5(content).hexdigest()
35 |
36 |
37 | def get_ckpt_path(name, root, check=False):
38 | assert name in URL_MAP
39 | path = os.path.join(root, CKPT_MAP[name])
40 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
41 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
42 | download(URL_MAP[name], path)
43 | md5 = md5_hash(path)
44 | assert md5 == MD5_MAP[name], md5
45 | return path
46 |
47 |
48 | class KeyNotFoundError(Exception):
49 | def __init__(self, cause, keys=None, visited=None):
50 | self.cause = cause
51 | self.keys = keys
52 | self.visited = visited
53 | messages = list()
54 | if keys is not None:
55 | messages.append("Key not found: {}".format(keys))
56 | if visited is not None:
57 | messages.append("Visited: {}".format(visited))
58 | messages.append("Cause:\n{}".format(cause))
59 | message = "\n".join(messages)
60 | super().__init__(message)
61 |
62 |
63 | def retrieve(
64 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
65 | ):
66 | """Given a nested list or dict return the desired value at key expanding
67 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion
68 | is done in-place.
69 |
70 | Parameters
71 | ----------
72 | list_or_dict : list or dict
73 | Possibly nested list or dictionary.
74 | key : str
75 | key/to/value, path like string describing all keys necessary to
76 | consider to get to the desired value. List indices can also be
77 | passed here.
78 | splitval : str
79 | String that defines the delimiter between keys of the
80 | different depth levels in `key`.
81 | default : obj
82 | Value returned if :attr:`key` is not found.
83 | expand : bool
84 | Whether to expand callable nodes on the path or not.
85 |
86 | Returns
87 | -------
88 | The desired value or if :attr:`default` is not ``None`` and the
89 | :attr:`key` is not found returns ``default``.
90 |
91 | Raises
92 | ------
93 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
94 | ``None``.
95 | """
96 |
97 | keys = key.split(splitval)
98 |
99 | success = True
100 | try:
101 | visited = []
102 | parent = None
103 | last_key = None
104 | for key in keys:
105 | if callable(list_or_dict):
106 | if not expand:
107 | raise KeyNotFoundError(
108 | ValueError(
109 | "Trying to get past callable node with expand=False."
110 | ),
111 | keys=keys,
112 | visited=visited,
113 | )
114 | list_or_dict = list_or_dict()
115 | parent[last_key] = list_or_dict
116 |
117 | last_key = key
118 | parent = list_or_dict
119 |
120 | try:
121 | if isinstance(list_or_dict, dict):
122 | list_or_dict = list_or_dict[key]
123 | else:
124 | list_or_dict = list_or_dict[int(key)]
125 | except (KeyError, IndexError, ValueError) as e:
126 | raise KeyNotFoundError(e, keys=keys, visited=visited)
127 |
128 | visited += [key]
129 | # final expansion of retrieved value
130 | if expand and callable(list_or_dict):
131 | list_or_dict = list_or_dict()
132 | parent[last_key] = list_or_dict
133 | except KeyNotFoundError as e:
134 | if default is None:
135 | raise e
136 | else:
137 | list_or_dict = default
138 | success = False
139 |
140 | if not pass_success:
141 | return list_or_dict
142 | else:
143 | return list_or_dict, success
144 |
145 |
146 | if __name__ == "__main__":
147 | config = {"keya": "a",
148 | "keyb": "b",
149 | "keyc":
150 | {"cc1": 1,
151 | "cc2": 2,
152 | }
153 | }
154 | from omegaconf import OmegaConf
155 | config = OmegaConf.create(config)
156 | print(config)
157 | retrieve(config, "keya")
158 |
159 |
--------------------------------------------------------------------------------
/tools/download_datasets.sh:
--------------------------------------------------------------------------------
1 | set -e # exit script if error
2 |
3 | echo "Please install wget first!"
4 | echo "Auto set up datasets in \"../datasets\""
5 |
6 | echo "Create dataset folders and subfolders"
7 | mkdir datasets
8 | mkdir datasets/coco
9 | mkdir datasets/coco/2014
10 | mkdir datasets/coco/2017
11 |
12 | echo "Download coco 2014 datasets (valid split) in \"./datasets/coco/2014\"."
13 | wget http://images.cocodataset.org/zips/val2014.zip
14 | wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
15 |
16 | mv val2014.zip datasets/coco/2014
17 | mv annotations_trainval2014.zip datasets/coco/2014
18 | unzip datasets/coco/2014/val2014.zip -d datasets/coco/2014
19 | unzip datasets/coco/2014/annotations_trainval2014.zip -d datasets/coco/2014
20 |
21 | echo "Download coco 2017 datasets (valid split) in \"./datasets/coco/2017\"."
22 | wget http://images.cocodataset.org/zips/val2017.zip
23 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
24 | wget http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip
25 |
26 | mv val2017.zip datasets/coco/2017
27 | mv annotations_trainval2017.zip datasets/coco/2017
28 | mv stuff_annotations_trainval2017.zip datasets/coco/2017
29 | unzip datasets/coco/2017/val2017.zip -d datasets/coco/2017
30 | unzip datasets/coco/2017/annotations_trainval2017.zip -d datasets/coco/2017
31 | unzip datasets/coco/2017/stuff_annotations_trainval2017.zip -d datasets/coco/2017
32 |
33 | echo "Remove cache files..."
34 | rm -rf datasets/coco/2014/val2014.zip
35 | rm -rf datasets/coco/2014/annotations_trainval2014.zip
36 | rm -rf datasets/coco/2017/val2017.zip
37 | rm -rf datasets/coco/2017/annotations_trainval2017.zip
38 | rm -rf datasets/coco/2017/stuff_annotations_trainval2017.zip
39 |
40 | echo "Move datasets to \"../datasets\""
41 | mv datasets ../
--------------------------------------------------------------------------------
/tools/download_models.sh:
--------------------------------------------------------------------------------
1 | wget https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
2 |
3 | mkdir exp
4 | mkdir exp/vqgan
5 | mkdir exp/vqgan/vq-f8
6 |
7 | mv ./vq-f8.zip exp/vqgan/vq-f8/
8 |
9 | cd exp/vqgan/vq-f8/
10 | unzip vq-f8.zip
11 | rm vq-f8.zip
12 | cd ../../
--------------------------------------------------------------------------------
/tools/ldm/train_ldm_coco_Layout2I.sh:
--------------------------------------------------------------------------------
1 |
2 | python main.py --base configs/ldm/coco_sg2im_ldm_Layout2I_vqgan_f8.yaml \
3 | -t True --gpus 1 -log_dir ./exp/ldm/Layout2I \
4 | -n coco_sg2im_ldm_Layout2I_vqgan_f8 --scale_lr False -tb True
5 |
--------------------------------------------------------------------------------
/tools/ldm/train_ldm_coco_T2I.sh:
--------------------------------------------------------------------------------
1 |
2 | python main.py --base configs/ldm/coco_stuff_ldm_T2I_vqgan_f8.yaml \
3 | -t True --gpus 1 -log_dir ./exp/ldm/T2I \
4 | -n coco_stuff_ldm_T2I_vqgan_f8 --scale_lr False -tb True
5 |
--------------------------------------------------------------------------------
/tools/vqgan/train_vqgan_coco.sh:
--------------------------------------------------------------------------------
1 | python main.py -t True --base configs/vqgan/coco_vqgan_f8.yaml \
2 | --gpus 1 -log_dir ./exp/vqgan -n coco_vqgan_f8
--------------------------------------------------------------------------------