├── README.md
├── __init__.py
├── configs
├── autoencoder
│ ├── autoencoder_kl_16x16x16.yaml
│ ├── autoencoder_kl_32x32x4.yaml
│ ├── autoencoder_kl_64x64x3.yaml
│ └── autoencoder_kl_8x8x64.yaml
├── latent-diffusion
│ ├── celebahq-ldm-vq-4.yaml
│ ├── cin-ldm-vq-f8.yaml
│ ├── cin256-v2.yaml
│ ├── ffhq-ldm-vq-4.yaml
│ ├── lsun_churches-ldm-kl-8.yaml
│ └── txt2img-1p4B-eval.yaml
├── retrieval-augmented-diffusion
│ └── 768x768.yaml
├── stable-diffusion
│ └── v1-inference.yaml
└── tasks
│ ├── box_inpainting_config.yaml
│ ├── gaussian_deblur_config.yaml
│ ├── hdr_config.yaml
│ ├── motion_deblur_config.yaml
│ ├── nonlinear_deblur_config.yaml
│ ├── phase_retrieval_config.yaml
│ ├── rand_inpainting_config.yaml
│ └── super_resolution_config.yaml
├── data
├── __pycache__
│ └── dataloader.cpython-38.pyc
├── dataloader.py
└── samples
│ ├── 00000.png
│ └── 00003.png
├── diffstategrad_sample_condition.py
├── diffstategrad_utils.py
├── environment.yaml
├── figures
├── hdr_example.png
├── hdr_short.png
├── manifold_diffstategrad.pdf
├── manifold_diffstategrad.png
└── phase_example.png
├── ldm
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── util.cpython-38.pyc
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ └── lsun.cpython-38.pyc
│ ├── base.py
│ ├── imagenet.py
│ └── lsun.py
├── lr_scheduler.py
├── models
│ ├── __pycache__
│ │ └── autoencoder.cpython-38.pyc
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── ddim.cpython-38.pyc
│ │ ├── ddpm.cpython-38.pyc
│ │ └── plms.cpython-38.pyc
│ │ ├── classifier.py
│ │ ├── ddpm.py
│ │ ├── diffstategrad_ddim.py
│ │ ├── dpm_solver
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── dpm_solver.cpython-38.pyc
│ │ │ └── sampler.cpython-38.pyc
│ │ ├── dpm_solver.py
│ │ └── sampler.py
│ │ ├── plms.py
│ │ └── utils.py
├── modules
│ ├── __pycache__
│ │ ├── attention.cpython-38.pyc
│ │ └── ema.cpython-38.pyc
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── model.cpython-38.pyc
│ │ │ ├── openaimodel.cpython-38.pyc
│ │ │ └── util.cpython-38.pyc
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── distributions.cpython-38.pyc
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── x_transformer.py
└── util.py
├── ldm_inverse
├── __pycache__
│ ├── condition_methods.cpython-38.pyc
│ └── measurements.cpython-38.pyc
├── condition_methods.py
└── measurements.py
├── model_loader.py
├── samples
├── 60004.png
└── 60074.png
├── scripts
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── utils.cpython-38.pyc
├── download_first_stages.sh
├── download_models.sh
├── img2img.py
├── inpaint.py
├── inv.py
├── knn2img.py
├── sample_diffusion.py
├── tests
│ └── test_watermark.py
├── train_searcher.py
├── txt2img.py
└── utils.py
├── src
├── clip
│ ├── clip.egg-info
│ │ └── PKG-INFO
│ ├── clip
│ │ ├── __init__.py
│ │ ├── clip.py
│ │ ├── model.py
│ │ └── simple_tokenizer.py
│ ├── hubconf.py
│ ├── setup.py
│ └── tests
│ │ └── test_consistency.py
└── taming-transformers
│ ├── main.py
│ ├── scripts
│ ├── extract_depth.py
│ ├── extract_segmentation.py
│ ├── extract_submodel.py
│ ├── make_samples.py
│ ├── make_scene_samples.py
│ ├── sample_conditional.py
│ └── sample_fast.py
│ ├── setup.py
│ ├── taming
│ ├── data
│ │ ├── ade20k.py
│ │ ├── annotated_objects_coco.py
│ │ ├── annotated_objects_dataset.py
│ │ ├── annotated_objects_open_images.py
│ │ ├── base.py
│ │ ├── coco.py
│ │ ├── conditional_builder
│ │ │ ├── objects_bbox.py
│ │ │ ├── objects_center_points.py
│ │ │ └── utils.py
│ │ ├── custom.py
│ │ ├── faceshq.py
│ │ ├── helper_types.py
│ │ ├── image_transforms.py
│ │ ├── imagenet.py
│ │ ├── open_images_helper.py
│ │ ├── sflckr.py
│ │ └── utils.py
│ ├── lr_scheduler.py
│ ├── models
│ │ ├── cond_transformer.py
│ │ ├── dummy_cond_stage.py
│ │ └── vqgan.py
│ ├── modules
│ │ ├── diffusionmodules
│ │ │ └── model.py
│ │ ├── discriminator
│ │ │ └── model.py
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── lpips.py
│ │ │ ├── segmentation.py
│ │ │ └── vqperceptual.py
│ │ ├── misc
│ │ │ └── coord.py
│ │ ├── transformer
│ │ │ ├── mingpt.py
│ │ │ └── permuter.py
│ │ ├── util.py
│ │ └── vqvae
│ │ │ ├── __pycache__
│ │ │ └── quantize.cpython-38.pyc
│ │ │ └── quantize.py
│ └── util.py
│ └── taming_transformers.egg-info
│ └── PKG-INFO
└── util
├── __pycache__
├── fastmri_utils.cpython-38.pyc
├── img_utils.cpython-38.pyc
└── resizer.cpython-38.pyc
├── compute_metric.py
├── fastmri_utils.py
├── img_utils.py
├── logger.py
├── resizer.py
└── tools.py
/README.md:
--------------------------------------------------------------------------------
1 | # Diffusion-State Guided Projected Gradient for Inverse Problems (ICLR 2025)
2 |
3 | 
4 |
5 | 
6 |
7 | ## Abstract
8 |
9 | In this work, we propose DiffStateGrad, a novel approach that enhances diffusion-based inverse problem solvers by projecting measurement guidance gradients onto a data-driven low-rank subspace defined by intermediate diffusion states. Our algorithm addresses the challenge of maintaining manifold consistency by performing singular value decomposition on intermediate diffusion states to define a projection matrix that captures local data statistics. This projection ensures that measurement guidance remains aligned with the learned data manifold while filtering out artifact-inducing components, leading to improved robustness and performance across various inverse problems. In this repository, we demonstrate the effectiveness of DiffStateGrad by applying it to ReSample's framework.
10 |
11 | 
12 |
13 | ## Implementation
14 |
15 | This repository provides a modified version of the ReSample codebase that incorporates our DiffStateGrad method. The implementation maintains the core functionality of ReSample while adding our enhancements for improved performance and stability.
16 |
17 | Our main contributions can be found in `diffstategrad_sample_condition.py` and `ldm/models/diffusion/diffstategrad_ddim.py`.
18 |
19 | ### DiffStateGrad Helper Methods
20 |
21 | The core utilities of DiffStateGrad are implemented in `diffstategrad_utils.py`. These utilities can be applied to other works as a plug-and-play module. The implementation includes three main functions:
22 |
23 | 1. `compute_rank_for_explained_variance`: Determines the rank needed to explain a target variance percentage across channels
24 | 2. `compute_svd_and_adaptive_rank`: Performs SVD on diffusion state and computes adaptive rank based on variance cutoff
25 | 3. `apply_diffstategrad`: Computes the projected gradient using our DiffStateGrad algorithm
26 |
27 | ### Example Usage
28 |
29 | ```python
30 | from diffstategrad_utils import compute_svd_and_adaptive_rank, apply_diffstategrad
31 |
32 | # During optimization:
33 | if iteration_count % period == 0:
34 | # Compute SVD and adaptive rank when needed
35 | U, s, Vh, adaptive_rank = compute_svd_and_adaptive_rank(z_t, var_cutoff=0.99)
36 |
37 | # Apply DiffStateGrad to the normalized gradient
38 | projected_grad = apply_diffstategrad(
39 | norm_grad=normalized_gradient,
40 | iteration_count=iteration_count,
41 | period=period,
42 | U=U, s=s, Vh=Vh,
43 | adaptive_rank=adaptive_rank
44 | )
45 |
46 | # Update diffusion state with projected gradient
47 | z_t = z_t - eta * projected_grad
48 | ```
49 |
50 | For complete implementation details, please refer to [`diffstategrad_utils.py`](https://github.com/rzirvi1665/DiffStateGrad/blob/main/diffstategrad_utils.py) in our repository.
51 |
52 | ## Getting Started
53 |
54 | ### 1) Clone the repository
55 |
56 | ```
57 | git clone https://github.com/rzirvi1665/DiffStateGrad.git
58 |
59 | cd DiffStateGrad
60 | ```
61 |
62 |
63 |
64 | ### 2) Download pretrained checkpoints (autoencoders and model)
65 |
66 | ```
67 | mkdir -p models/ldm
68 | wget https://ommer-lab.com/files/latent-diffusion/ffhq.zip -P ./models/ldm
69 | unzip models/ldm/ffhq.zip -d ./models/ldm
70 |
71 | mkdir -p models/first_stage_models/vq-f4
72 | wget https://ommer-lab.com/files/latent-diffusion/vq-f4.zip -P ./models/first_stage_models/vq-f4
73 | unzip models/first_stage_models/vq-f4/vq-f4.zip -d ./models/first_stage_models/vq-f4
74 | ```
75 |
76 |
77 |
78 | ### 3) Set environment
79 |
80 | We use the external codes for motion-blurring and non-linear deblurring following the DPS codebase.
81 |
82 | ```
83 | git clone https://github.com/VinAIResearch/blur-kernel-space-exploring bkse
84 |
85 | git clone https://github.com/LeviBorodenko/motionblur motionblur
86 | ```
87 |
88 | Install dependencies via
89 |
90 | ```
91 | conda env create -f environment.yaml
92 | ```
93 |
94 |
95 |
96 | ### 4) Inference
97 |
98 | ```
99 | python3 diffstategrad_sample_condition.py
100 | ```
101 |
102 | The code is currently configured to do inference on FFHQ. You can download the corresponding models from https://github.com/CompVis/latent-diffusion/tree/main and modify the checkpoint paths for other datasets and models.
103 |
104 |
105 |
106 |
107 | ## Task Configurations
108 |
109 | ```
110 | # Linear inverse problems
111 | - configs/tasks/super_resolution_config.yaml
112 | - configs/tasks/gaussian_deblur_config.yaml
113 | - configs/tasks/motion_deblur_config.yaml
114 | - configs/tasks/box_inpainting_config.yaml
115 | - configs/tasks/rand_inpainting_config.yaml
116 |
117 | # Non-linear inverse problems
118 | - configs/tasks/nonlinear_deblur_config.yaml
119 | - configs/tasks/phase_retrieval_config.yaml
120 | - configs/tasks/hdr_config.yaml
121 | ```
122 |
123 |
124 |
125 | ## Citation
126 | If you find our work interesting, please consider citing
127 |
128 | ```
129 | @inproceedings{
130 | zirvi2025diffusion,
131 | title={Diffusion State-Guided Projected Gradient for Inverse Problems},
132 | author={Rayhan Zirvi and Bahareh Tolooshams and Anima Anandkumar},
133 | booktitle={The Thirteenth International Conference on Learning Representations},
134 | year={2025}
135 | }
136 | ```
137 |
138 | ## MIT License
139 |
140 | All rights reserved unless otherwise stated by applicable licenses.
141 | If this code includes third-party components, they remain under their original licenses and attributions.
142 |
143 |
144 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/__init__.py
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_16x16x16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 16
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_64x64x3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_8x8x64.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 64
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16,8]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/celebahq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 |
15 | unet_config:
16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | image_size: 64
19 | in_channels: 3
20 | out_channels: 3
21 | model_channels: 224
22 | attention_resolutions:
23 | # note: this isn\t actually the resolution but
24 | # the downsampling factor, i.e. this corresnponds to
25 | # attention on spatial resolution 8,16,32, as the
26 | # spatial reolution of the latents is 64 for f4
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | num_head_channels: 32
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config: __is_unconditional__
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 48
64 | num_workers: 5
65 | wrap: false
66 | train:
67 | target: taming.data.faceshq.CelebAHQTrain
68 | params:
69 | size: 256
70 | validation:
71 | target: taming.data.faceshq.CelebAHQValidation
72 | params:
73 | size: 256
74 |
75 |
76 | lightning:
77 | callbacks:
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 5000
82 | max_images: 8
83 | increase_log_steps: False
84 |
85 | trainer:
86 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin-ldm-vq-f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | #note: this isn\t actually the resolution but
26 | # the downsampling factor, i.e. this corresnponds to
27 | # attention on spatial resolution 8,16,32, as the
28 | # spatial reolution of the latents is 32 for f8
29 | - 4
30 | - 2
31 | - 1
32 | num_res_blocks: 2
33 | channel_mult:
34 | - 1
35 | - 2
36 | - 4
37 | num_head_channels: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 512
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 4
45 | n_embed: 16384
46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47 | ddconfig:
48 | double_z: false
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 2
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions:
61 | - 32
62 | dropout: 0.0
63 | lossconfig:
64 | target: torch.nn.Identity
65 | cond_stage_config:
66 | target: ldm.modules.encoders.modules.ClassEmbedder
67 | params:
68 | embed_dim: 512
69 | key: class_label
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 64
74 | num_workers: 12
75 | wrap: false
76 | train:
77 | target: ldm.data.imagenet.ImageNetTrain
78 | params:
79 | config:
80 | size: 256
81 | validation:
82 | target: ldm.data.imagenet.ImageNetValidation
83 | params:
84 | config:
85 | size: 256
86 |
87 |
88 | lightning:
89 | callbacks:
90 | image_logger:
91 | target: main.ImageLogger
92 | params:
93 | batch_frequency: 5000
94 | max_images: 8
95 | increase_log_steps: False
96 |
97 | trainer:
98 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin256-v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/ffhq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | embed_dim: 3
40 | n_embed: 8192
41 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 42
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: taming.data.faceshq.FFHQTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: taming.data.faceshq.FFHQValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: "image"
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: False
16 | concat_mode: False
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [10000]
24 | cycle_lengths: [10000000000000]
25 | f_start: [1.e-6]
26 | f_max: [1.]
27 | f_min: [ 1.]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 192
36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37 | num_res_blocks: 2
38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39 | num_heads: 8
40 | use_scale_shift_norm: True
41 | resblock_updown: True
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: "val/rec_loss"
48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49 | ddconfig:
50 | double_z: True
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57 | num_res_blocks: 2
58 | attn_resolutions: [ ]
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config: "__is_unconditional__"
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 96
69 | num_workers: 5
70 | wrap: False
71 | train:
72 | target: ldm.data.lsun.LSUNChurchesTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: ldm.data.lsun.LSUNChurchesValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 5000
86 | max_images: 8
87 | increase_log_steps: False
88 |
89 |
90 | trainer:
91 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/txt2img-1p4B-eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 32
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions:
28 | - 4
29 | - 2
30 | - 1
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 4
36 | - 4
37 | num_heads: 8
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 1280
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 1280
71 | n_layer: 32
72 |
--------------------------------------------------------------------------------
/configs/retrieval-augmented-diffusion/768x768.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.015
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: jpg
11 | cond_stage_key: nix
12 | image_size: 48
13 | channels: 16
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_by_std: false
18 | scale_factor: 0.22765929
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 48
23 | in_channels: 16
24 | out_channels: 16
25 | model_channels: 448
26 | attention_resolutions:
27 | - 4
28 | - 2
29 | - 1
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | use_scale_shift_norm: false
37 | resblock_updown: false
38 | num_head_channels: 32
39 | use_spatial_transformer: true
40 | transformer_depth: 1
41 | context_dim: 768
42 | use_checkpoint: true
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | monitor: val/rec_loss
47 | embed_dim: 16
48 | ddconfig:
49 | double_z: true
50 | z_channels: 16
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 | cond_stage_config:
68 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/configs/tasks/box_inpainting_config.yaml:
--------------------------------------------------------------------------------
1 | conditioning:
2 | main_sampler: resample
3 | method: ps # Do not touch
4 | params:
5 | scale: 0.5
6 |
7 | data:
8 | name: celeb
9 | root: ./data/samples/
10 |
11 | measurement:
12 | operator:
13 | name: inpainting
14 | mask_opt:
15 | mask_type: box
16 | mask_len_range: !!python/tuple [128, 129] # for box
17 | # mask_prob_range: !!python/tuple [0.7, 0.71]
18 | image_size: 256
19 |
20 |
21 | noise:
22 | name: gaussian
23 | sigma: 0.01
24 |
--------------------------------------------------------------------------------
/configs/tasks/gaussian_deblur_config.yaml:
--------------------------------------------------------------------------------
1 | conditioning:
2 | main_sampler: resample
3 | method: ps # Do not touch
4 | params:
5 | scale: 0.5
6 |
7 | data:
8 | name: celeb
9 | root: ./data/samples/
10 |
11 | measurement:
12 | operator:
13 | name: gaussian_blur
14 | kernel_size: 61
15 | intensity: 3.0
16 |
17 | noise:
18 | name: gaussian
19 | sigma: 0.01
20 |
--------------------------------------------------------------------------------
/configs/tasks/hdr_config.yaml:
--------------------------------------------------------------------------------
1 | conditioning:
2 | main_sampler: resample
3 | method: ps # Do not touch
4 | params:
5 | scale: 1.0
6 |
7 | data:
8 | name: celeb
9 | root: ./data/samples/
10 |
11 | measurement:
12 | operator:
13 | name: high_dynamic_range
14 | scale: 2.0
15 |
16 | noise:
17 | name: gaussian
18 | sigma: 0.01
19 |
--------------------------------------------------------------------------------
/configs/tasks/motion_deblur_config.yaml:
--------------------------------------------------------------------------------
1 | conditioning:
2 | main_sampler: resample
3 | method: ps # Do not touch
4 | params:
5 | scale: 1.0
6 |
7 | data:
8 | name: celeb
9 | root: ./data/samples/
10 |
11 | measurement:
12 | operator:
13 | name: motion_blur
14 | kernel_size: 61
15 | intensity: 0.5
16 |
17 | noise:
18 | name: gaussian
19 | sigma: 0.01
20 |
--------------------------------------------------------------------------------
/configs/tasks/nonlinear_deblur_config.yaml:
--------------------------------------------------------------------------------
1 | conditioning:
2 | main_sampler: dps
3 | method: ps # Do not touch
4 | params:
5 | scale: 0.3
6 |
7 | data:
8 | name: celeb
9 | root: ./data/samples/
10 |
11 | measurement:
12 | operator:
13 | name: nonlinear_blur
14 | opt_yml_path: ./bkse/options/generate_blur/default.yml
15 |
16 | noise:
17 | name: gaussian
18 | sigma: 0.01
19 |
--------------------------------------------------------------------------------
/configs/tasks/phase_retrieval_config.yaml:
--------------------------------------------------------------------------------
1 | conditioning:
2 | main_sampler: resample
3 | method: ps # Do not touch
4 | params:
5 | scale: 1.0
6 |
7 | data:
8 | name: celeb
9 | root: ./data/samples/
10 |
11 | measurement:
12 | operator:
13 | name: phase_retrieval
14 | oversample: 2.0
15 |
16 | noise:
17 | name: gaussian
18 | sigma: 0.01
19 |
--------------------------------------------------------------------------------
/configs/tasks/rand_inpainting_config.yaml:
--------------------------------------------------------------------------------
1 | conditioning:
2 | main_sampler: resample
3 | method: ps # Do not touch
4 | params:
5 | scale: 0.5
6 |
7 | data:
8 | name: celeb
9 | root: ./data/samples/
10 |
11 | measurement:
12 | operator:
13 | name: inpainting
14 | mask_opt:
15 | mask_type: random
16 | mask_prob_range: !!python/tuple [0.7, 0.71]
17 | image_size: 256
18 |
19 |
20 | noise:
21 | name: gaussian
22 | sigma: 0.01
23 |
--------------------------------------------------------------------------------
/configs/tasks/super_resolution_config.yaml:
--------------------------------------------------------------------------------
1 | conditioning:
2 | main_sampler: resample
3 | method: ps # Do not touch
4 | params:
5 | scale: 0.1 # Try changing this
6 |
7 | data:
8 | name: ffhq
9 | root: ./data/samples/
10 |
11 | measurement:
12 | operator:
13 | name: super_resolution
14 | in_shape: !!python/tuple [1, 3, 256, 256]
15 | scale_factor: 4
16 |
17 | noise:
18 | name: gaussian
19 | sigma: 0.01
20 |
--------------------------------------------------------------------------------
/data/__pycache__/dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/data/__pycache__/dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/data/dataloader.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | from PIL import Image
3 | from typing import Callable, Optional
4 | from torch.utils.data import DataLoader
5 | from torchvision.datasets import VisionDataset
6 |
7 |
8 | __DATASET__ = {}
9 |
10 | def register_dataset(name: str):
11 | def wrapper(cls):
12 | if __DATASET__.get(name, None):
13 | raise NameError(f"Name {name} is already registered!")
14 | __DATASET__[name] = cls
15 | return cls
16 | return wrapper
17 |
18 |
19 | def get_dataset(name: str, root: str, **kwargs):
20 | if __DATASET__.get(name, None) is None:
21 | raise NameError(f"Dataset {name} is not defined.")
22 | return __DATASET__[name](root=root, **kwargs)
23 |
24 |
25 | def get_dataloader(dataset: VisionDataset,
26 | batch_size: int,
27 | num_workers: int,
28 | train: bool):
29 | dataloader = DataLoader(dataset,
30 | batch_size,
31 | shuffle=train,
32 | num_workers=num_workers,
33 | drop_last=train)
34 | return dataloader
35 |
36 |
37 | @register_dataset(name='celeb')
38 | class CELEBDataset(VisionDataset):
39 | def __init__(self, root: str, transforms: Optional[Callable]=None):
40 | super().__init__(root, transforms)
41 |
42 | self.fpaths = sorted(glob(root + '/**/*.png', recursive=True))
43 | assert len(self.fpaths) > 0, "File list is empty. Check the root."
44 |
45 | def __len__(self):
46 | return len(self.fpaths)
47 |
48 | def __getitem__(self, index: int):
49 | fpath = self.fpaths[index]
50 | img = Image.open(fpath).convert('RGB')
51 |
52 | if self.transforms is not None:
53 | img = self.transforms(img)
54 |
55 | return img
56 |
57 |
58 | @register_dataset(name='ffhq')
59 | class FFHQDataset(VisionDataset):
60 | def __init__(self, root: str, transforms: Optional[Callable]=None):
61 | super().__init__(root, transforms)
62 |
63 | self.fpaths = sorted(glob(root + '/**/*.png', recursive=True))
64 | assert len(self.fpaths) > 0, "File list is empty. Check the root."
65 |
66 | def __len__(self):
67 | return len(self.fpaths)
68 |
69 | def __getitem__(self, index: int):
70 | fpath = self.fpaths[index]
71 | img = Image.open(fpath).convert('RGB')
72 |
73 | if self.transforms is not None:
74 | img = self.transforms(img)
75 |
76 | return img
77 |
78 |
--------------------------------------------------------------------------------
/data/samples/00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/data/samples/00000.png
--------------------------------------------------------------------------------
/data/samples/00003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/data/samples/00003.png
--------------------------------------------------------------------------------
/diffstategrad_utils.py:
--------------------------------------------------------------------------------
1 | # DiffStateGrad helper method
2 | def compute_rank_for_explained_variance(singular_values, explained_variance_cutoff):
3 | """
4 | Computes average rank needed across channels to explain target variance percentage.
5 |
6 | Args:
7 | singular_values: List of arrays containing singular values per channel
8 | explained_variance_cutoff: Target explained variance ratio (0-1)
9 |
10 | Returns:
11 | int: Average rank needed across RGB channels
12 | """
13 | total_rank = 0
14 | for channel_singular_values in singular_values:
15 | squared_singular_values = channel_singular_values ** 2
16 | cumulative_variance = np.cumsum(squared_singular_values) / np.sum(squared_singular_values)
17 | rank = np.searchsorted(cumulative_variance, explained_variance_cutoff) + 1
18 | total_rank += rank
19 | return int(total_rank / 3)
20 |
21 | def compute_svd_and_adaptive_rank(z_t, var_cutoff):
22 | """
23 | Compute SVD and adaptive rank for the input tensor.
24 |
25 | Args:
26 | z_t: Input tensor (current image representation at time step t)
27 | var_cutoff: Variance cutoff for rank adaptation
28 |
29 | Returns:
30 | tuple: (U, s, Vh, adaptive_rank) where U, s, Vh are SVD components
31 | and adaptive_rank is the computed rank
32 | """
33 | # Compute SVD of current image representation
34 | U, s, Vh = torch.linalg.svd(z_t[0], full_matrices=False)
35 |
36 | # Compute adaptive rank
37 | s_numpy = s.detach().cpu().numpy()
38 |
39 | adaptive_rank = compute_rank_for_explained_variance([s_numpy], var_cutoff)
40 |
41 | return U, s, Vh, adaptive_rank
42 |
43 | def apply_diffstategrad(norm_grad, iteration_count, period, U=None, s=None, Vh=None, adaptive_rank=None):
44 | """
45 | Compute projected gradient using DiffStateGrad algorithm.
46 |
47 | Args:
48 | norm_grad: Normalized gradient
49 | iteration_count: Current iteration count
50 | period: Period of SVD projection
51 | U: Left singular vectors from SVD
52 | s: Singular values from SVD
53 | Vh: Right singular vectors from SVD
54 | adaptive_rank: Computed adaptive rank
55 |
56 | Returns:
57 | torch.Tensor: Projected gradient if period condition is met, otherwise original gradient
58 | """
59 | if period != 0 and iteration_count % period == 0:
60 | if any(param is None for param in [U, s, Vh, adaptive_rank]):
61 | raise ValueError("SVD components and adaptive_rank must be provided when iteration_count % period == 0")
62 |
63 | # Project gradient
64 | A = U[:, :, :adaptive_rank]
65 | B = Vh[:, :adaptive_rank, :]
66 |
67 | low_rank_grad = torch.matmul(A.permute(0, 2, 1), norm_grad[0]) @ B.permute(0, 2, 1)
68 | projected_grad = torch.matmul(A, low_rank_grad) @ B
69 |
70 | # Reshape projected gradient to match original shape
71 | projected_grad = projected_grad.float().unsqueeze(0) # Add batch dimension back
72 |
73 | return projected_grad
74 |
75 | return norm_grad
76 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: LDM
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=5.1=1_gnu
8 | - blas=1.0=mkl
9 | - brotli-python=1.0.9=py38h6a678d5_7
10 | - bzip2=1.0.8=h5eee18b_5
11 | - ca-certificates=2023.12.12=h06a4308_0
12 | - certifi=2024.2.2=py38h06a4308_0
13 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
14 | - cudatoolkit=11.3.1=h2bc3f7f_2
15 | - ffmpeg=4.3=hf484d3e_0
16 | - freetype=2.12.1=h4a9f257_0
17 | - gmp=6.2.1=h295c915_3
18 | - gnutls=3.6.15=he1e5248_0
19 | - idna=3.4=py38h06a4308_0
20 | - intel-openmp=2021.4.0=h06a4308_3561
21 | - jpeg=9e=h5eee18b_1
22 | - lame=3.100=h7b6447c_0
23 | - lcms2=2.12=h3be6417_0
24 | - ld_impl_linux-64=2.38=h1181459_1
25 | - lerc=3.0=h295c915_0
26 | - libdeflate=1.17=h5eee18b_1
27 | - libffi=3.3=he6710b0_2
28 | - libgcc-ng=11.2.0=h1234567_1
29 | - libgfortran-ng=11.2.0=h00389a5_1
30 | - libgfortran5=11.2.0=h1234567_1
31 | - libgomp=11.2.0=h1234567_1
32 | - libiconv=1.16=h7f8727e_2
33 | - libidn2=2.3.4=h5eee18b_0
34 | - libpng=1.6.39=h5eee18b_0
35 | - libstdcxx-ng=11.2.0=h1234567_1
36 | - libtasn1=4.19.0=h5eee18b_0
37 | - libtiff=4.5.1=h6a678d5_0
38 | - libunistring=0.9.10=h27cfd23_0
39 | - libuv=1.44.2=h5eee18b_0
40 | - libwebp-base=1.3.2=h5eee18b_0
41 | - lz4-c=1.9.4=h6a678d5_0
42 | - mkl=2021.4.0=h06a4308_640
43 | - mkl-service=2.4.0=py38h7f8727e_0
44 | - mkl_fft=1.3.1=py38hd3c417c_0
45 | - mkl_random=1.2.2=py38h51133e4_0
46 | - ncurses=6.4=h6a678d5_0
47 | - nettle=3.7.3=hbbd107a_1
48 | - openh264=2.1.1=h4ff587b_0
49 | - openjpeg=2.4.0=h3ad879b_0
50 | - openssl=1.1.1w=h7f8727e_0
51 | - pillow=10.2.0=py38h5eee18b_0
52 | - pip=20.3.3=py38h06a4308_0
53 | - pysocks=1.7.1=py38h06a4308_0
54 | - python=3.8.5=h7579374_1
55 | - pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0
56 | - pytorch-mutex=1.0=cuda
57 | - readline=8.2=h5eee18b_0
58 | - requests=2.31.0=py38h06a4308_1
59 | - setuptools=68.2.2=py38h06a4308_0
60 | - six=1.16.0=pyhd3eb1b0_1
61 | - sqlite=3.41.2=h5eee18b_0
62 | - tk=8.6.12=h1ccaba5_0
63 | - torchvision=0.12.0=py38_cu113
64 | - typing_extensions=4.9.0=py38h06a4308_1
65 | - urllib3=2.1.0=py38h06a4308_1
66 | - wheel=0.41.2=py38h06a4308_0
67 | - xz=5.4.6=h5eee18b_0
68 | - zlib=1.2.13=h5eee18b_0
69 | - zstd=1.5.5=hc292b87_0
70 | - pip:
71 | - absl-py==2.1.0
72 | - aiohttp==3.9.3
73 | - aiosignal==1.3.1
74 | - antlr4-python3-runtime==4.9.3
75 | - async-timeout==4.0.3
76 | - attrs==23.2.0
77 | - cachetools==5.3.3
78 | - contourpy==1.1.1
79 | - cycler==0.12.1
80 | - einops==0.7.0
81 | - fonttools==4.49.0
82 | - frozenlist==1.4.1
83 | - fsspec==2024.2.0
84 | - future==1.0.0
85 | - google-auth==2.28.1
86 | - google-auth-oauthlib==1.0.0
87 | - grpcio==1.62.0
88 | - imageio==2.34.0
89 | - importlib-metadata==7.0.1
90 | - kiwisolver==1.4.5
91 | - lazy-loader==0.3
92 | - lightning-utilities==0.10.1
93 | - markdown==3.5.2
94 | - markupsafe==2.1.5
95 | - matplotlib==3.6.0
96 | - multidict==6.0.5
97 | - networkx==3.1
98 | - numpy==1.24.4
99 | - oauthlib==3.2.2
100 | - omegaconf==2.3.0
101 | - packaging==23.2
102 | - protobuf==4.25.3
103 | - pyasn1==0.5.1
104 | - pyasn1-modules==0.3.0
105 | - pydeprecate==0.3.1
106 | - pyparsing==3.1.2
107 | - python-dateutil==2.9.0.post0
108 | - pytorch-lightning==1.4.2
109 | - pywavelets==1.4.1
110 | - pyyaml==6.0
111 | - requests-oauthlib==1.3.1
112 | - rsa==4.9
113 | - scikit-image==0.21.0
114 | - scipy==1.10.1
115 | - taming-transformers==0.0.1
116 | - taming-transformers-rom1504==0.0.6
117 | - tensorboard==2.14.0
118 | - tensorboard-data-server==0.7.2
119 | - tifffile==2023.7.10
120 | - torch-fidelity==0.3.0
121 | - torchmetrics==0.7.0
122 | - tqdm==4.64.1
123 | - werkzeug==3.0.1
124 | - yarl==1.9.4
125 | - zipp==3.17.0
126 | prefix: /home/kwonsm/.conda/envs/LDM
127 |
--------------------------------------------------------------------------------
/figures/hdr_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/figures/hdr_example.png
--------------------------------------------------------------------------------
/figures/hdr_short.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/figures/hdr_short.png
--------------------------------------------------------------------------------
/figures/manifold_diffstategrad.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/figures/manifold_diffstategrad.pdf
--------------------------------------------------------------------------------
/figures/manifold_diffstategrad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/figures/manifold_diffstategrad.png
--------------------------------------------------------------------------------
/figures/phase_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/figures/phase_example.png
--------------------------------------------------------------------------------
/ldm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/__init__.py
--------------------------------------------------------------------------------
/ldm/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/data/__init__.py
--------------------------------------------------------------------------------
/ldm/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/data/__pycache__/lsun.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/data/__pycache__/lsun.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val_100_images.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/ldm/models/__pycache__/autoencoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/__pycache__/autoencoder.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 |
5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
6 |
7 |
8 | class DPMSolverSampler(object):
9 | def __init__(self, model, **kwargs):
10 | super().__init__()
11 | self.model = model
12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
13 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
14 |
15 | def register_buffer(self, name, attr):
16 | if type(attr) == torch.Tensor:
17 | if attr.device != torch.device("cuda"):
18 | attr = attr.to(torch.device("cuda"))
19 | setattr(self, name, attr)
20 |
21 | @torch.no_grad()
22 | def sample(self,
23 | S,
24 | batch_size,
25 | shape,
26 | conditioning=None,
27 | callback=None,
28 | normals_sequence=None,
29 | img_callback=None,
30 | quantize_x0=False,
31 | eta=0.,
32 | mask=None,
33 | x0=None,
34 | temperature=1.,
35 | noise_dropout=0.,
36 | score_corrector=None,
37 | corrector_kwargs=None,
38 | verbose=True,
39 | x_T=None,
40 | log_every_t=100,
41 | unconditional_guidance_scale=1.,
42 | unconditional_conditioning=None,
43 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
44 | **kwargs
45 | ):
46 | if conditioning is not None:
47 | if isinstance(conditioning, dict):
48 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
49 | if cbs != batch_size:
50 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
51 | else:
52 | if conditioning.shape[0] != batch_size:
53 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
54 |
55 | # sampling
56 | C, H, W = shape
57 | size = (batch_size, C, H, W)
58 |
59 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
60 |
61 | device = self.model.betas.device
62 | if x_T is None:
63 | img = torch.randn(size, device=device)
64 | else:
65 | img = x_T
66 |
67 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
68 |
69 | model_fn = model_wrapper(
70 | lambda x, t, c: self.model.apply_model(x, t, c),
71 | ns,
72 | model_type="noise",
73 | guidance_type="classifier-free",
74 | condition=conditioning,
75 | unconditional_condition=unconditional_conditioning,
76 | guidance_scale=unconditional_guidance_scale,
77 | )
78 |
79 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
80 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
81 |
82 | return x.to(device), None
83 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import math
4 | import numpy as np
5 |
6 | import torch
7 | from torch.utils.data import DataLoader
8 | import torchvision.utils as vutils
9 |
10 |
11 | def get_config(config):
12 | with open(config, 'r') as stream:
13 | return yaml.load(stream)
14 |
15 | def prepare_sub_folder(output_directory):
16 | image_directory = os.path.join(output_directory, 'images')
17 | if not os.path.exists(image_directory):
18 | print("Creating directory: {}".format(image_directory))
19 | os.makedirs(image_directory)
20 | checkpoint_directory = os.path.join(output_directory, 'checkpoints')
21 | if not os.path.exists(checkpoint_directory):
22 | print("Creating directory: {}".format(checkpoint_directory))
23 | os.makedirs(checkpoint_directory)
24 | return checkpoint_directory, image_directory
25 |
26 |
27 |
28 |
29 | def save_image_3d(tensor, slice_idx, file_name):
30 | '''
31 | tensor: [bs, c, h, w, 1]
32 | '''
33 | image_num = len(slice_idx)
34 | tensor = tensor[0, slice_idx, ...].permute(0, 3, 1, 2).cpu().data # [c, 1, h, w]
35 | image_grid = vutils.make_grid(tensor, nrow=image_num, padding=0, normalize=True, scale_each=True)
36 | vutils.save_image(image_grid, file_name, nrow=1)
37 |
38 |
39 |
40 | def map_coordinates(input, coordinates):
41 | ''' PyTorch version of scipy.ndimage.interpolation.map_coordinates
42 | input: (B, H, W, C)
43 | coordinates: (2, ...)
44 | '''
45 | bs, h, w, c = input.size()
46 |
47 | def _coordinates_pad_wrap(h, w, coordinates):
48 | coordinates[0] = coordinates[0] % h
49 | coordinates[1] = coordinates[1] % w
50 | return coordinates
51 |
52 | co_floor = torch.floor(coordinates).long()
53 | co_ceil = torch.ceil(coordinates).long()
54 | d1 = (coordinates[1] - co_floor[1].float())
55 | d2 = (coordinates[0] - co_floor[0].float())
56 | co_floor = _coordinates_pad_wrap(h, w, co_floor)
57 | co_ceil = _coordinates_pad_wrap(h, w, co_ceil)
58 |
59 | f00 = input[:, co_floor[0], co_floor[1], :]
60 | f10 = input[:, co_floor[0], co_ceil[1], :]
61 | f01 = input[:, co_ceil[0], co_floor[1], :]
62 | f11 = input[:, co_ceil[0], co_ceil[1], :]
63 | d1 = d1[None, :, :, None].expand(bs, -1, -1, c)
64 | d2 = d2[None, :, :, None].expand(bs, -1, -1, c)
65 |
66 | fx1 = f00 + d1 * (f10 - f00)
67 | fx2 = f01 + d1 * (f11 - f01)
68 |
69 | return fx1 + d2 * (fx2 - fx1)
70 |
71 |
72 | def ct_parallel_project_2d(img, theta):
73 | bs, h, w, c = img.size()
74 |
75 | # (y, x)=(i, j): [0, w] -> [-0.5, 0.5]
76 | y, x = torch.meshgrid([torch.arange(h, dtype=torch.float32) / h - 0.5,
77 | torch.arange(w, dtype=torch.float32) / w - 0.5])
78 |
79 | # Rotation transform matrix: simulate parallel projection rays
80 | x_rot = x * torch.cos(theta) - y * torch.sin(theta)
81 | y_rot = x * torch.sin(theta) + y * torch.cos(theta)
82 |
83 | # Reverse back to index [0, w]
84 | x_rot = (x_rot + 0.5) * w
85 | y_rot = (y_rot + 0.5) * h
86 |
87 | # Resample (x, y) index of the pixel on the projection ray-theta
88 | sample_coords = torch.stack([y_rot, x_rot], dim=0).cuda() # [2, h, w]
89 | img_resampled = map_coordinates(img, sample_coords) # [b, h, w, c]
90 |
91 | # Compute integral projections along rays
92 | proj = torch.mean(img_resampled, dim=1, keepdim=True) # [b, 1, w, c]
93 |
94 | return proj
95 |
96 |
97 | def ct_parallel_project_2d_batch(img, thetas):
98 | '''
99 | img: input tensor [B, H, W, C]
100 | thetas: list of projection angles
101 | '''
102 | projs = []
103 | for theta in thetas:
104 | proj = ct_parallel_project_2d(img, theta)
105 | projs.append(proj)
106 | projs = torch.cat(projs, dim=1) # [b, num, w, c]
107 |
108 | return projs
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/attention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/__pycache__/attention.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/ema.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/__pycache__/ema.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/ldm_inverse/__pycache__/condition_methods.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm_inverse/__pycache__/condition_methods.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm_inverse/__pycache__/measurements.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm_inverse/__pycache__/measurements.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm_inverse/condition_methods.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import torch
3 |
4 | __CONDITIONING_METHOD__ = {}
5 |
6 | def register_conditioning_method(name: str):
7 | def wrapper(cls):
8 | if __CONDITIONING_METHOD__.get(name, None):
9 | raise NameError(f"Name {name} is already registered!")
10 | __CONDITIONING_METHOD__[name] = cls
11 | return cls
12 | return wrapper
13 |
14 | def get_conditioning_method(name: str, model, operator, noiser, **kwargs):
15 | if __CONDITIONING_METHOD__.get(name, None) is None:
16 | raise NameError(f"Name {name} is not defined!")
17 | return __CONDITIONING_METHOD__[name](model=model, operator=operator, noiser=noiser, **kwargs)
18 |
19 |
20 | class ConditioningMethod(ABC):
21 | def __init__(self, model, operator, noiser, **kwargs):
22 | self.model = model
23 | self.operator = operator
24 | self.noiser = noiser
25 |
26 | def project(self, data, noisy_measurement, **kwargs):
27 | return self.operator.project(data=data, measurement=noisy_measurement, **kwargs)
28 |
29 | def grad_and_value(self, x_prev, x_0_hat, measurement, **kwargs):
30 | if self.noiser.__name__ == 'gaussian':
31 | difference = measurement - self.operator.forward(self.model.differentiable_decode_first_stage( x_0_hat ), **kwargs)
32 | norm = torch.linalg.norm(difference)
33 | norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
34 | elif self.noiser.__name__ == 'poisson':
35 | Ax = self.operator.forward(self.model.differentiable_decode_first_stage(x_0_hat), **kwargs)
36 | difference = measurement-Ax
37 | norm = torch.linalg.norm(difference) / measurement.abs()
38 | norm = norm.mean()
39 | norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
40 |
41 | else:
42 | raise NotImplementedError
43 |
44 | return norm_grad, norm
45 |
46 |
47 | @abstractmethod
48 | def conditioning(self, x_t, measurement, noisy_measurement=None, **kwargs):
49 | pass
50 |
51 |
52 | @register_conditioning_method(name='ps')
53 | class PosteriorSampling(ConditioningMethod):
54 | def __init__(self, model, operator, noiser, **kwargs):
55 | super().__init__(model, operator, noiser)
56 | self.operator = operator
57 |
58 | def conditioning(self, x_prev, x_t, x_0_hat, measurement, scale=None, **kwargs):
59 | if scale is None:
60 | scale = 0.3
61 |
62 | norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs)
63 | x_t -= norm_grad * scale
64 | return x_t, norm
65 |
66 |
--------------------------------------------------------------------------------
/model_loader.py:
--------------------------------------------------------------------------------
1 | from ldm.util import instantiate_from_config
2 | import yaml
3 | import torch
4 |
5 | def load_yaml(file_path: str) -> dict:
6 | with open(file_path) as f:
7 | config = yaml.load(f, Loader=yaml.FullLoader)
8 | return config
9 |
10 |
11 | def load_model_from_config(config, ckpt, train=False):
12 | print(f"Loading model from {ckpt}")
13 | pl_sd = torch.load(ckpt)#, map_location="cpu")
14 | sd = pl_sd["state_dict"]
15 | model = instantiate_from_config(config.model)
16 | _, _ = model.load_state_dict(sd, strict=False)
17 |
18 | model.cuda()
19 |
20 | if train:
21 | model.train()
22 | else:
23 | model.eval()
24 |
25 | return model
--------------------------------------------------------------------------------
/samples/60004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/samples/60004.png
--------------------------------------------------------------------------------
/samples/60074.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/samples/60074.png
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/scripts/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/scripts/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/download_first_stages.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Usage:
4 | # ./scripts/download_first_stages.sh kl-f4 kl-f8
5 |
6 | MODELS=("kl-f4" "kl-f8" "kl-f16" "kl-f32" "vq-f4" "vq-f4-noattn" "vq-f8" "vq-f8-n256" "vq-f16")
7 | DOWNLOAD_PATH="https://ommer-lab.com/files/latent-diffusion"
8 |
9 | function download_first_stages() {
10 | local list=("$@")
11 |
12 | for arg in "${list[@]}"; do
13 | for model in "${MODELS[@]}"; do
14 | if [[ "$model" == "$arg" ]]; then
15 | echo "Downloading $model"
16 | model_dir="./models/first_stage_models/$arg"
17 | if [ ! -d "$model_dir" ]; then
18 | mkdir -p "$model_dir"
19 | echo "Directory created: $model_dir"
20 | else
21 | echo "Directory already exists: $model_dir"
22 | fi
23 | wget -O "$model_dir/model.zip" "$DOWNLOAD_PATH/$arg.zip"
24 | unzip -o "$model_dir/model.zip" -d "$model_dir"
25 | rm -rf "$model_dir/model.zip"
26 | fi
27 | done
28 | done
29 | }
30 |
31 | download_first_stages "$@"
32 |
--------------------------------------------------------------------------------
/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Usage:
4 | # ./scripts/download_models.sh celeba ffhq
5 |
6 | MODELS=("celeba" "ffhq" "lsun_churches" "lsun_bedrooms" "text2img" "cin" "semantic_synthesis" "semantic_synthesis256" "sr_bsr" "layout2img_model" "inpainting_big")
7 | DOWNLOAD_PATH="https://ommer-lab.com/files/latent-diffusion"
8 |
9 | function download_models() {
10 | local list=("$@")
11 |
12 | for arg in "${list[@]}"; do
13 | for model in "${MODELS[@]}"; do
14 | if [[ "$model" == "$arg" ]]; then
15 | echo "Downloading $model"
16 | model_dir="./models/ldm/$arg"
17 | if [ ! -d "$model_dir" ]; then
18 | mkdir -p "$model_dir"
19 | echo "Directory created: $model_dir"
20 | else
21 | echo "Directory already exists: $model_dir"
22 | fi
23 | wget -O "$model_dir/model.zip" "$DOWNLOAD_PATH/$arg.zip"
24 | unzip -o "$model_dir/model.zip" -d "$model_dir"
25 | rm -rf "$model_dir/model.zip"
26 | fi
27 | done
28 | done
29 | }
30 |
31 | download_models "$@"
32 |
--------------------------------------------------------------------------------
/scripts/inpaint.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | from omegaconf import OmegaConf
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch
7 | from main import instantiate_from_config
8 | from ldm.models.diffusion.ddim import DDIMSampler
9 |
10 |
11 | def make_batch(image, mask, device):
12 | image = np.array(Image.open(image).convert("RGB"))
13 | image = image.astype(np.float32)/255.0
14 | image = image[None].transpose(0,3,1,2)
15 | image = torch.from_numpy(image)
16 |
17 | mask = np.array(Image.open(mask).convert("L"))
18 | mask = mask.astype(np.float32)/255.0
19 | mask = mask[None,None]
20 | mask[mask < 0.5] = 0
21 | mask[mask >= 0.5] = 1
22 | mask = torch.from_numpy(mask)
23 |
24 | masked_image = (1-mask)*image
25 |
26 | batch = {"image": image, "mask": mask, "masked_image": masked_image}
27 | for k in batch:
28 | batch[k] = batch[k].to(device=device)
29 | batch[k] = batch[k]*2.0-1.0
30 | return batch
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "--indir",
37 | type=str,
38 | nargs="?",
39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
40 | )
41 | parser.add_argument(
42 | "--outdir",
43 | type=str,
44 | nargs="?",
45 | help="dir to write results to",
46 | )
47 | parser.add_argument(
48 | "--steps",
49 | type=int,
50 | default=50,
51 | help="number of ddim sampling steps",
52 | )
53 | opt = parser.parse_args()
54 |
55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
56 | images = [x.replace("_mask.png", ".png") for x in masks]
57 | print(f"Found {len(masks)} inputs.")
58 |
59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
60 | model = instantiate_from_config(config.model)
61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
62 | strict=False)
63 |
64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
65 | model = model.to(device)
66 | sampler = DDIMSampler(model)
67 |
68 | os.makedirs(opt.outdir, exist_ok=True)
69 | with torch.no_grad():
70 | with model.ema_scope():
71 | for image, mask in tqdm(zip(images, masks)):
72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1])
73 | batch = make_batch(image, mask, device=device)
74 |
75 | # encode masked image and concat downsampled mask
76 | c = model.cond_stage_model.encode(batch["masked_image"])
77 | cc = torch.nn.functional.interpolate(batch["mask"],
78 | size=c.shape[-2:])
79 | c = torch.cat((c, cc), dim=1)
80 |
81 | shape = (c.shape[1]-1,)+c.shape[2:]
82 | samples_ddim, _ = sampler.sample(S=opt.steps,
83 | conditioning=c,
84 | batch_size=c.shape[0],
85 | shape=shape,
86 | verbose=False)
87 | x_samples_ddim = model.decode_first_stage(samples_ddim)
88 |
89 | image = torch.clamp((batch["image"]+1.0)/2.0,
90 | min=0.0, max=1.0)
91 | mask = torch.clamp((batch["mask"]+1.0)/2.0,
92 | min=0.0, max=1.0)
93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
94 | min=0.0, max=1.0)
95 |
96 | inpainted = (1-mask)*image+mask*predicted_image
97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
99 |
--------------------------------------------------------------------------------
/scripts/inv.py:
--------------------------------------------------------------------------------
1 | from taming.models import vqgan
2 | from ldm.models.diffusion.ddim import DDIMSampler
3 | #@title loading utils
4 | import torch
5 | from omegaconf import OmegaConf
6 | from ldm.util import instantiate_from_config
7 | import numpy as np
8 | from PIL import Image
9 | from einops import rearrange
10 | from torchvision.utils import make_grid
11 | from einops import rearrange, repeat
12 | from utils import *
13 | import PIL
14 | import argparse
15 |
16 | def load_model_from_config(config, ckpt, train=False):
17 | print(f"Loading model from {ckpt}")
18 | pl_sd = torch.load(ckpt)#, map_location="cpu")
19 | sd = pl_sd["state_dict"]
20 | model = instantiate_from_config(config.model)
21 | m, u = model.load_state_dict(sd, strict=False)
22 |
23 | model.cuda()
24 | # model.train()
25 | if train:
26 | model.train()
27 | else:
28 | model.eval()
29 | return model
30 |
31 |
32 | def get_model(model_type = None):
33 | if model_type is None:
34 | config = OmegaConf.load("configs/latent-diffusion/celebahq-ldm-vq-4.yaml")
35 | model = load_model_from_config(config, "models/ldm/celeba256/model.ckpt")
36 | elif model_type == "celeb":
37 | config = OmegaConf.load("configs/latent-diffusion/celebahq-ldm-vq-4.yaml")
38 | model = load_model_from_config(config, "models/ldm/celeba256/model.ckpt")
39 | else:
40 | model = None
41 | return model
42 |
43 |
44 |
45 | def load_img(path):
46 | image = Image.open(path).convert("RGB")
47 | w, h = image.size
48 | print(f"loaded input image of size ({w}, {h}) from {path}")
49 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
50 | image = image.resize((w, h), resample=PIL.Image.LANCZOS)
51 | image = np.array(image).astype(np.float32) / 255.0
52 | image = image[None].transpose(0, 3, 1, 2)
53 | image = torch.from_numpy(image)
54 | return 2.*image - 1.
55 |
56 |
57 |
58 | device = torch.device("cuda")
59 | model = get_model()
60 | model.learning_rate = 5e-3
61 | ddim_steps = 100
62 | num_timesteps = 1000
63 | shape=[1, 3, 64, 64]
64 | z = torch.randn(shape, device=device)
65 | z.requires_grad = True
66 | init_image = load_img("/content/drive/MyDrive/stable-diffusion/ldct/test_recon.png").to(device)
67 | init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))
68 | opt = torch.optim.AdamW([z], lr = model.learning_rate)
69 | loss = torch.nn.MSELoss()
70 | angles = torch.tensor(np.linspace(0, np.pi, 25, endpoint=False))
71 | sampler = DDIMSampler(model)
72 | sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0, verbose = False)
73 | projection_orig = ct_parallel_project_2d_batch(init_image.permute(0,2,3,1), angles)
74 |
75 |
76 |
77 | for i in range(2000):
78 | opt.zero_grad()
79 | decoded_z = sampler.ddecode(z, t_start = 100, temp = 0)
80 | decoded_img = model.differentiable_decode_first_stage(decoded_z)
81 | # decoded_img = model.differentiable_decode_first_stage(z)
82 | projection_recon = ct_parallel_project_2d_batch(decoded_img.permute(0,2,3,1), angles)
83 | output = loss(projection_orig, projection_recon)
84 | # output = loss(init_image, decoded_img)
85 | output.backward()
86 | opt.step()
87 | loss_val = output.detach().cpu().numpy()
88 | if i % 5 == 0:
89 | print(loss_val, "loss ", str(i), " iter")
90 | if loss_val < 6.5e-5:
91 | break
92 |
93 | # sampler.make_schedule(ddim_num_steps=50, ddim_eta=0, verbose = False)
94 |
95 |
96 | print("start encoding decoding")
97 | # decoded_z = sampler.decode(z, cond = None, t_start = 499)
98 | t = repeat(torch.tensor([250]), '1 -> b', b=1)
99 | t = t.to(device).long()
100 | # np.save("CSGM_nopnp_500iter.npy", model.decode_first_stage(z).detach().cpu().numpy())
101 | np.save("CSGM_recon_2000iter_100_tstep_norandomness.npy", model.decode_first_stage(sampler.ddecode(z,cond=None,t_start=100, temp = 0)).detach().cpu().numpy())
102 | # print(z)
103 | # encoded = sampler.stochastic_encode(z, t)
104 | # decoded = sampler.decode(encoded, None, 250)
105 | # np.save("PnP_CSGM_recon500iter.npy", model.decode_first_stage(decoded).detach().cpu().numpy())
106 |
107 |
--------------------------------------------------------------------------------
/scripts/tests/test_watermark.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import fire
3 | from imwatermark import WatermarkDecoder
4 |
5 |
6 | def testit(img_path):
7 | bgr = cv2.imread(img_path)
8 | decoder = WatermarkDecoder('bytes', 136)
9 | watermark = decoder.decode(bgr, 'dwtDct')
10 | try:
11 | dec = watermark.decode('utf-8')
12 | except:
13 | dec = "null"
14 | print(dec)
15 |
16 |
17 | if __name__ == "__main__":
18 | fire.Fire(testit)
--------------------------------------------------------------------------------
/scripts/train_searcher.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import scann
4 | import argparse
5 | import glob
6 | from multiprocessing import cpu_count
7 | from tqdm import tqdm
8 |
9 | from ldm.util import parallel_data_prefetch
10 |
11 |
12 | def search_bruteforce(searcher):
13 | return searcher.score_brute_force().build()
14 |
15 |
16 | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
17 | partioning_trainsize, num_leaves, num_leaves_to_search):
18 | return searcher.tree(num_leaves=num_leaves,
19 | num_leaves_to_search=num_leaves_to_search,
20 | training_sample_size=partioning_trainsize). \
21 | score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
22 |
23 |
24 | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
25 | return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
26 | reorder_k).build()
27 |
28 | def load_datapool(dpath):
29 |
30 |
31 | def load_single_file(saved_embeddings):
32 | compressed = np.load(saved_embeddings)
33 | database = {key: compressed[key] for key in compressed.files}
34 | return database
35 |
36 | def load_multi_files(data_archive):
37 | database = {key: [] for key in data_archive[0].files}
38 | for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
39 | for key in d.files:
40 | database[key].append(d[key])
41 |
42 | return database
43 |
44 | print(f'Load saved patch embedding from "{dpath}"')
45 | file_content = glob.glob(os.path.join(dpath, '*.npz'))
46 |
47 | if len(file_content) == 1:
48 | data_pool = load_single_file(file_content[0])
49 | elif len(file_content) > 1:
50 | data = [np.load(f) for f in file_content]
51 | prefetched_data = parallel_data_prefetch(load_multi_files, data,
52 | n_proc=min(len(data), cpu_count()), target_data_type='dict')
53 |
54 | data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
55 | else:
56 | raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
57 |
58 | print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
59 | return data_pool
60 |
61 |
62 | def train_searcher(opt,
63 | metric='dot_product',
64 | partioning_trainsize=None,
65 | reorder_k=None,
66 | # todo tune
67 | aiq_thld=0.2,
68 | dims_per_block=2,
69 | num_leaves=None,
70 | num_leaves_to_search=None,):
71 |
72 | data_pool = load_datapool(opt.database)
73 | k = opt.knn
74 |
75 | if not reorder_k:
76 | reorder_k = 2 * k
77 |
78 | # normalize
79 | # embeddings =
80 | searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
81 | pool_size = data_pool['embedding'].shape[0]
82 |
83 | print(*(['#'] * 100))
84 | print('Initializing scaNN searcher with the following values:')
85 | print(f'k: {k}')
86 | print(f'metric: {metric}')
87 | print(f'reorder_k: {reorder_k}')
88 | print(f'anisotropic_quantization_threshold: {aiq_thld}')
89 | print(f'dims_per_block: {dims_per_block}')
90 | print(*(['#'] * 100))
91 | print('Start training searcher....')
92 | print(f'N samples in pool is {pool_size}')
93 |
94 | # this reflects the recommended design choices proposed at
95 | # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
96 | if pool_size < 2e4:
97 | print('Using brute force search.')
98 | searcher = search_bruteforce(searcher)
99 | elif 2e4 <= pool_size and pool_size < 1e5:
100 | print('Using asymmetric hashing search and reordering.')
101 | searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
102 | else:
103 | print('Using using partioning, asymmetric hashing search and reordering.')
104 |
105 | if not partioning_trainsize:
106 | partioning_trainsize = data_pool['embedding'].shape[0] // 10
107 | if not num_leaves:
108 | num_leaves = int(np.sqrt(pool_size))
109 |
110 | if not num_leaves_to_search:
111 | num_leaves_to_search = max(num_leaves // 20, 1)
112 |
113 | print('Partitioning params:')
114 | print(f'num_leaves: {num_leaves}')
115 | print(f'num_leaves_to_search: {num_leaves_to_search}')
116 | # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
117 | searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
118 | partioning_trainsize, num_leaves, num_leaves_to_search)
119 |
120 | print('Finish training searcher')
121 | searcher_savedir = opt.target_path
122 | os.makedirs(searcher_savedir, exist_ok=True)
123 | searcher.serialize(searcher_savedir)
124 | print(f'Saved trained searcher under "{searcher_savedir}"')
125 |
126 | if __name__ == '__main__':
127 | sys.path.append(os.getcwd())
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--database',
130 | '-d',
131 | default='data/rdm/retrieval_databases/openimages',
132 | type=str,
133 | help='path to folder containing the clip feature of the database')
134 | parser.add_argument('--target_path',
135 | '-t',
136 | default='data/rdm/searchers/openimages',
137 | type=str,
138 | help='path to the target folder where the searcher shall be stored.')
139 | parser.add_argument('--knn',
140 | '-k',
141 | default=20,
142 | type=int,
143 | help='number of nearest neighbors, for which the searcher shall be optimized')
144 |
145 | opt, _ = parser.parse_known_args()
146 |
147 | train_searcher(opt,)
--------------------------------------------------------------------------------
/src/clip/clip.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: clip
3 | Version: 1.0
4 | Author: OpenAI
5 | Provides-Extra: dev
6 | License-File: LICENSE
7 |
--------------------------------------------------------------------------------
/src/clip/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/src/clip/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/src/clip/hubconf.py:
--------------------------------------------------------------------------------
1 | from clip.clip import tokenize as _tokenize, load as _load, available_models as _available_models
2 | import re
3 | import string
4 |
5 | dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"]
6 |
7 | # For compatibility (cannot include special characters in function name)
8 | model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()}
9 |
10 | def _create_hub_entrypoint(model):
11 | def entrypoint(**kwargs):
12 | return _load(model, **kwargs)
13 |
14 | entrypoint.__doc__ = f"""Loads the {model} CLIP model
15 |
16 | Parameters
17 | ----------
18 | device : Union[str, torch.device]
19 | The device to put the loaded model
20 |
21 | jit : bool
22 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
23 |
24 | download_root: str
25 | path to download the model files; by default, it uses "~/.cache/clip"
26 |
27 | Returns
28 | -------
29 | model : torch.nn.Module
30 | The {model} CLIP model
31 |
32 | preprocess : Callable[[PIL.Image], torch.Tensor]
33 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
34 | """
35 | return entrypoint
36 |
37 | def tokenize():
38 | return _tokenize
39 |
40 | _entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()}
41 |
42 | globals().update(_entrypoints)
--------------------------------------------------------------------------------
/src/clip/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pkg_resources
4 | from setuptools import setup, find_packages
5 |
6 | setup(
7 | name="clip",
8 | py_modules=["clip"],
9 | version="1.0",
10 | description="",
11 | author="OpenAI",
12 | packages=find_packages(exclude=["tests*"]),
13 | install_requires=[
14 | str(r)
15 | for r in pkg_resources.parse_requirements(
16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
17 | )
18 | ],
19 | include_package_data=True,
20 | extras_require={'dev': ['pytest']},
21 | )
22 |
--------------------------------------------------------------------------------
/src/clip/tests/test_consistency.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from PIL import Image
5 |
6 | import clip
7 |
8 |
9 | @pytest.mark.parametrize('model_name', clip.available_models())
10 | def test_consistency(model_name):
11 | device = "cpu"
12 | jit_model, transform = clip.load(model_name, device=device, jit=True)
13 | py_model, _ = clip.load(model_name, device=device, jit=False)
14 |
15 | image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
16 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
17 |
18 | with torch.no_grad():
19 | logits_per_image, _ = jit_model(image, text)
20 | jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
21 |
22 | logits_per_image, _ = py_model(image, text)
23 | py_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
24 |
25 | assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1)
26 |
--------------------------------------------------------------------------------
/src/taming-transformers/scripts/extract_depth.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from tqdm import trange
5 | from PIL import Image
6 |
7 |
8 | def get_state(gpu):
9 | import torch
10 | midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
11 | if gpu:
12 | midas.cuda()
13 | midas.eval()
14 |
15 | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
16 | transform = midas_transforms.default_transform
17 |
18 | state = {"model": midas,
19 | "transform": transform}
20 | return state
21 |
22 |
23 | def depth_to_rgba(x):
24 | assert x.dtype == np.float32
25 | assert len(x.shape) == 2
26 | y = x.copy()
27 | y.dtype = np.uint8
28 | y = y.reshape(x.shape+(4,))
29 | return np.ascontiguousarray(y)
30 |
31 |
32 | def rgba_to_depth(x):
33 | assert x.dtype == np.uint8
34 | assert len(x.shape) == 3 and x.shape[2] == 4
35 | y = x.copy()
36 | y.dtype = np.float32
37 | y = y.reshape(x.shape[:2])
38 | return np.ascontiguousarray(y)
39 |
40 |
41 | def run(x, state):
42 | model = state["model"]
43 | transform = state["transform"]
44 | hw = x.shape[:2]
45 | with torch.no_grad():
46 | prediction = model(transform((x + 1.0) * 127.5).cuda())
47 | prediction = torch.nn.functional.interpolate(
48 | prediction.unsqueeze(1),
49 | size=hw,
50 | mode="bicubic",
51 | align_corners=False,
52 | ).squeeze()
53 | output = prediction.cpu().numpy()
54 | return output
55 |
56 |
57 | def get_filename(relpath, level=-2):
58 | # save class folder structure and filename:
59 | fn = relpath.split(os.sep)[level:]
60 | folder = fn[-2]
61 | file = fn[-1].split('.')[0]
62 | return folder, file
63 |
64 |
65 | def save_depth(dataset, path, debug=False):
66 | os.makedirs(path)
67 | N = len(dset)
68 | if debug:
69 | N = 10
70 | state = get_state(gpu=True)
71 | for idx in trange(N, desc="Data"):
72 | ex = dataset[idx]
73 | image, relpath = ex["image"], ex["relpath"]
74 | folder, filename = get_filename(relpath)
75 | # prepare
76 | folderabspath = os.path.join(path, folder)
77 | os.makedirs(folderabspath, exist_ok=True)
78 | savepath = os.path.join(folderabspath, filename)
79 | # run model
80 | xout = run(image, state)
81 | I = depth_to_rgba(xout)
82 | Image.fromarray(I).save("{}.png".format(savepath))
83 |
84 |
85 | if __name__ == "__main__":
86 | from taming.data.imagenet import ImageNetTrain, ImageNetValidation
87 | out = "data/imagenet_depth"
88 | if not os.path.exists(out):
89 | print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
90 | "(be prepared that the output size will be larger than ImageNet itself).")
91 | exit(1)
92 |
93 | # go
94 | dset = ImageNetValidation()
95 | abspath = os.path.join(out, "val")
96 | if os.path.exists(abspath):
97 | print("{} exists - not doing anything.".format(abspath))
98 | else:
99 | print("preparing {}".format(abspath))
100 | save_depth(dset, abspath)
101 | print("done with validation split")
102 |
103 | dset = ImageNetTrain()
104 | abspath = os.path.join(out, "train")
105 | if os.path.exists(abspath):
106 | print("{} exists - not doing anything.".format(abspath))
107 | else:
108 | print("preparing {}".format(abspath))
109 | save_depth(dset, abspath)
110 | print("done with train split")
111 |
112 | print("done done.")
113 |
--------------------------------------------------------------------------------
/src/taming-transformers/scripts/extract_segmentation.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import numpy as np
3 | import scipy
4 | import torch
5 | import torch.nn as nn
6 | from scipy import ndimage
7 | from tqdm import tqdm, trange
8 | from PIL import Image
9 | import torch.hub
10 | import torchvision
11 | import torch.nn.functional as F
12 |
13 | # download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
14 | # https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
15 | # and put the path here
16 | CKPT_PATH = "TODO"
17 |
18 | rescale = lambda x: (x + 1.) / 2.
19 |
20 | def rescale_bgr(x):
21 | x = (x+1)*127.5
22 | x = torch.flip(x, dims=[0])
23 | return x
24 |
25 |
26 | class COCOStuffSegmenter(nn.Module):
27 | def __init__(self, config):
28 | super().__init__()
29 | self.config = config
30 | self.n_labels = 182
31 | model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
32 | ckpt_path = CKPT_PATH
33 | model.load_state_dict(torch.load(ckpt_path))
34 | self.model = model
35 |
36 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
37 | self.image_transform = torchvision.transforms.Compose([
38 | torchvision.transforms.Lambda(lambda image: torch.stack(
39 | [normalize(rescale_bgr(x)) for x in image]))
40 | ])
41 |
42 | def forward(self, x, upsample=None):
43 | x = self._pre_process(x)
44 | x = self.model(x)
45 | if upsample is not None:
46 | x = torch.nn.functional.upsample_bilinear(x, size=upsample)
47 | return x
48 |
49 | def _pre_process(self, x):
50 | x = self.image_transform(x)
51 | return x
52 |
53 | @property
54 | def mean(self):
55 | # bgr
56 | return [104.008, 116.669, 122.675]
57 |
58 | @property
59 | def std(self):
60 | return [1.0, 1.0, 1.0]
61 |
62 | @property
63 | def input_size(self):
64 | return [3, 224, 224]
65 |
66 |
67 | def run_model(img, model):
68 | model = model.eval()
69 | with torch.no_grad():
70 | segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
71 | segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
72 | return segmentation.detach().cpu()
73 |
74 |
75 | def get_input(batch, k):
76 | x = batch[k]
77 | if len(x.shape) == 3:
78 | x = x[..., None]
79 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
80 | return x.float()
81 |
82 |
83 | def save_segmentation(segmentation, path):
84 | # --> class label to uint8, save as png
85 | os.makedirs(os.path.dirname(path), exist_ok=True)
86 | assert len(segmentation.shape)==4
87 | assert segmentation.shape[0]==1
88 | for seg in segmentation:
89 | seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
90 | seg = Image.fromarray(seg)
91 | seg.save(path)
92 |
93 |
94 | def iterate_dataset(dataloader, destpath, model):
95 | os.makedirs(destpath, exist_ok=True)
96 | num_processed = 0
97 | for i, batch in tqdm(enumerate(dataloader), desc="Data"):
98 | try:
99 | img = get_input(batch, "image")
100 | img = img.cuda()
101 | seg = run_model(img, model)
102 |
103 | path = batch["relative_file_path_"][0]
104 | path = os.path.splitext(path)[0]
105 |
106 | path = os.path.join(destpath, path + ".png")
107 | save_segmentation(seg, path)
108 | num_processed += 1
109 | except Exception as e:
110 | print(e)
111 | print("but anyhow..")
112 |
113 | print("Processed {} files. Bye.".format(num_processed))
114 |
115 |
116 | from taming.data.sflckr import Examples
117 | from torch.utils.data import DataLoader
118 |
119 | if __name__ == "__main__":
120 | dest = sys.argv[1]
121 | batchsize = 1
122 | print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
123 |
124 | model = COCOStuffSegmenter({}).cuda()
125 | print("Instantiated model.")
126 |
127 | dataset = Examples()
128 | dloader = DataLoader(dataset, batch_size=batchsize)
129 | iterate_dataset(dataloader=dloader, destpath=dest, model=model)
130 | print("done.")
131 |
--------------------------------------------------------------------------------
/src/taming-transformers/scripts/extract_submodel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 |
4 | if __name__ == "__main__":
5 | inpath = sys.argv[1]
6 | outpath = sys.argv[2]
7 | submodel = "cond_stage_model"
8 | if len(sys.argv) > 3:
9 | submodel = sys.argv[3]
10 |
11 | print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
12 |
13 | sd = torch.load(inpath, map_location="cpu")
14 | new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
15 | for k,v in sd["state_dict"].items()
16 | if k.startswith("cond_stage_model"))}
17 | torch.save(new_sd, outpath)
18 |
--------------------------------------------------------------------------------
/src/taming-transformers/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='taming-transformers',
5 | version='0.0.1',
6 | description='Taming Transformers for High-Resolution Image Synthesis',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
14 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/ade20k.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import albumentations
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 | from taming.data.sflckr import SegmentationBase # for examples included in repo
9 |
10 |
11 | class Examples(SegmentationBase):
12 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
13 | super().__init__(data_csv="data/ade20k_examples.txt",
14 | data_root="data/ade20k_images",
15 | segmentation_root="data/ade20k_segmentations",
16 | size=size, random_crop=random_crop,
17 | interpolation=interpolation,
18 | n_labels=151, shift_segmentation=False)
19 |
20 |
21 | # With semantic map and scene label
22 | class ADE20kBase(Dataset):
23 | def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
24 | self.split = self.get_split()
25 | self.n_labels = 151 # unknown + 150
26 | self.data_csv = {"train": "data/ade20k_train.txt",
27 | "validation": "data/ade20k_test.txt"}[self.split]
28 | self.data_root = "data/ade20k_root"
29 | with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
30 | self.scene_categories = f.read().splitlines()
31 | self.scene_categories = dict(line.split() for line in self.scene_categories)
32 | with open(self.data_csv, "r") as f:
33 | self.image_paths = f.read().splitlines()
34 | self._length = len(self.image_paths)
35 | self.labels = {
36 | "relative_file_path_": [l for l in self.image_paths],
37 | "file_path_": [os.path.join(self.data_root, "images", l)
38 | for l in self.image_paths],
39 | "relative_segmentation_path_": [l.replace(".jpg", ".png")
40 | for l in self.image_paths],
41 | "segmentation_path_": [os.path.join(self.data_root, "annotations",
42 | l.replace(".jpg", ".png"))
43 | for l in self.image_paths],
44 | "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
45 | for l in self.image_paths],
46 | }
47 |
48 | size = None if size is not None and size<=0 else size
49 | self.size = size
50 | if crop_size is None:
51 | self.crop_size = size if size is not None else None
52 | else:
53 | self.crop_size = crop_size
54 | if self.size is not None:
55 | self.interpolation = interpolation
56 | self.interpolation = {
57 | "nearest": cv2.INTER_NEAREST,
58 | "bilinear": cv2.INTER_LINEAR,
59 | "bicubic": cv2.INTER_CUBIC,
60 | "area": cv2.INTER_AREA,
61 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
62 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
63 | interpolation=self.interpolation)
64 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
65 | interpolation=cv2.INTER_NEAREST)
66 |
67 | if crop_size is not None:
68 | self.center_crop = not random_crop
69 | if self.center_crop:
70 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
71 | else:
72 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
73 | self.preprocessor = self.cropper
74 |
75 | def __len__(self):
76 | return self._length
77 |
78 | def __getitem__(self, i):
79 | example = dict((k, self.labels[k][i]) for k in self.labels)
80 | image = Image.open(example["file_path_"])
81 | if not image.mode == "RGB":
82 | image = image.convert("RGB")
83 | image = np.array(image).astype(np.uint8)
84 | if self.size is not None:
85 | image = self.image_rescaler(image=image)["image"]
86 | segmentation = Image.open(example["segmentation_path_"])
87 | segmentation = np.array(segmentation).astype(np.uint8)
88 | if self.size is not None:
89 | segmentation = self.segmentation_rescaler(image=segmentation)["image"]
90 | if self.size is not None:
91 | processed = self.preprocessor(image=image, mask=segmentation)
92 | else:
93 | processed = {"image": image, "mask": segmentation}
94 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
95 | segmentation = processed["mask"]
96 | onehot = np.eye(self.n_labels)[segmentation]
97 | example["segmentation"] = onehot
98 | return example
99 |
100 |
101 | class ADE20kTrain(ADE20kBase):
102 | # default to random_crop=True
103 | def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
104 | super().__init__(config=config, size=size, random_crop=random_crop,
105 | interpolation=interpolation, crop_size=crop_size)
106 |
107 | def get_split(self):
108 | return "train"
109 |
110 |
111 | class ADE20kValidation(ADE20kBase):
112 | def get_split(self):
113 | return "validation"
114 |
115 |
116 | if __name__ == "__main__":
117 | dset = ADE20kValidation()
118 | ex = dset[0]
119 | for k in ["image", "scene_category", "segmentation"]:
120 | print(type(ex[k]))
121 | try:
122 | print(ex[k].shape)
123 | except:
124 | print(ex[k])
125 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/annotated_objects_coco.py:
--------------------------------------------------------------------------------
1 | import json
2 | from itertools import chain
3 | from pathlib import Path
4 | from typing import Iterable, Dict, List, Callable, Any
5 | from collections import defaultdict
6 |
7 | from tqdm import tqdm
8 |
9 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
10 | from taming.data.helper_types import Annotation, ImageDescription, Category
11 |
12 | COCO_PATH_STRUCTURE = {
13 | 'train': {
14 | 'top_level': '',
15 | 'instances_annotations': 'annotations/instances_train2017.json',
16 | 'stuff_annotations': 'annotations/stuff_train2017.json',
17 | 'files': 'train2017'
18 | },
19 | 'validation': {
20 | 'top_level': '',
21 | 'instances_annotations': 'annotations/instances_val2017.json',
22 | 'stuff_annotations': 'annotations/stuff_val2017.json',
23 | 'files': 'val2017'
24 | }
25 | }
26 |
27 |
28 | def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
29 | return {
30 | str(img['id']): ImageDescription(
31 | id=img['id'],
32 | license=img.get('license'),
33 | file_name=img['file_name'],
34 | coco_url=img['coco_url'],
35 | original_size=(img['width'], img['height']),
36 | date_captured=img.get('date_captured'),
37 | flickr_url=img.get('flickr_url')
38 | )
39 | for img in description_json
40 | }
41 |
42 |
43 | def load_categories(category_json: Iterable) -> Dict[str, Category]:
44 | return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
45 | for cat in category_json if cat['name'] != 'other'}
46 |
47 |
48 | def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
49 | category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
50 | annotations = defaultdict(list)
51 | total = sum(len(a) for a in annotations_json)
52 | for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
53 | image_id = str(ann['image_id'])
54 | if image_id not in image_descriptions:
55 | raise ValueError(f'image_id [{image_id}] has no image description.')
56 | category_id = ann['category_id']
57 | try:
58 | category_no = category_no_for_id(str(category_id))
59 | except KeyError:
60 | continue
61 |
62 | width, height = image_descriptions[image_id].original_size
63 | bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
64 |
65 | annotations[image_id].append(
66 | Annotation(
67 | id=ann['id'],
68 | area=bbox[2]*bbox[3], # use bbox area
69 | is_group_of=ann['iscrowd'],
70 | image_id=ann['image_id'],
71 | bbox=bbox,
72 | category_id=str(category_id),
73 | category_no=category_no
74 | )
75 | )
76 | return dict(annotations)
77 |
78 |
79 | class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
80 | def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
81 | """
82 | @param data_path: is the path to the following folder structure:
83 | coco/
84 | ├── annotations
85 | │ ├── instances_train2017.json
86 | │ ├── instances_val2017.json
87 | │ ├── stuff_train2017.json
88 | │ └── stuff_val2017.json
89 | ├── train2017
90 | │ ├── 000000000009.jpg
91 | │ ├── 000000000025.jpg
92 | │ └── ...
93 | ├── val2017
94 | │ ├── 000000000139.jpg
95 | │ ├── 000000000285.jpg
96 | │ └── ...
97 | @param: split: one of 'train' or 'validation'
98 | @param: desired image size (give square images)
99 | """
100 | super().__init__(**kwargs)
101 | self.use_things = use_things
102 | self.use_stuff = use_stuff
103 |
104 | with open(self.paths['instances_annotations']) as f:
105 | inst_data_json = json.load(f)
106 | with open(self.paths['stuff_annotations']) as f:
107 | stuff_data_json = json.load(f)
108 |
109 | category_jsons = []
110 | annotation_jsons = []
111 | if self.use_things:
112 | category_jsons.append(inst_data_json['categories'])
113 | annotation_jsons.append(inst_data_json['annotations'])
114 | if self.use_stuff:
115 | category_jsons.append(stuff_data_json['categories'])
116 | annotation_jsons.append(stuff_data_json['annotations'])
117 |
118 | self.categories = load_categories(chain(*category_jsons))
119 | self.filter_categories()
120 | self.setup_category_id_and_number()
121 |
122 | self.image_descriptions = load_image_descriptions(inst_data_json['images'])
123 | annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
124 | self.annotations = self.filter_object_number(annotations, self.min_object_area,
125 | self.min_objects_per_image, self.max_objects_per_image)
126 | self.image_ids = list(self.annotations.keys())
127 | self.clean_up_annotations_and_image_descriptions()
128 |
129 | def get_path_structure(self) -> Dict[str, str]:
130 | if self.split not in COCO_PATH_STRUCTURE:
131 | raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
132 | return COCO_PATH_STRUCTURE[self.split]
133 |
134 | def get_image_path(self, image_id: str) -> Path:
135 | return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
136 |
137 | def get_image_description(self, image_id: str) -> Dict[str, Any]:
138 | # noinspection PyProtectedMember
139 | return self.image_descriptions[image_id]._asdict()
140 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/annotated_objects_open_images.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from csv import DictReader, reader as TupleReader
3 | from pathlib import Path
4 | from typing import Dict, List, Any
5 | import warnings
6 |
7 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
8 | from taming.data.helper_types import Annotation, Category
9 | from tqdm import tqdm
10 |
11 | OPEN_IMAGES_STRUCTURE = {
12 | 'train': {
13 | 'top_level': '',
14 | 'class_descriptions': 'class-descriptions-boxable.csv',
15 | 'annotations': 'oidv6-train-annotations-bbox.csv',
16 | 'file_list': 'train-images-boxable.csv',
17 | 'files': 'train'
18 | },
19 | 'validation': {
20 | 'top_level': '',
21 | 'class_descriptions': 'class-descriptions-boxable.csv',
22 | 'annotations': 'validation-annotations-bbox.csv',
23 | 'file_list': 'validation-images.csv',
24 | 'files': 'validation'
25 | },
26 | 'test': {
27 | 'top_level': '',
28 | 'class_descriptions': 'class-descriptions-boxable.csv',
29 | 'annotations': 'test-annotations-bbox.csv',
30 | 'file_list': 'test-images.csv',
31 | 'files': 'test'
32 | }
33 | }
34 |
35 |
36 | def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
37 | category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
38 | annotations: Dict[str, List[Annotation]] = defaultdict(list)
39 | with open(descriptor_path) as file:
40 | reader = DictReader(file)
41 | for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
42 | width = float(row['XMax']) - float(row['XMin'])
43 | height = float(row['YMax']) - float(row['YMin'])
44 | area = width * height
45 | category_id = row['LabelName']
46 | if category_id in category_mapping:
47 | category_id = category_mapping[category_id]
48 | if area >= min_object_area and category_id in category_no_for_id:
49 | annotations[row['ImageID']].append(
50 | Annotation(
51 | id=i,
52 | image_id=row['ImageID'],
53 | source=row['Source'],
54 | category_id=category_id,
55 | category_no=category_no_for_id[category_id],
56 | confidence=float(row['Confidence']),
57 | bbox=(float(row['XMin']), float(row['YMin']), width, height),
58 | area=area,
59 | is_occluded=bool(int(row['IsOccluded'])),
60 | is_truncated=bool(int(row['IsTruncated'])),
61 | is_group_of=bool(int(row['IsGroupOf'])),
62 | is_depiction=bool(int(row['IsDepiction'])),
63 | is_inside=bool(int(row['IsInside']))
64 | )
65 | )
66 | if 'train' in str(descriptor_path) and i < 14000000:
67 | warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
68 | return dict(annotations)
69 |
70 |
71 | def load_image_ids(csv_path: Path) -> List[str]:
72 | with open(csv_path) as file:
73 | reader = DictReader(file)
74 | return [row['image_name'] for row in reader]
75 |
76 |
77 | def load_categories(csv_path: Path) -> Dict[str, Category]:
78 | with open(csv_path) as file:
79 | reader = TupleReader(file)
80 | return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
81 |
82 |
83 | class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
84 | def __init__(self, use_additional_parameters: bool, **kwargs):
85 | """
86 | @param data_path: is the path to the following folder structure:
87 | open_images/
88 | │ oidv6-train-annotations-bbox.csv
89 | ├── class-descriptions-boxable.csv
90 | ├── oidv6-train-annotations-bbox.csv
91 | ├── test
92 | │ ├── 000026e7ee790996.jpg
93 | │ ├── 000062a39995e348.jpg
94 | │ └── ...
95 | ├── test-annotations-bbox.csv
96 | ├── test-images.csv
97 | ├── train
98 | │ ├── 000002b66c9c498e.jpg
99 | │ ├── 000002b97e5471a0.jpg
100 | │ └── ...
101 | ├── train-images-boxable.csv
102 | ├── validation
103 | │ ├── 0001eeaf4aed83f9.jpg
104 | │ ├── 0004886b7d043cfd.jpg
105 | │ └── ...
106 | ├── validation-annotations-bbox.csv
107 | └── validation-images.csv
108 | @param: split: one of 'train', 'validation' or 'test'
109 | @param: desired image size (returns square images)
110 | """
111 |
112 | super().__init__(**kwargs)
113 | self.use_additional_parameters = use_additional_parameters
114 |
115 | self.categories = load_categories(self.paths['class_descriptions'])
116 | self.filter_categories()
117 | self.setup_category_id_and_number()
118 |
119 | self.image_descriptions = {}
120 | annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
121 | self.category_number)
122 | self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
123 | self.max_objects_per_image)
124 | self.image_ids = list(self.annotations.keys())
125 | self.clean_up_annotations_and_image_descriptions()
126 |
127 | def get_path_structure(self) -> Dict[str, str]:
128 | if self.split not in OPEN_IMAGES_STRUCTURE:
129 | raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
130 | return OPEN_IMAGES_STRUCTURE[self.split]
131 |
132 | def get_image_path(self, image_id: str) -> Path:
133 | return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
134 |
135 | def get_image_description(self, image_id: str) -> Dict[str, Any]:
136 | image_path = self.get_image_path(image_id)
137 | return {'file_path': str(image_path), 'file_name': image_path.name}
138 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/base.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import numpy as np
3 | import albumentations
4 | from PIL import Image
5 | from torch.utils.data import Dataset, ConcatDataset
6 |
7 |
8 | class ConcatDatasetWithIndex(ConcatDataset):
9 | """Modified from original pytorch code to return dataset idx"""
10 | def __getitem__(self, idx):
11 | if idx < 0:
12 | if -idx > len(self):
13 | raise ValueError("absolute value of index should not exceed dataset length")
14 | idx = len(self) + idx
15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
16 | if dataset_idx == 0:
17 | sample_idx = idx
18 | else:
19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
20 | return self.datasets[dataset_idx][sample_idx], dataset_idx
21 |
22 |
23 | class ImagePaths(Dataset):
24 | def __init__(self, paths, size=None, random_crop=False, labels=None):
25 | self.size = size
26 | self.random_crop = random_crop
27 |
28 | self.labels = dict() if labels is None else labels
29 | self.labels["file_path_"] = paths
30 | self._length = len(paths)
31 |
32 | if self.size is not None and self.size > 0:
33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
34 | if not self.random_crop:
35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
36 | else:
37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
39 | else:
40 | self.preprocessor = lambda **kwargs: kwargs
41 |
42 | def __len__(self):
43 | return self._length
44 |
45 | def preprocess_image(self, image_path):
46 | image = Image.open(image_path)
47 | if not image.mode == "RGB":
48 | image = image.convert("RGB")
49 | image = np.array(image).astype(np.uint8)
50 | image = self.preprocessor(image=image)["image"]
51 | image = (image/127.5 - 1.0).astype(np.float32)
52 | return image
53 |
54 | def __getitem__(self, i):
55 | example = dict()
56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i])
57 | for k in self.labels:
58 | example[k] = self.labels[k][i]
59 | return example
60 |
61 |
62 | class NumpyPaths(ImagePaths):
63 | def preprocess_image(self, image_path):
64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
65 | image = np.transpose(image, (1,2,0))
66 | image = Image.fromarray(image, mode="RGB")
67 | image = np.array(image).astype(np.uint8)
68 | image = self.preprocessor(image=image)["image"]
69 | image = (image/127.5 - 1.0).astype(np.float32)
70 | return image
71 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/conditional_builder/objects_bbox.py:
--------------------------------------------------------------------------------
1 | from itertools import cycle
2 | from typing import List, Tuple, Callable, Optional
3 |
4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
5 | from more_itertools.recipes import grouper
6 | from taming.data.image_transforms import convert_pil_to_tensor
7 | from torch import LongTensor, Tensor
8 |
9 | from taming.data.helper_types import BoundingBox, Annotation
10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
12 | pad_list, get_plot_font_size, absolute_bbox
13 |
14 |
15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
16 | @property
17 | def object_descriptor_length(self) -> int:
18 | return 3
19 |
20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
21 | object_triples = [
22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
23 | for ann in annotations
24 | ]
25 | empty_triple = (self.none, self.none, self.none)
26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
27 | return object_triples
28 |
29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
30 | conditional_list = conditional.tolist()
31 | crop_coordinates = None
32 | if self.encode_crop:
33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
34 | conditional_list = conditional_list[:-2]
35 | object_triples = grouper(conditional_list, 3)
36 | assert conditional.shape[0] == self.embedding_dim
37 | return [
38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
39 | for object_triple in object_triples if object_triple[0] != self.none
40 | ], crop_coordinates
41 |
42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
44 | plot = pil_image.new('RGB', figure_size, WHITE)
45 | draw = pil_img_draw.Draw(plot)
46 | font = ImageFont.truetype(
47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
48 | size=get_plot_font_size(font_size, figure_size)
49 | )
50 | width, height = plot.size
51 | description, crop_coordinates = self.inverse_build(conditional)
52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
53 | annotation = self.representation_to_annotation(representation)
54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
55 | bbox = absolute_bbox(bbox, width, height)
56 | draw.rectangle(bbox, outline=color, width=line_width)
57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
58 | if crop_coordinates is not None:
59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
60 | return convert_pil_to_tensor(plot) / 127.5 - 1.
61 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/conditional_builder/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from typing import List, Any, Tuple, Optional
3 |
4 | from taming.data.helper_types import BoundingBox, Annotation
5 |
6 | # source: seaborn, color palette tab10
7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
9 | BLACK = (0, 0, 0)
10 | GRAY_75 = (63, 63, 63)
11 | GRAY_50 = (127, 127, 127)
12 | GRAY_25 = (191, 191, 191)
13 | WHITE = (255, 255, 255)
14 | FULL_CROP = (0., 0., 1., 1.)
15 |
16 |
17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
18 | """
19 | Give intersection area of two rectangles.
20 | @param rectangle1: (x0, y0, w, h) of first rectangle
21 | @param rectangle2: (x0, y0, w, h) of second rectangle
22 | """
23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
27 | return x_overlap * y_overlap
28 |
29 |
30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
32 |
33 |
34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
35 | bbox = relative_bbox
36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
38 |
39 |
40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
42 |
43 |
44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
45 | List[Annotation]:
46 | def clamp(x: float):
47 | return max(min(x, 1.), 0.)
48 |
49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0)
53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0)
54 | if flip:
55 | x0 = 1 - (x0 + w)
56 | return x0, y0, w, h
57 |
58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
59 |
60 |
61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
63 |
64 |
65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
66 | sl = slice(1) if short else slice(None)
67 | string = ''
68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
69 | return string
70 | if annotation.is_group_of:
71 | string += 'group'[sl] + ','
72 | if annotation.is_occluded:
73 | string += 'occluded'[sl] + ','
74 | if annotation.is_depiction:
75 | string += 'depiction'[sl] + ','
76 | if annotation.is_inside:
77 | string += 'inside'[sl]
78 | return '(' + string.strip(",") + ')'
79 |
80 |
81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
82 | if font_size is None:
83 | font_size = 10
84 | if max(figure_size) >= 256:
85 | font_size = 12
86 | if max(figure_size) >= 512:
87 | font_size = 15
88 | return font_size
89 |
90 |
91 | def get_circle_size(figure_size: Tuple[int, int]) -> int:
92 | circle_size = 2
93 | if max(figure_size) >= 256:
94 | circle_size = 3
95 | if max(figure_size) >= 512:
96 | circle_size = 4
97 | return circle_size
98 |
99 |
100 | def load_object_from_string(object_string: str) -> Any:
101 | """
102 | Source: https://stackoverflow.com/a/10773699
103 | """
104 | module_name, class_name = object_string.rsplit(".", 1)
105 | return getattr(importlib.import_module(module_name), class_name)
106 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/custom.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import albumentations
4 | from torch.utils.data import Dataset
5 |
6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7 |
8 |
9 | class CustomBase(Dataset):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__()
12 | self.data = None
13 |
14 | def __len__(self):
15 | return len(self.data)
16 |
17 | def __getitem__(self, i):
18 | example = self.data[i]
19 | return example
20 |
21 |
22 |
23 | class CustomTrain(CustomBase):
24 | def __init__(self, size, training_images_list_file):
25 | super().__init__()
26 | with open(training_images_list_file, "r") as f:
27 | paths = f.read().splitlines()
28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
29 |
30 |
31 | class CustomTest(CustomBase):
32 | def __init__(self, size, test_images_list_file):
33 | super().__init__()
34 | with open(test_images_list_file, "r") as f:
35 | paths = f.read().splitlines()
36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
37 |
38 |
39 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/faceshq.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import albumentations
4 | from torch.utils.data import Dataset
5 |
6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7 |
8 |
9 | class FacesBase(Dataset):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__()
12 | self.data = None
13 | self.keys = None
14 |
15 | def __len__(self):
16 | return len(self.data)
17 |
18 | def __getitem__(self, i):
19 | example = self.data[i]
20 | ex = {}
21 | if self.keys is not None:
22 | for k in self.keys:
23 | ex[k] = example[k]
24 | else:
25 | ex = example
26 | return ex
27 |
28 |
29 | class CelebAHQTrain(FacesBase):
30 | def __init__(self, size, keys=None):
31 | super().__init__()
32 | root = "data/celebahq"
33 | with open("data/celebahqtrain.txt", "r") as f:
34 | relpaths = f.read().splitlines()
35 | paths = [os.path.join(root, relpath) for relpath in relpaths]
36 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
37 | self.keys = keys
38 |
39 |
40 | class CelebAHQValidation(FacesBase):
41 | def __init__(self, size, keys=None):
42 | super().__init__()
43 | root = "data/celebahq"
44 | with open("data/celebahqvalidation.txt", "r") as f:
45 | relpaths = f.read().splitlines()
46 | paths = [os.path.join(root, relpath) for relpath in relpaths]
47 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
48 | self.keys = keys
49 |
50 |
51 | class FFHQTrain(FacesBase):
52 | def __init__(self, size, keys=None):
53 | super().__init__()
54 | root = "data/ffhq"
55 | with open("data/ffhqtrain.txt", "r") as f:
56 | relpaths = f.read().splitlines()
57 | paths = [os.path.join(root, relpath) for relpath in relpaths]
58 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
59 | self.keys = keys
60 |
61 |
62 | class FFHQValidation(FacesBase):
63 | def __init__(self, size, keys=None):
64 | super().__init__()
65 | root = "data/ffhq"
66 | with open("data/ffhqvalidation.txt", "r") as f:
67 | relpaths = f.read().splitlines()
68 | paths = [os.path.join(root, relpath) for relpath in relpaths]
69 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
70 | self.keys = keys
71 |
72 |
73 | class FacesHQTrain(Dataset):
74 | # CelebAHQ [0] + FFHQ [1]
75 | def __init__(self, size, keys=None, crop_size=None, coord=False):
76 | d1 = CelebAHQTrain(size=size, keys=keys)
77 | d2 = FFHQTrain(size=size, keys=keys)
78 | self.data = ConcatDatasetWithIndex([d1, d2])
79 | self.coord = coord
80 | if crop_size is not None:
81 | self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
82 | if self.coord:
83 | self.cropper = albumentations.Compose([self.cropper],
84 | additional_targets={"coord": "image"})
85 |
86 | def __len__(self):
87 | return len(self.data)
88 |
89 | def __getitem__(self, i):
90 | ex, y = self.data[i]
91 | if hasattr(self, "cropper"):
92 | if not self.coord:
93 | out = self.cropper(image=ex["image"])
94 | ex["image"] = out["image"]
95 | else:
96 | h,w,_ = ex["image"].shape
97 | coord = np.arange(h*w).reshape(h,w,1)/(h*w)
98 | out = self.cropper(image=ex["image"], coord=coord)
99 | ex["image"] = out["image"]
100 | ex["coord"] = out["coord"]
101 | ex["class"] = y
102 | return ex
103 |
104 |
105 | class FacesHQValidation(Dataset):
106 | # CelebAHQ [0] + FFHQ [1]
107 | def __init__(self, size, keys=None, crop_size=None, coord=False):
108 | d1 = CelebAHQValidation(size=size, keys=keys)
109 | d2 = FFHQValidation(size=size, keys=keys)
110 | self.data = ConcatDatasetWithIndex([d1, d2])
111 | self.coord = coord
112 | if crop_size is not None:
113 | self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
114 | if self.coord:
115 | self.cropper = albumentations.Compose([self.cropper],
116 | additional_targets={"coord": "image"})
117 |
118 | def __len__(self):
119 | return len(self.data)
120 |
121 | def __getitem__(self, i):
122 | ex, y = self.data[i]
123 | if hasattr(self, "cropper"):
124 | if not self.coord:
125 | out = self.cropper(image=ex["image"])
126 | ex["image"] = out["image"]
127 | else:
128 | h,w,_ = ex["image"].shape
129 | coord = np.arange(h*w).reshape(h,w,1)/(h*w)
130 | out = self.cropper(image=ex["image"], coord=coord)
131 | ex["image"] = out["image"]
132 | ex["coord"] = out["coord"]
133 | ex["class"] = y
134 | return ex
135 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/helper_types.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple, Optional, NamedTuple, Union
2 | from PIL.Image import Image as pil_image
3 | from torch import Tensor
4 |
5 | try:
6 | from typing import Literal
7 | except ImportError:
8 | from typing_extensions import Literal
9 |
10 | Image = Union[Tensor, pil_image]
11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d']
13 | SplitType = Literal['train', 'validation', 'test']
14 |
15 |
16 | class ImageDescription(NamedTuple):
17 | id: int
18 | file_name: str
19 | original_size: Tuple[int, int] # w, h
20 | url: Optional[str] = None
21 | license: Optional[int] = None
22 | coco_url: Optional[str] = None
23 | date_captured: Optional[str] = None
24 | flickr_url: Optional[str] = None
25 | flickr_id: Optional[str] = None
26 | coco_id: Optional[str] = None
27 |
28 |
29 | class Category(NamedTuple):
30 | id: str
31 | super_category: Optional[str]
32 | name: str
33 |
34 |
35 | class Annotation(NamedTuple):
36 | area: float
37 | image_id: str
38 | bbox: BoundingBox
39 | category_no: int
40 | category_id: str
41 | id: Optional[int] = None
42 | source: Optional[str] = None
43 | confidence: Optional[float] = None
44 | is_group_of: Optional[bool] = None
45 | is_truncated: Optional[bool] = None
46 | is_occluded: Optional[bool] = None
47 | is_depiction: Optional[bool] = None
48 | is_inside: Optional[bool] = None
49 | segmentation: Optional[Dict] = None
50 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/image_transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import warnings
3 | from typing import Union
4 |
5 | import torch
6 | from torch import Tensor
7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
8 | from torchvision.transforms.functional import _get_image_size as get_image_size
9 |
10 | from taming.data.helper_types import BoundingBox, Image
11 |
12 | pil_to_tensor = PILToTensor()
13 |
14 |
15 | def convert_pil_to_tensor(image: Image) -> Tensor:
16 | with warnings.catch_warnings():
17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
18 | warnings.simplefilter("ignore")
19 | return pil_to_tensor(image)
20 |
21 |
22 | class RandomCrop1dReturnCoordinates(RandomCrop):
23 | def forward(self, img: Image) -> (BoundingBox, Image):
24 | """
25 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
26 | Args:
27 | img (PIL Image or Tensor): Image to be cropped.
28 |
29 | Returns:
30 | Bounding box: x0, y0, w, h
31 | PIL Image or Tensor: Cropped image.
32 |
33 | Based on:
34 | torchvision.transforms.RandomCrop, torchvision 1.7.0
35 | """
36 | if self.padding is not None:
37 | img = F.pad(img, self.padding, self.fill, self.padding_mode)
38 |
39 | width, height = get_image_size(img)
40 | # pad the width if needed
41 | if self.pad_if_needed and width < self.size[1]:
42 | padding = [self.size[1] - width, 0]
43 | img = F.pad(img, padding, self.fill, self.padding_mode)
44 | # pad the height if needed
45 | if self.pad_if_needed and height < self.size[0]:
46 | padding = [0, self.size[0] - height]
47 | img = F.pad(img, padding, self.fill, self.padding_mode)
48 |
49 | i, j, h, w = self.get_params(img, self.size)
50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
51 | return bbox, F.crop(img, i, j, h, w)
52 |
53 |
54 | class Random2dCropReturnCoordinates(torch.nn.Module):
55 | """
56 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
57 | Args:
58 | img (PIL Image or Tensor): Image to be cropped.
59 |
60 | Returns:
61 | Bounding box: x0, y0, w, h
62 | PIL Image or Tensor: Cropped image.
63 |
64 | Based on:
65 | torchvision.transforms.RandomCrop, torchvision 1.7.0
66 | """
67 |
68 | def __init__(self, min_size: int):
69 | super().__init__()
70 | self.min_size = min_size
71 |
72 | def forward(self, img: Image) -> (BoundingBox, Image):
73 | width, height = get_image_size(img)
74 | max_size = min(width, height)
75 | if max_size <= self.min_size:
76 | size = max_size
77 | else:
78 | size = random.randint(self.min_size, max_size)
79 | top = random.randint(0, height - size)
80 | left = random.randint(0, width - size)
81 | bbox = left / width, top / height, size / width, size / height
82 | return bbox, F.crop(img, top, left, size, size)
83 |
84 |
85 | class CenterCropReturnCoordinates(CenterCrop):
86 | @staticmethod
87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
88 | if width > height:
89 | w = height / width
90 | h = 1.0
91 | x0 = 0.5 - w / 2
92 | y0 = 0.
93 | else:
94 | w = 1.0
95 | h = width / height
96 | x0 = 0.
97 | y0 = 0.5 - h / 2
98 | return x0, y0, w, h
99 |
100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
101 | """
102 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
103 | Args:
104 | img (PIL Image or Tensor): Image to be cropped.
105 |
106 | Returns:
107 | Bounding box: x0, y0, w, h
108 | PIL Image or Tensor: Cropped image.
109 | Based on:
110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
111 | """
112 | width, height = get_image_size(img)
113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
114 |
115 |
116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip):
117 | def forward(self, img: Image) -> (bool, Image):
118 | """
119 | Additionally to flipping, returns a boolean whether it was flipped or not.
120 | Args:
121 | img (PIL Image or Tensor): Image to be flipped.
122 |
123 | Returns:
124 | flipped: whether the image was flipped or not
125 | PIL Image or Tensor: Randomly flipped image.
126 |
127 | Based on:
128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
129 | """
130 | if torch.rand(1) < self.p:
131 | return True, F.hflip(img)
132 | return False, img
133 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/sflckr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import albumentations
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class SegmentationBase(Dataset):
10 | def __init__(self,
11 | data_csv, data_root, segmentation_root,
12 | size=None, random_crop=False, interpolation="bicubic",
13 | n_labels=182, shift_segmentation=False,
14 | ):
15 | self.n_labels = n_labels
16 | self.shift_segmentation = shift_segmentation
17 | self.data_csv = data_csv
18 | self.data_root = data_root
19 | self.segmentation_root = segmentation_root
20 | with open(self.data_csv, "r") as f:
21 | self.image_paths = f.read().splitlines()
22 | self._length = len(self.image_paths)
23 | self.labels = {
24 | "relative_file_path_": [l for l in self.image_paths],
25 | "file_path_": [os.path.join(self.data_root, l)
26 | for l in self.image_paths],
27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
28 | for l in self.image_paths]
29 | }
30 |
31 | size = None if size is not None and size<=0 else size
32 | self.size = size
33 | if self.size is not None:
34 | self.interpolation = interpolation
35 | self.interpolation = {
36 | "nearest": cv2.INTER_NEAREST,
37 | "bilinear": cv2.INTER_LINEAR,
38 | "bicubic": cv2.INTER_CUBIC,
39 | "area": cv2.INTER_AREA,
40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
42 | interpolation=self.interpolation)
43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
44 | interpolation=cv2.INTER_NEAREST)
45 | self.center_crop = not random_crop
46 | if self.center_crop:
47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
48 | else:
49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
50 | self.preprocessor = self.cropper
51 |
52 | def __len__(self):
53 | return self._length
54 |
55 | def __getitem__(self, i):
56 | example = dict((k, self.labels[k][i]) for k in self.labels)
57 | image = Image.open(example["file_path_"])
58 | if not image.mode == "RGB":
59 | image = image.convert("RGB")
60 | image = np.array(image).astype(np.uint8)
61 | if self.size is not None:
62 | image = self.image_rescaler(image=image)["image"]
63 | segmentation = Image.open(example["segmentation_path_"])
64 | assert segmentation.mode == "L", segmentation.mode
65 | segmentation = np.array(segmentation).astype(np.uint8)
66 | if self.shift_segmentation:
67 | # used to support segmentations containing unlabeled==255 label
68 | segmentation = segmentation+1
69 | if self.size is not None:
70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"]
71 | if self.size is not None:
72 | processed = self.preprocessor(image=image,
73 | mask=segmentation
74 | )
75 | else:
76 | processed = {"image": image,
77 | "mask": segmentation
78 | }
79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
80 | segmentation = processed["mask"]
81 | onehot = np.eye(self.n_labels)[segmentation]
82 | example["segmentation"] = onehot
83 | return example
84 |
85 |
86 | class Examples(SegmentationBase):
87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
88 | super().__init__(data_csv="data/sflckr_examples.txt",
89 | data_root="data/sflckr_images",
90 | segmentation_root="data/sflckr_segmentations",
91 | size=size, random_crop=random_crop, interpolation=interpolation)
92 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/data/utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import tarfile
4 | import urllib
5 | import zipfile
6 | from pathlib import Path
7 |
8 | import numpy as np
9 | import torch
10 | from taming.data.helper_types import Annotation
11 | from torch._six import string_classes
12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
13 | from tqdm import tqdm
14 |
15 |
16 | def unpack(path):
17 | if path.endswith("tar.gz"):
18 | with tarfile.open(path, "r:gz") as tar:
19 | tar.extractall(path=os.path.split(path)[0])
20 | elif path.endswith("tar"):
21 | with tarfile.open(path, "r:") as tar:
22 | tar.extractall(path=os.path.split(path)[0])
23 | elif path.endswith("zip"):
24 | with zipfile.ZipFile(path, "r") as f:
25 | f.extractall(path=os.path.split(path)[0])
26 | else:
27 | raise NotImplementedError(
28 | "Unknown file extension: {}".format(os.path.splitext(path)[1])
29 | )
30 |
31 |
32 | def reporthook(bar):
33 | """tqdm progress bar for downloads."""
34 |
35 | def hook(b=1, bsize=1, tsize=None):
36 | if tsize is not None:
37 | bar.total = tsize
38 | bar.update(b * bsize - bar.n)
39 |
40 | return hook
41 |
42 |
43 | def get_root(name):
44 | base = "data/"
45 | root = os.path.join(base, name)
46 | os.makedirs(root, exist_ok=True)
47 | return root
48 |
49 |
50 | def is_prepared(root):
51 | return Path(root).joinpath(".ready").exists()
52 |
53 |
54 | def mark_prepared(root):
55 | Path(root).joinpath(".ready").touch()
56 |
57 |
58 | def prompt_download(file_, source, target_dir, content_dir=None):
59 | targetpath = os.path.join(target_dir, file_)
60 | while not os.path.exists(targetpath):
61 | if content_dir is not None and os.path.exists(
62 | os.path.join(target_dir, content_dir)
63 | ):
64 | break
65 | print(
66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
67 | )
68 | if content_dir is not None:
69 | print(
70 | "Or place its content into '{}'.".format(
71 | os.path.join(target_dir, content_dir)
72 | )
73 | )
74 | input("Press Enter when done...")
75 | return targetpath
76 |
77 |
78 | def download_url(file_, url, target_dir):
79 | targetpath = os.path.join(target_dir, file_)
80 | os.makedirs(target_dir, exist_ok=True)
81 | with tqdm(
82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
83 | ) as bar:
84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
85 | return targetpath
86 |
87 |
88 | def download_urls(urls, target_dir):
89 | paths = dict()
90 | for fname, url in urls.items():
91 | outpath = download_url(fname, url, target_dir)
92 | paths[fname] = outpath
93 | return paths
94 |
95 |
96 | def quadratic_crop(x, bbox, alpha=1.0):
97 | """bbox is xmin, ymin, xmax, ymax"""
98 | im_h, im_w = x.shape[:2]
99 | bbox = np.array(bbox, dtype=np.float32)
100 | bbox = np.clip(bbox, 0, max(im_h, im_w))
101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
102 | w = bbox[2] - bbox[0]
103 | h = bbox[3] - bbox[1]
104 | l = int(alpha * max(w, h))
105 | l = max(l, 2)
106 |
107 | required_padding = -1 * min(
108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
109 | )
110 | required_padding = int(np.ceil(required_padding))
111 | if required_padding > 0:
112 | padding = [
113 | [required_padding, required_padding],
114 | [required_padding, required_padding],
115 | ]
116 | padding += [[0, 0]] * (len(x.shape) - 2)
117 | x = np.pad(x, padding, "reflect")
118 | center = center[0] + required_padding, center[1] + required_padding
119 | xmin = int(center[0] - l / 2)
120 | ymin = int(center[1] - l / 2)
121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
122 |
123 |
124 | def custom_collate(batch):
125 | r"""source: pytorch 1.9.0, only one modification to original code """
126 |
127 | elem = batch[0]
128 | elem_type = type(elem)
129 | if isinstance(elem, torch.Tensor):
130 | out = None
131 | if torch.utils.data.get_worker_info() is not None:
132 | # If we're in a background process, concatenate directly into a
133 | # shared memory tensor to avoid an extra copy
134 | numel = sum([x.numel() for x in batch])
135 | storage = elem.storage()._new_shared(numel)
136 | out = elem.new(storage)
137 | return torch.stack(batch, 0, out=out)
138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
139 | and elem_type.__name__ != 'string_':
140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
141 | # array of string classes and object
142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype))
144 |
145 | return custom_collate([torch.as_tensor(b) for b in batch])
146 | elif elem.shape == (): # scalars
147 | return torch.as_tensor(batch)
148 | elif isinstance(elem, float):
149 | return torch.tensor(batch, dtype=torch.float64)
150 | elif isinstance(elem, int):
151 | return torch.tensor(batch)
152 | elif isinstance(elem, string_classes):
153 | return batch
154 | elif isinstance(elem, collections.abc.Mapping):
155 | return {key: custom_collate([d[key] for d in batch]) for key in elem}
156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
159 | return batch # added
160 | elif isinstance(elem, collections.abc.Sequence):
161 | # check to make sure that the elements in batch have consistent size
162 | it = iter(batch)
163 | elem_size = len(next(it))
164 | if not all(len(elem) == elem_size for elem in it):
165 | raise RuntimeError('each element in list of batch should be of equal size')
166 | transposed = zip(*batch)
167 | return [custom_collate(samples) for samples in transposed]
168 |
169 | raise TypeError(default_collate_err_msg_format.format(elem_type))
170 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n):
33 | return self.schedule(n)
34 |
35 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/models/dummy_cond_stage.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 |
4 | class DummyCondStage:
5 | def __init__(self, conditional_key):
6 | self.conditional_key = conditional_key
7 | self.train = None
8 |
9 | def eval(self):
10 | return self
11 |
12 | @staticmethod
13 | def encode(c: Tensor):
14 | return c, None, (None, None, c)
15 |
16 | @staticmethod
17 | def decode(c: Tensor):
18 | return c
19 |
20 | @staticmethod
21 | def to_rgb(c: Tensor):
22 | return c
23 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/modules/discriminator/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 |
4 |
5 | from taming.modules.util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find('Conv') != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find('BatchNorm') != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class NLayerDiscriminator(nn.Module):
18 | """Defines a PatchGAN discriminator as in Pix2Pix
19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20 | """
21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22 | """Construct a PatchGAN discriminator
23 | Parameters:
24 | input_nc (int) -- the number of channels in input images
25 | ndf (int) -- the number of filters in the last conv layer
26 | n_layers (int) -- the number of conv layers in the discriminator
27 | norm_layer -- normalization layer
28 | """
29 | super(NLayerDiscriminator, self).__init__()
30 | if not use_actnorm:
31 | norm_layer = nn.BatchNorm2d
32 | else:
33 | norm_layer = ActNorm
34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35 | use_bias = norm_layer.func != nn.BatchNorm2d
36 | else:
37 | use_bias = norm_layer != nn.BatchNorm2d
38 |
39 | kw = 4
40 | padw = 1
41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42 | nf_mult = 1
43 | nf_mult_prev = 1
44 | for n in range(1, n_layers): # gradually increase the number of filters
45 | nf_mult_prev = nf_mult
46 | nf_mult = min(2 ** n, 8)
47 | sequence += [
48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49 | norm_layer(ndf * nf_mult),
50 | nn.LeakyReLU(0.2, True)
51 | ]
52 |
53 | nf_mult_prev = nf_mult
54 | nf_mult = min(2 ** n_layers, 8)
55 | sequence += [
56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57 | norm_layer(ndf * nf_mult),
58 | nn.LeakyReLU(0.2, True)
59 | ]
60 |
61 | sequence += [
62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63 | self.main = nn.Sequential(*sequence)
64 |
65 | def forward(self, input):
66 | """Standard forward."""
67 | return self.main(input)
68 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from taming.modules.losses.vqperceptual import DummyLoss
2 |
3 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/modules/losses/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from collections import namedtuple
7 |
8 | from taming.util import get_ckpt_path
9 |
10 |
11 | class LPIPS(nn.Module):
12 | # Learned perceptual metric
13 | def __init__(self, use_dropout=True):
14 | super().__init__()
15 | self.scaling_layer = ScalingLayer()
16 | self.chns = [64, 128, 256, 512, 512] # vg16 features
17 | self.net = vgg16(pretrained=True, requires_grad=False)
18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23 | self.load_from_pretrained()
24 | for param in self.parameters():
25 | param.requires_grad = False
26 |
27 | def load_from_pretrained(self, name="vgg_lpips"):
28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
31 |
32 | @classmethod
33 | def from_pretrained(cls, name="vgg_lpips"):
34 | if name != "vgg_lpips":
35 | raise NotImplementedError
36 | model = cls()
37 | ckpt = get_ckpt_path(name)
38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39 | return model
40 |
41 | def forward(self, input, target):
42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
44 | feats0, feats1, diffs = {}, {}, {}
45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46 | for kk in range(len(self.chns)):
47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49 |
50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51 | val = res[0]
52 | for l in range(1, len(self.chns)):
53 | val += res[l]
54 | return val
55 |
56 |
57 | class ScalingLayer(nn.Module):
58 | def __init__(self):
59 | super(ScalingLayer, self).__init__()
60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62 |
63 | def forward(self, inp):
64 | return (inp - self.shift) / self.scale
65 |
66 |
67 | class NetLinLayer(nn.Module):
68 | """ A single linear layer which does a 1x1 conv """
69 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
70 | super(NetLinLayer, self).__init__()
71 | layers = [nn.Dropout(), ] if (use_dropout) else []
72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73 | self.model = nn.Sequential(*layers)
74 |
75 |
76 | class vgg16(torch.nn.Module):
77 | def __init__(self, requires_grad=False, pretrained=True):
78 | super(vgg16, self).__init__()
79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80 | self.slice1 = torch.nn.Sequential()
81 | self.slice2 = torch.nn.Sequential()
82 | self.slice3 = torch.nn.Sequential()
83 | self.slice4 = torch.nn.Sequential()
84 | self.slice5 = torch.nn.Sequential()
85 | self.N_slices = 5
86 | for x in range(4):
87 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
88 | for x in range(4, 9):
89 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
90 | for x in range(9, 16):
91 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
92 | for x in range(16, 23):
93 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
94 | for x in range(23, 30):
95 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
96 | if not requires_grad:
97 | for param in self.parameters():
98 | param.requires_grad = False
99 |
100 | def forward(self, X):
101 | h = self.slice1(X)
102 | h_relu1_2 = h
103 | h = self.slice2(h)
104 | h_relu2_2 = h
105 | h = self.slice3(h)
106 | h_relu3_3 = h
107 | h = self.slice4(h)
108 | h_relu4_3 = h
109 | h = self.slice5(h)
110 | h_relu5_3 = h
111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113 | return out
114 |
115 |
116 | def normalize_tensor(x,eps=1e-10):
117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118 | return x/(norm_factor+eps)
119 |
120 |
121 | def spatial_average(x, keepdim=True):
122 | return x.mean([2,3],keepdim=keepdim)
123 |
124 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/modules/losses/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class BCELoss(nn.Module):
6 | def forward(self, prediction, target):
7 | loss = F.binary_cross_entropy_with_logits(prediction,target)
8 | return loss, {}
9 |
10 |
11 | class BCELossWithQuant(nn.Module):
12 | def __init__(self, codebook_weight=1.):
13 | super().__init__()
14 | self.codebook_weight = codebook_weight
15 |
16 | def forward(self, qloss, target, prediction, split):
17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18 | loss = bce_loss + self.codebook_weight*qloss
19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20 | "{}/bce_loss".format(split): bce_loss.detach().mean(),
21 | "{}/quant_loss".format(split): qloss.detach().mean()
22 | }
23 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from taming.modules.losses.lpips import LPIPS
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 |
8 |
9 | class DummyLoss(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 |
14 | def adopt_weight(weight, global_step, threshold=0, value=0.):
15 | if global_step < threshold:
16 | weight = value
17 | return weight
18 |
19 |
20 | def hinge_d_loss(logits_real, logits_fake):
21 | loss_real = torch.mean(F.relu(1. - logits_real))
22 | loss_fake = torch.mean(F.relu(1. + logits_fake))
23 | d_loss = 0.5 * (loss_real + loss_fake)
24 | return d_loss
25 |
26 |
27 | def vanilla_d_loss(logits_real, logits_fake):
28 | d_loss = 0.5 * (
29 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
30 | torch.mean(torch.nn.functional.softplus(logits_fake)))
31 | return d_loss
32 |
33 |
34 | class VQLPIPSWithDiscriminator(nn.Module):
35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
38 | disc_ndf=64, disc_loss="hinge"):
39 | super().__init__()
40 | assert disc_loss in ["hinge", "vanilla"]
41 | self.codebook_weight = codebook_weight
42 | self.pixel_weight = pixelloss_weight
43 | self.perceptual_loss = LPIPS().eval()
44 | self.perceptual_weight = perceptual_weight
45 |
46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
47 | n_layers=disc_num_layers,
48 | use_actnorm=use_actnorm,
49 | ndf=disc_ndf
50 | ).apply(weights_init)
51 | self.discriminator_iter_start = disc_start
52 | if disc_loss == "hinge":
53 | self.disc_loss = hinge_d_loss
54 | elif disc_loss == "vanilla":
55 | self.disc_loss = vanilla_d_loss
56 | else:
57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
59 | self.disc_factor = disc_factor
60 | self.discriminator_weight = disc_weight
61 | self.disc_conditional = disc_conditional
62 |
63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
64 | if last_layer is not None:
65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
67 | else:
68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
70 |
71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
73 | d_weight = d_weight * self.discriminator_weight
74 | return d_weight
75 |
76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
77 | global_step, last_layer=None, cond=None, split="train"):
78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
79 | if self.perceptual_weight > 0:
80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
81 | rec_loss = rec_loss + self.perceptual_weight * p_loss
82 | else:
83 | p_loss = torch.tensor([0.0])
84 |
85 | nll_loss = rec_loss
86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
87 | nll_loss = torch.mean(nll_loss)
88 |
89 | # now the GAN part
90 | if optimizer_idx == 0:
91 | # generator update
92 | if cond is None:
93 | assert not self.disc_conditional
94 | logits_fake = self.discriminator(reconstructions.contiguous())
95 | else:
96 | assert self.disc_conditional
97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
98 | g_loss = -torch.mean(logits_fake)
99 |
100 | try:
101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
102 | except RuntimeError:
103 | assert not self.training
104 | d_weight = torch.tensor(0.0)
105 |
106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
108 |
109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
111 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
112 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
113 | "{}/p_loss".format(split): p_loss.detach().mean(),
114 | "{}/d_weight".format(split): d_weight.detach(),
115 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
116 | "{}/g_loss".format(split): g_loss.detach().mean(),
117 | }
118 | return loss, log
119 |
120 | if optimizer_idx == 1:
121 | # second pass for discriminator update
122 | if cond is None:
123 | logits_real = self.discriminator(inputs.contiguous().detach())
124 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
125 | else:
126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
128 |
129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
131 |
132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
133 | "{}/logits_real".format(split): logits_real.detach().mean(),
134 | "{}/logits_fake".format(split): logits_fake.detach().mean()
135 | }
136 | return d_loss, log
137 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/modules/misc/coord.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class CoordStage(object):
4 | def __init__(self, n_embed, down_factor):
5 | self.n_embed = n_embed
6 | self.down_factor = down_factor
7 |
8 | def eval(self):
9 | return self
10 |
11 | def encode(self, c):
12 | """fake vqmodel interface"""
13 | assert 0.0 <= c.min() and c.max() <= 1.0
14 | b,ch,h,w = c.shape
15 | assert ch == 1
16 |
17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18 | mode="area")
19 | c = c.clamp(0.0, 1.0)
20 | c = self.n_embed*c
21 | c_quant = c.round()
22 | c_ind = c_quant.to(dtype=torch.long)
23 |
24 | info = None, None, c_ind
25 | return c_quant, None, info
26 |
27 | def decode(self, c):
28 | c = c/self.n_embed
29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30 | mode="nearest")
31 | return c
32 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/modules/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def count_params(model):
6 | total_params = sum(p.numel() for p in model.parameters())
7 | return total_params
8 |
9 |
10 | class ActNorm(nn.Module):
11 | def __init__(self, num_features, logdet=False, affine=True,
12 | allow_reverse_init=False):
13 | assert affine
14 | super().__init__()
15 | self.logdet = logdet
16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18 | self.allow_reverse_init = allow_reverse_init
19 |
20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21 |
22 | def initialize(self, input):
23 | with torch.no_grad():
24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25 | mean = (
26 | flatten.mean(1)
27 | .unsqueeze(1)
28 | .unsqueeze(2)
29 | .unsqueeze(3)
30 | .permute(1, 0, 2, 3)
31 | )
32 | std = (
33 | flatten.std(1)
34 | .unsqueeze(1)
35 | .unsqueeze(2)
36 | .unsqueeze(3)
37 | .permute(1, 0, 2, 3)
38 | )
39 |
40 | self.loc.data.copy_(-mean)
41 | self.scale.data.copy_(1 / (std + 1e-6))
42 |
43 | def forward(self, input, reverse=False):
44 | if reverse:
45 | return self.reverse(input)
46 | if len(input.shape) == 2:
47 | input = input[:,:,None,None]
48 | squeeze = True
49 | else:
50 | squeeze = False
51 |
52 | _, _, height, width = input.shape
53 |
54 | if self.training and self.initialized.item() == 0:
55 | self.initialize(input)
56 | self.initialized.fill_(1)
57 |
58 | h = self.scale * (input + self.loc)
59 |
60 | if squeeze:
61 | h = h.squeeze(-1).squeeze(-1)
62 |
63 | if self.logdet:
64 | log_abs = torch.log(torch.abs(self.scale))
65 | logdet = height*width*torch.sum(log_abs)
66 | logdet = logdet * torch.ones(input.shape[0]).to(input)
67 | return h, logdet
68 |
69 | return h
70 |
71 | def reverse(self, output):
72 | if self.training and self.initialized.item() == 0:
73 | if not self.allow_reverse_init:
74 | raise RuntimeError(
75 | "Initializing ActNorm in reverse direction is "
76 | "disabled by default. Use allow_reverse_init=True to enable."
77 | )
78 | else:
79 | self.initialize(output)
80 | self.initialized.fill_(1)
81 |
82 | if len(output.shape) == 2:
83 | output = output[:,:,None,None]
84 | squeeze = True
85 | else:
86 | squeeze = False
87 |
88 | h = output / self.scale - self.loc
89 |
90 | if squeeze:
91 | h = h.squeeze(-1).squeeze(-1)
92 | return h
93 |
94 |
95 | class AbstractEncoder(nn.Module):
96 | def __init__(self):
97 | super().__init__()
98 |
99 | def encode(self, *args, **kwargs):
100 | raise NotImplementedError
101 |
102 |
103 | class Labelator(AbstractEncoder):
104 | """Net2Net Interface for Class-Conditional Model"""
105 | def __init__(self, n_classes, quantize_interface=True):
106 | super().__init__()
107 | self.n_classes = n_classes
108 | self.quantize_interface = quantize_interface
109 |
110 | def encode(self, c):
111 | c = c[:,None]
112 | if self.quantize_interface:
113 | return c, None, [None, None, c.long()]
114 | return c
115 |
116 |
117 | class SOSProvider(AbstractEncoder):
118 | # for unconditional training
119 | def __init__(self, sos_token, quantize_interface=True):
120 | super().__init__()
121 | self.sos_token = sos_token
122 | self.quantize_interface = quantize_interface
123 |
124 | def encode(self, x):
125 | # get batch size from data and replicate sos_token
126 | c = torch.ones(x.shape[0], 1)*self.sos_token
127 | c = c.long().to(x.device)
128 | if self.quantize_interface:
129 | return c, None, [None, None, c]
130 | return c
131 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming/modules/vqvae/__pycache__/quantize.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/src/taming-transformers/taming/modules/vqvae/__pycache__/quantize.cpython-38.pyc
--------------------------------------------------------------------------------
/src/taming-transformers/taming/util.py:
--------------------------------------------------------------------------------
1 | import os, hashlib
2 | import requests
3 | from tqdm import tqdm
4 |
5 | URL_MAP = {
6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7 | }
8 |
9 | CKPT_MAP = {
10 | "vgg_lpips": "vgg.pth"
11 | }
12 |
13 | MD5_MAP = {
14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15 | }
16 |
17 |
18 | def download(url, local_path, chunk_size=1024):
19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20 | with requests.get(url, stream=True) as r:
21 | total_size = int(r.headers.get("content-length", 0))
22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23 | with open(local_path, "wb") as f:
24 | for data in r.iter_content(chunk_size=chunk_size):
25 | if data:
26 | f.write(data)
27 | pbar.update(chunk_size)
28 |
29 |
30 | def md5_hash(path):
31 | with open(path, "rb") as f:
32 | content = f.read()
33 | return hashlib.md5(content).hexdigest()
34 |
35 |
36 | def get_ckpt_path(name, root, check=False):
37 | assert name in URL_MAP
38 | path = os.path.join(root, CKPT_MAP[name])
39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41 | download(URL_MAP[name], path)
42 | md5 = md5_hash(path)
43 | assert md5 == MD5_MAP[name], md5
44 | return path
45 |
46 |
47 | class KeyNotFoundError(Exception):
48 | def __init__(self, cause, keys=None, visited=None):
49 | self.cause = cause
50 | self.keys = keys
51 | self.visited = visited
52 | messages = list()
53 | if keys is not None:
54 | messages.append("Key not found: {}".format(keys))
55 | if visited is not None:
56 | messages.append("Visited: {}".format(visited))
57 | messages.append("Cause:\n{}".format(cause))
58 | message = "\n".join(messages)
59 | super().__init__(message)
60 |
61 |
62 | def retrieve(
63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64 | ):
65 | """Given a nested list or dict return the desired value at key expanding
66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67 | is done in-place.
68 |
69 | Parameters
70 | ----------
71 | list_or_dict : list or dict
72 | Possibly nested list or dictionary.
73 | key : str
74 | key/to/value, path like string describing all keys necessary to
75 | consider to get to the desired value. List indices can also be
76 | passed here.
77 | splitval : str
78 | String that defines the delimiter between keys of the
79 | different depth levels in `key`.
80 | default : obj
81 | Value returned if :attr:`key` is not found.
82 | expand : bool
83 | Whether to expand callable nodes on the path or not.
84 |
85 | Returns
86 | -------
87 | The desired value or if :attr:`default` is not ``None`` and the
88 | :attr:`key` is not found returns ``default``.
89 |
90 | Raises
91 | ------
92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93 | ``None``.
94 | """
95 |
96 | keys = key.split(splitval)
97 |
98 | success = True
99 | try:
100 | visited = []
101 | parent = None
102 | last_key = None
103 | for key in keys:
104 | if callable(list_or_dict):
105 | if not expand:
106 | raise KeyNotFoundError(
107 | ValueError(
108 | "Trying to get past callable node with expand=False."
109 | ),
110 | keys=keys,
111 | visited=visited,
112 | )
113 | list_or_dict = list_or_dict()
114 | parent[last_key] = list_or_dict
115 |
116 | last_key = key
117 | parent = list_or_dict
118 |
119 | try:
120 | if isinstance(list_or_dict, dict):
121 | list_or_dict = list_or_dict[key]
122 | else:
123 | list_or_dict = list_or_dict[int(key)]
124 | except (KeyError, IndexError, ValueError) as e:
125 | raise KeyNotFoundError(e, keys=keys, visited=visited)
126 |
127 | visited += [key]
128 | # final expansion of retrieved value
129 | if expand and callable(list_or_dict):
130 | list_or_dict = list_or_dict()
131 | parent[last_key] = list_or_dict
132 | except KeyNotFoundError as e:
133 | if default is None:
134 | raise e
135 | else:
136 | list_or_dict = default
137 | success = False
138 |
139 | if not pass_success:
140 | return list_or_dict
141 | else:
142 | return list_or_dict, success
143 |
144 |
145 | if __name__ == "__main__":
146 | config = {"keya": "a",
147 | "keyb": "b",
148 | "keyc":
149 | {"cc1": 1,
150 | "cc2": 2,
151 | }
152 | }
153 | from omegaconf import OmegaConf
154 | config = OmegaConf.create(config)
155 | print(config)
156 | retrieve(config, "keya")
157 |
158 |
--------------------------------------------------------------------------------
/src/taming-transformers/taming_transformers.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: taming-transformers
3 | Version: 0.0.1
4 | Summary: Taming Transformers for High-Resolution Image Synthesis
5 |
--------------------------------------------------------------------------------
/util/__pycache__/fastmri_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/util/__pycache__/fastmri_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/util/__pycache__/img_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/util/__pycache__/img_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/util/__pycache__/resizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/util/__pycache__/resizer.cpython-38.pyc
--------------------------------------------------------------------------------
/util/compute_metric.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from skimage.metrics import peak_signal_noise_ratio
3 | from tqdm import tqdm
4 |
5 | import matplotlib.pyplot as plt
6 | import lpips
7 | import numpy as np
8 | import torch
9 |
10 |
11 | device = 'cuda:0'
12 | loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
13 |
14 | task = 'SR'
15 | factor = 4
16 | sigma = 0.1
17 | scale = 1.0
18 |
19 |
20 | label_root = Path(f'/media/harry/tomo/FFHQ/256_1000')
21 |
22 | delta_recon_root = Path(f'./results/{task}/ffhq/{factor}/{sigma}/ps/{scale}/recon')
23 | normal_recon_root = Path(f'./results/{task}/ffhq/{factor}/{sigma}/ps+/{scale}/recon')
24 |
25 | psnr_delta_list = []
26 | psnr_normal_list = []
27 |
28 | lpips_delta_list = []
29 | lpips_normal_list = []
30 | for idx in tqdm(range(150)):
31 | fname = str(idx).zfill(5)
32 |
33 | label = plt.imread(label_root / f'{fname}.png')[:, :, :3]
34 | delta_recon = plt.imread(delta_recon_root / f'{fname}.png')[:, :, :3]
35 | normal_recon = plt.imread(normal_recon_root / f'{fname}.png')[:, :, :3]
36 |
37 | psnr_delta = peak_signal_noise_ratio(label, delta_recon)
38 | psnr_normal = peak_signal_noise_ratio(label, normal_recon)
39 |
40 | psnr_delta_list.append(psnr_delta)
41 | psnr_normal_list.append(psnr_normal)
42 |
43 | delta_recon = torch.from_numpy(delta_recon).permute(2, 0, 1).to(device)
44 | normal_recon = torch.from_numpy(normal_recon).permute(2, 0, 1).to(device)
45 | label = torch.from_numpy(label).permute(2, 0, 1).to(device)
46 |
47 | delta_recon = delta_recon.view(1, 3, 256, 256) * 2. - 1.
48 | normal_recon = normal_recon.view(1, 3, 256, 256) * 2. - 1.
49 | label = label.view(1, 3, 256, 256) * 2. - 1.
50 |
51 | delta_d = loss_fn_vgg(delta_recon, label)
52 | normal_d = loss_fn_vgg(normal_recon, label)
53 |
54 | lpips_delta_list.append(delta_d)
55 | lpips_normal_list.append(normal_d)
56 |
57 | psnr_delta_avg = sum(psnr_delta_list) / len(psnr_delta_list)
58 | lpips_delta_avg = sum(lpips_delta_list) / len(lpips_delta_list)
59 |
60 | psnr_normal_avg = sum(psnr_normal_list) / len(psnr_normal_list)
61 | lpips_normal_avg = sum(lpips_normal_list) / len(lpips_normal_list)
62 |
63 | print(f'Delta PSNR: {psnr_delta_avg}')
64 | print(f'Delta LPIPS: {lpips_delta_avg}')
65 |
66 | print(f'Normal PSNR: {psnr_normal_avg}')
67 | print(f'Normal LPIPS: {lpips_normal_avg}')
--------------------------------------------------------------------------------
/util/fastmri_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 | This source code is licensed under the MIT license found in the
4 | LICENSE file in the root directory of this source tree.
5 | """
6 |
7 | from typing import List, Optional
8 |
9 | import torch
10 | from packaging import version
11 |
12 | if version.parse(torch.__version__) >= version.parse("1.7.0"):
13 | import torch.fft # type: ignore
14 |
15 |
16 | def fft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
17 | """
18 | Apply centered 2 dimensional Fast Fourier Transform.
19 | Args:
20 | data: Complex valued input data containing at least 3 dimensions:
21 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size
22 | 2. All other dimensions are assumed to be batch dimensions.
23 | norm: Whether to include normalization. Must be one of ``"backward"``
24 | or ``"ortho"``. See ``torch.fft.fft`` on PyTorch 1.9.0 for details.
25 | Returns:
26 | The FFT of the input.
27 | """
28 | if not data.shape[-1] == 2:
29 | raise ValueError("Tensor does not have separate complex dim.")
30 | if norm not in ("ortho", "backward"):
31 | raise ValueError("norm must be 'ortho' or 'backward'.")
32 | normalized = True if norm == "ortho" else False
33 |
34 | data = ifftshift(data, dim=[-3, -2])
35 | data = torch.fft(data, 2, normalized=normalized)
36 | data = fftshift(data, dim=[-3, -2])
37 |
38 | return data
39 |
40 |
41 | def ifft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
42 | """
43 | Apply centered 2-dimensional Inverse Fast Fourier Transform.
44 | Args:
45 | data: Complex valued input data containing at least 3 dimensions:
46 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size
47 | 2. All other dimensions are assumed to be batch dimensions.
48 | norm: Whether to include normalization. Must be one of ``"backward"``
49 | or ``"ortho"``. See ``torch.fft.ifft`` on PyTorch 1.9.0 for
50 | details.
51 | Returns:
52 | The IFFT of the input.
53 | """
54 | if not data.shape[-1] == 2:
55 | raise ValueError("Tensor does not have separate complex dim.")
56 | if norm not in ("ortho", "backward"):
57 | raise ValueError("norm must be 'ortho' or 'backward'.")
58 | normalized = True if norm == "ortho" else False
59 |
60 | data = ifftshift(data, dim=[-3, -2])
61 | data = torch.ifft(data, 2, normalized=normalized)
62 | data = fftshift(data, dim=[-3, -2])
63 |
64 | return data
65 |
66 |
67 | def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
68 | """
69 | Apply centered 2 dimensional Fast Fourier Transform.
70 | Args:
71 | data: Complex valued input data containing at least 3 dimensions:
72 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size
73 | 2. All other dimensions are assumed to be batch dimensions.
74 | norm: Normalization mode. See ``torch.fft.fft``.
75 | Returns:
76 | The FFT of the input.
77 | """
78 | if not data.shape[-1] == 2:
79 | raise ValueError("Tensor does not have separate complex dim.")
80 |
81 | data = ifftshift(data, dim=[-3, -2])
82 | data = torch.view_as_real(
83 | torch.fft.fftn( # type: ignore
84 | torch.view_as_complex(data), dim=(-2, -1), norm=norm
85 | )
86 | )
87 | data = fftshift(data, dim=[-3, -2])
88 |
89 | return data
90 |
91 |
92 | def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
93 | """
94 | Apply centered 2-dimensional Inverse Fast Fourier Transform.
95 | Args:
96 | data: Complex valued input data containing at least 3 dimensions:
97 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size
98 | 2. All other dimensions are assumed to be batch dimensions.
99 | norm: Normalization mode. See ``torch.fft.ifft``.
100 | Returns:
101 | The IFFT of the input.
102 | """
103 | if not data.shape[-1] == 2:
104 | raise ValueError("Tensor does not have separate complex dim.")
105 |
106 | data = ifftshift(data, dim=[-3, -2])
107 | data = torch.view_as_real(
108 | torch.fft.ifftn( # type: ignore
109 | torch.view_as_complex(data), dim=(-2, -1), norm=norm
110 | )
111 | )
112 | data = fftshift(data, dim=[-3, -2])
113 |
114 | return data
115 |
116 |
117 | # Helper functions
118 |
119 |
120 | def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
121 | """
122 | Similar to roll but for only one dim.
123 | Args:
124 | x: A PyTorch tensor.
125 | shift: Amount to roll.
126 | dim: Which dimension to roll.
127 | Returns:
128 | Rolled version of x.
129 | """
130 | shift = shift % x.size(dim)
131 | if shift == 0:
132 | return x
133 |
134 | left = x.narrow(dim, 0, x.size(dim) - shift)
135 | right = x.narrow(dim, x.size(dim) - shift, shift)
136 |
137 | return torch.cat((right, left), dim=dim)
138 |
139 |
140 | def roll(
141 | x: torch.Tensor,
142 | shift: List[int],
143 | dim: List[int],
144 | ) -> torch.Tensor:
145 | """
146 | Similar to np.roll but applies to PyTorch Tensors.
147 | Args:
148 | x: A PyTorch tensor.
149 | shift: Amount to roll.
150 | dim: Which dimension to roll.
151 | Returns:
152 | Rolled version of x.
153 | """
154 | if len(shift) != len(dim):
155 | raise ValueError("len(shift) must match len(dim)")
156 |
157 | for (s, d) in zip(shift, dim):
158 | x = roll_one_dim(x, s, d)
159 |
160 | return x
161 |
162 |
163 | def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
164 | """
165 | Similar to np.fft.fftshift but applies to PyTorch Tensors
166 | Args:
167 | x: A PyTorch tensor.
168 | dim: Which dimension to fftshift.
169 | Returns:
170 | fftshifted version of x.
171 | """
172 | if dim is None:
173 | # this weird code is necessary for toch.jit.script typing
174 | dim = [0] * (x.dim())
175 | for i in range(1, x.dim()):
176 | dim[i] = i
177 |
178 | # also necessary for torch.jit.script
179 | shift = [0] * len(dim)
180 | for i, dim_num in enumerate(dim):
181 | shift[i] = x.shape[dim_num] // 2
182 |
183 | return roll(x, shift, dim)
184 |
185 |
186 | def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
187 | """
188 | Similar to np.fft.ifftshift but applies to PyTorch Tensors
189 | Args:
190 | x: A PyTorch tensor.
191 | dim: Which dimension to ifftshift.
192 | Returns:
193 | ifftshifted version of x.
194 | """
195 | if dim is None:
196 | # this weird code is necessary for toch.jit.script typing
197 | dim = [0] * (x.dim())
198 | for i in range(1, x.dim()):
199 | dim[i] = i
200 |
201 | # also necessary for torch.jit.script
202 | shift = [0] * len(dim)
203 | for i, dim_num in enumerate(dim):
204 | shift[i] = (x.shape[dim_num] + 1) // 2
205 |
206 | return roll(x, shift, dim)
--------------------------------------------------------------------------------
/util/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | def get_logger():
4 | logger = logging.getLogger(name='DPS')
5 | logger.setLevel(logging.INFO)
6 |
7 | formatter = logging.Formatter("%(asctime)s [%(name)s] >> %(message)s")
8 | stream_handler = logging.StreamHandler()
9 | stream_handler.setFormatter(formatter)
10 | logger.addHandler(stream_handler)
11 |
12 | return logger
--------------------------------------------------------------------------------