├── README.md ├── __init__.py ├── configs ├── autoencoder │ ├── autoencoder_kl_16x16x16.yaml │ ├── autoencoder_kl_32x32x4.yaml │ ├── autoencoder_kl_64x64x3.yaml │ └── autoencoder_kl_8x8x64.yaml ├── latent-diffusion │ ├── celebahq-ldm-vq-4.yaml │ ├── cin-ldm-vq-f8.yaml │ ├── cin256-v2.yaml │ ├── ffhq-ldm-vq-4.yaml │ ├── lsun_churches-ldm-kl-8.yaml │ └── txt2img-1p4B-eval.yaml ├── retrieval-augmented-diffusion │ └── 768x768.yaml ├── stable-diffusion │ └── v1-inference.yaml └── tasks │ ├── box_inpainting_config.yaml │ ├── gaussian_deblur_config.yaml │ ├── hdr_config.yaml │ ├── motion_deblur_config.yaml │ ├── nonlinear_deblur_config.yaml │ ├── phase_retrieval_config.yaml │ ├── rand_inpainting_config.yaml │ └── super_resolution_config.yaml ├── data ├── __pycache__ │ └── dataloader.cpython-38.pyc ├── dataloader.py └── samples │ ├── 00000.png │ └── 00003.png ├── diffstategrad_sample_condition.py ├── diffstategrad_utils.py ├── environment.yaml ├── figures ├── hdr_example.png ├── hdr_short.png ├── manifold_diffstategrad.pdf ├── manifold_diffstategrad.png └── phase_example.png ├── ldm ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── util.cpython-38.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── lsun.cpython-38.pyc │ ├── base.py │ ├── imagenet.py │ └── lsun.py ├── lr_scheduler.py ├── models │ ├── __pycache__ │ │ └── autoencoder.cpython-38.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── ddim.cpython-38.pyc │ │ ├── ddpm.cpython-38.pyc │ │ └── plms.cpython-38.pyc │ │ ├── classifier.py │ │ ├── ddpm.py │ │ ├── diffstategrad_ddim.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── dpm_solver.cpython-38.pyc │ │ │ └── sampler.cpython-38.pyc │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── utils.py ├── modules │ ├── __pycache__ │ │ ├── attention.cpython-38.pyc │ │ └── ema.cpython-38.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── openaimodel.cpython-38.pyc │ │ │ └── util.cpython-38.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── distributions.cpython-38.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── ldm_inverse ├── __pycache__ │ ├── condition_methods.cpython-38.pyc │ └── measurements.cpython-38.pyc ├── condition_methods.py └── measurements.py ├── model_loader.py ├── samples ├── 60004.png └── 60074.png ├── scripts ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── utils.cpython-38.pyc ├── download_first_stages.sh ├── download_models.sh ├── img2img.py ├── inpaint.py ├── inv.py ├── knn2img.py ├── sample_diffusion.py ├── tests │ └── test_watermark.py ├── train_searcher.py ├── txt2img.py └── utils.py ├── src ├── clip │ ├── clip.egg-info │ │ └── PKG-INFO │ ├── clip │ │ ├── __init__.py │ │ ├── clip.py │ │ ├── model.py │ │ └── simple_tokenizer.py │ ├── hubconf.py │ ├── setup.py │ └── tests │ │ └── test_consistency.py └── taming-transformers │ ├── main.py │ ├── scripts │ ├── extract_depth.py │ ├── extract_segmentation.py │ ├── extract_submodel.py │ ├── make_samples.py │ ├── make_scene_samples.py │ ├── sample_conditional.py │ └── sample_fast.py │ ├── setup.py │ ├── taming │ ├── data │ │ ├── ade20k.py │ │ ├── annotated_objects_coco.py │ │ ├── annotated_objects_dataset.py │ │ ├── annotated_objects_open_images.py │ │ ├── base.py │ │ ├── coco.py │ │ ├── conditional_builder │ │ │ ├── objects_bbox.py │ │ │ ├── objects_center_points.py │ │ │ └── utils.py │ │ ├── custom.py │ │ ├── faceshq.py │ │ ├── helper_types.py │ │ ├── image_transforms.py │ │ ├── imagenet.py │ │ ├── open_images_helper.py │ │ ├── sflckr.py │ │ └── utils.py │ ├── lr_scheduler.py │ ├── models │ │ ├── cond_transformer.py │ │ ├── dummy_cond_stage.py │ │ └── vqgan.py │ ├── modules │ │ ├── diffusionmodules │ │ │ └── model.py │ │ ├── discriminator │ │ │ └── model.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── lpips.py │ │ │ ├── segmentation.py │ │ │ └── vqperceptual.py │ │ ├── misc │ │ │ └── coord.py │ │ ├── transformer │ │ │ ├── mingpt.py │ │ │ └── permuter.py │ │ ├── util.py │ │ └── vqvae │ │ │ ├── __pycache__ │ │ │ └── quantize.cpython-38.pyc │ │ │ └── quantize.py │ └── util.py │ └── taming_transformers.egg-info │ └── PKG-INFO └── util ├── __pycache__ ├── fastmri_utils.cpython-38.pyc ├── img_utils.cpython-38.pyc └── resizer.cpython-38.pyc ├── compute_metric.py ├── fastmri_utils.py ├── img_utils.py ├── logger.py ├── resizer.py └── tools.py /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion-State Guided Projected Gradient for Inverse Problems (ICLR 2025) 2 | 3 | ![example](https://github.com/rzirvi1665/DiffStateGrad/blob/main/figures/phase_example.png) 4 | 5 | ![example](https://github.com/rzirvi1665/DiffStateGrad/blob/main/figures/hdr_example.png) 6 | 7 | ## Abstract 8 | 9 | In this work, we propose DiffStateGrad, a novel approach that enhances diffusion-based inverse problem solvers by projecting measurement guidance gradients onto a data-driven low-rank subspace defined by intermediate diffusion states. Our algorithm addresses the challenge of maintaining manifold consistency by performing singular value decomposition on intermediate diffusion states to define a projection matrix that captures local data statistics. This projection ensures that measurement guidance remains aligned with the learned data manifold while filtering out artifact-inducing components, leading to improved robustness and performance across various inverse problems. In this repository, we demonstrate the effectiveness of DiffStateGrad by applying it to ReSample's framework. 10 | 11 | ![example](https://github.com/rzirvi1665/DiffStateGrad/blob/main/figures/manifold_diffstategrad.png) 12 | 13 | ## Implementation 14 | 15 | This repository provides a modified version of the ReSample codebase that incorporates our DiffStateGrad method. The implementation maintains the core functionality of ReSample while adding our enhancements for improved performance and stability. 16 | 17 | Our main contributions can be found in `diffstategrad_sample_condition.py` and `ldm/models/diffusion/diffstategrad_ddim.py`. 18 | 19 | ### DiffStateGrad Helper Methods 20 | 21 | The core utilities of DiffStateGrad are implemented in `diffstategrad_utils.py`. These utilities can be applied to other works as a plug-and-play module. The implementation includes three main functions: 22 | 23 | 1. `compute_rank_for_explained_variance`: Determines the rank needed to explain a target variance percentage across channels 24 | 2. `compute_svd_and_adaptive_rank`: Performs SVD on diffusion state and computes adaptive rank based on variance cutoff 25 | 3. `apply_diffstategrad`: Computes the projected gradient using our DiffStateGrad algorithm 26 | 27 | ### Example Usage 28 | 29 | ```python 30 | from diffstategrad_utils import compute_svd_and_adaptive_rank, apply_diffstategrad 31 | 32 | # During optimization: 33 | if iteration_count % period == 0: 34 | # Compute SVD and adaptive rank when needed 35 | U, s, Vh, adaptive_rank = compute_svd_and_adaptive_rank(z_t, var_cutoff=0.99) 36 | 37 | # Apply DiffStateGrad to the normalized gradient 38 | projected_grad = apply_diffstategrad( 39 | norm_grad=normalized_gradient, 40 | iteration_count=iteration_count, 41 | period=period, 42 | U=U, s=s, Vh=Vh, 43 | adaptive_rank=adaptive_rank 44 | ) 45 | 46 | # Update diffusion state with projected gradient 47 | z_t = z_t - eta * projected_grad 48 | ``` 49 | 50 | For complete implementation details, please refer to [`diffstategrad_utils.py`](https://github.com/rzirvi1665/DiffStateGrad/blob/main/diffstategrad_utils.py) in our repository. 51 | 52 | ## Getting Started 53 | 54 | ### 1) Clone the repository 55 | 56 | ``` 57 | git clone https://github.com/rzirvi1665/DiffStateGrad.git 58 | 59 | cd DiffStateGrad 60 | ``` 61 | 62 |
63 | 64 | ### 2) Download pretrained checkpoints (autoencoders and model) 65 | 66 | ``` 67 | mkdir -p models/ldm 68 | wget https://ommer-lab.com/files/latent-diffusion/ffhq.zip -P ./models/ldm 69 | unzip models/ldm/ffhq.zip -d ./models/ldm 70 | 71 | mkdir -p models/first_stage_models/vq-f4 72 | wget https://ommer-lab.com/files/latent-diffusion/vq-f4.zip -P ./models/first_stage_models/vq-f4 73 | unzip models/first_stage_models/vq-f4/vq-f4.zip -d ./models/first_stage_models/vq-f4 74 | ``` 75 | 76 |
77 | 78 | ### 3) Set environment 79 | 80 | We use the external codes for motion-blurring and non-linear deblurring following the DPS codebase. 81 | 82 | ``` 83 | git clone https://github.com/VinAIResearch/blur-kernel-space-exploring bkse 84 | 85 | git clone https://github.com/LeviBorodenko/motionblur motionblur 86 | ``` 87 | 88 | Install dependencies via 89 | 90 | ``` 91 | conda env create -f environment.yaml 92 | ``` 93 | 94 |
95 | 96 | ### 4) Inference 97 | 98 | ``` 99 | python3 diffstategrad_sample_condition.py 100 | ``` 101 | 102 | The code is currently configured to do inference on FFHQ. You can download the corresponding models from https://github.com/CompVis/latent-diffusion/tree/main and modify the checkpoint paths for other datasets and models. 103 | 104 | 105 |
106 | 107 | ## Task Configurations 108 | 109 | ``` 110 | # Linear inverse problems 111 | - configs/tasks/super_resolution_config.yaml 112 | - configs/tasks/gaussian_deblur_config.yaml 113 | - configs/tasks/motion_deblur_config.yaml 114 | - configs/tasks/box_inpainting_config.yaml 115 | - configs/tasks/rand_inpainting_config.yaml 116 | 117 | # Non-linear inverse problems 118 | - configs/tasks/nonlinear_deblur_config.yaml 119 | - configs/tasks/phase_retrieval_config.yaml 120 | - configs/tasks/hdr_config.yaml 121 | ``` 122 | 123 |
124 | 125 | ## Citation 126 | If you find our work interesting, please consider citing 127 | 128 | ``` 129 | @inproceedings{ 130 | zirvi2025diffusion, 131 | title={Diffusion State-Guided Projected Gradient for Inverse Problems}, 132 | author={Rayhan Zirvi and Bahareh Tolooshams and Anima Anandkumar}, 133 | booktitle={The Thirteenth International Conference on Learning Representations}, 134 | year={2025} 135 | } 136 | ``` 137 | 138 | ## MIT License 139 | 140 | All rights reserved unless otherwise stated by applicable licenses. 141 | If this code includes third-party components, they remain under their original licenses and attributions. 142 | 143 | 144 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/__init__.py -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /configs/tasks/box_inpainting_config.yaml: -------------------------------------------------------------------------------- 1 | conditioning: 2 | main_sampler: resample 3 | method: ps # Do not touch 4 | params: 5 | scale: 0.5 6 | 7 | data: 8 | name: celeb 9 | root: ./data/samples/ 10 | 11 | measurement: 12 | operator: 13 | name: inpainting 14 | mask_opt: 15 | mask_type: box 16 | mask_len_range: !!python/tuple [128, 129] # for box 17 | # mask_prob_range: !!python/tuple [0.7, 0.71] 18 | image_size: 256 19 | 20 | 21 | noise: 22 | name: gaussian 23 | sigma: 0.01 24 | -------------------------------------------------------------------------------- /configs/tasks/gaussian_deblur_config.yaml: -------------------------------------------------------------------------------- 1 | conditioning: 2 | main_sampler: resample 3 | method: ps # Do not touch 4 | params: 5 | scale: 0.5 6 | 7 | data: 8 | name: celeb 9 | root: ./data/samples/ 10 | 11 | measurement: 12 | operator: 13 | name: gaussian_blur 14 | kernel_size: 61 15 | intensity: 3.0 16 | 17 | noise: 18 | name: gaussian 19 | sigma: 0.01 20 | -------------------------------------------------------------------------------- /configs/tasks/hdr_config.yaml: -------------------------------------------------------------------------------- 1 | conditioning: 2 | main_sampler: resample 3 | method: ps # Do not touch 4 | params: 5 | scale: 1.0 6 | 7 | data: 8 | name: celeb 9 | root: ./data/samples/ 10 | 11 | measurement: 12 | operator: 13 | name: high_dynamic_range 14 | scale: 2.0 15 | 16 | noise: 17 | name: gaussian 18 | sigma: 0.01 19 | -------------------------------------------------------------------------------- /configs/tasks/motion_deblur_config.yaml: -------------------------------------------------------------------------------- 1 | conditioning: 2 | main_sampler: resample 3 | method: ps # Do not touch 4 | params: 5 | scale: 1.0 6 | 7 | data: 8 | name: celeb 9 | root: ./data/samples/ 10 | 11 | measurement: 12 | operator: 13 | name: motion_blur 14 | kernel_size: 61 15 | intensity: 0.5 16 | 17 | noise: 18 | name: gaussian 19 | sigma: 0.01 20 | -------------------------------------------------------------------------------- /configs/tasks/nonlinear_deblur_config.yaml: -------------------------------------------------------------------------------- 1 | conditioning: 2 | main_sampler: dps 3 | method: ps # Do not touch 4 | params: 5 | scale: 0.3 6 | 7 | data: 8 | name: celeb 9 | root: ./data/samples/ 10 | 11 | measurement: 12 | operator: 13 | name: nonlinear_blur 14 | opt_yml_path: ./bkse/options/generate_blur/default.yml 15 | 16 | noise: 17 | name: gaussian 18 | sigma: 0.01 19 | -------------------------------------------------------------------------------- /configs/tasks/phase_retrieval_config.yaml: -------------------------------------------------------------------------------- 1 | conditioning: 2 | main_sampler: resample 3 | method: ps # Do not touch 4 | params: 5 | scale: 1.0 6 | 7 | data: 8 | name: celeb 9 | root: ./data/samples/ 10 | 11 | measurement: 12 | operator: 13 | name: phase_retrieval 14 | oversample: 2.0 15 | 16 | noise: 17 | name: gaussian 18 | sigma: 0.01 19 | -------------------------------------------------------------------------------- /configs/tasks/rand_inpainting_config.yaml: -------------------------------------------------------------------------------- 1 | conditioning: 2 | main_sampler: resample 3 | method: ps # Do not touch 4 | params: 5 | scale: 0.5 6 | 7 | data: 8 | name: celeb 9 | root: ./data/samples/ 10 | 11 | measurement: 12 | operator: 13 | name: inpainting 14 | mask_opt: 15 | mask_type: random 16 | mask_prob_range: !!python/tuple [0.7, 0.71] 17 | image_size: 256 18 | 19 | 20 | noise: 21 | name: gaussian 22 | sigma: 0.01 23 | -------------------------------------------------------------------------------- /configs/tasks/super_resolution_config.yaml: -------------------------------------------------------------------------------- 1 | conditioning: 2 | main_sampler: resample 3 | method: ps # Do not touch 4 | params: 5 | scale: 0.1 # Try changing this 6 | 7 | data: 8 | name: ffhq 9 | root: ./data/samples/ 10 | 11 | measurement: 12 | operator: 13 | name: super_resolution 14 | in_shape: !!python/tuple [1, 3, 256, 256] 15 | scale_factor: 4 16 | 17 | noise: 18 | name: gaussian 19 | sigma: 0.01 20 | -------------------------------------------------------------------------------- /data/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/data/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from PIL import Image 3 | from typing import Callable, Optional 4 | from torch.utils.data import DataLoader 5 | from torchvision.datasets import VisionDataset 6 | 7 | 8 | __DATASET__ = {} 9 | 10 | def register_dataset(name: str): 11 | def wrapper(cls): 12 | if __DATASET__.get(name, None): 13 | raise NameError(f"Name {name} is already registered!") 14 | __DATASET__[name] = cls 15 | return cls 16 | return wrapper 17 | 18 | 19 | def get_dataset(name: str, root: str, **kwargs): 20 | if __DATASET__.get(name, None) is None: 21 | raise NameError(f"Dataset {name} is not defined.") 22 | return __DATASET__[name](root=root, **kwargs) 23 | 24 | 25 | def get_dataloader(dataset: VisionDataset, 26 | batch_size: int, 27 | num_workers: int, 28 | train: bool): 29 | dataloader = DataLoader(dataset, 30 | batch_size, 31 | shuffle=train, 32 | num_workers=num_workers, 33 | drop_last=train) 34 | return dataloader 35 | 36 | 37 | @register_dataset(name='celeb') 38 | class CELEBDataset(VisionDataset): 39 | def __init__(self, root: str, transforms: Optional[Callable]=None): 40 | super().__init__(root, transforms) 41 | 42 | self.fpaths = sorted(glob(root + '/**/*.png', recursive=True)) 43 | assert len(self.fpaths) > 0, "File list is empty. Check the root." 44 | 45 | def __len__(self): 46 | return len(self.fpaths) 47 | 48 | def __getitem__(self, index: int): 49 | fpath = self.fpaths[index] 50 | img = Image.open(fpath).convert('RGB') 51 | 52 | if self.transforms is not None: 53 | img = self.transforms(img) 54 | 55 | return img 56 | 57 | 58 | @register_dataset(name='ffhq') 59 | class FFHQDataset(VisionDataset): 60 | def __init__(self, root: str, transforms: Optional[Callable]=None): 61 | super().__init__(root, transforms) 62 | 63 | self.fpaths = sorted(glob(root + '/**/*.png', recursive=True)) 64 | assert len(self.fpaths) > 0, "File list is empty. Check the root." 65 | 66 | def __len__(self): 67 | return len(self.fpaths) 68 | 69 | def __getitem__(self, index: int): 70 | fpath = self.fpaths[index] 71 | img = Image.open(fpath).convert('RGB') 72 | 73 | if self.transforms is not None: 74 | img = self.transforms(img) 75 | 76 | return img 77 | 78 | -------------------------------------------------------------------------------- /data/samples/00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/data/samples/00000.png -------------------------------------------------------------------------------- /data/samples/00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/data/samples/00003.png -------------------------------------------------------------------------------- /diffstategrad_utils.py: -------------------------------------------------------------------------------- 1 | # DiffStateGrad helper method 2 | def compute_rank_for_explained_variance(singular_values, explained_variance_cutoff): 3 | """ 4 | Computes average rank needed across channels to explain target variance percentage. 5 | 6 | Args: 7 | singular_values: List of arrays containing singular values per channel 8 | explained_variance_cutoff: Target explained variance ratio (0-1) 9 | 10 | Returns: 11 | int: Average rank needed across RGB channels 12 | """ 13 | total_rank = 0 14 | for channel_singular_values in singular_values: 15 | squared_singular_values = channel_singular_values ** 2 16 | cumulative_variance = np.cumsum(squared_singular_values) / np.sum(squared_singular_values) 17 | rank = np.searchsorted(cumulative_variance, explained_variance_cutoff) + 1 18 | total_rank += rank 19 | return int(total_rank / 3) 20 | 21 | def compute_svd_and_adaptive_rank(z_t, var_cutoff): 22 | """ 23 | Compute SVD and adaptive rank for the input tensor. 24 | 25 | Args: 26 | z_t: Input tensor (current image representation at time step t) 27 | var_cutoff: Variance cutoff for rank adaptation 28 | 29 | Returns: 30 | tuple: (U, s, Vh, adaptive_rank) where U, s, Vh are SVD components 31 | and adaptive_rank is the computed rank 32 | """ 33 | # Compute SVD of current image representation 34 | U, s, Vh = torch.linalg.svd(z_t[0], full_matrices=False) 35 | 36 | # Compute adaptive rank 37 | s_numpy = s.detach().cpu().numpy() 38 | 39 | adaptive_rank = compute_rank_for_explained_variance([s_numpy], var_cutoff) 40 | 41 | return U, s, Vh, adaptive_rank 42 | 43 | def apply_diffstategrad(norm_grad, iteration_count, period, U=None, s=None, Vh=None, adaptive_rank=None): 44 | """ 45 | Compute projected gradient using DiffStateGrad algorithm. 46 | 47 | Args: 48 | norm_grad: Normalized gradient 49 | iteration_count: Current iteration count 50 | period: Period of SVD projection 51 | U: Left singular vectors from SVD 52 | s: Singular values from SVD 53 | Vh: Right singular vectors from SVD 54 | adaptive_rank: Computed adaptive rank 55 | 56 | Returns: 57 | torch.Tensor: Projected gradient if period condition is met, otherwise original gradient 58 | """ 59 | if period != 0 and iteration_count % period == 0: 60 | if any(param is None for param in [U, s, Vh, adaptive_rank]): 61 | raise ValueError("SVD components and adaptive_rank must be provided when iteration_count % period == 0") 62 | 63 | # Project gradient 64 | A = U[:, :, :adaptive_rank] 65 | B = Vh[:, :adaptive_rank, :] 66 | 67 | low_rank_grad = torch.matmul(A.permute(0, 2, 1), norm_grad[0]) @ B.permute(0, 2, 1) 68 | projected_grad = torch.matmul(A, low_rank_grad) @ B 69 | 70 | # Reshape projected gradient to match original shape 71 | projected_grad = projected_grad.float().unsqueeze(0) # Add batch dimension back 72 | 73 | return projected_grad 74 | 75 | return norm_grad 76 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: LDM 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - brotli-python=1.0.9=py38h6a678d5_7 10 | - bzip2=1.0.8=h5eee18b_5 11 | - ca-certificates=2023.12.12=h06a4308_0 12 | - certifi=2024.2.2=py38h06a4308_0 13 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 14 | - cudatoolkit=11.3.1=h2bc3f7f_2 15 | - ffmpeg=4.3=hf484d3e_0 16 | - freetype=2.12.1=h4a9f257_0 17 | - gmp=6.2.1=h295c915_3 18 | - gnutls=3.6.15=he1e5248_0 19 | - idna=3.4=py38h06a4308_0 20 | - intel-openmp=2021.4.0=h06a4308_3561 21 | - jpeg=9e=h5eee18b_1 22 | - lame=3.100=h7b6447c_0 23 | - lcms2=2.12=h3be6417_0 24 | - ld_impl_linux-64=2.38=h1181459_1 25 | - lerc=3.0=h295c915_0 26 | - libdeflate=1.17=h5eee18b_1 27 | - libffi=3.3=he6710b0_2 28 | - libgcc-ng=11.2.0=h1234567_1 29 | - libgfortran-ng=11.2.0=h00389a5_1 30 | - libgfortran5=11.2.0=h1234567_1 31 | - libgomp=11.2.0=h1234567_1 32 | - libiconv=1.16=h7f8727e_2 33 | - libidn2=2.3.4=h5eee18b_0 34 | - libpng=1.6.39=h5eee18b_0 35 | - libstdcxx-ng=11.2.0=h1234567_1 36 | - libtasn1=4.19.0=h5eee18b_0 37 | - libtiff=4.5.1=h6a678d5_0 38 | - libunistring=0.9.10=h27cfd23_0 39 | - libuv=1.44.2=h5eee18b_0 40 | - libwebp-base=1.3.2=h5eee18b_0 41 | - lz4-c=1.9.4=h6a678d5_0 42 | - mkl=2021.4.0=h06a4308_640 43 | - mkl-service=2.4.0=py38h7f8727e_0 44 | - mkl_fft=1.3.1=py38hd3c417c_0 45 | - mkl_random=1.2.2=py38h51133e4_0 46 | - ncurses=6.4=h6a678d5_0 47 | - nettle=3.7.3=hbbd107a_1 48 | - openh264=2.1.1=h4ff587b_0 49 | - openjpeg=2.4.0=h3ad879b_0 50 | - openssl=1.1.1w=h7f8727e_0 51 | - pillow=10.2.0=py38h5eee18b_0 52 | - pip=20.3.3=py38h06a4308_0 53 | - pysocks=1.7.1=py38h06a4308_0 54 | - python=3.8.5=h7579374_1 55 | - pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0 56 | - pytorch-mutex=1.0=cuda 57 | - readline=8.2=h5eee18b_0 58 | - requests=2.31.0=py38h06a4308_1 59 | - setuptools=68.2.2=py38h06a4308_0 60 | - six=1.16.0=pyhd3eb1b0_1 61 | - sqlite=3.41.2=h5eee18b_0 62 | - tk=8.6.12=h1ccaba5_0 63 | - torchvision=0.12.0=py38_cu113 64 | - typing_extensions=4.9.0=py38h06a4308_1 65 | - urllib3=2.1.0=py38h06a4308_1 66 | - wheel=0.41.2=py38h06a4308_0 67 | - xz=5.4.6=h5eee18b_0 68 | - zlib=1.2.13=h5eee18b_0 69 | - zstd=1.5.5=hc292b87_0 70 | - pip: 71 | - absl-py==2.1.0 72 | - aiohttp==3.9.3 73 | - aiosignal==1.3.1 74 | - antlr4-python3-runtime==4.9.3 75 | - async-timeout==4.0.3 76 | - attrs==23.2.0 77 | - cachetools==5.3.3 78 | - contourpy==1.1.1 79 | - cycler==0.12.1 80 | - einops==0.7.0 81 | - fonttools==4.49.0 82 | - frozenlist==1.4.1 83 | - fsspec==2024.2.0 84 | - future==1.0.0 85 | - google-auth==2.28.1 86 | - google-auth-oauthlib==1.0.0 87 | - grpcio==1.62.0 88 | - imageio==2.34.0 89 | - importlib-metadata==7.0.1 90 | - kiwisolver==1.4.5 91 | - lazy-loader==0.3 92 | - lightning-utilities==0.10.1 93 | - markdown==3.5.2 94 | - markupsafe==2.1.5 95 | - matplotlib==3.6.0 96 | - multidict==6.0.5 97 | - networkx==3.1 98 | - numpy==1.24.4 99 | - oauthlib==3.2.2 100 | - omegaconf==2.3.0 101 | - packaging==23.2 102 | - protobuf==4.25.3 103 | - pyasn1==0.5.1 104 | - pyasn1-modules==0.3.0 105 | - pydeprecate==0.3.1 106 | - pyparsing==3.1.2 107 | - python-dateutil==2.9.0.post0 108 | - pytorch-lightning==1.4.2 109 | - pywavelets==1.4.1 110 | - pyyaml==6.0 111 | - requests-oauthlib==1.3.1 112 | - rsa==4.9 113 | - scikit-image==0.21.0 114 | - scipy==1.10.1 115 | - taming-transformers==0.0.1 116 | - taming-transformers-rom1504==0.0.6 117 | - tensorboard==2.14.0 118 | - tensorboard-data-server==0.7.2 119 | - tifffile==2023.7.10 120 | - torch-fidelity==0.3.0 121 | - torchmetrics==0.7.0 122 | - tqdm==4.64.1 123 | - werkzeug==3.0.1 124 | - yarl==1.9.4 125 | - zipp==3.17.0 126 | prefix: /home/kwonsm/.conda/envs/LDM 127 | -------------------------------------------------------------------------------- /figures/hdr_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/figures/hdr_example.png -------------------------------------------------------------------------------- /figures/hdr_short.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/figures/hdr_short.png -------------------------------------------------------------------------------- /figures/manifold_diffstategrad.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/figures/manifold_diffstategrad.pdf -------------------------------------------------------------------------------- /figures/manifold_diffstategrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/figures/manifold_diffstategrad.png -------------------------------------------------------------------------------- /figures/phase_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/figures/phase_example.png -------------------------------------------------------------------------------- /ldm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/__init__.py -------------------------------------------------------------------------------- /ldm/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/lsun.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/data/__pycache__/lsun.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val_100_images.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/plms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 6 | 7 | 8 | class DPMSolverSampler(object): 9 | def __init__(self, model, **kwargs): 10 | super().__init__() 11 | self.model = model 12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 13 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 14 | 15 | def register_buffer(self, name, attr): 16 | if type(attr) == torch.Tensor: 17 | if attr.device != torch.device("cuda"): 18 | attr = attr.to(torch.device("cuda")) 19 | setattr(self, name, attr) 20 | 21 | @torch.no_grad() 22 | def sample(self, 23 | S, 24 | batch_size, 25 | shape, 26 | conditioning=None, 27 | callback=None, 28 | normals_sequence=None, 29 | img_callback=None, 30 | quantize_x0=False, 31 | eta=0., 32 | mask=None, 33 | x0=None, 34 | temperature=1., 35 | noise_dropout=0., 36 | score_corrector=None, 37 | corrector_kwargs=None, 38 | verbose=True, 39 | x_T=None, 40 | log_every_t=100, 41 | unconditional_guidance_scale=1., 42 | unconditional_conditioning=None, 43 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 44 | **kwargs 45 | ): 46 | if conditioning is not None: 47 | if isinstance(conditioning, dict): 48 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 49 | if cbs != batch_size: 50 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 51 | else: 52 | if conditioning.shape[0] != batch_size: 53 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 54 | 55 | # sampling 56 | C, H, W = shape 57 | size = (batch_size, C, H, W) 58 | 59 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 60 | 61 | device = self.model.betas.device 62 | if x_T is None: 63 | img = torch.randn(size, device=device) 64 | else: 65 | img = x_T 66 | 67 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 68 | 69 | model_fn = model_wrapper( 70 | lambda x, t, c: self.model.apply_model(x, t, c), 71 | ns, 72 | model_type="noise", 73 | guidance_type="classifier-free", 74 | condition=conditioning, 75 | unconditional_condition=unconditional_conditioning, 76 | guidance_scale=unconditional_guidance_scale, 77 | ) 78 | 79 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 80 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 81 | 82 | return x.to(device), None 83 | -------------------------------------------------------------------------------- /ldm/models/diffusion/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import math 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import torchvision.utils as vutils 9 | 10 | 11 | def get_config(config): 12 | with open(config, 'r') as stream: 13 | return yaml.load(stream) 14 | 15 | def prepare_sub_folder(output_directory): 16 | image_directory = os.path.join(output_directory, 'images') 17 | if not os.path.exists(image_directory): 18 | print("Creating directory: {}".format(image_directory)) 19 | os.makedirs(image_directory) 20 | checkpoint_directory = os.path.join(output_directory, 'checkpoints') 21 | if not os.path.exists(checkpoint_directory): 22 | print("Creating directory: {}".format(checkpoint_directory)) 23 | os.makedirs(checkpoint_directory) 24 | return checkpoint_directory, image_directory 25 | 26 | 27 | 28 | 29 | def save_image_3d(tensor, slice_idx, file_name): 30 | ''' 31 | tensor: [bs, c, h, w, 1] 32 | ''' 33 | image_num = len(slice_idx) 34 | tensor = tensor[0, slice_idx, ...].permute(0, 3, 1, 2).cpu().data # [c, 1, h, w] 35 | image_grid = vutils.make_grid(tensor, nrow=image_num, padding=0, normalize=True, scale_each=True) 36 | vutils.save_image(image_grid, file_name, nrow=1) 37 | 38 | 39 | 40 | def map_coordinates(input, coordinates): 41 | ''' PyTorch version of scipy.ndimage.interpolation.map_coordinates 42 | input: (B, H, W, C) 43 | coordinates: (2, ...) 44 | ''' 45 | bs, h, w, c = input.size() 46 | 47 | def _coordinates_pad_wrap(h, w, coordinates): 48 | coordinates[0] = coordinates[0] % h 49 | coordinates[1] = coordinates[1] % w 50 | return coordinates 51 | 52 | co_floor = torch.floor(coordinates).long() 53 | co_ceil = torch.ceil(coordinates).long() 54 | d1 = (coordinates[1] - co_floor[1].float()) 55 | d2 = (coordinates[0] - co_floor[0].float()) 56 | co_floor = _coordinates_pad_wrap(h, w, co_floor) 57 | co_ceil = _coordinates_pad_wrap(h, w, co_ceil) 58 | 59 | f00 = input[:, co_floor[0], co_floor[1], :] 60 | f10 = input[:, co_floor[0], co_ceil[1], :] 61 | f01 = input[:, co_ceil[0], co_floor[1], :] 62 | f11 = input[:, co_ceil[0], co_ceil[1], :] 63 | d1 = d1[None, :, :, None].expand(bs, -1, -1, c) 64 | d2 = d2[None, :, :, None].expand(bs, -1, -1, c) 65 | 66 | fx1 = f00 + d1 * (f10 - f00) 67 | fx2 = f01 + d1 * (f11 - f01) 68 | 69 | return fx1 + d2 * (fx2 - fx1) 70 | 71 | 72 | def ct_parallel_project_2d(img, theta): 73 | bs, h, w, c = img.size() 74 | 75 | # (y, x)=(i, j): [0, w] -> [-0.5, 0.5] 76 | y, x = torch.meshgrid([torch.arange(h, dtype=torch.float32) / h - 0.5, 77 | torch.arange(w, dtype=torch.float32) / w - 0.5]) 78 | 79 | # Rotation transform matrix: simulate parallel projection rays 80 | x_rot = x * torch.cos(theta) - y * torch.sin(theta) 81 | y_rot = x * torch.sin(theta) + y * torch.cos(theta) 82 | 83 | # Reverse back to index [0, w] 84 | x_rot = (x_rot + 0.5) * w 85 | y_rot = (y_rot + 0.5) * h 86 | 87 | # Resample (x, y) index of the pixel on the projection ray-theta 88 | sample_coords = torch.stack([y_rot, x_rot], dim=0).cuda() # [2, h, w] 89 | img_resampled = map_coordinates(img, sample_coords) # [b, h, w, c] 90 | 91 | # Compute integral projections along rays 92 | proj = torch.mean(img_resampled, dim=1, keepdim=True) # [b, 1, w, c] 93 | 94 | return proj 95 | 96 | 97 | def ct_parallel_project_2d_batch(img, thetas): 98 | ''' 99 | img: input tensor [B, H, W, C] 100 | thetas: list of projection angles 101 | ''' 102 | projs = [] 103 | for theta in thetas: 104 | proj = ct_parallel_project_2d(img, theta) 105 | projs.append(proj) 106 | projs = torch.cat(projs, dim=1) # [b, num, w, c] 107 | 108 | return projs -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/__pycache__/ema.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /ldm_inverse/__pycache__/condition_methods.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm_inverse/__pycache__/condition_methods.cpython-38.pyc -------------------------------------------------------------------------------- /ldm_inverse/__pycache__/measurements.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/ldm_inverse/__pycache__/measurements.cpython-38.pyc -------------------------------------------------------------------------------- /ldm_inverse/condition_methods.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | 4 | __CONDITIONING_METHOD__ = {} 5 | 6 | def register_conditioning_method(name: str): 7 | def wrapper(cls): 8 | if __CONDITIONING_METHOD__.get(name, None): 9 | raise NameError(f"Name {name} is already registered!") 10 | __CONDITIONING_METHOD__[name] = cls 11 | return cls 12 | return wrapper 13 | 14 | def get_conditioning_method(name: str, model, operator, noiser, **kwargs): 15 | if __CONDITIONING_METHOD__.get(name, None) is None: 16 | raise NameError(f"Name {name} is not defined!") 17 | return __CONDITIONING_METHOD__[name](model=model, operator=operator, noiser=noiser, **kwargs) 18 | 19 | 20 | class ConditioningMethod(ABC): 21 | def __init__(self, model, operator, noiser, **kwargs): 22 | self.model = model 23 | self.operator = operator 24 | self.noiser = noiser 25 | 26 | def project(self, data, noisy_measurement, **kwargs): 27 | return self.operator.project(data=data, measurement=noisy_measurement, **kwargs) 28 | 29 | def grad_and_value(self, x_prev, x_0_hat, measurement, **kwargs): 30 | if self.noiser.__name__ == 'gaussian': 31 | difference = measurement - self.operator.forward(self.model.differentiable_decode_first_stage( x_0_hat ), **kwargs) 32 | norm = torch.linalg.norm(difference) 33 | norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0] 34 | elif self.noiser.__name__ == 'poisson': 35 | Ax = self.operator.forward(self.model.differentiable_decode_first_stage(x_0_hat), **kwargs) 36 | difference = measurement-Ax 37 | norm = torch.linalg.norm(difference) / measurement.abs() 38 | norm = norm.mean() 39 | norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0] 40 | 41 | else: 42 | raise NotImplementedError 43 | 44 | return norm_grad, norm 45 | 46 | 47 | @abstractmethod 48 | def conditioning(self, x_t, measurement, noisy_measurement=None, **kwargs): 49 | pass 50 | 51 | 52 | @register_conditioning_method(name='ps') 53 | class PosteriorSampling(ConditioningMethod): 54 | def __init__(self, model, operator, noiser, **kwargs): 55 | super().__init__(model, operator, noiser) 56 | self.operator = operator 57 | 58 | def conditioning(self, x_prev, x_t, x_0_hat, measurement, scale=None, **kwargs): 59 | if scale is None: 60 | scale = 0.3 61 | 62 | norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs) 63 | x_t -= norm_grad * scale 64 | return x_t, norm 65 | 66 | -------------------------------------------------------------------------------- /model_loader.py: -------------------------------------------------------------------------------- 1 | from ldm.util import instantiate_from_config 2 | import yaml 3 | import torch 4 | 5 | def load_yaml(file_path: str) -> dict: 6 | with open(file_path) as f: 7 | config = yaml.load(f, Loader=yaml.FullLoader) 8 | return config 9 | 10 | 11 | def load_model_from_config(config, ckpt, train=False): 12 | print(f"Loading model from {ckpt}") 13 | pl_sd = torch.load(ckpt)#, map_location="cpu") 14 | sd = pl_sd["state_dict"] 15 | model = instantiate_from_config(config.model) 16 | _, _ = model.load_state_dict(sd, strict=False) 17 | 18 | model.cuda() 19 | 20 | if train: 21 | model.train() 22 | else: 23 | model.eval() 24 | 25 | return model -------------------------------------------------------------------------------- /samples/60004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/samples/60004.png -------------------------------------------------------------------------------- /samples/60074.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/samples/60074.png -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/scripts/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/scripts/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/download_first_stages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Usage: 4 | # ./scripts/download_first_stages.sh kl-f4 kl-f8 5 | 6 | MODELS=("kl-f4" "kl-f8" "kl-f16" "kl-f32" "vq-f4" "vq-f4-noattn" "vq-f8" "vq-f8-n256" "vq-f16") 7 | DOWNLOAD_PATH="https://ommer-lab.com/files/latent-diffusion" 8 | 9 | function download_first_stages() { 10 | local list=("$@") 11 | 12 | for arg in "${list[@]}"; do 13 | for model in "${MODELS[@]}"; do 14 | if [[ "$model" == "$arg" ]]; then 15 | echo "Downloading $model" 16 | model_dir="./models/first_stage_models/$arg" 17 | if [ ! -d "$model_dir" ]; then 18 | mkdir -p "$model_dir" 19 | echo "Directory created: $model_dir" 20 | else 21 | echo "Directory already exists: $model_dir" 22 | fi 23 | wget -O "$model_dir/model.zip" "$DOWNLOAD_PATH/$arg.zip" 24 | unzip -o "$model_dir/model.zip" -d "$model_dir" 25 | rm -rf "$model_dir/model.zip" 26 | fi 27 | done 28 | done 29 | } 30 | 31 | download_first_stages "$@" 32 | -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Usage: 4 | # ./scripts/download_models.sh celeba ffhq 5 | 6 | MODELS=("celeba" "ffhq" "lsun_churches" "lsun_bedrooms" "text2img" "cin" "semantic_synthesis" "semantic_synthesis256" "sr_bsr" "layout2img_model" "inpainting_big") 7 | DOWNLOAD_PATH="https://ommer-lab.com/files/latent-diffusion" 8 | 9 | function download_models() { 10 | local list=("$@") 11 | 12 | for arg in "${list[@]}"; do 13 | for model in "${MODELS[@]}"; do 14 | if [[ "$model" == "$arg" ]]; then 15 | echo "Downloading $model" 16 | model_dir="./models/ldm/$arg" 17 | if [ ! -d "$model_dir" ]; then 18 | mkdir -p "$model_dir" 19 | echo "Directory created: $model_dir" 20 | else 21 | echo "Directory already exists: $model_dir" 22 | fi 23 | wget -O "$model_dir/model.zip" "$DOWNLOAD_PATH/$arg.zip" 24 | unzip -o "$model_dir/model.zip" -d "$model_dir" 25 | rm -rf "$model_dir/model.zip" 26 | fi 27 | done 28 | done 29 | } 30 | 31 | download_models "$@" 32 | -------------------------------------------------------------------------------- /scripts/inpaint.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | from main import instantiate_from_config 8 | from ldm.models.diffusion.ddim import DDIMSampler 9 | 10 | 11 | def make_batch(image, mask, device): 12 | image = np.array(Image.open(image).convert("RGB")) 13 | image = image.astype(np.float32)/255.0 14 | image = image[None].transpose(0,3,1,2) 15 | image = torch.from_numpy(image) 16 | 17 | mask = np.array(Image.open(mask).convert("L")) 18 | mask = mask.astype(np.float32)/255.0 19 | mask = mask[None,None] 20 | mask[mask < 0.5] = 0 21 | mask[mask >= 0.5] = 1 22 | mask = torch.from_numpy(mask) 23 | 24 | masked_image = (1-mask)*image 25 | 26 | batch = {"image": image, "mask": mask, "masked_image": masked_image} 27 | for k in batch: 28 | batch[k] = batch[k].to(device=device) 29 | batch[k] = batch[k]*2.0-1.0 30 | return batch 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--indir", 37 | type=str, 38 | nargs="?", 39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", 40 | ) 41 | parser.add_argument( 42 | "--outdir", 43 | type=str, 44 | nargs="?", 45 | help="dir to write results to", 46 | ) 47 | parser.add_argument( 48 | "--steps", 49 | type=int, 50 | default=50, 51 | help="number of ddim sampling steps", 52 | ) 53 | opt = parser.parse_args() 54 | 55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) 56 | images = [x.replace("_mask.png", ".png") for x in masks] 57 | print(f"Found {len(masks)} inputs.") 58 | 59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") 60 | model = instantiate_from_config(config.model) 61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], 62 | strict=False) 63 | 64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 65 | model = model.to(device) 66 | sampler = DDIMSampler(model) 67 | 68 | os.makedirs(opt.outdir, exist_ok=True) 69 | with torch.no_grad(): 70 | with model.ema_scope(): 71 | for image, mask in tqdm(zip(images, masks)): 72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1]) 73 | batch = make_batch(image, mask, device=device) 74 | 75 | # encode masked image and concat downsampled mask 76 | c = model.cond_stage_model.encode(batch["masked_image"]) 77 | cc = torch.nn.functional.interpolate(batch["mask"], 78 | size=c.shape[-2:]) 79 | c = torch.cat((c, cc), dim=1) 80 | 81 | shape = (c.shape[1]-1,)+c.shape[2:] 82 | samples_ddim, _ = sampler.sample(S=opt.steps, 83 | conditioning=c, 84 | batch_size=c.shape[0], 85 | shape=shape, 86 | verbose=False) 87 | x_samples_ddim = model.decode_first_stage(samples_ddim) 88 | 89 | image = torch.clamp((batch["image"]+1.0)/2.0, 90 | min=0.0, max=1.0) 91 | mask = torch.clamp((batch["mask"]+1.0)/2.0, 92 | min=0.0, max=1.0) 93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, 94 | min=0.0, max=1.0) 95 | 96 | inpainted = (1-mask)*image+mask*predicted_image 97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath) 99 | -------------------------------------------------------------------------------- /scripts/inv.py: -------------------------------------------------------------------------------- 1 | from taming.models import vqgan 2 | from ldm.models.diffusion.ddim import DDIMSampler 3 | #@title loading utils 4 | import torch 5 | from omegaconf import OmegaConf 6 | from ldm.util import instantiate_from_config 7 | import numpy as np 8 | from PIL import Image 9 | from einops import rearrange 10 | from torchvision.utils import make_grid 11 | from einops import rearrange, repeat 12 | from utils import * 13 | import PIL 14 | import argparse 15 | 16 | def load_model_from_config(config, ckpt, train=False): 17 | print(f"Loading model from {ckpt}") 18 | pl_sd = torch.load(ckpt)#, map_location="cpu") 19 | sd = pl_sd["state_dict"] 20 | model = instantiate_from_config(config.model) 21 | m, u = model.load_state_dict(sd, strict=False) 22 | 23 | model.cuda() 24 | # model.train() 25 | if train: 26 | model.train() 27 | else: 28 | model.eval() 29 | return model 30 | 31 | 32 | def get_model(model_type = None): 33 | if model_type is None: 34 | config = OmegaConf.load("configs/latent-diffusion/celebahq-ldm-vq-4.yaml") 35 | model = load_model_from_config(config, "models/ldm/celeba256/model.ckpt") 36 | elif model_type == "celeb": 37 | config = OmegaConf.load("configs/latent-diffusion/celebahq-ldm-vq-4.yaml") 38 | model = load_model_from_config(config, "models/ldm/celeba256/model.ckpt") 39 | else: 40 | model = None 41 | return model 42 | 43 | 44 | 45 | def load_img(path): 46 | image = Image.open(path).convert("RGB") 47 | w, h = image.size 48 | print(f"loaded input image of size ({w}, {h}) from {path}") 49 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 50 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) 51 | image = np.array(image).astype(np.float32) / 255.0 52 | image = image[None].transpose(0, 3, 1, 2) 53 | image = torch.from_numpy(image) 54 | return 2.*image - 1. 55 | 56 | 57 | 58 | device = torch.device("cuda") 59 | model = get_model() 60 | model.learning_rate = 5e-3 61 | ddim_steps = 100 62 | num_timesteps = 1000 63 | shape=[1, 3, 64, 64] 64 | z = torch.randn(shape, device=device) 65 | z.requires_grad = True 66 | init_image = load_img("/content/drive/MyDrive/stable-diffusion/ldct/test_recon.png").to(device) 67 | init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) 68 | opt = torch.optim.AdamW([z], lr = model.learning_rate) 69 | loss = torch.nn.MSELoss() 70 | angles = torch.tensor(np.linspace(0, np.pi, 25, endpoint=False)) 71 | sampler = DDIMSampler(model) 72 | sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0, verbose = False) 73 | projection_orig = ct_parallel_project_2d_batch(init_image.permute(0,2,3,1), angles) 74 | 75 | 76 | 77 | for i in range(2000): 78 | opt.zero_grad() 79 | decoded_z = sampler.ddecode(z, t_start = 100, temp = 0) 80 | decoded_img = model.differentiable_decode_first_stage(decoded_z) 81 | # decoded_img = model.differentiable_decode_first_stage(z) 82 | projection_recon = ct_parallel_project_2d_batch(decoded_img.permute(0,2,3,1), angles) 83 | output = loss(projection_orig, projection_recon) 84 | # output = loss(init_image, decoded_img) 85 | output.backward() 86 | opt.step() 87 | loss_val = output.detach().cpu().numpy() 88 | if i % 5 == 0: 89 | print(loss_val, "loss ", str(i), " iter") 90 | if loss_val < 6.5e-5: 91 | break 92 | 93 | # sampler.make_schedule(ddim_num_steps=50, ddim_eta=0, verbose = False) 94 | 95 | 96 | print("start encoding decoding") 97 | # decoded_z = sampler.decode(z, cond = None, t_start = 499) 98 | t = repeat(torch.tensor([250]), '1 -> b', b=1) 99 | t = t.to(device).long() 100 | # np.save("CSGM_nopnp_500iter.npy", model.decode_first_stage(z).detach().cpu().numpy()) 101 | np.save("CSGM_recon_2000iter_100_tstep_norandomness.npy", model.decode_first_stage(sampler.ddecode(z,cond=None,t_start=100, temp = 0)).detach().cpu().numpy()) 102 | # print(z) 103 | # encoded = sampler.stochastic_encode(z, t) 104 | # decoded = sampler.decode(encoded, None, 250) 105 | # np.save("PnP_CSGM_recon500iter.npy", model.decode_first_stage(decoded).detach().cpu().numpy()) 106 | 107 | -------------------------------------------------------------------------------- /scripts/tests/test_watermark.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import fire 3 | from imwatermark import WatermarkDecoder 4 | 5 | 6 | def testit(img_path): 7 | bgr = cv2.imread(img_path) 8 | decoder = WatermarkDecoder('bytes', 136) 9 | watermark = decoder.decode(bgr, 'dwtDct') 10 | try: 11 | dec = watermark.decode('utf-8') 12 | except: 13 | dec = "null" 14 | print(dec) 15 | 16 | 17 | if __name__ == "__main__": 18 | fire.Fire(testit) -------------------------------------------------------------------------------- /scripts/train_searcher.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import scann 4 | import argparse 5 | import glob 6 | from multiprocessing import cpu_count 7 | from tqdm import tqdm 8 | 9 | from ldm.util import parallel_data_prefetch 10 | 11 | 12 | def search_bruteforce(searcher): 13 | return searcher.score_brute_force().build() 14 | 15 | 16 | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, 17 | partioning_trainsize, num_leaves, num_leaves_to_search): 18 | return searcher.tree(num_leaves=num_leaves, 19 | num_leaves_to_search=num_leaves_to_search, 20 | training_sample_size=partioning_trainsize). \ 21 | score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() 22 | 23 | 24 | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): 25 | return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( 26 | reorder_k).build() 27 | 28 | def load_datapool(dpath): 29 | 30 | 31 | def load_single_file(saved_embeddings): 32 | compressed = np.load(saved_embeddings) 33 | database = {key: compressed[key] for key in compressed.files} 34 | return database 35 | 36 | def load_multi_files(data_archive): 37 | database = {key: [] for key in data_archive[0].files} 38 | for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): 39 | for key in d.files: 40 | database[key].append(d[key]) 41 | 42 | return database 43 | 44 | print(f'Load saved patch embedding from "{dpath}"') 45 | file_content = glob.glob(os.path.join(dpath, '*.npz')) 46 | 47 | if len(file_content) == 1: 48 | data_pool = load_single_file(file_content[0]) 49 | elif len(file_content) > 1: 50 | data = [np.load(f) for f in file_content] 51 | prefetched_data = parallel_data_prefetch(load_multi_files, data, 52 | n_proc=min(len(data), cpu_count()), target_data_type='dict') 53 | 54 | data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} 55 | else: 56 | raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') 57 | 58 | print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') 59 | return data_pool 60 | 61 | 62 | def train_searcher(opt, 63 | metric='dot_product', 64 | partioning_trainsize=None, 65 | reorder_k=None, 66 | # todo tune 67 | aiq_thld=0.2, 68 | dims_per_block=2, 69 | num_leaves=None, 70 | num_leaves_to_search=None,): 71 | 72 | data_pool = load_datapool(opt.database) 73 | k = opt.knn 74 | 75 | if not reorder_k: 76 | reorder_k = 2 * k 77 | 78 | # normalize 79 | # embeddings = 80 | searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) 81 | pool_size = data_pool['embedding'].shape[0] 82 | 83 | print(*(['#'] * 100)) 84 | print('Initializing scaNN searcher with the following values:') 85 | print(f'k: {k}') 86 | print(f'metric: {metric}') 87 | print(f'reorder_k: {reorder_k}') 88 | print(f'anisotropic_quantization_threshold: {aiq_thld}') 89 | print(f'dims_per_block: {dims_per_block}') 90 | print(*(['#'] * 100)) 91 | print('Start training searcher....') 92 | print(f'N samples in pool is {pool_size}') 93 | 94 | # this reflects the recommended design choices proposed at 95 | # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md 96 | if pool_size < 2e4: 97 | print('Using brute force search.') 98 | searcher = search_bruteforce(searcher) 99 | elif 2e4 <= pool_size and pool_size < 1e5: 100 | print('Using asymmetric hashing search and reordering.') 101 | searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 102 | else: 103 | print('Using using partioning, asymmetric hashing search and reordering.') 104 | 105 | if not partioning_trainsize: 106 | partioning_trainsize = data_pool['embedding'].shape[0] // 10 107 | if not num_leaves: 108 | num_leaves = int(np.sqrt(pool_size)) 109 | 110 | if not num_leaves_to_search: 111 | num_leaves_to_search = max(num_leaves // 20, 1) 112 | 113 | print('Partitioning params:') 114 | print(f'num_leaves: {num_leaves}') 115 | print(f'num_leaves_to_search: {num_leaves_to_search}') 116 | # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 117 | searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, 118 | partioning_trainsize, num_leaves, num_leaves_to_search) 119 | 120 | print('Finish training searcher') 121 | searcher_savedir = opt.target_path 122 | os.makedirs(searcher_savedir, exist_ok=True) 123 | searcher.serialize(searcher_savedir) 124 | print(f'Saved trained searcher under "{searcher_savedir}"') 125 | 126 | if __name__ == '__main__': 127 | sys.path.append(os.getcwd()) 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--database', 130 | '-d', 131 | default='data/rdm/retrieval_databases/openimages', 132 | type=str, 133 | help='path to folder containing the clip feature of the database') 134 | parser.add_argument('--target_path', 135 | '-t', 136 | default='data/rdm/searchers/openimages', 137 | type=str, 138 | help='path to the target folder where the searcher shall be stored.') 139 | parser.add_argument('--knn', 140 | '-k', 141 | default=20, 142 | type=int, 143 | help='number of nearest neighbors, for which the searcher shall be optimized') 144 | 145 | opt, _ = parser.parse_known_args() 146 | 147 | train_searcher(opt,) -------------------------------------------------------------------------------- /src/clip/clip.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: clip 3 | Version: 1.0 4 | Author: OpenAI 5 | Provides-Extra: dev 6 | License-File: LICENSE 7 | -------------------------------------------------------------------------------- /src/clip/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /src/clip/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /src/clip/hubconf.py: -------------------------------------------------------------------------------- 1 | from clip.clip import tokenize as _tokenize, load as _load, available_models as _available_models 2 | import re 3 | import string 4 | 5 | dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"] 6 | 7 | # For compatibility (cannot include special characters in function name) 8 | model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()} 9 | 10 | def _create_hub_entrypoint(model): 11 | def entrypoint(**kwargs): 12 | return _load(model, **kwargs) 13 | 14 | entrypoint.__doc__ = f"""Loads the {model} CLIP model 15 | 16 | Parameters 17 | ---------- 18 | device : Union[str, torch.device] 19 | The device to put the loaded model 20 | 21 | jit : bool 22 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 23 | 24 | download_root: str 25 | path to download the model files; by default, it uses "~/.cache/clip" 26 | 27 | Returns 28 | ------- 29 | model : torch.nn.Module 30 | The {model} CLIP model 31 | 32 | preprocess : Callable[[PIL.Image], torch.Tensor] 33 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 34 | """ 35 | return entrypoint 36 | 37 | def tokenize(): 38 | return _tokenize 39 | 40 | _entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()} 41 | 42 | globals().update(_entrypoints) -------------------------------------------------------------------------------- /src/clip/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="clip", 8 | py_modules=["clip"], 9 | version="1.0", 10 | description="", 11 | author="OpenAI", 12 | packages=find_packages(exclude=["tests*"]), 13 | install_requires=[ 14 | str(r) 15 | for r in pkg_resources.parse_requirements( 16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 17 | ) 18 | ], 19 | include_package_data=True, 20 | extras_require={'dev': ['pytest']}, 21 | ) 22 | -------------------------------------------------------------------------------- /src/clip/tests/test_consistency.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from PIL import Image 5 | 6 | import clip 7 | 8 | 9 | @pytest.mark.parametrize('model_name', clip.available_models()) 10 | def test_consistency(model_name): 11 | device = "cpu" 12 | jit_model, transform = clip.load(model_name, device=device, jit=True) 13 | py_model, _ = clip.load(model_name, device=device, jit=False) 14 | 15 | image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device) 16 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) 17 | 18 | with torch.no_grad(): 19 | logits_per_image, _ = jit_model(image, text) 20 | jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 21 | 22 | logits_per_image, _ = py_model(image, text) 23 | py_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 24 | 25 | assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1) 26 | -------------------------------------------------------------------------------- /src/taming-transformers/scripts/extract_depth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import trange 5 | from PIL import Image 6 | 7 | 8 | def get_state(gpu): 9 | import torch 10 | midas = torch.hub.load("intel-isl/MiDaS", "MiDaS") 11 | if gpu: 12 | midas.cuda() 13 | midas.eval() 14 | 15 | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") 16 | transform = midas_transforms.default_transform 17 | 18 | state = {"model": midas, 19 | "transform": transform} 20 | return state 21 | 22 | 23 | def depth_to_rgba(x): 24 | assert x.dtype == np.float32 25 | assert len(x.shape) == 2 26 | y = x.copy() 27 | y.dtype = np.uint8 28 | y = y.reshape(x.shape+(4,)) 29 | return np.ascontiguousarray(y) 30 | 31 | 32 | def rgba_to_depth(x): 33 | assert x.dtype == np.uint8 34 | assert len(x.shape) == 3 and x.shape[2] == 4 35 | y = x.copy() 36 | y.dtype = np.float32 37 | y = y.reshape(x.shape[:2]) 38 | return np.ascontiguousarray(y) 39 | 40 | 41 | def run(x, state): 42 | model = state["model"] 43 | transform = state["transform"] 44 | hw = x.shape[:2] 45 | with torch.no_grad(): 46 | prediction = model(transform((x + 1.0) * 127.5).cuda()) 47 | prediction = torch.nn.functional.interpolate( 48 | prediction.unsqueeze(1), 49 | size=hw, 50 | mode="bicubic", 51 | align_corners=False, 52 | ).squeeze() 53 | output = prediction.cpu().numpy() 54 | return output 55 | 56 | 57 | def get_filename(relpath, level=-2): 58 | # save class folder structure and filename: 59 | fn = relpath.split(os.sep)[level:] 60 | folder = fn[-2] 61 | file = fn[-1].split('.')[0] 62 | return folder, file 63 | 64 | 65 | def save_depth(dataset, path, debug=False): 66 | os.makedirs(path) 67 | N = len(dset) 68 | if debug: 69 | N = 10 70 | state = get_state(gpu=True) 71 | for idx in trange(N, desc="Data"): 72 | ex = dataset[idx] 73 | image, relpath = ex["image"], ex["relpath"] 74 | folder, filename = get_filename(relpath) 75 | # prepare 76 | folderabspath = os.path.join(path, folder) 77 | os.makedirs(folderabspath, exist_ok=True) 78 | savepath = os.path.join(folderabspath, filename) 79 | # run model 80 | xout = run(image, state) 81 | I = depth_to_rgba(xout) 82 | Image.fromarray(I).save("{}.png".format(savepath)) 83 | 84 | 85 | if __name__ == "__main__": 86 | from taming.data.imagenet import ImageNetTrain, ImageNetValidation 87 | out = "data/imagenet_depth" 88 | if not os.path.exists(out): 89 | print("Please create a folder or symlink '{}' to extract depth data ".format(out) + 90 | "(be prepared that the output size will be larger than ImageNet itself).") 91 | exit(1) 92 | 93 | # go 94 | dset = ImageNetValidation() 95 | abspath = os.path.join(out, "val") 96 | if os.path.exists(abspath): 97 | print("{} exists - not doing anything.".format(abspath)) 98 | else: 99 | print("preparing {}".format(abspath)) 100 | save_depth(dset, abspath) 101 | print("done with validation split") 102 | 103 | dset = ImageNetTrain() 104 | abspath = os.path.join(out, "train") 105 | if os.path.exists(abspath): 106 | print("{} exists - not doing anything.".format(abspath)) 107 | else: 108 | print("preparing {}".format(abspath)) 109 | save_depth(dset, abspath) 110 | print("done with train split") 111 | 112 | print("done done.") 113 | -------------------------------------------------------------------------------- /src/taming-transformers/scripts/extract_segmentation.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import numpy as np 3 | import scipy 4 | import torch 5 | import torch.nn as nn 6 | from scipy import ndimage 7 | from tqdm import tqdm, trange 8 | from PIL import Image 9 | import torch.hub 10 | import torchvision 11 | import torch.nn.functional as F 12 | 13 | # download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from 14 | # https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth 15 | # and put the path here 16 | CKPT_PATH = "TODO" 17 | 18 | rescale = lambda x: (x + 1.) / 2. 19 | 20 | def rescale_bgr(x): 21 | x = (x+1)*127.5 22 | x = torch.flip(x, dims=[0]) 23 | return x 24 | 25 | 26 | class COCOStuffSegmenter(nn.Module): 27 | def __init__(self, config): 28 | super().__init__() 29 | self.config = config 30 | self.n_labels = 182 31 | model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels) 32 | ckpt_path = CKPT_PATH 33 | model.load_state_dict(torch.load(ckpt_path)) 34 | self.model = model 35 | 36 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) 37 | self.image_transform = torchvision.transforms.Compose([ 38 | torchvision.transforms.Lambda(lambda image: torch.stack( 39 | [normalize(rescale_bgr(x)) for x in image])) 40 | ]) 41 | 42 | def forward(self, x, upsample=None): 43 | x = self._pre_process(x) 44 | x = self.model(x) 45 | if upsample is not None: 46 | x = torch.nn.functional.upsample_bilinear(x, size=upsample) 47 | return x 48 | 49 | def _pre_process(self, x): 50 | x = self.image_transform(x) 51 | return x 52 | 53 | @property 54 | def mean(self): 55 | # bgr 56 | return [104.008, 116.669, 122.675] 57 | 58 | @property 59 | def std(self): 60 | return [1.0, 1.0, 1.0] 61 | 62 | @property 63 | def input_size(self): 64 | return [3, 224, 224] 65 | 66 | 67 | def run_model(img, model): 68 | model = model.eval() 69 | with torch.no_grad(): 70 | segmentation = model(img, upsample=(img.shape[2], img.shape[3])) 71 | segmentation = torch.argmax(segmentation, dim=1, keepdim=True) 72 | return segmentation.detach().cpu() 73 | 74 | 75 | def get_input(batch, k): 76 | x = batch[k] 77 | if len(x.shape) == 3: 78 | x = x[..., None] 79 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 80 | return x.float() 81 | 82 | 83 | def save_segmentation(segmentation, path): 84 | # --> class label to uint8, save as png 85 | os.makedirs(os.path.dirname(path), exist_ok=True) 86 | assert len(segmentation.shape)==4 87 | assert segmentation.shape[0]==1 88 | for seg in segmentation: 89 | seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8) 90 | seg = Image.fromarray(seg) 91 | seg.save(path) 92 | 93 | 94 | def iterate_dataset(dataloader, destpath, model): 95 | os.makedirs(destpath, exist_ok=True) 96 | num_processed = 0 97 | for i, batch in tqdm(enumerate(dataloader), desc="Data"): 98 | try: 99 | img = get_input(batch, "image") 100 | img = img.cuda() 101 | seg = run_model(img, model) 102 | 103 | path = batch["relative_file_path_"][0] 104 | path = os.path.splitext(path)[0] 105 | 106 | path = os.path.join(destpath, path + ".png") 107 | save_segmentation(seg, path) 108 | num_processed += 1 109 | except Exception as e: 110 | print(e) 111 | print("but anyhow..") 112 | 113 | print("Processed {} files. Bye.".format(num_processed)) 114 | 115 | 116 | from taming.data.sflckr import Examples 117 | from torch.utils.data import DataLoader 118 | 119 | if __name__ == "__main__": 120 | dest = sys.argv[1] 121 | batchsize = 1 122 | print("Running with batch-size {}, saving to {}...".format(batchsize, dest)) 123 | 124 | model = COCOStuffSegmenter({}).cuda() 125 | print("Instantiated model.") 126 | 127 | dataset = Examples() 128 | dloader = DataLoader(dataset, batch_size=batchsize) 129 | iterate_dataset(dataloader=dloader, destpath=dest, model=model) 130 | print("done.") 131 | -------------------------------------------------------------------------------- /src/taming-transformers/scripts/extract_submodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | if __name__ == "__main__": 5 | inpath = sys.argv[1] 6 | outpath = sys.argv[2] 7 | submodel = "cond_stage_model" 8 | if len(sys.argv) > 3: 9 | submodel = sys.argv[3] 10 | 11 | print("Extracting {} from {} to {}.".format(submodel, inpath, outpath)) 12 | 13 | sd = torch.load(inpath, map_location="cpu") 14 | new_sd = {"state_dict": dict((k.split(".", 1)[-1],v) 15 | for k,v in sd["state_dict"].items() 16 | if k.startswith("cond_stage_model"))} 17 | torch.save(new_sd, outpath) 18 | -------------------------------------------------------------------------------- /src/taming-transformers/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='taming-transformers', 5 | version='0.0.1', 6 | description='Taming Transformers for High-Resolution Image Synthesis', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | from taming.data.sflckr import SegmentationBase # for examples included in repo 9 | 10 | 11 | class Examples(SegmentationBase): 12 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"): 13 | super().__init__(data_csv="data/ade20k_examples.txt", 14 | data_root="data/ade20k_images", 15 | segmentation_root="data/ade20k_segmentations", 16 | size=size, random_crop=random_crop, 17 | interpolation=interpolation, 18 | n_labels=151, shift_segmentation=False) 19 | 20 | 21 | # With semantic map and scene label 22 | class ADE20kBase(Dataset): 23 | def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): 24 | self.split = self.get_split() 25 | self.n_labels = 151 # unknown + 150 26 | self.data_csv = {"train": "data/ade20k_train.txt", 27 | "validation": "data/ade20k_test.txt"}[self.split] 28 | self.data_root = "data/ade20k_root" 29 | with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: 30 | self.scene_categories = f.read().splitlines() 31 | self.scene_categories = dict(line.split() for line in self.scene_categories) 32 | with open(self.data_csv, "r") as f: 33 | self.image_paths = f.read().splitlines() 34 | self._length = len(self.image_paths) 35 | self.labels = { 36 | "relative_file_path_": [l for l in self.image_paths], 37 | "file_path_": [os.path.join(self.data_root, "images", l) 38 | for l in self.image_paths], 39 | "relative_segmentation_path_": [l.replace(".jpg", ".png") 40 | for l in self.image_paths], 41 | "segmentation_path_": [os.path.join(self.data_root, "annotations", 42 | l.replace(".jpg", ".png")) 43 | for l in self.image_paths], 44 | "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] 45 | for l in self.image_paths], 46 | } 47 | 48 | size = None if size is not None and size<=0 else size 49 | self.size = size 50 | if crop_size is None: 51 | self.crop_size = size if size is not None else None 52 | else: 53 | self.crop_size = crop_size 54 | if self.size is not None: 55 | self.interpolation = interpolation 56 | self.interpolation = { 57 | "nearest": cv2.INTER_NEAREST, 58 | "bilinear": cv2.INTER_LINEAR, 59 | "bicubic": cv2.INTER_CUBIC, 60 | "area": cv2.INTER_AREA, 61 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 62 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 63 | interpolation=self.interpolation) 64 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 65 | interpolation=cv2.INTER_NEAREST) 66 | 67 | if crop_size is not None: 68 | self.center_crop = not random_crop 69 | if self.center_crop: 70 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 71 | else: 72 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 73 | self.preprocessor = self.cropper 74 | 75 | def __len__(self): 76 | return self._length 77 | 78 | def __getitem__(self, i): 79 | example = dict((k, self.labels[k][i]) for k in self.labels) 80 | image = Image.open(example["file_path_"]) 81 | if not image.mode == "RGB": 82 | image = image.convert("RGB") 83 | image = np.array(image).astype(np.uint8) 84 | if self.size is not None: 85 | image = self.image_rescaler(image=image)["image"] 86 | segmentation = Image.open(example["segmentation_path_"]) 87 | segmentation = np.array(segmentation).astype(np.uint8) 88 | if self.size is not None: 89 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 90 | if self.size is not None: 91 | processed = self.preprocessor(image=image, mask=segmentation) 92 | else: 93 | processed = {"image": image, "mask": segmentation} 94 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 95 | segmentation = processed["mask"] 96 | onehot = np.eye(self.n_labels)[segmentation] 97 | example["segmentation"] = onehot 98 | return example 99 | 100 | 101 | class ADE20kTrain(ADE20kBase): 102 | # default to random_crop=True 103 | def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): 104 | super().__init__(config=config, size=size, random_crop=random_crop, 105 | interpolation=interpolation, crop_size=crop_size) 106 | 107 | def get_split(self): 108 | return "train" 109 | 110 | 111 | class ADE20kValidation(ADE20kBase): 112 | def get_split(self): 113 | return "validation" 114 | 115 | 116 | if __name__ == "__main__": 117 | dset = ADE20kValidation() 118 | ex = dset[0] 119 | for k in ["image", "scene_category", "segmentation"]: 120 | print(type(ex[k])) 121 | try: 122 | print(ex[k].shape) 123 | except: 124 | print(ex[k]) 125 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/annotated_objects_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import chain 3 | from pathlib import Path 4 | from typing import Iterable, Dict, List, Callable, Any 5 | from collections import defaultdict 6 | 7 | from tqdm import tqdm 8 | 9 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset 10 | from taming.data.helper_types import Annotation, ImageDescription, Category 11 | 12 | COCO_PATH_STRUCTURE = { 13 | 'train': { 14 | 'top_level': '', 15 | 'instances_annotations': 'annotations/instances_train2017.json', 16 | 'stuff_annotations': 'annotations/stuff_train2017.json', 17 | 'files': 'train2017' 18 | }, 19 | 'validation': { 20 | 'top_level': '', 21 | 'instances_annotations': 'annotations/instances_val2017.json', 22 | 'stuff_annotations': 'annotations/stuff_val2017.json', 23 | 'files': 'val2017' 24 | } 25 | } 26 | 27 | 28 | def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: 29 | return { 30 | str(img['id']): ImageDescription( 31 | id=img['id'], 32 | license=img.get('license'), 33 | file_name=img['file_name'], 34 | coco_url=img['coco_url'], 35 | original_size=(img['width'], img['height']), 36 | date_captured=img.get('date_captured'), 37 | flickr_url=img.get('flickr_url') 38 | ) 39 | for img in description_json 40 | } 41 | 42 | 43 | def load_categories(category_json: Iterable) -> Dict[str, Category]: 44 | return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) 45 | for cat in category_json if cat['name'] != 'other'} 46 | 47 | 48 | def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], 49 | category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: 50 | annotations = defaultdict(list) 51 | total = sum(len(a) for a in annotations_json) 52 | for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): 53 | image_id = str(ann['image_id']) 54 | if image_id not in image_descriptions: 55 | raise ValueError(f'image_id [{image_id}] has no image description.') 56 | category_id = ann['category_id'] 57 | try: 58 | category_no = category_no_for_id(str(category_id)) 59 | except KeyError: 60 | continue 61 | 62 | width, height = image_descriptions[image_id].original_size 63 | bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) 64 | 65 | annotations[image_id].append( 66 | Annotation( 67 | id=ann['id'], 68 | area=bbox[2]*bbox[3], # use bbox area 69 | is_group_of=ann['iscrowd'], 70 | image_id=ann['image_id'], 71 | bbox=bbox, 72 | category_id=str(category_id), 73 | category_no=category_no 74 | ) 75 | ) 76 | return dict(annotations) 77 | 78 | 79 | class AnnotatedObjectsCoco(AnnotatedObjectsDataset): 80 | def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): 81 | """ 82 | @param data_path: is the path to the following folder structure: 83 | coco/ 84 | ├── annotations 85 | │ ├── instances_train2017.json 86 | │ ├── instances_val2017.json 87 | │ ├── stuff_train2017.json 88 | │ └── stuff_val2017.json 89 | ├── train2017 90 | │ ├── 000000000009.jpg 91 | │ ├── 000000000025.jpg 92 | │ └── ... 93 | ├── val2017 94 | │ ├── 000000000139.jpg 95 | │ ├── 000000000285.jpg 96 | │ └── ... 97 | @param: split: one of 'train' or 'validation' 98 | @param: desired image size (give square images) 99 | """ 100 | super().__init__(**kwargs) 101 | self.use_things = use_things 102 | self.use_stuff = use_stuff 103 | 104 | with open(self.paths['instances_annotations']) as f: 105 | inst_data_json = json.load(f) 106 | with open(self.paths['stuff_annotations']) as f: 107 | stuff_data_json = json.load(f) 108 | 109 | category_jsons = [] 110 | annotation_jsons = [] 111 | if self.use_things: 112 | category_jsons.append(inst_data_json['categories']) 113 | annotation_jsons.append(inst_data_json['annotations']) 114 | if self.use_stuff: 115 | category_jsons.append(stuff_data_json['categories']) 116 | annotation_jsons.append(stuff_data_json['annotations']) 117 | 118 | self.categories = load_categories(chain(*category_jsons)) 119 | self.filter_categories() 120 | self.setup_category_id_and_number() 121 | 122 | self.image_descriptions = load_image_descriptions(inst_data_json['images']) 123 | annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) 124 | self.annotations = self.filter_object_number(annotations, self.min_object_area, 125 | self.min_objects_per_image, self.max_objects_per_image) 126 | self.image_ids = list(self.annotations.keys()) 127 | self.clean_up_annotations_and_image_descriptions() 128 | 129 | def get_path_structure(self) -> Dict[str, str]: 130 | if self.split not in COCO_PATH_STRUCTURE: 131 | raise ValueError(f'Split [{self.split} does not exist for COCO data.]') 132 | return COCO_PATH_STRUCTURE[self.split] 133 | 134 | def get_image_path(self, image_id: str) -> Path: 135 | return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) 136 | 137 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 138 | # noinspection PyProtectedMember 139 | return self.image_descriptions[image_id]._asdict() 140 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/annotated_objects_open_images.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from csv import DictReader, reader as TupleReader 3 | from pathlib import Path 4 | from typing import Dict, List, Any 5 | import warnings 6 | 7 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset 8 | from taming.data.helper_types import Annotation, Category 9 | from tqdm import tqdm 10 | 11 | OPEN_IMAGES_STRUCTURE = { 12 | 'train': { 13 | 'top_level': '', 14 | 'class_descriptions': 'class-descriptions-boxable.csv', 15 | 'annotations': 'oidv6-train-annotations-bbox.csv', 16 | 'file_list': 'train-images-boxable.csv', 17 | 'files': 'train' 18 | }, 19 | 'validation': { 20 | 'top_level': '', 21 | 'class_descriptions': 'class-descriptions-boxable.csv', 22 | 'annotations': 'validation-annotations-bbox.csv', 23 | 'file_list': 'validation-images.csv', 24 | 'files': 'validation' 25 | }, 26 | 'test': { 27 | 'top_level': '', 28 | 'class_descriptions': 'class-descriptions-boxable.csv', 29 | 'annotations': 'test-annotations-bbox.csv', 30 | 'file_list': 'test-images.csv', 31 | 'files': 'test' 32 | } 33 | } 34 | 35 | 36 | def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str], 37 | category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]: 38 | annotations: Dict[str, List[Annotation]] = defaultdict(list) 39 | with open(descriptor_path) as file: 40 | reader = DictReader(file) 41 | for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'): 42 | width = float(row['XMax']) - float(row['XMin']) 43 | height = float(row['YMax']) - float(row['YMin']) 44 | area = width * height 45 | category_id = row['LabelName'] 46 | if category_id in category_mapping: 47 | category_id = category_mapping[category_id] 48 | if area >= min_object_area and category_id in category_no_for_id: 49 | annotations[row['ImageID']].append( 50 | Annotation( 51 | id=i, 52 | image_id=row['ImageID'], 53 | source=row['Source'], 54 | category_id=category_id, 55 | category_no=category_no_for_id[category_id], 56 | confidence=float(row['Confidence']), 57 | bbox=(float(row['XMin']), float(row['YMin']), width, height), 58 | area=area, 59 | is_occluded=bool(int(row['IsOccluded'])), 60 | is_truncated=bool(int(row['IsTruncated'])), 61 | is_group_of=bool(int(row['IsGroupOf'])), 62 | is_depiction=bool(int(row['IsDepiction'])), 63 | is_inside=bool(int(row['IsInside'])) 64 | ) 65 | ) 66 | if 'train' in str(descriptor_path) and i < 14000000: 67 | warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].') 68 | return dict(annotations) 69 | 70 | 71 | def load_image_ids(csv_path: Path) -> List[str]: 72 | with open(csv_path) as file: 73 | reader = DictReader(file) 74 | return [row['image_name'] for row in reader] 75 | 76 | 77 | def load_categories(csv_path: Path) -> Dict[str, Category]: 78 | with open(csv_path) as file: 79 | reader = TupleReader(file) 80 | return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader} 81 | 82 | 83 | class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset): 84 | def __init__(self, use_additional_parameters: bool, **kwargs): 85 | """ 86 | @param data_path: is the path to the following folder structure: 87 | open_images/ 88 | │ oidv6-train-annotations-bbox.csv 89 | ├── class-descriptions-boxable.csv 90 | ├── oidv6-train-annotations-bbox.csv 91 | ├── test 92 | │ ├── 000026e7ee790996.jpg 93 | │ ├── 000062a39995e348.jpg 94 | │ └── ... 95 | ├── test-annotations-bbox.csv 96 | ├── test-images.csv 97 | ├── train 98 | │ ├── 000002b66c9c498e.jpg 99 | │ ├── 000002b97e5471a0.jpg 100 | │ └── ... 101 | ├── train-images-boxable.csv 102 | ├── validation 103 | │ ├── 0001eeaf4aed83f9.jpg 104 | │ ├── 0004886b7d043cfd.jpg 105 | │ └── ... 106 | ├── validation-annotations-bbox.csv 107 | └── validation-images.csv 108 | @param: split: one of 'train', 'validation' or 'test' 109 | @param: desired image size (returns square images) 110 | """ 111 | 112 | super().__init__(**kwargs) 113 | self.use_additional_parameters = use_additional_parameters 114 | 115 | self.categories = load_categories(self.paths['class_descriptions']) 116 | self.filter_categories() 117 | self.setup_category_id_and_number() 118 | 119 | self.image_descriptions = {} 120 | annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping, 121 | self.category_number) 122 | self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image, 123 | self.max_objects_per_image) 124 | self.image_ids = list(self.annotations.keys()) 125 | self.clean_up_annotations_and_image_descriptions() 126 | 127 | def get_path_structure(self) -> Dict[str, str]: 128 | if self.split not in OPEN_IMAGES_STRUCTURE: 129 | raise ValueError(f'Split [{self.split} does not exist for Open Images data.]') 130 | return OPEN_IMAGES_STRUCTURE[self.split] 131 | 132 | def get_image_path(self, image_id: str) -> Path: 133 | return self.paths['files'].joinpath(f'{image_id:0>16}.jpg') 134 | 135 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 136 | image_path = self.get_image_path(image_id) 137 | return {'file_path': str(image_path), 'file_name': image_path.name} 138 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False, labels=None): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/conditional_builder/objects_bbox.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | from typing import List, Tuple, Callable, Optional 3 | 4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 5 | from more_itertools.recipes import grouper 6 | from taming.data.image_transforms import convert_pil_to_tensor 7 | from torch import LongTensor, Tensor 8 | 9 | from taming.data.helper_types import BoundingBox, Annotation 10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder 11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ 12 | pad_list, get_plot_font_size, absolute_bbox 13 | 14 | 15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 16 | @property 17 | def object_descriptor_length(self) -> int: 18 | return 3 19 | 20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 21 | object_triples = [ 22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) 23 | for ann in annotations 24 | ] 25 | empty_triple = (self.none, self.none, self.none) 26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 27 | return object_triples 28 | 29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 30 | conditional_list = conditional.tolist() 31 | crop_coordinates = None 32 | if self.encode_crop: 33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 34 | conditional_list = conditional_list[:-2] 35 | object_triples = grouper(conditional_list, 3) 36 | assert conditional.shape[0] == self.embedding_dim 37 | return [ 38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) 39 | for object_triple in object_triples if object_triple[0] != self.none 40 | ], crop_coordinates 41 | 42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 44 | plot = pil_image.new('RGB', figure_size, WHITE) 45 | draw = pil_img_draw.Draw(plot) 46 | font = ImageFont.truetype( 47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", 48 | size=get_plot_font_size(font_size, figure_size) 49 | ) 50 | width, height = plot.size 51 | description, crop_coordinates = self.inverse_build(conditional) 52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 53 | annotation = self.representation_to_annotation(representation) 54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 55 | bbox = absolute_bbox(bbox, width, height) 56 | draw.rectangle(bbox, outline=color, width=line_width) 57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) 58 | if crop_coordinates is not None: 59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 60 | return convert_pil_to_tensor(plot) / 127.5 - 1. 61 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/conditional_builder/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import List, Any, Tuple, Optional 3 | 4 | from taming.data.helper_types import BoundingBox, Annotation 5 | 6 | # source: seaborn, color palette tab10 7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), 8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] 9 | BLACK = (0, 0, 0) 10 | GRAY_75 = (63, 63, 63) 11 | GRAY_50 = (127, 127, 127) 12 | GRAY_25 = (191, 191, 191) 13 | WHITE = (255, 255, 255) 14 | FULL_CROP = (0., 0., 1., 1.) 15 | 16 | 17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: 18 | """ 19 | Give intersection area of two rectangles. 20 | @param rectangle1: (x0, y0, w, h) of first rectangle 21 | @param rectangle2: (x0, y0, w, h) of second rectangle 22 | """ 23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] 24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] 25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) 26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) 27 | return x_overlap * y_overlap 28 | 29 | 30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: 31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] 32 | 33 | 34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: 35 | bbox = relative_bbox 36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height 37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) 38 | 39 | 40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: 41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))] 42 | 43 | 44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ 45 | List[Annotation]: 46 | def clamp(x: float): 47 | return max(min(x, 1.), 0.) 48 | 49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox: 50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) 51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) 52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0) 53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0) 54 | if flip: 55 | x0 = 1 - (x0 + w) 56 | return x0, y0, w, h 57 | 58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] 59 | 60 | 61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: 62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] 63 | 64 | 65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: 66 | sl = slice(1) if short else slice(None) 67 | string = '' 68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): 69 | return string 70 | if annotation.is_group_of: 71 | string += 'group'[sl] + ',' 72 | if annotation.is_occluded: 73 | string += 'occluded'[sl] + ',' 74 | if annotation.is_depiction: 75 | string += 'depiction'[sl] + ',' 76 | if annotation.is_inside: 77 | string += 'inside'[sl] 78 | return '(' + string.strip(",") + ')' 79 | 80 | 81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: 82 | if font_size is None: 83 | font_size = 10 84 | if max(figure_size) >= 256: 85 | font_size = 12 86 | if max(figure_size) >= 512: 87 | font_size = 15 88 | return font_size 89 | 90 | 91 | def get_circle_size(figure_size: Tuple[int, int]) -> int: 92 | circle_size = 2 93 | if max(figure_size) >= 256: 94 | circle_size = 3 95 | if max(figure_size) >= 512: 96 | circle_size = 4 97 | return circle_size 98 | 99 | 100 | def load_object_from_string(object_string: str) -> Any: 101 | """ 102 | Source: https://stackoverflow.com/a/10773699 103 | """ 104 | module_name, class_name = object_string.rsplit(".", 1) 105 | return getattr(importlib.import_module(module_name), class_name) 106 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class CustomBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, i): 18 | example = self.data[i] 19 | return example 20 | 21 | 22 | 23 | class CustomTrain(CustomBase): 24 | def __init__(self, size, training_images_list_file): 25 | super().__init__() 26 | with open(training_images_list_file, "r") as f: 27 | paths = f.read().splitlines() 28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 29 | 30 | 31 | class CustomTest(CustomBase): 32 | def __init__(self, size, test_images_list_file): 33 | super().__init__() 34 | with open(test_images_list_file, "r") as f: 35 | paths = f.read().splitlines() 36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 37 | 38 | 39 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/faceshq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class FacesBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | self.keys = None 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def __getitem__(self, i): 19 | example = self.data[i] 20 | ex = {} 21 | if self.keys is not None: 22 | for k in self.keys: 23 | ex[k] = example[k] 24 | else: 25 | ex = example 26 | return ex 27 | 28 | 29 | class CelebAHQTrain(FacesBase): 30 | def __init__(self, size, keys=None): 31 | super().__init__() 32 | root = "data/celebahq" 33 | with open("data/celebahqtrain.txt", "r") as f: 34 | relpaths = f.read().splitlines() 35 | paths = [os.path.join(root, relpath) for relpath in relpaths] 36 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 37 | self.keys = keys 38 | 39 | 40 | class CelebAHQValidation(FacesBase): 41 | def __init__(self, size, keys=None): 42 | super().__init__() 43 | root = "data/celebahq" 44 | with open("data/celebahqvalidation.txt", "r") as f: 45 | relpaths = f.read().splitlines() 46 | paths = [os.path.join(root, relpath) for relpath in relpaths] 47 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 48 | self.keys = keys 49 | 50 | 51 | class FFHQTrain(FacesBase): 52 | def __init__(self, size, keys=None): 53 | super().__init__() 54 | root = "data/ffhq" 55 | with open("data/ffhqtrain.txt", "r") as f: 56 | relpaths = f.read().splitlines() 57 | paths = [os.path.join(root, relpath) for relpath in relpaths] 58 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 59 | self.keys = keys 60 | 61 | 62 | class FFHQValidation(FacesBase): 63 | def __init__(self, size, keys=None): 64 | super().__init__() 65 | root = "data/ffhq" 66 | with open("data/ffhqvalidation.txt", "r") as f: 67 | relpaths = f.read().splitlines() 68 | paths = [os.path.join(root, relpath) for relpath in relpaths] 69 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 70 | self.keys = keys 71 | 72 | 73 | class FacesHQTrain(Dataset): 74 | # CelebAHQ [0] + FFHQ [1] 75 | def __init__(self, size, keys=None, crop_size=None, coord=False): 76 | d1 = CelebAHQTrain(size=size, keys=keys) 77 | d2 = FFHQTrain(size=size, keys=keys) 78 | self.data = ConcatDatasetWithIndex([d1, d2]) 79 | self.coord = coord 80 | if crop_size is not None: 81 | self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) 82 | if self.coord: 83 | self.cropper = albumentations.Compose([self.cropper], 84 | additional_targets={"coord": "image"}) 85 | 86 | def __len__(self): 87 | return len(self.data) 88 | 89 | def __getitem__(self, i): 90 | ex, y = self.data[i] 91 | if hasattr(self, "cropper"): 92 | if not self.coord: 93 | out = self.cropper(image=ex["image"]) 94 | ex["image"] = out["image"] 95 | else: 96 | h,w,_ = ex["image"].shape 97 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 98 | out = self.cropper(image=ex["image"], coord=coord) 99 | ex["image"] = out["image"] 100 | ex["coord"] = out["coord"] 101 | ex["class"] = y 102 | return ex 103 | 104 | 105 | class FacesHQValidation(Dataset): 106 | # CelebAHQ [0] + FFHQ [1] 107 | def __init__(self, size, keys=None, crop_size=None, coord=False): 108 | d1 = CelebAHQValidation(size=size, keys=keys) 109 | d2 = FFHQValidation(size=size, keys=keys) 110 | self.data = ConcatDatasetWithIndex([d1, d2]) 111 | self.coord = coord 112 | if crop_size is not None: 113 | self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) 114 | if self.coord: 115 | self.cropper = albumentations.Compose([self.cropper], 116 | additional_targets={"coord": "image"}) 117 | 118 | def __len__(self): 119 | return len(self.data) 120 | 121 | def __getitem__(self, i): 122 | ex, y = self.data[i] 123 | if hasattr(self, "cropper"): 124 | if not self.coord: 125 | out = self.cropper(image=ex["image"]) 126 | ex["image"] = out["image"] 127 | else: 128 | h,w,_ = ex["image"].shape 129 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 130 | out = self.cropper(image=ex["image"], coord=coord) 131 | ex["image"] = out["image"] 132 | ex["coord"] = out["coord"] 133 | ex["class"] = y 134 | return ex 135 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/image_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from typing import Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor 8 | from torchvision.transforms.functional import _get_image_size as get_image_size 9 | 10 | from taming.data.helper_types import BoundingBox, Image 11 | 12 | pil_to_tensor = PILToTensor() 13 | 14 | 15 | def convert_pil_to_tensor(image: Image) -> Tensor: 16 | with warnings.catch_warnings(): 17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 18 | warnings.simplefilter("ignore") 19 | return pil_to_tensor(image) 20 | 21 | 22 | class RandomCrop1dReturnCoordinates(RandomCrop): 23 | def forward(self, img: Image) -> (BoundingBox, Image): 24 | """ 25 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 26 | Args: 27 | img (PIL Image or Tensor): Image to be cropped. 28 | 29 | Returns: 30 | Bounding box: x0, y0, w, h 31 | PIL Image or Tensor: Cropped image. 32 | 33 | Based on: 34 | torchvision.transforms.RandomCrop, torchvision 1.7.0 35 | """ 36 | if self.padding is not None: 37 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 38 | 39 | width, height = get_image_size(img) 40 | # pad the width if needed 41 | if self.pad_if_needed and width < self.size[1]: 42 | padding = [self.size[1] - width, 0] 43 | img = F.pad(img, padding, self.fill, self.padding_mode) 44 | # pad the height if needed 45 | if self.pad_if_needed and height < self.size[0]: 46 | padding = [0, self.size[0] - height] 47 | img = F.pad(img, padding, self.fill, self.padding_mode) 48 | 49 | i, j, h, w = self.get_params(img, self.size) 50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h 51 | return bbox, F.crop(img, i, j, h, w) 52 | 53 | 54 | class Random2dCropReturnCoordinates(torch.nn.Module): 55 | """ 56 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 57 | Args: 58 | img (PIL Image or Tensor): Image to be cropped. 59 | 60 | Returns: 61 | Bounding box: x0, y0, w, h 62 | PIL Image or Tensor: Cropped image. 63 | 64 | Based on: 65 | torchvision.transforms.RandomCrop, torchvision 1.7.0 66 | """ 67 | 68 | def __init__(self, min_size: int): 69 | super().__init__() 70 | self.min_size = min_size 71 | 72 | def forward(self, img: Image) -> (BoundingBox, Image): 73 | width, height = get_image_size(img) 74 | max_size = min(width, height) 75 | if max_size <= self.min_size: 76 | size = max_size 77 | else: 78 | size = random.randint(self.min_size, max_size) 79 | top = random.randint(0, height - size) 80 | left = random.randint(0, width - size) 81 | bbox = left / width, top / height, size / width, size / height 82 | return bbox, F.crop(img, top, left, size, size) 83 | 84 | 85 | class CenterCropReturnCoordinates(CenterCrop): 86 | @staticmethod 87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: 88 | if width > height: 89 | w = height / width 90 | h = 1.0 91 | x0 = 0.5 - w / 2 92 | y0 = 0. 93 | else: 94 | w = 1.0 95 | h = width / height 96 | x0 = 0. 97 | y0 = 0.5 - h / 2 98 | return x0, y0, w, h 99 | 100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): 101 | """ 102 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 103 | Args: 104 | img (PIL Image or Tensor): Image to be cropped. 105 | 106 | Returns: 107 | Bounding box: x0, y0, w, h 108 | PIL Image or Tensor: Cropped image. 109 | Based on: 110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 111 | """ 112 | width, height = get_image_size(img) 113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) 114 | 115 | 116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip): 117 | def forward(self, img: Image) -> (bool, Image): 118 | """ 119 | Additionally to flipping, returns a boolean whether it was flipped or not. 120 | Args: 121 | img (PIL Image or Tensor): Image to be flipped. 122 | 123 | Returns: 124 | flipped: whether the image was flipped or not 125 | PIL Image or Tensor: Randomly flipped image. 126 | 127 | Based on: 128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 129 | """ 130 | if torch.rand(1) < self.p: 131 | return True, F.hflip(img) 132 | return False, img 133 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/sflckr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SegmentationBase(Dataset): 10 | def __init__(self, 11 | data_csv, data_root, segmentation_root, 12 | size=None, random_crop=False, interpolation="bicubic", 13 | n_labels=182, shift_segmentation=False, 14 | ): 15 | self.n_labels = n_labels 16 | self.shift_segmentation = shift_segmentation 17 | self.data_csv = data_csv 18 | self.data_root = data_root 19 | self.segmentation_root = segmentation_root 20 | with open(self.data_csv, "r") as f: 21 | self.image_paths = f.read().splitlines() 22 | self._length = len(self.image_paths) 23 | self.labels = { 24 | "relative_file_path_": [l for l in self.image_paths], 25 | "file_path_": [os.path.join(self.data_root, l) 26 | for l in self.image_paths], 27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) 28 | for l in self.image_paths] 29 | } 30 | 31 | size = None if size is not None and size<=0 else size 32 | self.size = size 33 | if self.size is not None: 34 | self.interpolation = interpolation 35 | self.interpolation = { 36 | "nearest": cv2.INTER_NEAREST, 37 | "bilinear": cv2.INTER_LINEAR, 38 | "bicubic": cv2.INTER_CUBIC, 39 | "area": cv2.INTER_AREA, 40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 42 | interpolation=self.interpolation) 43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 44 | interpolation=cv2.INTER_NEAREST) 45 | self.center_crop = not random_crop 46 | if self.center_crop: 47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) 48 | else: 49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) 50 | self.preprocessor = self.cropper 51 | 52 | def __len__(self): 53 | return self._length 54 | 55 | def __getitem__(self, i): 56 | example = dict((k, self.labels[k][i]) for k in self.labels) 57 | image = Image.open(example["file_path_"]) 58 | if not image.mode == "RGB": 59 | image = image.convert("RGB") 60 | image = np.array(image).astype(np.uint8) 61 | if self.size is not None: 62 | image = self.image_rescaler(image=image)["image"] 63 | segmentation = Image.open(example["segmentation_path_"]) 64 | assert segmentation.mode == "L", segmentation.mode 65 | segmentation = np.array(segmentation).astype(np.uint8) 66 | if self.shift_segmentation: 67 | # used to support segmentations containing unlabeled==255 label 68 | segmentation = segmentation+1 69 | if self.size is not None: 70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 71 | if self.size is not None: 72 | processed = self.preprocessor(image=image, 73 | mask=segmentation 74 | ) 75 | else: 76 | processed = {"image": image, 77 | "mask": segmentation 78 | } 79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 80 | segmentation = processed["mask"] 81 | onehot = np.eye(self.n_labels)[segmentation] 82 | example["segmentation"] = onehot 83 | return example 84 | 85 | 86 | class Examples(SegmentationBase): 87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"): 88 | super().__init__(data_csv="data/sflckr_examples.txt", 89 | data_root="data/sflckr_images", 90 | segmentation_root="data/sflckr_segmentations", 91 | size=size, random_crop=random_crop, interpolation=interpolation) 92 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import tarfile 4 | import urllib 5 | import zipfile 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | from taming.data.helper_types import Annotation 11 | from torch._six import string_classes 12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format 13 | from tqdm import tqdm 14 | 15 | 16 | def unpack(path): 17 | if path.endswith("tar.gz"): 18 | with tarfile.open(path, "r:gz") as tar: 19 | tar.extractall(path=os.path.split(path)[0]) 20 | elif path.endswith("tar"): 21 | with tarfile.open(path, "r:") as tar: 22 | tar.extractall(path=os.path.split(path)[0]) 23 | elif path.endswith("zip"): 24 | with zipfile.ZipFile(path, "r") as f: 25 | f.extractall(path=os.path.split(path)[0]) 26 | else: 27 | raise NotImplementedError( 28 | "Unknown file extension: {}".format(os.path.splitext(path)[1]) 29 | ) 30 | 31 | 32 | def reporthook(bar): 33 | """tqdm progress bar for downloads.""" 34 | 35 | def hook(b=1, bsize=1, tsize=None): 36 | if tsize is not None: 37 | bar.total = tsize 38 | bar.update(b * bsize - bar.n) 39 | 40 | return hook 41 | 42 | 43 | def get_root(name): 44 | base = "data/" 45 | root = os.path.join(base, name) 46 | os.makedirs(root, exist_ok=True) 47 | return root 48 | 49 | 50 | def is_prepared(root): 51 | return Path(root).joinpath(".ready").exists() 52 | 53 | 54 | def mark_prepared(root): 55 | Path(root).joinpath(".ready").touch() 56 | 57 | 58 | def prompt_download(file_, source, target_dir, content_dir=None): 59 | targetpath = os.path.join(target_dir, file_) 60 | while not os.path.exists(targetpath): 61 | if content_dir is not None and os.path.exists( 62 | os.path.join(target_dir, content_dir) 63 | ): 64 | break 65 | print( 66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) 67 | ) 68 | if content_dir is not None: 69 | print( 70 | "Or place its content into '{}'.".format( 71 | os.path.join(target_dir, content_dir) 72 | ) 73 | ) 74 | input("Press Enter when done...") 75 | return targetpath 76 | 77 | 78 | def download_url(file_, url, target_dir): 79 | targetpath = os.path.join(target_dir, file_) 80 | os.makedirs(target_dir, exist_ok=True) 81 | with tqdm( 82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ 83 | ) as bar: 84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) 85 | return targetpath 86 | 87 | 88 | def download_urls(urls, target_dir): 89 | paths = dict() 90 | for fname, url in urls.items(): 91 | outpath = download_url(fname, url, target_dir) 92 | paths[fname] = outpath 93 | return paths 94 | 95 | 96 | def quadratic_crop(x, bbox, alpha=1.0): 97 | """bbox is xmin, ymin, xmax, ymax""" 98 | im_h, im_w = x.shape[:2] 99 | bbox = np.array(bbox, dtype=np.float32) 100 | bbox = np.clip(bbox, 0, max(im_h, im_w)) 101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) 102 | w = bbox[2] - bbox[0] 103 | h = bbox[3] - bbox[1] 104 | l = int(alpha * max(w, h)) 105 | l = max(l, 2) 106 | 107 | required_padding = -1 * min( 108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) 109 | ) 110 | required_padding = int(np.ceil(required_padding)) 111 | if required_padding > 0: 112 | padding = [ 113 | [required_padding, required_padding], 114 | [required_padding, required_padding], 115 | ] 116 | padding += [[0, 0]] * (len(x.shape) - 2) 117 | x = np.pad(x, padding, "reflect") 118 | center = center[0] + required_padding, center[1] + required_padding 119 | xmin = int(center[0] - l / 2) 120 | ymin = int(center[1] - l / 2) 121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) 122 | 123 | 124 | def custom_collate(batch): 125 | r"""source: pytorch 1.9.0, only one modification to original code """ 126 | 127 | elem = batch[0] 128 | elem_type = type(elem) 129 | if isinstance(elem, torch.Tensor): 130 | out = None 131 | if torch.utils.data.get_worker_info() is not None: 132 | # If we're in a background process, concatenate directly into a 133 | # shared memory tensor to avoid an extra copy 134 | numel = sum([x.numel() for x in batch]) 135 | storage = elem.storage()._new_shared(numel) 136 | out = elem.new(storage) 137 | return torch.stack(batch, 0, out=out) 138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 139 | and elem_type.__name__ != 'string_': 140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 141 | # array of string classes and object 142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 144 | 145 | return custom_collate([torch.as_tensor(b) for b in batch]) 146 | elif elem.shape == (): # scalars 147 | return torch.as_tensor(batch) 148 | elif isinstance(elem, float): 149 | return torch.tensor(batch, dtype=torch.float64) 150 | elif isinstance(elem, int): 151 | return torch.tensor(batch) 152 | elif isinstance(elem, string_classes): 153 | return batch 154 | elif isinstance(elem, collections.abc.Mapping): 155 | return {key: custom_collate([d[key] for d in batch]) for key in elem} 156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch))) 158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added 159 | return batch # added 160 | elif isinstance(elem, collections.abc.Sequence): 161 | # check to make sure that the elements in batch have consistent size 162 | it = iter(batch) 163 | elem_size = len(next(it)) 164 | if not all(len(elem) == elem_size for elem in it): 165 | raise RuntimeError('each element in list of batch should be of equal size') 166 | transposed = zip(*batch) 167 | return [custom_collate(samples) for samples in transposed] 168 | 169 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 170 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from taming.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from taming.modules.losses.lpips import LPIPS 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | 34 | class VQLPIPSWithDiscriminator(nn.Module): 35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 38 | disc_ndf=64, disc_loss="hinge"): 39 | super().__init__() 40 | assert disc_loss in ["hinge", "vanilla"] 41 | self.codebook_weight = codebook_weight 42 | self.pixel_weight = pixelloss_weight 43 | self.perceptual_loss = LPIPS().eval() 44 | self.perceptual_weight = perceptual_weight 45 | 46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 47 | n_layers=disc_num_layers, 48 | use_actnorm=use_actnorm, 49 | ndf=disc_ndf 50 | ).apply(weights_init) 51 | self.discriminator_iter_start = disc_start 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | self.disc_conditional = disc_conditional 62 | 63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 64 | if last_layer is not None: 65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 67 | else: 68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 70 | 71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 73 | d_weight = d_weight * self.discriminator_weight 74 | return d_weight 75 | 76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 77 | global_step, last_layer=None, cond=None, split="train"): 78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 79 | if self.perceptual_weight > 0: 80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 81 | rec_loss = rec_loss + self.perceptual_weight * p_loss 82 | else: 83 | p_loss = torch.tensor([0.0]) 84 | 85 | nll_loss = rec_loss 86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 87 | nll_loss = torch.mean(nll_loss) 88 | 89 | # now the GAN part 90 | if optimizer_idx == 0: 91 | # generator update 92 | if cond is None: 93 | assert not self.disc_conditional 94 | logits_fake = self.discriminator(reconstructions.contiguous()) 95 | else: 96 | assert self.disc_conditional 97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | try: 101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 102 | except RuntimeError: 103 | assert not self.training 104 | d_weight = torch.tensor(0.0) 105 | 106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 108 | 109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 111 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 112 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 113 | "{}/p_loss".format(split): p_loss.detach().mean(), 114 | "{}/d_weight".format(split): d_weight.detach(), 115 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 116 | "{}/g_loss".format(split): g_loss.detach().mean(), 117 | } 118 | return loss, log 119 | 120 | if optimizer_idx == 1: 121 | # second pass for discriminator update 122 | if cond is None: 123 | logits_real = self.discriminator(inputs.contiguous().detach()) 124 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 125 | else: 126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 128 | 129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 131 | 132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 133 | "{}/logits_real".format(split): logits_real.detach().mean(), 134 | "{}/logits_fake".format(split): logits_fake.detach().mean() 135 | } 136 | return d_loss, log 137 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/vqvae/__pycache__/quantize.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/src/taming-transformers/taming/modules/vqvae/__pycache__/quantize.cpython-38.pyc -------------------------------------------------------------------------------- /src/taming-transformers/taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /src/taming-transformers/taming_transformers.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: taming-transformers 3 | Version: 0.0.1 4 | Summary: Taming Transformers for High-Resolution Image Synthesis 5 | -------------------------------------------------------------------------------- /util/__pycache__/fastmri_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/util/__pycache__/fastmri_utils.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/img_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/util/__pycache__/img_utils.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/resizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anima-Lab/DiffStateGrad/c2a6a11c0475c21e6dad6a3dfe30d194389ddafb/util/__pycache__/resizer.cpython-38.pyc -------------------------------------------------------------------------------- /util/compute_metric.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from skimage.metrics import peak_signal_noise_ratio 3 | from tqdm import tqdm 4 | 5 | import matplotlib.pyplot as plt 6 | import lpips 7 | import numpy as np 8 | import torch 9 | 10 | 11 | device = 'cuda:0' 12 | loss_fn_vgg = lpips.LPIPS(net='vgg').to(device) 13 | 14 | task = 'SR' 15 | factor = 4 16 | sigma = 0.1 17 | scale = 1.0 18 | 19 | 20 | label_root = Path(f'/media/harry/tomo/FFHQ/256_1000') 21 | 22 | delta_recon_root = Path(f'./results/{task}/ffhq/{factor}/{sigma}/ps/{scale}/recon') 23 | normal_recon_root = Path(f'./results/{task}/ffhq/{factor}/{sigma}/ps+/{scale}/recon') 24 | 25 | psnr_delta_list = [] 26 | psnr_normal_list = [] 27 | 28 | lpips_delta_list = [] 29 | lpips_normal_list = [] 30 | for idx in tqdm(range(150)): 31 | fname = str(idx).zfill(5) 32 | 33 | label = plt.imread(label_root / f'{fname}.png')[:, :, :3] 34 | delta_recon = plt.imread(delta_recon_root / f'{fname}.png')[:, :, :3] 35 | normal_recon = plt.imread(normal_recon_root / f'{fname}.png')[:, :, :3] 36 | 37 | psnr_delta = peak_signal_noise_ratio(label, delta_recon) 38 | psnr_normal = peak_signal_noise_ratio(label, normal_recon) 39 | 40 | psnr_delta_list.append(psnr_delta) 41 | psnr_normal_list.append(psnr_normal) 42 | 43 | delta_recon = torch.from_numpy(delta_recon).permute(2, 0, 1).to(device) 44 | normal_recon = torch.from_numpy(normal_recon).permute(2, 0, 1).to(device) 45 | label = torch.from_numpy(label).permute(2, 0, 1).to(device) 46 | 47 | delta_recon = delta_recon.view(1, 3, 256, 256) * 2. - 1. 48 | normal_recon = normal_recon.view(1, 3, 256, 256) * 2. - 1. 49 | label = label.view(1, 3, 256, 256) * 2. - 1. 50 | 51 | delta_d = loss_fn_vgg(delta_recon, label) 52 | normal_d = loss_fn_vgg(normal_recon, label) 53 | 54 | lpips_delta_list.append(delta_d) 55 | lpips_normal_list.append(normal_d) 56 | 57 | psnr_delta_avg = sum(psnr_delta_list) / len(psnr_delta_list) 58 | lpips_delta_avg = sum(lpips_delta_list) / len(lpips_delta_list) 59 | 60 | psnr_normal_avg = sum(psnr_normal_list) / len(psnr_normal_list) 61 | lpips_normal_avg = sum(lpips_normal_list) / len(lpips_normal_list) 62 | 63 | print(f'Delta PSNR: {psnr_delta_avg}') 64 | print(f'Delta LPIPS: {lpips_delta_avg}') 65 | 66 | print(f'Normal PSNR: {psnr_normal_avg}') 67 | print(f'Normal LPIPS: {lpips_normal_avg}') -------------------------------------------------------------------------------- /util/fastmri_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | This source code is licensed under the MIT license found in the 4 | LICENSE file in the root directory of this source tree. 5 | """ 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | from packaging import version 11 | 12 | if version.parse(torch.__version__) >= version.parse("1.7.0"): 13 | import torch.fft # type: ignore 14 | 15 | 16 | def fft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 17 | """ 18 | Apply centered 2 dimensional Fast Fourier Transform. 19 | Args: 20 | data: Complex valued input data containing at least 3 dimensions: 21 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 22 | 2. All other dimensions are assumed to be batch dimensions. 23 | norm: Whether to include normalization. Must be one of ``"backward"`` 24 | or ``"ortho"``. See ``torch.fft.fft`` on PyTorch 1.9.0 for details. 25 | Returns: 26 | The FFT of the input. 27 | """ 28 | if not data.shape[-1] == 2: 29 | raise ValueError("Tensor does not have separate complex dim.") 30 | if norm not in ("ortho", "backward"): 31 | raise ValueError("norm must be 'ortho' or 'backward'.") 32 | normalized = True if norm == "ortho" else False 33 | 34 | data = ifftshift(data, dim=[-3, -2]) 35 | data = torch.fft(data, 2, normalized=normalized) 36 | data = fftshift(data, dim=[-3, -2]) 37 | 38 | return data 39 | 40 | 41 | def ifft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 42 | """ 43 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 44 | Args: 45 | data: Complex valued input data containing at least 3 dimensions: 46 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 47 | 2. All other dimensions are assumed to be batch dimensions. 48 | norm: Whether to include normalization. Must be one of ``"backward"`` 49 | or ``"ortho"``. See ``torch.fft.ifft`` on PyTorch 1.9.0 for 50 | details. 51 | Returns: 52 | The IFFT of the input. 53 | """ 54 | if not data.shape[-1] == 2: 55 | raise ValueError("Tensor does not have separate complex dim.") 56 | if norm not in ("ortho", "backward"): 57 | raise ValueError("norm must be 'ortho' or 'backward'.") 58 | normalized = True if norm == "ortho" else False 59 | 60 | data = ifftshift(data, dim=[-3, -2]) 61 | data = torch.ifft(data, 2, normalized=normalized) 62 | data = fftshift(data, dim=[-3, -2]) 63 | 64 | return data 65 | 66 | 67 | def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 68 | """ 69 | Apply centered 2 dimensional Fast Fourier Transform. 70 | Args: 71 | data: Complex valued input data containing at least 3 dimensions: 72 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 73 | 2. All other dimensions are assumed to be batch dimensions. 74 | norm: Normalization mode. See ``torch.fft.fft``. 75 | Returns: 76 | The FFT of the input. 77 | """ 78 | if not data.shape[-1] == 2: 79 | raise ValueError("Tensor does not have separate complex dim.") 80 | 81 | data = ifftshift(data, dim=[-3, -2]) 82 | data = torch.view_as_real( 83 | torch.fft.fftn( # type: ignore 84 | torch.view_as_complex(data), dim=(-2, -1), norm=norm 85 | ) 86 | ) 87 | data = fftshift(data, dim=[-3, -2]) 88 | 89 | return data 90 | 91 | 92 | def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 93 | """ 94 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 95 | Args: 96 | data: Complex valued input data containing at least 3 dimensions: 97 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 98 | 2. All other dimensions are assumed to be batch dimensions. 99 | norm: Normalization mode. See ``torch.fft.ifft``. 100 | Returns: 101 | The IFFT of the input. 102 | """ 103 | if not data.shape[-1] == 2: 104 | raise ValueError("Tensor does not have separate complex dim.") 105 | 106 | data = ifftshift(data, dim=[-3, -2]) 107 | data = torch.view_as_real( 108 | torch.fft.ifftn( # type: ignore 109 | torch.view_as_complex(data), dim=(-2, -1), norm=norm 110 | ) 111 | ) 112 | data = fftshift(data, dim=[-3, -2]) 113 | 114 | return data 115 | 116 | 117 | # Helper functions 118 | 119 | 120 | def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor: 121 | """ 122 | Similar to roll but for only one dim. 123 | Args: 124 | x: A PyTorch tensor. 125 | shift: Amount to roll. 126 | dim: Which dimension to roll. 127 | Returns: 128 | Rolled version of x. 129 | """ 130 | shift = shift % x.size(dim) 131 | if shift == 0: 132 | return x 133 | 134 | left = x.narrow(dim, 0, x.size(dim) - shift) 135 | right = x.narrow(dim, x.size(dim) - shift, shift) 136 | 137 | return torch.cat((right, left), dim=dim) 138 | 139 | 140 | def roll( 141 | x: torch.Tensor, 142 | shift: List[int], 143 | dim: List[int], 144 | ) -> torch.Tensor: 145 | """ 146 | Similar to np.roll but applies to PyTorch Tensors. 147 | Args: 148 | x: A PyTorch tensor. 149 | shift: Amount to roll. 150 | dim: Which dimension to roll. 151 | Returns: 152 | Rolled version of x. 153 | """ 154 | if len(shift) != len(dim): 155 | raise ValueError("len(shift) must match len(dim)") 156 | 157 | for (s, d) in zip(shift, dim): 158 | x = roll_one_dim(x, s, d) 159 | 160 | return x 161 | 162 | 163 | def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: 164 | """ 165 | Similar to np.fft.fftshift but applies to PyTorch Tensors 166 | Args: 167 | x: A PyTorch tensor. 168 | dim: Which dimension to fftshift. 169 | Returns: 170 | fftshifted version of x. 171 | """ 172 | if dim is None: 173 | # this weird code is necessary for toch.jit.script typing 174 | dim = [0] * (x.dim()) 175 | for i in range(1, x.dim()): 176 | dim[i] = i 177 | 178 | # also necessary for torch.jit.script 179 | shift = [0] * len(dim) 180 | for i, dim_num in enumerate(dim): 181 | shift[i] = x.shape[dim_num] // 2 182 | 183 | return roll(x, shift, dim) 184 | 185 | 186 | def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: 187 | """ 188 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 189 | Args: 190 | x: A PyTorch tensor. 191 | dim: Which dimension to ifftshift. 192 | Returns: 193 | ifftshifted version of x. 194 | """ 195 | if dim is None: 196 | # this weird code is necessary for toch.jit.script typing 197 | dim = [0] * (x.dim()) 198 | for i in range(1, x.dim()): 199 | dim[i] = i 200 | 201 | # also necessary for torch.jit.script 202 | shift = [0] * len(dim) 203 | for i, dim_num in enumerate(dim): 204 | shift[i] = (x.shape[dim_num] + 1) // 2 205 | 206 | return roll(x, shift, dim) -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def get_logger(): 4 | logger = logging.getLogger(name='DPS') 5 | logger.setLevel(logging.INFO) 6 | 7 | formatter = logging.Formatter("%(asctime)s [%(name)s] >> %(message)s") 8 | stream_handler = logging.StreamHandler() 9 | stream_handler.setFormatter(formatter) 10 | logger.addHandler(stream_handler) 11 | 12 | return logger --------------------------------------------------------------------------------