├── .gitignore ├── LICENSE ├── README.md ├── _configs ├── algo │ ├── ddim.yaml │ ├── ddrm.yaml │ ├── dps.yaml │ ├── identity.yaml │ ├── mcg.yaml │ ├── pgdm.yaml │ ├── reddiff.yaml │ ├── reddiff_parallel.yaml │ ├── sds.yaml │ └── sds_var.yaml ├── ca.yaml ├── classifier │ ├── imagenet256_cond.yaml │ ├── imagenet512_cond.yaml │ └── none.yaml ├── dataset │ ├── ffhq256_train.yaml │ ├── ffhq256_val.yaml │ ├── imagenet256_train.yaml │ ├── imagenet256_val.yaml │ └── imagenet512_val.yaml ├── ddrmpp.yaml ├── ddrmpp_ffhq.yaml ├── diffusion │ └── linear1000.yaml ├── dist │ └── localhost.yaml ├── exp │ ├── default.yaml │ └── fid_stats.yaml ├── ffhq256_uncond.yaml ├── fid.yaml ├── fid_stats.yaml ├── imagenet256_cond.yaml ├── imagenet256_uncond.yaml ├── imagenet512_cond.yaml ├── inception_score.yaml ├── loader │ ├── imagenet256_ddrm.yaml │ ├── imagenet256_ddrmpp.yaml │ ├── imagenet256_inception.yaml │ └── imagenet512_ddrmpp.yaml ├── model │ ├── ffhq256_uncond.yaml │ ├── imagenet256_cond.yaml │ ├── imagenet256_uncond.yaml │ └── imagenet512_cond.yaml ├── nonlinear_deblur_config.yaml ├── psnr.yaml └── stablediff │ └── params.yaml ├── algos ├── __init__.py ├── ddim.py ├── ddrm.py ├── deis.py ├── dps.py ├── identity.py ├── mcg.py ├── pgdm.py ├── reddiff.py ├── reddiff_parallel.py ├── sds.py └── sds_var.py ├── bkse ├── .gitignore ├── LICENSE ├── README.md ├── data │ ├── GOPRO_dataset.py │ ├── REDS_dataset.py │ ├── __init__.py │ ├── data_sampler.py │ ├── mix_dataset.py │ └── util.py ├── data_augmentation.py ├── domain_specific_deblur.py ├── experiments │ └── pretrained │ │ └── kernel.pth ├── generate_blur.py ├── generic_deblur.py ├── imgs │ ├── blur_faces │ │ └── face01.png │ ├── blur_imgs │ │ ├── blur1.png │ │ └── blur2.png │ ├── results │ │ ├── augmentation.jpg │ │ ├── domain_specific_deblur.jpg │ │ ├── general_deblurring.jpg │ │ ├── generate_blur.jpg │ │ └── kernel_encoding_wGT.png │ ├── sharp_imgs │ │ └── mushishi.png │ └── teaser.jpg ├── models │ ├── __init__.py │ ├── arch_util.py │ ├── backbones │ │ ├── resnet.py │ │ ├── skip │ │ │ ├── concat.py │ │ │ ├── downsampler.py │ │ │ ├── non_local_dot_product.py │ │ │ ├── skip.py │ │ │ └── util.py │ │ └── unet_parts.py │ ├── deblurring │ │ ├── image_deblur.py │ │ └── joint_deblur.py │ ├── dips.py │ ├── dsd │ │ ├── bicubic.py │ │ ├── dsd.py │ │ ├── dsd_stylegan.py │ │ ├── dsd_stylegan2.py │ │ ├── op │ │ │ ├── __init__.py │ │ │ ├── fused_act.py │ │ │ ├── fused_bias_act.cpp │ │ │ ├── fused_bias_act_kernel.cu │ │ │ ├── upfirdn2d.cpp │ │ │ ├── upfirdn2d.py │ │ │ └── upfirdn2d_kernel.cu │ │ ├── spherical_optimizer.py │ │ ├── stylegan.py │ │ └── stylegan2.py │ ├── kernel_encoding │ │ ├── base_model.py │ │ ├── image_base_model.py │ │ └── kernel_wizard.py │ ├── losses │ │ ├── charbonnier_loss.py │ │ ├── dsd_loss.py │ │ ├── gan_loss.py │ │ ├── hyper_laplacian_penalty.py │ │ ├── perceptual_loss.py │ │ └── ssim_loss.py │ └── lr_scheduler.py ├── options │ ├── __init__.py │ ├── data_augmentation │ │ └── default.yml │ ├── domain_specific_deblur │ │ ├── stylegan.yml │ │ └── stylegan2.yml │ ├── generate_blur │ │ └── default.yml │ ├── generic_deblur │ │ └── default.yml │ ├── kernel_encoding │ │ ├── GOPRO │ │ │ ├── wVAE.yml │ │ │ └── woVAE.yml │ │ ├── REDS │ │ │ └── woVAE.yml │ │ └── mix │ │ │ └── woVAE.yml │ └── options.py ├── requirements.txt ├── scripts │ ├── create_lmdb.py │ └── download_dataset.py ├── train.py ├── train_script.sh └── utils │ ├── __init__.py │ └── util.py ├── datasets ├── __init__.py ├── ffhq.py ├── imagenet.py ├── lmdb_dataset.py └── lsun.py ├── demo └── output.gif ├── eval ├── ca.py ├── fid.py ├── fid_stats.py ├── inception_score.py └── psnr.py ├── main.py ├── misc ├── dgp_top10.txt ├── dgp_top100.txt ├── dgp_top1k.txt ├── mcg_top1k.txt ├── palette_10k.txt ├── palette_5k.txt ├── palette_jpeg_demo.txt └── sr3_top1k.txt ├── models ├── __init__.py ├── classifier_guidance_model.py ├── diffusion.py └── guided_diffusion │ ├── __init__.py │ ├── fp16_util.py │ ├── logger.py │ ├── nn.py │ ├── script_util.py │ └── unet.py ├── motionblur ├── .gitignore ├── README.md ├── __init__.py ├── environment.yaml ├── example_kernel │ ├── kernel0.png │ ├── kernel100.png │ ├── kernel25.png │ ├── kernel50.png │ └── kernel75.png ├── images │ ├── flag.png │ ├── flagBLURRED.png │ └── moon.png ├── intensity.png └── motionblur.py ├── output.gif ├── playground ├── 100.png ├── 1000.png ├── 20.png ├── 50.png ├── 500.png ├── Untitled.ipynb ├── Untitled1.ipynb ├── adm.png ├── adm0.png ├── adm1.png ├── adm2.png ├── adm3.png ├── adm4.png ├── adm5.png ├── awd.png ├── awd_fwd.png ├── coltran_bot5k.txt ├── coltran_top5k.txt ├── compare_openai_sr_model.ipynb ├── ctest10k.txt ├── ddrmpp_res │ ├── 0.png │ ├── 10.png │ ├── 2.png │ ├── 20.png │ ├── 3.png │ └── 7.png ├── dgp_top1k.txt ├── figures.ipynb ├── jpeg5_deg │ ├── 0.png │ ├── 10.png │ ├── 2.png │ ├── 20.png │ ├── 3.png │ └── 7.png ├── jpeg5_ori │ ├── 0.png │ ├── 10.png │ ├── 2.png │ ├── 20.png │ ├── 3.png │ └── 7.png ├── lip.ipynb ├── palette_10k.txt ├── palette_img │ ├── 0.jpg │ ├── 10.jpg │ ├── 11.jpg │ ├── 2.jpg │ ├── 3.jpg │ └── 7.jpg ├── palette_result │ ├── 0.png │ ├── 10.png │ ├── 11.png │ ├── 2.png │ ├── 20.png │ ├── 3.png │ └── 7.png ├── plot_awd_steps.ipynb ├── process_imagenet_txt.ipynb └── svd.ipynb ├── requirements.txt ├── run_eval.sh ├── sample_batch.sh ├── sample_test.sh └── utils ├── __init__.py ├── checkpoints.py ├── dct.py ├── degredations.py ├── distributed.py ├── fft_utils.py ├── functions.py ├── jpeg_quantization.py ├── jpeg_torch.py └── save.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | outputs/ 3 | _exp 4 | _data 5 | dist/ 6 | !_configs/exp/ 7 | !_configs/dist/ 8 | 9 | src/ 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | *.npy 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023-2024, NVIDIA Corporation & affiliates. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. -------------------------------------------------------------------------------- /_configs/algo/ddim.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: 'ddim' 4 | eta: 0.5 5 | sdedit: False 6 | cond_awd: False -------------------------------------------------------------------------------- /_configs/algo/ddrm.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "ddrm" 4 | sigma_y: 0.0 5 | eta: 0.85 6 | eta_b: 1 7 | deg: "deno" 8 | lr: 0.25 -------------------------------------------------------------------------------- /_configs/algo/dps.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "dps" 4 | deg: "deno" 5 | awd: True 6 | cond_awd: False 7 | grad_term_weight: 0.1 #0.1 for in2_20ff, and 1.0 for sr4 8 | sigma_y: 0.0 9 | eta: 0.0 10 | mcg: False 11 | original: True 12 | lr: 0.25 13 | -------------------------------------------------------------------------------- /_configs/algo/identity.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: 'identity' -------------------------------------------------------------------------------- /_configs/algo/mcg.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "mcg" 4 | deg: "in2_box" 5 | grad_term_weight: 1 6 | sigma_y: 0.0 7 | eta: 0.0 8 | -------------------------------------------------------------------------------- /_configs/algo/pgdm.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "pgdm" 4 | deg: "deno" 5 | awd: True 6 | cond_awd: False 7 | grad_term_weight: 1 8 | sigma_y: 0.0 9 | eta: 0.0 10 | mcg: False 11 | lr: 0.25 -------------------------------------------------------------------------------- /_configs/algo/reddiff.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "reddiff" 4 | deg: "deno" 5 | awd: True 6 | cond_awd: False 7 | obs_weight: 1.0 8 | grad_term_weight: 0.25 9 | denoise_term_weight: "linear" #"linear", "sqrt", "log", "square", "trunc_linear", "const", "power2over3" 10 | sigma_y: 0.0 11 | eta: 0.0 12 | lr: 0.1 13 | sigma_x0: 0.0 14 | -------------------------------------------------------------------------------- /_configs/algo/reddiff_parallel.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "reddiff_parallel" 4 | deg: "deno" 5 | awd: True 6 | cond_awd: False 7 | grad_term_weight: 0.5 8 | denoise_term_weight: "linear" #"linear", "sqrt", "log", "square", "trunc_linear", "const", "power2over3" 9 | sigma_y: 0.0 10 | eta: 0.0 11 | lr: 0.1 -------------------------------------------------------------------------------- /_configs/algo/sds.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "sds" 4 | deg: "deno" 5 | awd: True 6 | cond_awd: False 7 | obs_weight: 0.0 8 | grad_term_weight: 1.0 9 | denoise_term_weight: "linear" #"linear", "sqrt", "log", "square", "trunc_linear", "const", "power2over3" 10 | sigma_y: 0.0 11 | eta: 0.0 12 | lr: 0.1 13 | sigma_x0: 0.0 14 | -------------------------------------------------------------------------------- /_configs/algo/sds_var.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "sds_var" 4 | deg: "deno" 5 | awd: True 6 | cond_awd: False 7 | obs_weight: 0.0 8 | grad_term_weight: 1.0 9 | denoise_term_weight: "linear" #"linear", "sqrt", "log", "square", "trunc_linear", "const", "power2over3" 10 | sigma_y: 0.0 11 | eta: 0.0 12 | lr: 0.1 13 | sigma_x0: 0.0 14 | -------------------------------------------------------------------------------- /_configs/ca.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | defaults: 4 | - dist: localhost 5 | - dataset: imagenet256_val 6 | - loader: imagenet256_inception 7 | - exp: fid_stats 8 | - _self_ 9 | 10 | 11 | results: ??? -------------------------------------------------------------------------------- /_configs/classifier/imagenet256_cond.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | _target_: models.guided_diffusion.script_util.create_classifier 4 | image_size: 256 5 | classifier_attention_resolutions: "32,16,8" 6 | classifier_depth: 2 7 | classifier_pool: "attention" 8 | classifier_resblock_updown: True 9 | classifier_width: 128 10 | classifier_use_scale_shift_norm: True 11 | classifier_scale: 1.0 12 | classifier_use_fp16: True 13 | ckpt: "imagenet_256_classifier" -------------------------------------------------------------------------------- /_configs/classifier/imagenet512_cond.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | _target_: models.guided_diffusion.script_util.create_classifier 4 | image_size: 512 5 | classifier_attention_resolutions: "32,16,8" 6 | classifier_depth: 2 7 | classifier_pool: "attention" 8 | classifier_resblock_updown: True 9 | classifier_width: 128 10 | classifier_use_scale_shift_norm: True 11 | classifier_scale: 1.0 12 | classifier_use_fp16: True 13 | ckpt: "imagenet_512_classifier" -------------------------------------------------------------------------------- /_configs/classifier/none.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | _target_: utils.return_none -------------------------------------------------------------------------------- /_configs/dataset/ffhq256_train.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "FFHQ_256x256" 4 | root: "_data" 5 | split: "train" 6 | image_size: 256 7 | channels: 3 8 | transform: "default" 9 | subset: -1 -------------------------------------------------------------------------------- /_configs/dataset/ffhq256_val.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "FFHQ_256x256" 4 | root: "_data" 5 | split: "val" 6 | image_size: 256 7 | channels: 3 8 | transform: "default" 9 | subset: 1000 -------------------------------------------------------------------------------- /_configs/dataset/imagenet256_train.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "ImageNet_256x256" 4 | root: "_data" 5 | split: "train" 6 | image_size: 256 7 | channels: 3 8 | meta_root: "data/imagenet" 9 | transform: "diffusion" 10 | subset_txt: "" -------------------------------------------------------------------------------- /_configs/dataset/imagenet256_val.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "ImageNet_256x256" 4 | root: "/home/mmardani/research/datasets/imagenet-root" 5 | split: "val" 6 | image_size: 256 7 | channels: 3 8 | meta_root: "/home/mmardani/research/datasets/imagenet-root" 9 | transform: "diffusion" 10 | #subset_txt: "/home/mmardani/research/stable-diffusion-sampling-gitlab/pgdm/misc/sr3_top1k.txt" 11 | subset_txt: "/home/mmardani/research/stable-diffusion-sampling-gitlab/pgdm/misc/dgp_top1k.txt" 12 | -------------------------------------------------------------------------------- /_configs/dataset/imagenet512_val.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | name: "ImageNet_512x512" 4 | root: "/home/mmardani/research/datasets/imagenet-root" 5 | split: "val" 6 | image_size: 512 7 | channels: 3 8 | meta_root: "/home/mmardani/research/datasets/imagenet-root" 9 | transform: "diffusion" 10 | subset_txt: "/home/mmardani/research/stable-diffusion-sampling-gitlab/pgdm/misc/sr3_top1k.txt" 11 | -------------------------------------------------------------------------------- /_configs/ddrmpp.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | defaults: 4 | - imagenet256_uncond 5 | - _self_ 6 | -------------------------------------------------------------------------------- /_configs/ddrmpp_ffhq.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | defaults: 4 | - ffhq256_uncond 5 | - _self_ 6 | -------------------------------------------------------------------------------- /_configs/diffusion/linear1000.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | beta_schedule: linear 4 | beta_start: 0.0001 5 | beta_end: 0.02 6 | num_diffusion_timesteps: 1000 -------------------------------------------------------------------------------- /_configs/dist/localhost.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | port: "12345" 4 | master_address: "localhost" 5 | node_rank: 0 6 | num_proc_node: 1 7 | num_processes_per_node: 1 8 | backend: "gloo" #"nccl" -------------------------------------------------------------------------------- /_configs/exp/default.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | #local machine 4 | root: "/_exp" 5 | save_evolution: False 6 | ckpt_root: "ckpts" 7 | samples_root: "samples" 8 | overwrite: True 9 | num_steps: 50 10 | start_step: 1000 11 | end_step: 0 12 | smoke_test: 4 13 | logfreq: 200 14 | save_ori: False 15 | save_deg: False 16 | seed: 1 17 | name: ??? 18 | -------------------------------------------------------------------------------- /_configs/exp/fid_stats.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | root: "_exp" -------------------------------------------------------------------------------- /_configs/ffhq256_uncond.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | defaults: 4 | - dataset: ffhq256_val 5 | - loader: imagenet256_ddrm 6 | - model: ffhq256_uncond 7 | - classifier: none 8 | - dist: localhost 9 | - exp: default 10 | - diffusion: linear1000 11 | - algo: ddim 12 | - _self_ -------------------------------------------------------------------------------- /_configs/fid.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | defaults: 4 | - dist: localhost 5 | - exp: fid_stats 6 | - _self_ 7 | 8 | path1: 'none' #??? 9 | path2: 'none' 10 | results: 'none' -------------------------------------------------------------------------------- /_configs/fid_stats.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | defaults: 4 | - dist: localhost 5 | - dataset: imagenet256_val 6 | - loader: imagenet256_ddrmpp 7 | - exp: fid_stats 8 | - _self_ 9 | 10 | fid: 11 | mode: "legacy_pytorch" 12 | 13 | dataset: 14 | transform: "identity" 15 | 16 | save_path: 'none' 17 | mean_std_stats: False -------------------------------------------------------------------------------- /_configs/imagenet256_cond.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | defaults: 4 | - stablediff: params 5 | - dataset: imagenet256_val 6 | - loader: imagenet256_ddrm 7 | - model: imagenet256_cond 8 | - classifier: imagenet256_cond 9 | - dist: localhost 10 | - exp: default 11 | - diffusion: linear1000 12 | - algo: ddim 13 | - _self_ -------------------------------------------------------------------------------- /_configs/imagenet256_uncond.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | defaults: 4 | - dataset: imagenet256_val 5 | - loader: imagenet256_ddrm 6 | - model: imagenet256_uncond 7 | - classifier: none 8 | - dist: localhost 9 | - exp: default 10 | - diffusion: linear1000 11 | - algo: ddim 12 | - _self_ -------------------------------------------------------------------------------- /_configs/imagenet512_cond.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | defaults: 4 | - stablediff: params 5 | - dataset: imagenet512_val 6 | - loader: imagenet512_ddrmpp 7 | - model: imagenet512_cond 8 | - classifier: imagenet512_cond 9 | - dist: localhost 10 | - exp: default 11 | - diffusion: linear1000 12 | - algo: ddim 13 | - _self_ -------------------------------------------------------------------------------- /_configs/inception_score.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | defaults: 4 | - dist: localhost 5 | - dataset: imagenet256_val 6 | - loader: imagenet256_inception 7 | - exp: fid_stats 8 | - _self_ 9 | 10 | dataset: 11 | transform: "isc_cropped" 12 | 13 | results: ??? -------------------------------------------------------------------------------- /_configs/loader/imagenet256_ddrm.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | batch_size: 1 4 | num_workers: 12 5 | shuffle: False 6 | drop_last: False 7 | pin_memory: True -------------------------------------------------------------------------------- /_configs/loader/imagenet256_ddrmpp.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | batch_size: 1 4 | num_workers: 12 5 | shuffle: False 6 | drop_last: False 7 | pin_memory: True -------------------------------------------------------------------------------- /_configs/loader/imagenet256_inception.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | batch_size: 256 4 | num_workers: 8 5 | shuffle: False 6 | drop_last: False 7 | pin_memory: True -------------------------------------------------------------------------------- /_configs/loader/imagenet512_ddrmpp.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | batch_size: 2 4 | num_workers: 12 5 | shuffle: False 6 | drop_last: False 7 | pin_memory: True -------------------------------------------------------------------------------- /_configs/model/ffhq256_uncond.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | _target_: models.guided_diffusion.script_util.create_model 4 | in_channels: 3 5 | out_channels: 3 6 | num_channels: 128 7 | num_heads: 4 8 | num_res_blocks: 1 9 | attention_resolutions: "16" 10 | dropout: 0.0 11 | resamp_with_conv: True 12 | learn_sigma: True 13 | use_scale_shift_norm: true 14 | use_fp16: false 15 | resblock_updown: true 16 | num_heads_upsample: -1 17 | var_type: 'fixedsmall' 18 | num_head_channels: 64 19 | image_size: 256 20 | class_cond: false 21 | use_new_attention_order: false 22 | ckpt: "ffhq_256" 23 | -------------------------------------------------------------------------------- /_configs/model/imagenet256_cond.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | _target_: models.guided_diffusion.script_util.create_model 4 | in_channels: 3 5 | out_channels: 3 6 | num_channels: 256 7 | num_heads: 4 8 | num_res_blocks: 2 9 | attention_resolutions: "32,16,8" 10 | dropout: 0.0 11 | resamp_with_conv: True 12 | learn_sigma: True 13 | use_scale_shift_norm: true 14 | use_fp16: true 15 | resblock_updown: true 16 | num_heads_upsample: -1 17 | var_type: 'fixedsmall' 18 | num_head_channels: 64 19 | image_size: 256 20 | class_cond: True 21 | use_checkpoint: true 22 | use_new_attention_order: false 23 | ckpt: "imagenet_256_cond" -------------------------------------------------------------------------------- /_configs/model/imagenet256_uncond.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | _target_: models.guided_diffusion.script_util.create_model 4 | in_channels: 3 5 | out_channels: 3 6 | num_channels: 256 7 | num_heads: 4 8 | num_res_blocks: 2 9 | attention_resolutions: "32,16,8" 10 | dropout: 0.0 11 | resamp_with_conv: True 12 | learn_sigma: True 13 | use_scale_shift_norm: true 14 | use_fp16: true 15 | resblock_updown: true 16 | num_heads_upsample: -1 17 | var_type: 'fixedsmall' 18 | num_head_channels: 64 19 | image_size: 256 20 | class_cond: false 21 | use_checkpoint: true 22 | use_new_attention_order: false 23 | ckpt: "imagenet_256_uncond" 24 | 25 | -------------------------------------------------------------------------------- /_configs/model/imagenet512_cond.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | _target_: models.guided_diffusion.script_util.create_model 4 | in_channels: 3 5 | out_channels: 3 6 | num_channels: 256 7 | num_heads: 4 8 | num_res_blocks: 2 9 | attention_resolutions: "32,16,8" 10 | dropout: 0.0 11 | resamp_with_conv: True 12 | learn_sigma: True 13 | use_scale_shift_norm: true 14 | use_fp16: false 15 | resblock_updown: true 16 | num_heads_upsample: -1 17 | var_type: 'fixedsmall' 18 | num_head_channels: 64 19 | image_size: 512 20 | class_cond: True 21 | use_checkpoint: true 22 | use_new_attention_order: false 23 | ckpt: "imagenet_512_cond" -------------------------------------------------------------------------------- /_configs/nonlinear_deblur_config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | conditioning: 4 | method: ps 5 | params: 6 | scale: 0.3 7 | 8 | data: 9 | name: ffhq 10 | root: ./data/samples/ 11 | 12 | measurement: 13 | operator: 14 | name: nonlinear_blur 15 | opt_yml_path: ./bkse/options/generate_blur/default.yml 16 | 17 | noise: 18 | name: gaussian 19 | sigma: 0.05 -------------------------------------------------------------------------------- /_configs/psnr.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | defaults: 4 | - dist: localhost 5 | - dataset@dataset1: imagenet256_val 6 | - dataset@dataset2: imagenet256_val 7 | - loader: imagenet256_inception 8 | - exp: fid_stats 9 | - _self_ 10 | 11 | fid: 12 | mode: "legacy_pytorch" 13 | 14 | results: ??? 15 | save_path: ??? -------------------------------------------------------------------------------- /algos/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | from .ddim import DDIM 4 | from .ddrm import DDRM 5 | from .pgdm import PGDM 6 | from .identity import Identity 7 | from .reddiff import REDDIFF 8 | from .reddiff_parallel import REDDIFF_PARALLEL 9 | from .mcg import MCG 10 | from .dps import DPS 11 | from .sds import SDS 12 | from .sds_var import SDS_VAR 13 | 14 | 15 | def build_algo(cg_model, cfg): 16 | if cfg.algo.name == 'identity': 17 | return Identity(cg_model, cfg) 18 | elif cfg.algo.name == 'ddim': 19 | return DDIM(cg_model, cfg) 20 | elif cfg.algo.name == 'ddrm': 21 | return DDRM(cg_model, cfg) 22 | elif cfg.algo.name == 'pgdm': 23 | return PGDM(cg_model, cfg) 24 | elif cfg.algo.name == 'reddiff': 25 | return REDDIFF(cg_model, cfg) 26 | elif cfg.algo.name == 'reddiff_parallel': 27 | return REDDIFF_PARALLEL(cg_model, cfg) 28 | elif cfg.algo.name == 'mcg': 29 | return MCG(cg_model, cfg) 30 | elif cfg.algo.name == 'dps': 31 | return DPS(cg_model, cfg) 32 | elif cfg.algo.name == 'sds': 33 | return SDS(cg_model, cfg) 34 | elif cfg.algo.name == 'sds_var': 35 | return SDS_VAR(cg_model, cfg) 36 | else: 37 | raise ValueError(f'No algorithm named {cfg.algo.name}') 38 | -------------------------------------------------------------------------------- /algos/ddim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | 6 | from models.classifier_guidance_model import ClassifierGuidanceModel 7 | 8 | 9 | class DDIM: 10 | def __init__(self, model: ClassifierGuidanceModel, cfg: DictConfig): 11 | self.model = model 12 | self.diffusion = model.diffusion 13 | self.eta = cfg.algo.eta 14 | self.sdedit = cfg.algo.sdedit 15 | self.cond_awd = cfg.algo.cond_awd 16 | 17 | @torch.no_grad() 18 | def sample(self, x, y, ts, **kwargs): 19 | x = self.initialize(x, y, ts, **kwargs) 20 | n = x.size(0) 21 | ss = [-1] + list(ts[:-1]) 22 | xt_s = [x.cpu()] 23 | x0_s = [] 24 | 25 | xt = x 26 | 27 | for ti, si in zip(reversed(ts), reversed(ss)): 28 | t = torch.ones(n).to(x.device).long() * ti 29 | s = torch.ones(n).to(x.device).long() * si 30 | alpha_t = self.diffusion.alpha(t).view(-1, 1, 1, 1) 31 | alpha_s = self.diffusion.alpha(s).view(-1, 1, 1, 1) 32 | c1 = ((1 - alpha_t / alpha_s) * (1 - alpha_s) / (1 - alpha_t)).sqrt() * self.eta 33 | c2 = ((1 - alpha_s) - c1 ** 2).sqrt() 34 | if self.cond_awd: 35 | scale = alpha_s.sqrt() / (alpha_s.sqrt() - c2 * alpha_t.sqrt() / (1 - alpha_t).sqrt()) 36 | scale = scale.view(-1)[0].item() 37 | else: 38 | scale = 1.0 39 | et, x0_pred = self.model(xt, y, t, scale=scale) 40 | xs = alpha_s.sqrt() * x0_pred + c1 * torch.randn_like(xt) + c2 * et 41 | xt_s.append(xs.cpu()) 42 | x0_s.append(x0_pred.cpu()) 43 | xt = xs 44 | 45 | return list(reversed(xt_s)), list(reversed(x0_s)) 46 | 47 | def initialize(self, x, y, ts, **kwargs): 48 | if self.sdedit: 49 | n = x.size(0) 50 | ti = ts[-1] 51 | t = torch.ones(n).to(x.device).long() * ti 52 | alpha_t = self.diffusion.alpha(t).view(-1, 1, 1, 1) 53 | return x * alpha_t.sqrt() + torch.randn_like(x) * (1 - alpha_t).sqrt() 54 | else: 55 | return torch.randn_like(x) 56 | -------------------------------------------------------------------------------- /algos/deis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | 6 | from models.classifier_guidance_model import ClassifierGuidanceModel 7 | 8 | 9 | class DEIS: 10 | def __init__(self, model: ClassifierGuidanceModel, cfg: DictConfig): 11 | self.model = model 12 | self.diffusion = model.diffusion 13 | self.eta = cfg.algo.eta 14 | self.sdedit = cfg.algo.sdedit 15 | self.cond_awd = cfg.algo.cond_awd 16 | 17 | @torch.no_grad() 18 | def sample(self, x, y, ts, **kwargs): 19 | x = self.initialize(x, y, ts, **kwargs) 20 | n = x.size(0) 21 | ss = [-1] + list(ts[:-1]) 22 | xt_s = [x.cpu()] 23 | x0_s = [] 24 | 25 | xt = x 26 | 27 | for ti, si in zip(reversed(ts), reversed(ss)): 28 | t = torch.ones(n).to(x.device).long() * ti 29 | s = torch.ones(n).to(x.device).long() * si 30 | alpha_t = self.diffusion.alpha(t).view(-1, 1, 1, 1) 31 | alpha_s = self.diffusion.alpha(s).view(-1, 1, 1, 1) 32 | c1 = ((1 - alpha_t / alpha_s) * (1 - alpha_s) / (1 - alpha_t)).sqrt() * self.eta 33 | c2 = ((1 - alpha_s) - c1 ** 2).sqrt() 34 | if self.cond_awd: 35 | scale = alpha_s.sqrt() / (alpha_s.sqrt() - c2 * alpha_t.sqrt() / (1 - alpha_t).sqrt()) 36 | scale = scale.view(-1)[0].item() 37 | else: 38 | scale = 1.0 39 | et, x0_pred = self.model(xt, y, t, scale=scale) 40 | xs = alpha_s.sqrt() * x0_pred + c1 * torch.randn_like(xt) + c2 * et 41 | xt_s.append(xs.cpu()) 42 | x0_s.append(x0_pred.cpu()) 43 | xt = xs 44 | 45 | return list(reversed(xt_s)), list(reversed(x0_s)) 46 | 47 | def get_coef(self, ts, device): 48 | ss = [-1] + list(ts[:-1]) 49 | rev_ts = list(reversed(ts)) 50 | rev_ss = list(reversed(ss)) 51 | 52 | rev_ts_th = torch.tensor(rev_ts).float().to(device) 53 | rev_ss_th = torch.tensor(rev_ss).float().to(device) 54 | 55 | alpha_rev_ts = self.diffusion.alpha(rev_ts_th) 56 | alpha_rev_ss = self.diffusion.alpha(rev_ss_th) 57 | psi = alpha_rev_ss / alpha_rev_ts 58 | 59 | 60 | 61 | 62 | 63 | def initialize(self, x, y, ts, **kwargs): 64 | if self.sdedit: 65 | n = x.size(0) 66 | ti = ts[-1] 67 | t = torch.ones(n).to(x.device).long() * ti 68 | alpha_t = self.diffusion.alpha(t).view(-1, 1, 1, 1) 69 | return x * alpha_t.sqrt() + torch.randn_like(x) * (1 - alpha_t).sqrt() 70 | else: 71 | return torch.randn_like(x) 72 | -------------------------------------------------------------------------------- /algos/dps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from omegaconf import DictConfig 6 | 7 | from models.classifier_guidance_model import ClassifierGuidanceModel 8 | from utils.degredations import build_degredation_model 9 | from .ddim import DDIM 10 | 11 | 12 | class DPS(DDIM): 13 | def __init__(self, model: ClassifierGuidanceModel, cfg: DictConfig): 14 | self.model = model 15 | self.diffusion = model.diffusion 16 | self.H = build_degredation_model(cfg) 17 | self.cfg = cfg 18 | self.awd = cfg.algo.awd 19 | self.cond_awd = cfg.algo.cond_awd 20 | self.mcg = cfg.algo.mcg 21 | self.grad_term_weight = cfg.algo.grad_term_weight 22 | self.eta = cfg.algo.eta 23 | self.original = cfg.algo.original 24 | 25 | def sample(self, x, y, ts, **kwargs): 26 | y_0 = kwargs["y_0"] 27 | n = x.size(0) 28 | H = self.H 29 | 30 | x = self.initialize(x, y, ts, y_0=y_0) 31 | ss = [-1] + list(ts[:-1]) 32 | xt_s = [x.cpu()] 33 | x0_s = [] 34 | 35 | xt = x 36 | for ti, si in zip(reversed(ts), reversed(ss)): 37 | t = torch.ones(n).to(x.device).long() * ti 38 | s = torch.ones(n).to(x.device).long() * si 39 | alpha_t = self.diffusion.alpha(t).view(-1, 1, 1, 1) 40 | alpha_s = self.diffusion.alpha(s).view(-1, 1, 1, 1) 41 | c1 = ((1 - alpha_t / alpha_s) * (1 - alpha_s) / (1 - alpha_t)).sqrt() * self.eta 42 | c2 = ((1 - alpha_s) - c1 ** 2).sqrt() 43 | xt = xt.clone().to('cuda').requires_grad_(True) 44 | 45 | if self.cond_awd: 46 | scale = alpha_s.sqrt() / (alpha_s.sqrt() - c2 * alpha_t.sqrt() / (1 - alpha_t).sqrt()) 47 | scale = scale.view(-1)[0].item() 48 | else: 49 | scale = 1.0 50 | 51 | et, x0_pred = self.model(xt, y, t, scale=scale) 52 | mat_norm = ((y_0 - H.H(x0_pred)).reshape(n, -1) ** 2).sum(dim=1).sqrt().detach() 53 | mat = ((y_0 - H.H(x0_pred)).reshape(n, -1) ** 2).sum() 54 | 55 | grad_term = torch.autograd.grad(mat, xt, retain_graph=True)[0] 56 | 57 | if self.original: 58 | coeff = self.grad_term_weight / mat_norm.reshape(-1, 1, 1, 1) 59 | else: 60 | coeff = alpha_s.sqrt() * alpha_t.sqrt() # - c2 * alpha_t.sqrt() / (1 - alpha_t).sqrt() 61 | 62 | grad_term = grad_term.detach() 63 | xs = alpha_s.sqrt() * x0_pred.detach() + c1 * torch.randn_like(xt) + c2 * et.detach() - grad_term * coeff 64 | xt_s.append(xs.detach().cpu()) 65 | x0_s.append(x0_pred.detach().cpu()) 66 | xt = xs 67 | 68 | return list(reversed(xt_s)), list(reversed(x0_s)) 69 | 70 | def initialize(self, x, y, ts, **kwargs): 71 | y_0 = kwargs['y_0'] 72 | H = self.H 73 | deg = self.cfg.algo.deg 74 | n = x.size(0) 75 | x_0 = H.H_pinv(y_0).view(*x.size()).detach() 76 | ti = ts[-1] 77 | t = torch.ones(n).to(x.device).long() * ti 78 | alpha_t = self.diffusion.alpha(t).view(-1, 1, 1, 1) 79 | return alpha_t.sqrt() * x_0 + (1 - alpha_t).sqrt() * torch.randn_like(x_0) 80 | -------------------------------------------------------------------------------- /algos/identity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | from omegaconf import DictConfig 4 | 5 | from models.classifier_guidance_model import ClassifierGuidanceModel 6 | 7 | 8 | class Identity: 9 | def __init__(self, model: ClassifierGuidanceModel, cfg: DictConfig): 10 | self.model = model 11 | self.diffusion = model.diffusion 12 | 13 | def sample(self, x, y, ts, **kwargs): 14 | return [x], [] 15 | -------------------------------------------------------------------------------- /algos/mcg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from omegaconf import DictConfig 6 | 7 | from models.classifier_guidance_model import ClassifierGuidanceModel 8 | from utils.degredations import build_degredation_model 9 | from .ddim import DDIM 10 | 11 | 12 | class MCG(DDIM): 13 | def __init__(self, model: ClassifierGuidanceModel, cfg: DictConfig): 14 | self.model = model 15 | self.diffusion = model.diffusion 16 | self.H = build_degredation_model(cfg) 17 | self.cfg = cfg 18 | self.grad_term_weight = cfg.algo.grad_term_weight 19 | self.eta = cfg.algo.eta 20 | 21 | def sample(self, x, y, ts, **kwargs): 22 | y_0 = kwargs["y_0"] 23 | n = x.size(0) 24 | H = self.H 25 | 26 | x = self.initialize(x, y, ts, y_0=y_0) 27 | ss = [-1] + list(ts[:-1]) 28 | xt_s = [x.cpu()] 29 | x0_s = [] 30 | 31 | xt = x 32 | 33 | for ti, si in zip(reversed(ts), reversed(ss)): 34 | t = torch.ones(n).to(x.device).long() * ti 35 | s = torch.ones(n).to(x.device).long() * si 36 | alpha_t = self.diffusion.alpha(t).view(-1, 1, 1, 1) 37 | alpha_s = self.diffusion.alpha(s).view(-1, 1, 1, 1) 38 | c1 = ((1 - alpha_t / alpha_s) * (1 - alpha_s) / (1 - alpha_t)).sqrt() * self.eta 39 | c2 = ((1 - alpha_s) - c1 ** 2).sqrt() 40 | xt = xt.clone().to('cuda').requires_grad_(True) 41 | 42 | scale = 1.0 43 | 44 | et, x0_pred = self.model(xt, y, t, scale=scale) 45 | xs1 = alpha_s.sqrt() * x0_pred.detach() + c1 * torch.randn_like(xt) + c2 * et.detach() 46 | 47 | # mat = (H.H_pinv(y_0) - H.H_pinv(H.H(x0_pred))).detach().reshape(n, -1) 48 | 49 | # mat_x = (mat * x0_pred.reshape(n, -1)).sum() 50 | 51 | mat_x = ((H.H_pinv(y_0) - H.H_pinv(H.H(x0_pred))) ** 2).sum(dim=0).sum() 52 | print(mat_x) 53 | 54 | grad_term = torch.autograd.grad(mat_x, xt, retain_graph=True)[0] * self.grad_term_weight * alpha_t.sqrt() 55 | grad_term = grad_term.detach() * (1 - H.singulars().view(x0_pred.size())) 56 | 57 | xs2 = xs1 - grad_term 58 | ys = H.H_pinv(y_0).view(xt.size()) * alpha_s.sqrt() + (1 - alpha_s).sqrt() * torch.randn_like(xt) 59 | # x0_pred = (x0_pred + grad_term).detach() 60 | 61 | xs = xs2 * (1 - H.singulars().view(x0_pred.size())) + ys * H.singulars().view(xt.size()) 62 | xs = xs.detach() 63 | 64 | # if 'in2' in self.cfg.algo.deg: 65 | # x0_pred = x0_pred * (1 - H.singulars().view(x0_pred.size())) + (y_0 * H.singulars()).view(x0_pred.size()) 66 | 67 | # if not self.awd: 68 | # et = (xt - x0_pred * alpha_t.sqrt()) / (1 - alpha_t).sqrt() 69 | # et = et.detach() 70 | 71 | xt_s.append(xs.detach().cpu()) 72 | x0_s.append(x0_pred.detach().cpu()) 73 | xt = xs 74 | 75 | return list(reversed(xt_s)), list(reversed(x0_s)) 76 | 77 | def initialize(self, x, y, ts, **kwargs): 78 | y_0 = kwargs['y_0'] 79 | H = self.H 80 | deg = self.cfg.algo.deg 81 | n = x.size(0) 82 | x_0 = H.H_pinv(y_0).view(*x.size()).detach() 83 | ti = ts[-1] 84 | t = torch.ones(n).to(x.device).long() * ti 85 | alpha_t = self.diffusion.alpha(t).view(-1, 1, 1, 1) 86 | return alpha_t.sqrt() * x_0 + (1 - alpha_t).sqrt() * torch.randn_like(x_0) 87 | -------------------------------------------------------------------------------- /algos/pgdm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from omegaconf import DictConfig 6 | 7 | from models.classifier_guidance_model import ClassifierGuidanceModel 8 | from utils.degredations import build_degredation_model 9 | from .ddim import DDIM 10 | 11 | 12 | class PGDM(DDIM): 13 | def __init__(self, model: ClassifierGuidanceModel, cfg: DictConfig): 14 | self.model = model 15 | self.diffusion = model.diffusion 16 | self.H = build_degredation_model(cfg) 17 | self.cfg = cfg 18 | self.awd = cfg.algo.awd 19 | self.cond_awd = cfg.algo.cond_awd 20 | self.mcg = cfg.algo.mcg 21 | self.grad_term_weight = cfg.algo.grad_term_weight 22 | self.eta = cfg.algo.eta 23 | 24 | def sample(self, x, y, ts, **kwargs): 25 | y_0 = kwargs["y_0"] 26 | n = x.size(0) 27 | H = self.H 28 | 29 | x = self.initialize(x, y, ts, y_0=y_0) 30 | ss = [-1] + list(ts[:-1]) 31 | xt_s = [x.cpu()] 32 | x0_s = [] 33 | 34 | xt = x 35 | tot = 0 36 | for ti, si in zip(reversed(ts), reversed(ss)): 37 | t = torch.ones(n).to(x.device).long() * ti 38 | s = torch.ones(n).to(x.device).long() * si 39 | alpha_t = self.diffusion.alpha(t).view(-1, 1, 1, 1) 40 | alpha_s = self.diffusion.alpha(s).view(-1, 1, 1, 1) 41 | c1 = ((1 - alpha_t / alpha_s) * (1 - alpha_s) / (1 - alpha_t)).sqrt() * self.eta 42 | c2 = ((1 - alpha_s) - c1 ** 2).sqrt() 43 | xt = xt.clone().to('cuda').requires_grad_(True) 44 | 45 | if self.cond_awd: 46 | scale = alpha_s.sqrt() / (alpha_s.sqrt() - c2 * alpha_t.sqrt() / (1 - alpha_t).sqrt()) 47 | scale = scale.view(-1)[0].item() 48 | else: 49 | scale = 1.0 50 | 51 | et, x0_pred = self.model(xt, y, t, scale=scale) 52 | mat = (H.H_pinv(y_0) - H.H_pinv(H.H(x0_pred))).reshape(n, -1) 53 | 54 | 55 | mat_x = (mat.detach() * x0_pred.reshape(n, -1)).sum() 56 | if self.cfg.algo.deg == "hdr": 57 | contrast = torch.std(x0_pred.view(n, -1), dim=1).sum() 58 | #print(mat_x, contrast) 59 | mat_x = mat_x + contrast * 1500 60 | 61 | # sigma_t = (1 - alpha_t).sqrt() / alpha_t.sqrt() 62 | # f = lambda x: torch.tanh(x) 63 | grad_term = torch.autograd.grad(mat_x, xt, retain_graph=True)[0] 64 | 65 | # g2 = (grad_term ** 2).sum().sqrt().item() 66 | 67 | grad_term = grad_term.detach() 68 | # * alpha_t.sqrt() # * self.grad_term_weight * alpha_t.sqrt() # (sigma_t / f(sigma_t)) ** 2 * alpha_t.sqrt() 69 | 70 | coeff = alpha_s.sqrt() 71 | if not self.awd: 72 | coeff = coeff - c2 * alpha_t.sqrt() / (1 - alpha_t).sqrt() 73 | coeff = coeff * alpha_t.sqrt() * self.grad_term_weight 74 | 75 | if self.mcg: 76 | coeff = alpha_t.sqrt() * self.grad_term_weight 77 | # coeff = torch.ones_like(alpha_s) * 1.0 / g2 * 5.0 78 | 79 | # tot += (grad_term ** 2).sum().sqrt().item() * coeff.item() 80 | 81 | # print(f'{(mat ** 2).sum().item():.4f}\t{((sigma_t / f(sigma_t)) ** 2 * alpha_t.sqrt()).item():.4f}\t{(grad_term ** 2).sum().sqrt().item() * coeff.item():.4f}\t{coeff.item() * alpha_t.sqrt().item():.6f}\t{g2:.4f}') 82 | # grad_term2 = torch.autograd.grad((mat ** 2).sum(), xt, retain_graph=True)[0] * alpha_t.sqrt() * self.grad_term_weight 83 | # import ipdb; ipdb.set_trace() 84 | 85 | # x0_pred = (x0_pred + grad_term).detach() 86 | x0_pred = x0_pred.detach() 87 | 88 | if 'in2' in self.cfg.algo.deg: 89 | x0_pred = x0_pred * (1 - H.singulars().view(x0_pred.size())) + (y_0 * H.singulars()).view(x0_pred.size()) 90 | grad_term = grad_term * (1 - H.singulars().view(grad_term.size())) 91 | 92 | if not self.awd: 93 | et = (xt - x0_pred * alpha_t.sqrt()) / (1 - alpha_t).sqrt() 94 | et = et.detach() 95 | 96 | xs = alpha_s.sqrt() * x0_pred + c1 * torch.randn_like(xt) + c2 * et + grad_term * coeff 97 | xt_s.append(xs.detach().cpu()) 98 | x0_s.append(x0_pred.detach().cpu()) 99 | xt = xs 100 | # print(f'tot: {tot:.4f}') 101 | return list(reversed(xt_s)), list(reversed(x0_s)) 102 | 103 | def initialize(self, x, y, ts, **kwargs): 104 | y_0 = kwargs['y_0'] 105 | H = self.H 106 | deg = self.cfg.algo.deg 107 | n = x.size(0) 108 | x_0 = H.H_pinv(y_0).view(*x.size()).detach() 109 | ti = ts[-1] 110 | t = torch.ones(n).to(x.device).long() * ti 111 | alpha_t = self.diffusion.alpha(t).view(-1, 1, 1, 1) 112 | return alpha_t.sqrt() * x_0 + (1 - alpha_t).sqrt() * torch.randn_like(x_0) 113 | -------------------------------------------------------------------------------- /bkse/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.png 4 | *.jpg 5 | !imgs/* 6 | !imgs/** 7 | experiments/pretrained/* 8 | !experiments/pretrained/kernel.pth 9 | results/* 10 | datasets/* 11 | -------------------------------------------------------------------------------- /bkse/data/__init__.py: -------------------------------------------------------------------------------- 1 | """create dataset and dataloader""" 2 | import logging 3 | 4 | import torch 5 | import torch.utils.data 6 | 7 | 8 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 9 | phase = dataset_opt["phase"] 10 | if phase == "train": 11 | if opt["dist"]: 12 | world_size = torch.distributed.get_world_size() 13 | num_workers = dataset_opt["n_workers"] 14 | assert dataset_opt["batch_size"] % world_size == 0 15 | batch_size = dataset_opt["batch_size"] // world_size 16 | shuffle = False 17 | else: 18 | num_workers = dataset_opt["n_workers"] * len(opt["gpu_ids"]) 19 | batch_size = dataset_opt["batch_size"] 20 | shuffle = True 21 | return torch.utils.data.DataLoader( 22 | dataset, 23 | batch_size=batch_size, 24 | shuffle=shuffle, 25 | num_workers=num_workers, 26 | sampler=sampler, 27 | drop_last=True, 28 | pin_memory=False, 29 | ) 30 | else: 31 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=False) 32 | 33 | 34 | def create_dataset(dataset_opt): 35 | mode = dataset_opt["mode"] 36 | # datasets for image restoration 37 | if mode == "REDS": 38 | from data.REDS_dataset import REDSDataset as D 39 | elif mode == "GOPRO": 40 | from data.GOPRO_dataset import GOPRODataset as D 41 | elif mode == "fewshot": 42 | from data.fewshot_dataset import FewShotDataset as D 43 | elif mode == "levin": 44 | from data.levin_dataset import LevinDataset as D 45 | elif mode == "mix": 46 | from data.mix_dataset import MixDataset as D 47 | else: 48 | raise NotImplementedError(f"Dataset {mode} is not recognized.") 49 | dataset = D(dataset_opt) 50 | 51 | logger = logging.getLogger("base") 52 | logger.info("Dataset [{:s} - {:s}] is created.".format(dataset.__class__.__name__, dataset_opt["name"])) 53 | return dataset 54 | -------------------------------------------------------------------------------- /bkse/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from torch.utils.data.distributed.DistributedSampler 3 | Support enlarging the dataset for *iteration-oriented* training, 4 | for saving time when restart the dataloader after each epoch 5 | """ 6 | import math 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch.utils.data.sampler import Sampler 11 | 12 | 13 | class DistIterSampler(Sampler): 14 | """Sampler that restricts data loading to a subset of the dataset. 15 | 16 | It is especially useful in conjunction with 17 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 18 | process can pass a DistributedSampler instance as a DataLoader sampler, 19 | and load a subset of the original dataset that is exclusive to it. 20 | 21 | .. note:: 22 | Dataset is assumed to be of constant size. 23 | 24 | Arguments: 25 | dataset: Dataset used for sampling. 26 | num_replicas (optional): Number of processes participating in 27 | distributed training. 28 | rank (optional): Rank of the current process within num_replicas. 29 | """ 30 | 31 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): 32 | if num_replicas is None: 33 | if not dist.is_available(): 34 | raise RuntimeError( 35 | "Requires distributed \ 36 | package to be available" 37 | ) 38 | num_replicas = dist.get_world_size() 39 | if rank is None: 40 | if not dist.is_available(): 41 | raise RuntimeError( 42 | "Requires distributed \ 43 | package to be available" 44 | ) 45 | rank = dist.get_rank() 46 | self.dataset = dataset 47 | self.num_replicas = num_replicas 48 | self.rank = rank 49 | self.epoch = 0 50 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) 51 | self.total_size = self.num_samples * self.num_replicas 52 | 53 | def __iter__(self): 54 | # deterministically shuffle based on epoch 55 | g = torch.Generator() 56 | g.manual_seed(self.epoch) 57 | indices = torch.randperm(self.total_size, generator=g).tolist() 58 | 59 | dsize = len(self.dataset) 60 | indices = [v % dsize for v in indices] 61 | 62 | # subsample 63 | indices = indices[self.rank : self.total_size : self.num_replicas] 64 | assert len(indices) == self.num_samples 65 | 66 | return iter(indices) 67 | 68 | def __len__(self): 69 | return self.num_samples 70 | 71 | def set_epoch(self, epoch): 72 | self.epoch = epoch 73 | -------------------------------------------------------------------------------- /bkse/data/mix_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mix dataset 3 | support reading images from lmdb 4 | """ 5 | import logging 6 | import random 7 | 8 | import data.util as util 9 | import lmdb 10 | import numpy as np 11 | import torch 12 | import torch.utils.data as data 13 | 14 | 15 | logger = logging.getLogger("base") 16 | 17 | 18 | class MixDataset(data.Dataset): 19 | """ 20 | Reading the training REDS dataset 21 | key example: 000_00000000 22 | HQ: Ground-Truth; 23 | LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames 24 | support reading N LQ frames, N = 1, 3, 5, 7 25 | """ 26 | 27 | def __init__(self, opt): 28 | super(MixDataset, self).__init__() 29 | self.opt = opt 30 | # temporal augmentation 31 | 32 | self.HQ_roots = opt["dataroots_HQ"] 33 | self.LQ_roots = opt["dataroots_LQ"] 34 | self.use_identical = opt["identical_loss"] 35 | dataset_weights = opt["dataset_weights"] 36 | self.data_type = "lmdb" 37 | # directly load image keys 38 | self.HQ_envs, self.LQ_envs = None, None 39 | self.paths_HQ = [] 40 | for idx, (HQ_root, LQ_root) in enumerate(zip(self.HQ_roots, self.LQ_roots)): 41 | paths_HQ, _ = util.get_image_paths(self.data_type, HQ_root) 42 | self.paths_HQ += list(zip([idx] * len(paths_HQ), paths_HQ)) * dataset_weights[idx] 43 | random.shuffle(self.paths_HQ) 44 | logger.info("Using lmdb meta info for cache keys.") 45 | 46 | def _init_lmdb(self): 47 | self.HQ_envs, self.LQ_envs = [], [] 48 | for HQ_root, LQ_root in zip(self.HQ_roots, self.LQ_roots): 49 | self.HQ_envs.append(lmdb.open(HQ_root, readonly=True, lock=False, readahead=False, meminit=False)) 50 | self.LQ_envs.append(lmdb.open(LQ_root, readonly=True, lock=False, readahead=False, meminit=False)) 51 | 52 | def __getitem__(self, index): 53 | if self.HQ_envs is None: 54 | self._init_lmdb() 55 | 56 | HQ_size = self.opt["HQ_size"] 57 | env_idx, key = self.paths_HQ[index] 58 | name_a, name_b = key.split("_") 59 | target_frame_idx = int(name_b) 60 | 61 | # determine the neighbor frames 62 | # ensure not exceeding the borders 63 | neighbor_list = [target_frame_idx] 64 | name_b = "{:08d}".format(neighbor_list[0]) 65 | 66 | # get the HQ image (as the center frame) 67 | img_HQ_l = [] 68 | for v in neighbor_list: 69 | img_HQ = util.read_img(self.HQ_envs[env_idx], "{}_{:08d}".format(name_a, v), (3, 720, 1280)) 70 | img_HQ_l.append(img_HQ) 71 | 72 | # get LQ images 73 | img_LQ = util.read_img(self.LQ_envs[env_idx], "{}_{:08d}".format(name_a, neighbor_list[-1]), (3, 720, 1280)) 74 | if self.opt["phase"] == "train": 75 | _, H, W = 3, 720, 1280 # LQ size 76 | # randomly crop 77 | rnd_h = random.randint(0, max(0, H - HQ_size)) 78 | rnd_w = random.randint(0, max(0, W - HQ_size)) 79 | img_LQ = img_LQ[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] 80 | img_HQ_l = [v[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] for v in img_HQ_l] 81 | 82 | # augmentation - flip, rotate 83 | img_HQ_l.append(img_LQ) 84 | rlt = util.augment(img_HQ_l, self.opt["use_flip"], self.opt["use_rot"]) 85 | img_HQ_l = rlt[0:-1] 86 | img_LQ = rlt[-1] 87 | 88 | # stack LQ images to NHWC, N is the frame number 89 | img_HQs = np.stack(img_HQ_l, axis=0) 90 | # BGR to RGB, HWC to CHW, numpy to tensor 91 | img_LQ = img_LQ[:, :, [2, 1, 0]] 92 | img_HQs = img_HQs[:, :, :, [2, 1, 0]] 93 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 94 | img_HQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HQs, (0, 3, 1, 2)))).float() 95 | # print(img_LQ.shape, img_HQs.shape) 96 | 97 | if self.use_identical and np.random.randint(0, 10) == 0: 98 | img_LQ = img_HQs[-1, :, :, :] 99 | return {"LQ": img_LQ, "HQs": img_HQs, "identical_w": 10} 100 | 101 | return {"LQ": img_LQ, "HQs": img_HQs, "identical_w": 0} 102 | 103 | def __len__(self): 104 | return len(self.paths_HQ) 105 | -------------------------------------------------------------------------------- /bkse/domain_specific_deblur.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from math import ceil, log10 3 | from pathlib import Path 4 | 5 | import torchvision 6 | import yaml 7 | from PIL import Image 8 | from torch.nn import DataParallel 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | 12 | class Images(Dataset): 13 | def __init__(self, root_dir, duplicates): 14 | self.root_path = Path(root_dir) 15 | self.image_list = list(self.root_path.glob("*.png")) 16 | self.duplicates = ( 17 | duplicates # Number of times to duplicate the image in the dataset to produce multiple HR images 18 | ) 19 | 20 | def __len__(self): 21 | return self.duplicates * len(self.image_list) 22 | 23 | def __getitem__(self, idx): 24 | img_path = self.image_list[idx // self.duplicates] 25 | image = torchvision.transforms.ToTensor()(Image.open(img_path)) 26 | if self.duplicates == 1: 27 | return image, img_path.stem 28 | else: 29 | return image, img_path.stem + f"_{(idx % self.duplicates)+1}" 30 | 31 | 32 | parser = argparse.ArgumentParser(description="PULSE") 33 | 34 | # I/O arguments 35 | parser.add_argument("--input_dir", type=str, default="imgs/blur_faces", help="input data directory") 36 | parser.add_argument( 37 | "--output_dir", type=str, default="experiments/domain_specific_deblur/results", help="output data directory" 38 | ) 39 | parser.add_argument( 40 | "--cache_dir", 41 | type=str, 42 | default="experiments/domain_specific_deblur/cache", 43 | help="cache directory for model weights", 44 | ) 45 | parser.add_argument( 46 | "--yml_path", type=str, default="options/domain_specific_deblur/stylegan2.yml", help="configuration file" 47 | ) 48 | 49 | kwargs = vars(parser.parse_args()) 50 | 51 | with open(kwargs["yml_path"], "rb") as f: 52 | opt = yaml.safe_load(f) 53 | 54 | dataset = Images(kwargs["input_dir"], duplicates=opt["duplicates"]) 55 | out_path = Path(kwargs["output_dir"]) 56 | out_path.mkdir(parents=True, exist_ok=True) 57 | 58 | dataloader = DataLoader(dataset, batch_size=opt["batch_size"]) 59 | 60 | if opt["stylegan_ver"] == 1: 61 | from models.dsd.dsd_stylegan import DSDStyleGAN 62 | 63 | model = DSDStyleGAN(opt=opt, cache_dir=kwargs["cache_dir"]) 64 | else: 65 | from models.dsd.dsd_stylegan2 import DSDStyleGAN2 66 | 67 | model = DSDStyleGAN2(opt=opt, cache_dir=kwargs["cache_dir"]) 68 | 69 | model = DataParallel(model) 70 | 71 | toPIL = torchvision.transforms.ToPILImage() 72 | 73 | for ref_im, ref_im_name in dataloader: 74 | if opt["save_intermediate"]: 75 | padding = ceil(log10(100)) 76 | for i in range(opt["batch_size"]): 77 | int_path_HR = Path(out_path / ref_im_name[i] / "HR") 78 | int_path_LR = Path(out_path / ref_im_name[i] / "LR") 79 | int_path_HR.mkdir(parents=True, exist_ok=True) 80 | int_path_LR.mkdir(parents=True, exist_ok=True) 81 | for j, (HR, LR) in enumerate(model(ref_im)): 82 | for i in range(opt["batch_size"]): 83 | toPIL(HR[i].cpu().detach().clamp(0, 1)).save(int_path_HR / f"{ref_im_name[i]}_{j:0{padding}}.png") 84 | toPIL(LR[i].cpu().detach().clamp(0, 1)).save(int_path_LR / f"{ref_im_name[i]}_{j:0{padding}}.png") 85 | else: 86 | # out_im = model(ref_im,**kwargs) 87 | for j, (HR, LR) in enumerate(model(ref_im)): 88 | for i in range(opt["batch_size"]): 89 | toPIL(HR[i].cpu().detach().clamp(0, 1)).save(out_path / f"{ref_im_name[i]}.png") 90 | -------------------------------------------------------------------------------- /bkse/experiments/pretrained/kernel.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/experiments/pretrained/kernel.pth -------------------------------------------------------------------------------- /bkse/generate_blur.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import os.path as osp 6 | import torch 7 | import utils.util as util 8 | import yaml 9 | from models.kernel_encoding.kernel_wizard import KernelWizard 10 | 11 | 12 | def main(): 13 | device = torch.device("cuda") 14 | 15 | parser = argparse.ArgumentParser(description="Kernel extractor testing") 16 | 17 | parser.add_argument("--image_path", action="store", help="image path", type=str, required=True) 18 | parser.add_argument("--yml_path", action="store", help="yml path", type=str, required=True) 19 | parser.add_argument("--save_path", action="store", help="save path", type=str, default=".") 20 | parser.add_argument("--num_samples", action="store", help="number of samples", type=int, default=1) 21 | 22 | args = parser.parse_args() 23 | 24 | image_path = args.image_path 25 | yml_path = args.yml_path 26 | num_samples = args.num_samples 27 | 28 | # Initializing mode 29 | with open(yml_path, "r") as f: 30 | opt = yaml.load(f)["KernelWizard"] 31 | model_path = opt["pretrained"] 32 | model = KernelWizard(opt) 33 | model.eval() 34 | model.load_state_dict(torch.load(model_path)) 35 | model = model.to(device) 36 | 37 | HQ = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) / 255.0 38 | HQ = np.transpose(HQ, (2, 0, 1)) 39 | HQ_tensor = torch.Tensor(HQ).unsqueeze(0).to(device).cuda() 40 | 41 | for i in range(num_samples): 42 | print(f"Sample #{i}/{num_samples}") 43 | with torch.no_grad(): 44 | kernel = torch.randn((1, 512, 2, 2)).cuda() * 1.2 45 | LQ_tensor = model.adaptKernel(HQ_tensor, kernel) 46 | 47 | dst = osp.join(args.save_path, f"blur{i:03d}.png") 48 | LQ_img = util.tensor2img(LQ_tensor) 49 | 50 | cv2.imwrite(dst, LQ_img) 51 | 52 | 53 | main() 54 | -------------------------------------------------------------------------------- /bkse/generic_deblur.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import yaml 5 | from models.deblurring.joint_deblur import JointDeblur 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser(description="Kernel extractor testing") 10 | 11 | parser.add_argument("--image_path", action="store", help="image path", type=str, required=True) 12 | parser.add_argument("--save_path", action="store", help="save path", type=str, default="res.png") 13 | parser.add_argument("--yml_path", action="store", help="yml path", type=str, required=True) 14 | 15 | args = parser.parse_args() 16 | 17 | # Initializing mode 18 | with open(args.yml_path, "rb") as f: 19 | opt = yaml.safe_load(f) 20 | model = JointDeblur(opt) 21 | 22 | blur_img = cv2.cvtColor(cv2.imread(args.image_path), cv2.COLOR_BGR2RGB) 23 | sharp_img = model.deblur(blur_img) 24 | 25 | cv2.imwrite(args.save_path, sharp_img) 26 | 27 | 28 | main() 29 | -------------------------------------------------------------------------------- /bkse/imgs/blur_faces/face01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/imgs/blur_faces/face01.png -------------------------------------------------------------------------------- /bkse/imgs/blur_imgs/blur1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/imgs/blur_imgs/blur1.png -------------------------------------------------------------------------------- /bkse/imgs/blur_imgs/blur2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/imgs/blur_imgs/blur2.png -------------------------------------------------------------------------------- /bkse/imgs/results/augmentation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/imgs/results/augmentation.jpg -------------------------------------------------------------------------------- /bkse/imgs/results/domain_specific_deblur.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/imgs/results/domain_specific_deblur.jpg -------------------------------------------------------------------------------- /bkse/imgs/results/general_deblurring.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/imgs/results/general_deblurring.jpg -------------------------------------------------------------------------------- /bkse/imgs/results/generate_blur.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/imgs/results/generate_blur.jpg -------------------------------------------------------------------------------- /bkse/imgs/results/kernel_encoding_wGT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/imgs/results/kernel_encoding_wGT.png -------------------------------------------------------------------------------- /bkse/imgs/sharp_imgs/mushishi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/imgs/sharp_imgs/mushishi.png -------------------------------------------------------------------------------- /bkse/imgs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/imgs/teaser.jpg -------------------------------------------------------------------------------- /bkse/models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | logger = logging.getLogger("base") 5 | 6 | 7 | def create_model(opt): 8 | model = opt["model"] 9 | if model == "image_base": 10 | from models.kernel_encoding.image_base_model import ImageBaseModel as M 11 | else: 12 | raise NotImplementedError("Model [{:s}] not recognized.".format(model)) 13 | m = M(opt) 14 | logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) 15 | return m 16 | -------------------------------------------------------------------------------- /bkse/models/arch_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | 7 | class Identity(nn.Module): 8 | def forward(self, x): 9 | return x 10 | 11 | 12 | def get_norm_layer(norm_type="instance"): 13 | """Return a normalization layer 14 | Parameters: 15 | norm_type (str) -- the name of the normalization 16 | layer: batch | instance | none 17 | 18 | For BatchNorm, we use learnable affine parameters and 19 | track running statistics (mean/stddev). 20 | 21 | For InstanceNorm, we do not use learnable affine 22 | parameters. We do not track running statistics. 23 | """ 24 | if norm_type == "batch": 25 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 26 | elif norm_type == "instance": 27 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 28 | elif norm_type == "none": 29 | 30 | def norm_layer(x): 31 | return Identity() 32 | 33 | else: 34 | raise NotImplementedError( 35 | f"normalization layer {norm_type}\ 36 | is not found" 37 | ) 38 | return norm_layer 39 | 40 | 41 | def initialize_weights(net_l, scale=1): 42 | if not isinstance(net_l, list): 43 | net_l = [net_l] 44 | for net in net_l: 45 | for m in net.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | init.kaiming_normal_(m.weight, a=0, mode="fan_in") 48 | m.weight.data *= scale # for residual block 49 | if m.bias is not None: 50 | m.bias.data.zero_() 51 | elif isinstance(m, nn.Linear): 52 | init.kaiming_normal_(m.weight, a=0, mode="fan_in") 53 | m.weight.data *= scale 54 | if m.bias is not None: 55 | m.bias.data.zero_() 56 | elif isinstance(m, nn.BatchNorm2d): 57 | init.constant_(m.weight, 1) 58 | init.constant_(m.bias.data, 0.0) 59 | -------------------------------------------------------------------------------- /bkse/models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from bkse.models.arch_util import initialize_weights 4 | 5 | 6 | class ResnetBlock(nn.Module): 7 | """Define a Resnet block""" 8 | 9 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 10 | """Initialize the Resnet block 11 | A resnet block is a conv block with skip connections 12 | We construct a conv block with build_conv_block function, 13 | and implement skip connections in function. 14 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 15 | """ 16 | super(ResnetBlock, self).__init__() 17 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 18 | 19 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 20 | """Construct a convolutional block. 21 | Parameters: 22 | dim (int) -- the number of channels in the conv layer. 23 | padding_type (str) -- the name of padding 24 | layer: reflect | replicate | zero 25 | norm_layer -- normalization layer 26 | use_dropout (bool) -- if use dropout layers. 27 | use_bias (bool) -- if the conv layer uses bias or not 28 | Returns a conv block (with a conv layer, a normalization layer, 29 | and a non-linearity layer (ReLU)) 30 | """ 31 | conv_block = [] 32 | p = 0 33 | if padding_type == "reflect": 34 | conv_block += [nn.ReflectionPad2d(1)] 35 | elif padding_type == "replicate": 36 | conv_block += [nn.ReplicationPad2d(1)] 37 | elif padding_type == "zero": 38 | p = 1 39 | else: 40 | raise NotImplementedError( 41 | f"padding {padding_type} \ 42 | is not implemented" 43 | ) 44 | 45 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 46 | if use_dropout: 47 | conv_block += [nn.Dropout(0.5)] 48 | 49 | p = 0 50 | if padding_type == "reflect": 51 | conv_block += [nn.ReflectionPad2d(1)] 52 | elif padding_type == "replicate": 53 | conv_block += [nn.ReplicationPad2d(1)] 54 | elif padding_type == "zero": 55 | p = 1 56 | else: 57 | raise NotImplementedError( 58 | f"padding {padding_type} \ 59 | is not implemented" 60 | ) 61 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] 62 | 63 | return nn.Sequential(*conv_block) 64 | 65 | def forward(self, x): 66 | """Forward function (with skip connections)""" 67 | out = x + self.conv_block(x) # add skip connections 68 | return out 69 | 70 | 71 | class ResidualBlock_noBN(nn.Module): 72 | """Residual block w/o BN 73 | ---Conv-ReLU-Conv-+- 74 | |________________| 75 | """ 76 | 77 | def __init__(self, nf=64): 78 | super(ResidualBlock_noBN, self).__init__() 79 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 80 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 81 | 82 | # initialization 83 | initialize_weights([self.conv1, self.conv2], 0.1) 84 | 85 | def forward(self, x): 86 | identity = x 87 | out = F.relu(self.conv1(x), inplace=False) 88 | out = self.conv2(out) 89 | return identity + out 90 | -------------------------------------------------------------------------------- /bkse/models/backbones/skip/concat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Concat(nn.Module): 7 | def __init__(self, dim, *args): 8 | super(Concat, self).__init__() 9 | self.dim = dim 10 | 11 | for idx, module in enumerate(args): 12 | self.add_module(str(idx), module) 13 | 14 | def forward(self, input): 15 | inputs = [] 16 | for module in self._modules.values(): 17 | inputs.append(module(input)) 18 | 19 | inputs_shapes2 = [x.shape[2] for x in inputs] 20 | inputs_shapes3 = [x.shape[3] for x in inputs] 21 | 22 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all( 23 | np.array(inputs_shapes3) == min(inputs_shapes3) 24 | ): 25 | inputs_ = inputs 26 | else: 27 | target_shape2 = min(inputs_shapes2) 28 | target_shape3 = min(inputs_shapes3) 29 | 30 | inputs_ = [] 31 | for inp in inputs: 32 | diff2 = (inp.size(2) - target_shape2) // 2 33 | diff3 = (inp.size(3) - target_shape3) // 2 34 | inputs_.append(inp[:, :, diff2 : diff2 + target_shape2, diff3 : diff3 + target_shape3]) 35 | 36 | return torch.cat(inputs_, dim=self.dim) 37 | 38 | def __len__(self): 39 | return len(self._modules) 40 | -------------------------------------------------------------------------------- /bkse/models/backbones/skip/non_local_dot_product.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class _NonLocalBlockND(nn.Module): 6 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 7 | super(_NonLocalBlockND, self).__init__() 8 | 9 | assert dimension in [1, 2, 3] 10 | 11 | self.dimension = dimension 12 | self.sub_sample = sub_sample 13 | 14 | self.in_channels = in_channels 15 | self.inter_channels = inter_channels 16 | 17 | if self.inter_channels is None: 18 | self.inter_channels = in_channels // 2 19 | if self.inter_channels == 0: 20 | self.inter_channels = 1 21 | 22 | if dimension == 3: 23 | conv_nd = nn.Conv3d 24 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 25 | bn = nn.BatchNorm3d 26 | elif dimension == 2: 27 | conv_nd = nn.Conv2d 28 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 29 | bn = nn.BatchNorm2d 30 | else: 31 | conv_nd = nn.Conv1d 32 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 33 | bn = nn.BatchNorm1d 34 | 35 | self.g = conv_nd( 36 | in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 37 | ) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd( 42 | in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 43 | ), 44 | bn(self.in_channels), 45 | ) 46 | nn.init.constant_(self.W[1].weight, 0) 47 | nn.init.constant_(self.W[1].bias, 0) 48 | else: 49 | self.W = conv_nd( 50 | in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 51 | ) 52 | nn.init.constant_(self.W.weight, 0) 53 | nn.init.constant_(self.W.bias, 0) 54 | 55 | self.theta = conv_nd( 56 | in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 57 | ) 58 | 59 | self.phi = conv_nd( 60 | in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 61 | ) 62 | 63 | if sub_sample: 64 | self.g = nn.Sequential(self.g, max_pool_layer) 65 | self.phi = nn.Sequential(self.phi, max_pool_layer) 66 | 67 | def forward(self, x): 68 | """ 69 | :param x: (b, c, t, h, w) 70 | :return: 71 | """ 72 | 73 | batch_size = x.size(0) 74 | 75 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 76 | g_x = g_x.permute(0, 2, 1) 77 | 78 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 79 | theta_x = theta_x.permute(0, 2, 1) 80 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 81 | f = torch.matmul(theta_x, phi_x) 82 | N = f.size(-1) 83 | f_div_C = f / N 84 | 85 | y = torch.matmul(f_div_C, g_x) 86 | y = y.permute(0, 2, 1).contiguous() 87 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 88 | W_y = self.W(y) 89 | z = W_y + x 90 | 91 | return z 92 | 93 | 94 | class NONLocalBlock1D(_NonLocalBlockND): 95 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 96 | super(NONLocalBlock1D, self).__init__( 97 | in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer 98 | ) 99 | 100 | 101 | class NONLocalBlock2D(_NonLocalBlockND): 102 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 103 | super(NONLocalBlock2D, self).__init__( 104 | in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer 105 | ) 106 | 107 | 108 | class NONLocalBlock3D(_NonLocalBlockND): 109 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 110 | super(NONLocalBlock3D, self).__init__( 111 | in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer 112 | ) 113 | 114 | 115 | if __name__ == "__main__": 116 | for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]: 117 | img = torch.zeros(2, 3, 20) 118 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 119 | out = net(img) 120 | print(out.size()) 121 | 122 | img = torch.zeros(2, 3, 20, 20) 123 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 124 | out = net(img) 125 | print(out.size()) 126 | 127 | img = torch.randn(2, 3, 8, 20, 20) 128 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 129 | out = net(img) 130 | print(out.size()) 131 | -------------------------------------------------------------------------------- /bkse/models/backbones/skip/skip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .concat import Concat 5 | from .non_local_dot_product import NONLocalBlock2D 6 | from .util import get_activation, get_conv 7 | 8 | 9 | def add_module(self, module): 10 | self.add_module(str(len(self) + 1), module) 11 | 12 | 13 | torch.nn.Module.add = add_module 14 | 15 | 16 | def skip( 17 | num_input_channels=2, 18 | num_output_channels=3, 19 | num_channels_down=[16, 32, 64, 128, 128], 20 | num_channels_up=[16, 32, 64, 128, 128], 21 | num_channels_skip=[4, 4, 4, 4, 4], 22 | filter_size_down=3, 23 | filter_size_up=3, 24 | filter_skip_size=1, 25 | need_sigmoid=True, 26 | need_bias=True, 27 | pad="zero", 28 | upsample_mode="nearest", 29 | downsample_mode="stride", 30 | act_fun="LeakyReLU", 31 | need1x1_up=True, 32 | ): 33 | """Assembles encoder-decoder with skip connections. 34 | 35 | Arguments: 36 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 37 | pad (string): zero|reflection (default: 'zero') 38 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 39 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') 40 | 41 | """ 42 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 43 | 44 | n_scales = len(num_channels_down) 45 | 46 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)): 47 | upsample_mode = [upsample_mode] * n_scales 48 | 49 | if not (isinstance(downsample_mode, list) or isinstance(downsample_mode, tuple)): 50 | downsample_mode = [downsample_mode] * n_scales 51 | 52 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)): 53 | filter_size_down = [filter_size_down] * n_scales 54 | 55 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)): 56 | filter_size_up = [filter_size_up] * n_scales 57 | 58 | last_scale = n_scales - 1 59 | 60 | model = nn.Sequential() 61 | model_tmp = model 62 | 63 | input_depth = num_input_channels 64 | for i in range(len(num_channels_down)): 65 | 66 | deeper = nn.Sequential() 67 | skip = nn.Sequential() 68 | 69 | if num_channels_skip[i] != 0: 70 | model_tmp.add(Concat(1, skip, deeper)) 71 | else: 72 | model_tmp.add(deeper) 73 | 74 | model_tmp.add( 75 | nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i])) 76 | ) 77 | 78 | if num_channels_skip[i] != 0: 79 | skip.add(get_conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) 80 | skip.add(nn.BatchNorm2d(num_channels_skip[i])) 81 | skip.add(get_activation(act_fun)) 82 | 83 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) 84 | 85 | deeper.add( 86 | get_conv( 87 | input_depth, 88 | num_channels_down[i], 89 | filter_size_down[i], 90 | 2, 91 | bias=need_bias, 92 | pad=pad, 93 | downsample_mode=downsample_mode[i], 94 | ) 95 | ) 96 | deeper.add(nn.BatchNorm2d(num_channels_down[i])) 97 | deeper.add(get_activation(act_fun)) 98 | if i > 1: 99 | deeper.add(NONLocalBlock2D(in_channels=num_channels_down[i])) 100 | deeper.add(get_conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 101 | deeper.add(nn.BatchNorm2d(num_channels_down[i])) 102 | deeper.add(get_activation(act_fun)) 103 | 104 | deeper_main = nn.Sequential() 105 | 106 | if i == len(num_channels_down) - 1: 107 | # The deepest 108 | k = num_channels_down[i] 109 | else: 110 | deeper.add(deeper_main) 111 | k = num_channels_up[i + 1] 112 | 113 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) 114 | 115 | model_tmp.add( 116 | get_conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad) 117 | ) 118 | model_tmp.add(nn.BatchNorm2d(num_channels_up[i])) 119 | model_tmp.add(get_activation(act_fun)) 120 | 121 | if need1x1_up: 122 | model_tmp.add(get_conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 123 | model_tmp.add(nn.BatchNorm2d(num_channels_up[i])) 124 | model_tmp.add(get_activation(act_fun)) 125 | 126 | input_depth = num_channels_down[i] 127 | model_tmp = deeper_main 128 | 129 | model.add(get_conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 130 | if need_sigmoid: 131 | model.add(nn.Sigmoid()) 132 | 133 | return model 134 | -------------------------------------------------------------------------------- /bkse/models/backbones/skip/util.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .downsampler import Downsampler 4 | 5 | 6 | class Swish(nn.Module): 7 | """ 8 | https://arxiv.org/abs/1710.05941 9 | The hype was so huge that I could not help but try it 10 | """ 11 | 12 | def __init__(self): 13 | super(Swish, self).__init__() 14 | self.s = nn.Sigmoid() 15 | 16 | def forward(self, x): 17 | return x * self.s(x) 18 | 19 | 20 | def get_conv(in_f, out_f, kernel_size, stride=1, bias=True, pad="zero", downsample_mode="stride"): 21 | downsampler = None 22 | if stride != 1 and downsample_mode != "stride": 23 | 24 | if downsample_mode == "avg": 25 | downsampler = nn.AvgPool2d(stride, stride) 26 | elif downsample_mode == "max": 27 | downsampler = nn.MaxPool2d(stride, stride) 28 | elif downsample_mode in ["lanczos2", "lanczos3"]: 29 | downsampler = Downsampler( 30 | n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True 31 | ) 32 | else: 33 | assert False 34 | 35 | stride = 1 36 | 37 | padder = None 38 | to_pad = int((kernel_size - 1) / 2) 39 | if pad == "reflection": 40 | padder = nn.ReflectionPad2d(to_pad) 41 | to_pad = 0 42 | 43 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 44 | 45 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 46 | return nn.Sequential(*layers) 47 | 48 | 49 | def get_activation(act_fun="LeakyReLU"): 50 | """ 51 | Either string defining an activation function or module (e.g. nn.ReLU) 52 | """ 53 | if isinstance(act_fun, str): 54 | if act_fun == "LeakyReLU": 55 | return nn.LeakyReLU(0.2, inplace=True) 56 | elif act_fun == "Swish": 57 | return Swish() 58 | elif act_fun == "ELU": 59 | return nn.ELU() 60 | elif act_fun == "none": 61 | return nn.Sequential() 62 | else: 63 | assert False 64 | else: 65 | return act_fun() 66 | -------------------------------------------------------------------------------- /bkse/models/backbones/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import functools 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class DoubleConv(nn.Module): 10 | """(convolution => [BN] => ReLU) * 2""" 11 | 12 | def __init__(self, in_channels, out_channels, mid_channels=None): 13 | super().__init__() 14 | if not mid_channels: 15 | mid_channels = out_channels 16 | self.double_conv = nn.Sequential( 17 | nn.Conv2d(in_channels, mid_channels, kernel_size=5, padding=2), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=5, padding=2), 20 | nn.ReLU(inplace=True), 21 | ) 22 | 23 | def forward(self, x): 24 | return self.double_conv(x) 25 | 26 | 27 | class UnetSkipConnectionBlock(nn.Module): 28 | """Defines the Unet submodule with skip connection. 29 | X -------------------identity---------------------- 30 | |-- downsampling -- |submodule| -- upsampling --| 31 | """ 32 | 33 | def __init__( 34 | self, 35 | outer_nc, 36 | inner_nc, 37 | input_nc=None, 38 | submodule=None, 39 | outermost=False, 40 | innermost=False, 41 | norm_layer=nn.BatchNorm2d, 42 | use_dropout=False, 43 | ): 44 | """Construct a Unet submodule with skip connections. 45 | Parameters: 46 | outer_nc (int) -- the number of filters in the outer conv layer 47 | inner_nc (int) -- the number of filters in the inner conv layer 48 | input_nc (int) -- the number of channels in input images/features 49 | submodule (UnetSkipConnectionBlock) --previously defined submodules 50 | outermost (bool) -- if this module is the outermost module 51 | innermost (bool) -- if this module is the innermost module 52 | norm_layer -- normalization layer 53 | use_dropout (bool) -- if use dropout layers. 54 | """ 55 | super(UnetSkipConnectionBlock, self).__init__() 56 | self.outermost = outermost 57 | self.innermost = innermost 58 | if type(norm_layer) == functools.partial: 59 | use_bias = norm_layer.func == nn.InstanceNorm2d 60 | else: 61 | use_bias = norm_layer == nn.InstanceNorm2d 62 | if input_nc is None: 63 | input_nc = outer_nc 64 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 65 | downrelu = nn.LeakyReLU(0.2, True) 66 | downnorm = norm_layer(inner_nc) 67 | uprelu = nn.ReLU(True) 68 | upnorm = norm_layer(outer_nc) 69 | 70 | if outermost: 71 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) 72 | # upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 73 | # upconv = DoubleConv(inner_nc * 2, outer_nc) 74 | up = [uprelu, upconv, nn.Tanh()] 75 | down = [downconv] 76 | self.down = nn.Sequential(*down) 77 | self.submodule = submodule 78 | self.up = nn.Sequential(*up) 79 | elif innermost: 80 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 81 | # upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 82 | # upconv = DoubleConv(inner_nc * 2, outer_nc) 83 | down = [downrelu, downconv] 84 | up = [uprelu, upconv, upnorm] 85 | self.down = nn.Sequential(*down) 86 | self.up = nn.Sequential(*up) 87 | else: 88 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 89 | # upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 90 | # upconv = DoubleConv(inner_nc * 2, outer_nc) 91 | down = [downrelu, downconv, downnorm] 92 | up = [uprelu, upconv, upnorm] 93 | if use_dropout: 94 | up += [nn.Dropout(0.5)] 95 | 96 | self.down = nn.Sequential(*down) 97 | self.submodule = submodule 98 | self.up = nn.Sequential(*up) 99 | 100 | def forward(self, x, noise): 101 | 102 | if self.outermost: 103 | return self.up(self.submodule(self.down(x), noise)) 104 | elif self.innermost: # add skip connections 105 | if noise is None: 106 | noise = torch.randn((1, 512, 8, 8)).cuda() * 0.0007 107 | return torch.cat((self.up(torch.cat((self.down(x), noise), dim=1)), x), dim=1) 108 | else: 109 | return torch.cat((self.up(self.submodule(self.down(x), noise)), x), dim=1) 110 | -------------------------------------------------------------------------------- /bkse/models/deblurring/image_deblur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import utils.util as util 4 | from models.dips import ImageDIP, KernelDIP 5 | from models.kernel_encoding.kernel_wizard import KernelWizard 6 | from models.losses.hyper_laplacian_penalty import HyperLaplacianPenalty 7 | from models.losses.perceptual_loss import PerceptualLoss 8 | from models.losses.ssim_loss import SSIM 9 | from torch.optim.lr_scheduler import StepLR 10 | from tqdm import tqdm 11 | 12 | 13 | class ImageDeblur: 14 | def __init__(self, opt): 15 | self.opt = opt 16 | 17 | # losses 18 | self.ssim_loss = SSIM().cuda() 19 | self.mse = nn.MSELoss().cuda() 20 | self.perceptual_loss = PerceptualLoss().cuda() 21 | self.laplace_penalty = HyperLaplacianPenalty(3, 0.66).cuda() 22 | 23 | self.kernel_wizard = KernelWizard(opt["KernelWizard"]).cuda() 24 | self.kernel_wizard.load_state_dict(torch.load(opt["KernelWizard"]["pretrained"])) 25 | 26 | for k, v in self.kernel_wizard.named_parameters(): 27 | v.requires_grad = False 28 | 29 | def reset_optimizers(self): 30 | self.x_optimizer = torch.optim.Adam(self.x_dip.parameters(), lr=self.opt["x_lr"]) 31 | self.k_optimizer = torch.optim.Adam(self.k_dip.parameters(), lr=self.opt["k_lr"]) 32 | 33 | self.x_scheduler = StepLR(self.x_optimizer, step_size=self.opt["num_iters"] // 5, gamma=0.7) 34 | 35 | self.k_scheduler = StepLR(self.k_optimizer, step_size=self.opt["num_iters"] // 5, gamma=0.7) 36 | 37 | def prepare_DIPs(self): 38 | # x is stand for the sharp image, k is stand for the kernel 39 | self.x_dip = ImageDIP(self.opt["ImageDIP"]).cuda() 40 | self.k_dip = KernelDIP(self.opt["KernelDIP"]).cuda() 41 | 42 | # fixed input vectors of DIPs 43 | # zk and zx are the length of the corresponding vectors 44 | self.dip_zk = util.get_noise(64, "noise", (64, 64)).cuda() 45 | self.dip_zx = util.get_noise(8, "noise", self.opt["img_size"]).cuda() 46 | 47 | def warmup(self, warmup_x, warmup_k): 48 | # Input vector of DIPs is sampled from N(z, I) 49 | reg_noise_std = self.opt["reg_noise_std"] 50 | 51 | for step in tqdm(range(self.opt["num_warmup_iters"])): 52 | self.x_optimizer.zero_grad() 53 | dip_zx_rand = self.dip_zx + reg_noise_std * torch.randn_like(self.dip_zx).cuda() 54 | x = self.x_dip(dip_zx_rand) 55 | 56 | loss = self.mse(x, warmup_x) 57 | loss.backward() 58 | self.x_optimizer.step() 59 | 60 | print("Warming up k DIP") 61 | for step in tqdm(range(self.opt["num_warmup_iters"])): 62 | self.k_optimizer.zero_grad() 63 | dip_zk_rand = self.dip_zk + reg_noise_std * torch.randn_like(self.dip_zk).cuda() 64 | k = self.k_dip(dip_zk_rand) 65 | 66 | loss = self.mse(k, warmup_k) 67 | loss.backward() 68 | self.k_optimizer.step() 69 | 70 | def deblur(self, img): 71 | pass 72 | -------------------------------------------------------------------------------- /bkse/models/deblurring/joint_deblur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils.util as util 3 | from models.deblurring.image_deblur import ImageDeblur 4 | from tqdm import tqdm 5 | 6 | 7 | class JointDeblur(ImageDeblur): 8 | def __init__(self, opt): 9 | super(JointDeblur, self).__init__(opt) 10 | 11 | def deblur(self, y): 12 | """Deblur image 13 | Args: 14 | y: Blur image 15 | """ 16 | y = util.img2tensor(y).unsqueeze(0).cuda() 17 | 18 | self.prepare_DIPs() 19 | self.reset_optimizers() 20 | 21 | warmup_k = torch.load(self.opt["warmup_k_path"]).cuda() 22 | self.warmup(y, warmup_k) 23 | 24 | # Input vector of DIPs is sampled from N(z, I) 25 | 26 | print("Deblurring") 27 | reg_noise_std = self.opt["reg_noise_std"] 28 | for step in tqdm(range(self.opt["num_iters"])): 29 | dip_zx_rand = self.dip_zx + reg_noise_std * torch.randn_like(self.dip_zx).cuda() 30 | dip_zk_rand = self.dip_zk + reg_noise_std * torch.randn_like(self.dip_zk).cuda() 31 | 32 | self.x_optimizer.zero_grad() 33 | self.k_optimizer.zero_grad() 34 | 35 | self.x_scheduler.step() 36 | self.k_scheduler.step() 37 | 38 | x = self.x_dip(dip_zx_rand) 39 | k = self.k_dip(dip_zk_rand) 40 | 41 | fake_y = self.kernel_wizard.adaptKernel(x, k) 42 | 43 | if step < self.opt["num_iters"] // 2: 44 | total_loss = 6e-1 * self.perceptual_loss(fake_y, y) 45 | total_loss += 1 - self.ssim_loss(fake_y, y) 46 | total_loss += 5e-5 * torch.norm(k) 47 | total_loss += 2e-2 * self.laplace_penalty(x) 48 | else: 49 | total_loss = self.perceptual_loss(fake_y, y) 50 | total_loss += 5e-2 * self.laplace_penalty(x) 51 | total_loss += 5e-4 * torch.norm(k) 52 | 53 | total_loss.backward() 54 | 55 | self.x_optimizer.step() 56 | self.k_optimizer.step() 57 | 58 | # debugging 59 | # if step % 100 == 0: 60 | # print(torch.norm(k)) 61 | # print(f"{self.k_optimizer.param_groups[0]['lr']:.3e}") 62 | 63 | return util.tensor2img(x.detach()) 64 | -------------------------------------------------------------------------------- /bkse/models/dips.py: -------------------------------------------------------------------------------- 1 | import models.arch_util as arch_util 2 | import torch.nn as nn 3 | from models.backbones.resnet import ResnetBlock 4 | from models.backbones.skip.skip import skip 5 | 6 | 7 | class KernelDIP(nn.Module): 8 | """ 9 | DIP (Deep Image Prior) for blur kernel 10 | """ 11 | 12 | def __init__(self, opt): 13 | super(KernelDIP, self).__init__() 14 | 15 | norm_layer = arch_util.get_norm_layer("none") 16 | n_blocks = opt["n_blocks"] 17 | nf = opt["nf"] 18 | padding_type = opt["padding_type"] 19 | use_dropout = opt["use_dropout"] 20 | kernel_dim = opt["kernel_dim"] 21 | 22 | input_nc = 64 23 | model = [ 24 | nn.ReflectionPad2d(3), 25 | nn.Conv2d(input_nc, nf, kernel_size=7, padding=0, bias=True), 26 | norm_layer(nf), 27 | nn.ReLU(True), 28 | ] 29 | 30 | n_downsampling = 5 31 | for i in range(n_downsampling): # add downsampling layers 32 | mult = 2 ** i 33 | input_nc = min(nf * mult, kernel_dim) 34 | output_nc = min(nf * mult * 2, kernel_dim) 35 | model += [ 36 | nn.Conv2d(input_nc, output_nc, kernel_size=3, stride=2, padding=1, bias=True), 37 | norm_layer(nf * mult * 2), 38 | nn.ReLU(True), 39 | ] 40 | 41 | for i in range(n_blocks): # add ResNet blocks 42 | model += [ 43 | ResnetBlock( 44 | kernel_dim, 45 | padding_type=padding_type, 46 | norm_layer=norm_layer, 47 | use_dropout=use_dropout, 48 | use_bias=True, 49 | ) 50 | ] 51 | 52 | self.model = nn.Sequential(*model) 53 | 54 | def forward(self, noise): 55 | return self.model(noise) 56 | 57 | 58 | class ImageDIP(nn.Module): 59 | """ 60 | DIP (Deep Image Prior) for sharp image 61 | """ 62 | 63 | def __init__(self, opt): 64 | super(ImageDIP, self).__init__() 65 | 66 | input_nc = opt["input_nc"] 67 | output_nc = opt["output_nc"] 68 | 69 | self.model = skip( 70 | input_nc, 71 | output_nc, 72 | num_channels_down=[128, 128, 128, 128, 128], 73 | num_channels_up=[128, 128, 128, 128, 128], 74 | num_channels_skip=[16, 16, 16, 16, 16], 75 | upsample_mode="bilinear", 76 | need_sigmoid=True, 77 | need_bias=True, 78 | pad=opt["padding_type"], 79 | act_fun="LeakyReLU", 80 | ) 81 | 82 | def forward(self, img): 83 | return self.model(img) 84 | -------------------------------------------------------------------------------- /bkse/models/dsd/bicubic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class BicubicDownSample(nn.Module): 7 | def bicubic_kernel(self, x, a=-0.50): 8 | """ 9 | This equation is exactly copied from the website below: 10 | https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic 11 | """ 12 | abs_x = torch.abs(x) 13 | if abs_x <= 1.0: 14 | return (a + 2.0) * torch.pow(abs_x, 3.0) - (a + 3.0) * torch.pow(abs_x, 2.0) + 1 15 | elif 1.0 < abs_x < 2.0: 16 | return a * torch.pow(abs_x, 3) - 5.0 * a * torch.pow(abs_x, 2.0) + 8.0 * a * abs_x - 4.0 * a 17 | else: 18 | return 0.0 19 | 20 | def __init__(self, factor=4, cuda=True, padding="reflect"): 21 | super().__init__() 22 | self.factor = factor 23 | size = factor * 4 24 | k = torch.tensor( 25 | [self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor) for i in range(size)], 26 | dtype=torch.float32, 27 | ) 28 | k = k / torch.sum(k) 29 | # k = torch.einsum('i,j->ij', (k, k)) 30 | k1 = torch.reshape(k, shape=(1, 1, size, 1)) 31 | self.k1 = torch.cat([k1, k1, k1], dim=0) 32 | k2 = torch.reshape(k, shape=(1, 1, 1, size)) 33 | self.k2 = torch.cat([k2, k2, k2], dim=0) 34 | self.cuda = ".cuda" if cuda else "" 35 | self.padding = padding 36 | for param in self.parameters(): 37 | param.requires_grad = False 38 | 39 | def forward(self, x, nhwc=False, clip_round=False, byte_output=False): 40 | # x = torch.from_numpy(x).type('torch.FloatTensor') 41 | filter_height = self.factor * 4 42 | filter_width = self.factor * 4 43 | stride = self.factor 44 | 45 | pad_along_height = max(filter_height - stride, 0) 46 | pad_along_width = max(filter_width - stride, 0) 47 | filters1 = self.k1.type("torch{}.FloatTensor".format(self.cuda)) 48 | filters2 = self.k2.type("torch{}.FloatTensor".format(self.cuda)) 49 | 50 | # compute actual padding values for each side 51 | pad_top = pad_along_height // 2 52 | pad_bottom = pad_along_height - pad_top 53 | pad_left = pad_along_width // 2 54 | pad_right = pad_along_width - pad_left 55 | 56 | # apply mirror padding 57 | if nhwc: 58 | x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW 59 | 60 | # downscaling performed by 1-d convolution 61 | x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding) 62 | x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3) 63 | if clip_round: 64 | x = torch.clamp(torch.round(x), 0.0, 255.0) 65 | 66 | x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding) 67 | x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3) 68 | if clip_round: 69 | x = torch.clamp(torch.round(x), 0.0, 255.0) 70 | 71 | if nhwc: 72 | x = torch.transpose(torch.transpose(x, 1, 3), 1, 2) 73 | if byte_output: 74 | return x.type("torch.{}.ByteTensor".format(self.cuda)) 75 | else: 76 | return x 77 | -------------------------------------------------------------------------------- /bkse/models/dsd/dsd_stylegan.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from models.dsd.dsd import DSD 5 | from models.dsd.stylegan import G_mapping, G_synthesis 6 | from models.losses.dsd_loss import LossBuilderStyleGAN 7 | 8 | 9 | class DSDStyleGAN(DSD): 10 | def __init__(self, opt, cache_dir): 11 | super(DSDStyleGAN, self).__init__(opt, cache_dir) 12 | 13 | def load_synthesis_network(self): 14 | self.synthesis = G_synthesis().cuda() 15 | self.synthesis.load_state_dict(torch.load("experiments/pretrained/stylegan_synthesis.pt")) 16 | for v in self.synthesis.parameters(): 17 | v.requires_grad = False 18 | 19 | def initialize_mapping_network(self): 20 | if Path("experiments/pretrained/gaussian_fit_stylegan.pt").exists(): 21 | self.gaussian_fit = torch.load("experiments/pretrained/gaussian_fit_stylegan.pt") 22 | else: 23 | if self.verbose: 24 | print("\tRunning Mapping Network") 25 | 26 | mapping = G_mapping().cuda() 27 | mapping.load_state_dict(torch.load("experiments/pretrained/stylegan_mapping.pt")) 28 | with torch.no_grad(): 29 | torch.manual_seed(0) 30 | latent = torch.randn((1000000, 512), dtype=torch.float32, device="cuda") 31 | latent_out = torch.nn.LeakyReLU(5)(mapping(latent)) 32 | self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)} 33 | torch.save(self.gaussian_fit, "experiments/pretrained/gaussian_fit_stylegan.pt") 34 | if self.verbose: 35 | print('\tSaved "gaussian_fit_stylegan.pt"') 36 | 37 | def initialize_latent_space(self): 38 | batch_size = self.opt["batch_size"] 39 | 40 | # Generate latent tensor 41 | if self.opt["tile_latent"]: 42 | self.latent = torch.randn((batch_size, 1, 512), dtype=torch.float, requires_grad=True, device="cuda") 43 | else: 44 | self.latent = torch.randn((batch_size, 18, 512), dtype=torch.float, requires_grad=True, device="cuda") 45 | 46 | # Generate list of noise tensors 47 | noise = [] # stores all of the noise tensors 48 | noise_vars = [] # stores the noise tensors that we want to optimize on 49 | 50 | noise_type = self.opt["noise_type"] 51 | bad_noise_layers = self.opt["bad_noise_layers"] 52 | for i in range(18): 53 | # dimension of the ith noise tensor 54 | res = (batch_size, 1, 2 ** (i // 2 + 2), 2 ** (i // 2 + 2)) 55 | 56 | if noise_type == "zero" or i in [int(layer) for layer in bad_noise_layers.split(".")]: 57 | new_noise = torch.zeros(res, dtype=torch.float, device="cuda") 58 | new_noise.requires_grad = False 59 | elif noise_type == "fixed": 60 | new_noise = torch.randn(res, dtype=torch.float, device="cuda") 61 | new_noise.requires_grad = False 62 | elif noise_type == "trainable": 63 | new_noise = torch.randn(res, dtype=torch.float, device="cuda") 64 | if i < self.opt["num_trainable_noise_layers"]: 65 | new_noise.requires_grad = True 66 | noise_vars.append(new_noise) 67 | else: 68 | new_noise.requires_grad = False 69 | else: 70 | raise Exception("unknown noise type") 71 | 72 | noise.append(new_noise) 73 | 74 | self.latent_x_var_list = [self.latent] + noise_vars 75 | self.noise = noise 76 | 77 | def initialize_loss(self, ref_im): 78 | self.loss_builder = LossBuilderStyleGAN(ref_im, self.opt).cuda() 79 | 80 | def get_gen_im(self, latent_in): 81 | return (self.synthesis(latent_in, self.noise) + 1) / 2 82 | -------------------------------------------------------------------------------- /bkse/models/dsd/dsd_stylegan2.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from models.dsd.dsd import DSD 5 | from models.dsd.stylegan2 import Generator 6 | from models.losses.dsd_loss import LossBuilderStyleGAN2 7 | 8 | 9 | class DSDStyleGAN2(DSD): 10 | def __init__(self, opt, cache_dir): 11 | super(DSDStyleGAN2, self).__init__(opt, cache_dir) 12 | 13 | def load_synthesis_network(self): 14 | self.synthesis = Generator(size=256, style_dim=512, n_mlp=8).cuda() 15 | self.synthesis.load_state_dict(torch.load("experiments/pretrained/stylegan2.pt")["g_ema"], strict=False) 16 | for v in self.synthesis.parameters(): 17 | v.requires_grad = False 18 | 19 | def initialize_mapping_network(self): 20 | if Path("experiments/pretrained/gaussian_fit_stylegan2.pt").exists(): 21 | self.gaussian_fit = torch.load("experiments/pretrained/gaussian_fit_stylegan2.pt") 22 | else: 23 | if self.verbose: 24 | print("\tRunning Mapping Network") 25 | with torch.no_grad(): 26 | torch.manual_seed(0) 27 | latent = torch.randn((1000000, 512), dtype=torch.float32, device="cuda") 28 | latent_out = torch.nn.LeakyReLU(5)(self.synthesis.get_latent(latent)) 29 | self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)} 30 | torch.save(self.gaussian_fit, "experiments/pretrained/gaussian_fit_stylegan2.pt") 31 | if self.verbose: 32 | print('\tSaved "gaussian_fit_stylegan2.pt"') 33 | 34 | def initialize_latent_space(self): 35 | batch_size = self.opt["batch_size"] 36 | 37 | # Generate latent tensor 38 | if self.opt["tile_latent"]: 39 | self.latent = torch.randn((batch_size, 1, 512), dtype=torch.float, requires_grad=True, device="cuda") 40 | else: 41 | self.latent = torch.randn((batch_size, 14, 512), dtype=torch.float, requires_grad=True, device="cuda") 42 | 43 | # Generate list of noise tensors 44 | noise = [] # stores all of the noise tensors 45 | noise_vars = [] # stores the noise tensors that we want to optimize on 46 | 47 | for i in range(14): 48 | res = (i + 5) // 2 49 | res = [1, 1, 2 ** res, 2 ** res] 50 | 51 | noise_type = self.opt["noise_type"] 52 | bad_noise_layers = self.opt["bad_noise_layers"] 53 | if noise_type == "zero" or i in [int(layer) for layer in bad_noise_layers.split(".")]: 54 | new_noise = torch.zeros(res, dtype=torch.float, device="cuda") 55 | new_noise.requires_grad = False 56 | elif noise_type == "fixed": 57 | new_noise = torch.randn(res, dtype=torch.float, device="cuda") 58 | new_noise.requires_grad = False 59 | elif noise_type == "trainable": 60 | new_noise = torch.randn(res, dtype=torch.float, device="cuda") 61 | if i < self.opt["num_trainable_noise_layers"]: 62 | new_noise.requires_grad = True 63 | noise_vars.append(new_noise) 64 | else: 65 | new_noise.requires_grad = False 66 | else: 67 | raise Exception("unknown noise type") 68 | 69 | noise.append(new_noise) 70 | 71 | self.latent_x_var_list = [self.latent] + noise_vars 72 | self.noise = noise 73 | 74 | def initialize_loss(self, ref_im): 75 | self.loss_builder = LossBuilderStyleGAN2(ref_im, self.opt).cuda() 76 | 77 | def get_gen_im(self, latent_in): 78 | return (self.synthesis([latent_in], input_is_latent=True, noise=self.noise)[0] + 1) / 2 79 | -------------------------------------------------------------------------------- /bkse/models/dsd/op/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/models/dsd/op/__init__.py -------------------------------------------------------------------------------- /bkse/models/dsd/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.nn import functional as F 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | if bias: 37 | grad_bias = grad_input.sum(dim).detach() 38 | 39 | else: 40 | grad_bias = None 41 | 42 | return grad_input, grad_bias 43 | 44 | @staticmethod 45 | def backward(ctx, gradgrad_input, gradgrad_bias): 46 | (out,) = ctx.saved_tensors 47 | gradgrad_out = fused.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale) 48 | 49 | return gradgrad_out, None, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | 57 | if bias is None: 58 | bias = empty 59 | 60 | ctx.bias = bias is not None 61 | 62 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 63 | ctx.save_for_backward(out) 64 | ctx.negative_slope = negative_slope 65 | ctx.scale = scale 66 | 67 | return out 68 | 69 | @staticmethod 70 | def backward(ctx, grad_output): 71 | (out,) = ctx.saved_tensors 72 | 73 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 74 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 75 | ) 76 | 77 | return grad_input, grad_bias, None, None 78 | 79 | 80 | class FusedLeakyReLU(nn.Module): 81 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 82 | super().__init__() 83 | 84 | if bias: 85 | self.bias = nn.Parameter(torch.zeros(channel)) 86 | 87 | else: 88 | self.bias = None 89 | 90 | self.negative_slope = negative_slope 91 | self.scale = scale 92 | 93 | def forward(self, input): 94 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 95 | 96 | 97 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 98 | if input.device.type == "cpu": 99 | if bias is not None: 100 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 101 | return F.leaky_relu(input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2) * scale 102 | 103 | else: 104 | return F.leaky_relu(input, negative_slope=0.2) * scale 105 | 106 | else: 107 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 108 | -------------------------------------------------------------------------------- /bkse/models/dsd/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /bkse/models/dsd/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /bkse/models/dsd/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /bkse/models/dsd/spherical_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | # Spherical Optimizer Class 6 | # Uses the first two dimensions as batch information 7 | # Optimizes over the surface of a sphere using the initial radius throughout 8 | # 9 | # Example Usage: 10 | # opt = SphericalOptimizer(torch.optim.SGD, [x], lr=0.01) 11 | 12 | 13 | class SphericalOptimizer(Optimizer): 14 | def __init__(self, optimizer, params, **kwargs): 15 | self.opt = optimizer(params, **kwargs) 16 | self.params = params 17 | with torch.no_grad(): 18 | self.radii = { 19 | param: (param.pow(2).sum(tuple(range(2, param.ndim)), keepdim=True) + 1e-9).sqrt() for param in params 20 | } 21 | 22 | @torch.no_grad() 23 | def step(self, closure=None): 24 | loss = self.opt.step(closure) 25 | for param in self.params: 26 | param.data.div_((param.pow(2).sum(tuple(range(2, param.ndim)), keepdim=True) + 1e-9).sqrt()) 27 | param.mul_(self.radii[param]) 28 | 29 | return loss 30 | -------------------------------------------------------------------------------- /bkse/models/losses/charbonnier_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-6): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 15 | return loss 16 | -------------------------------------------------------------------------------- /bkse/models/losses/dsd_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.dsd.bicubic import BicubicDownSample 3 | from models.kernel_encoding.kernel_wizard import KernelWizard 4 | from models.losses.ssim_loss import SSIM 5 | 6 | 7 | class LossBuilder(torch.nn.Module): 8 | def __init__(self, ref_im, opt): 9 | super(LossBuilder, self).__init__() 10 | assert ref_im.shape[2] == ref_im.shape[3] 11 | self.ref_im = ref_im 12 | loss_str = opt["loss_str"] 13 | self.parsed_loss = [loss_term.split("*") for loss_term in loss_str.split("+")] 14 | self.eps = opt["eps"] 15 | 16 | self.ssim = SSIM().cuda() 17 | 18 | self.D = KernelWizard(opt["KernelWizard"]).cuda() 19 | self.D.load_state_dict(torch.load(opt["KernelWizard"]["pretrained"])) 20 | for v in self.D.parameters(): 21 | v.requires_grad = False 22 | 23 | # Takes a list of tensors, flattens them, and concatenates them into a vector 24 | # Used to calculate euclidian distance between lists of tensors 25 | def flatcat(self, l): 26 | l = l if (isinstance(l, list)) else [l] 27 | return torch.cat([x.flatten() for x in l], dim=0) 28 | 29 | def _loss_l2(self, gen_im_lr, ref_im, **kwargs): 30 | return (gen_im_lr - ref_im).pow(2).mean((1, 2, 3)).clamp(min=self.eps).sum() 31 | 32 | def _loss_l1(self, gen_im_lr, ref_im, **kwargs): 33 | return 10 * ((gen_im_lr - ref_im).abs().mean((1, 2, 3)).clamp(min=self.eps).sum()) 34 | 35 | # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors 36 | def _loss_geocross(self, latent, **kwargs): 37 | pass 38 | 39 | 40 | class LossBuilderStyleGAN(LossBuilder): 41 | def __init__(self, ref_im, opt): 42 | super(LossBuilderStyleGAN, self).__init__(ref_im, opt) 43 | im_size = ref_im.shape[2] 44 | factor = opt["output_size"] // im_size 45 | assert im_size * factor == opt["output_size"] 46 | self.bicub = BicubicDownSample(factor=factor) 47 | 48 | # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors 49 | def _loss_geocross(self, latent, **kwargs): 50 | if latent.shape[1] == 1: 51 | return 0 52 | else: 53 | X = latent.view(-1, 1, 18, 512) 54 | Y = latent.view(-1, 18, 1, 512) 55 | A = ((X - Y).pow(2).sum(-1) + 1e-9).sqrt() 56 | B = ((X + Y).pow(2).sum(-1) + 1e-9).sqrt() 57 | D = 2 * torch.atan2(A, B) 58 | D = ((D.pow(2) * 512).mean((1, 2)) / 8.0).sum() 59 | return D 60 | 61 | def forward(self, latent, gen_im, kernel, step): 62 | var_dict = { 63 | "latent": latent, 64 | "gen_im_lr": self.D.adaptKernel(self.bicub(gen_im), kernel), 65 | "ref_im": self.ref_im, 66 | } 67 | loss = 0 68 | loss_fun_dict = { 69 | "L2": self._loss_l2, 70 | "L1": self._loss_l1, 71 | "GEOCROSS": self._loss_geocross, 72 | } 73 | losses = {} 74 | 75 | for weight, loss_type in self.parsed_loss: 76 | tmp_loss = loss_fun_dict[loss_type](**var_dict) 77 | losses[loss_type] = tmp_loss 78 | loss += float(weight) * tmp_loss 79 | loss += 5e-5 * torch.norm(kernel) 80 | losses["Norm"] = torch.norm(kernel) 81 | 82 | return loss, losses 83 | 84 | def get_blur_img(self, sharp_img, kernel): 85 | return self.D.adaptKernel(self.bicub(sharp_img), kernel).cpu().detach().clamp(0, 1) 86 | 87 | 88 | class LossBuilderStyleGAN2(LossBuilder): 89 | def __init__(self, ref_im, opt): 90 | super(LossBuilderStyleGAN2, self).__init__(ref_im, opt) 91 | 92 | # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors 93 | def _loss_geocross(self, latent, **kwargs): 94 | if latent.shape[1] == 1: 95 | return 0 96 | else: 97 | X = latent.view(-1, 1, 14, 512) 98 | Y = latent.view(-1, 14, 1, 512) 99 | A = ((X - Y).pow(2).sum(-1) + 1e-9).sqrt() 100 | B = ((X + Y).pow(2).sum(-1) + 1e-9).sqrt() 101 | D = 2 * torch.atan2(A, B) 102 | D = ((D.pow(2) * 512).mean((1, 2)) / 6.0).sum() 103 | return D 104 | 105 | def forward(self, latent, gen_im, kernel, step): 106 | var_dict = { 107 | "latent": latent, 108 | "gen_im_lr": self.D.adaptKernel(gen_im, kernel), 109 | "ref_im": self.ref_im, 110 | } 111 | loss = 0 112 | loss_fun_dict = { 113 | "L2": self._loss_l2, 114 | "L1": self._loss_l1, 115 | "GEOCROSS": self._loss_geocross, 116 | } 117 | losses = {} 118 | 119 | for weight, loss_type in self.parsed_loss: 120 | tmp_loss = loss_fun_dict[loss_type](**var_dict) 121 | losses[loss_type] = tmp_loss 122 | loss += float(weight) * tmp_loss 123 | loss += 1e-4 * torch.norm(kernel) 124 | losses["Norm"] = torch.norm(kernel) 125 | 126 | return loss, losses 127 | 128 | def get_blur_img(self, sharp_img, kernel): 129 | return self.D.adaptKernel(sharp_img, kernel).cpu().detach().clamp(0, 1) 130 | -------------------------------------------------------------------------------- /bkse/models/losses/gan_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 6 | class GANLoss(nn.Module): 7 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 8 | super(GANLoss, self).__init__() 9 | self.gan_type = gan_type.lower() 10 | self.real_label_val = real_label_val 11 | self.fake_label_val = fake_label_val 12 | 13 | if self.gan_type == "gan" or self.gan_type == "ragan": 14 | self.loss = nn.BCEWithLogitsLoss() 15 | elif self.gan_type == "lsgan": 16 | self.loss = nn.MSELoss() 17 | elif self.gan_type == "wgan-gp": 18 | 19 | def wgan_loss(input, target): 20 | # target is boolean 21 | return -1 * input.mean() if target else input.mean() 22 | 23 | self.loss = wgan_loss 24 | else: 25 | raise NotImplementedError("GAN type [{:s}] is not found".format(self.gan_type)) 26 | 27 | def get_target_label(self, input, target_is_real): 28 | if self.gan_type == "wgan-gp": 29 | return target_is_real 30 | if target_is_real: 31 | return torch.empty_like(input).fill_(self.real_label_val) 32 | else: 33 | return torch.empty_like(input).fill_(self.fake_label_val) 34 | 35 | def forward(self, input, target_is_real): 36 | target_label = self.get_target_label(input, target_is_real) 37 | loss = self.loss(input, target_label) 38 | return loss 39 | -------------------------------------------------------------------------------- /bkse/models/losses/hyper_laplacian_penalty.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class HyperLaplacianPenalty(nn.Module): 7 | def __init__(self, num_channels, alpha, eps=1e-6): 8 | super(HyperLaplacianPenalty, self).__init__() 9 | 10 | self.alpha = alpha 11 | self.eps = eps 12 | 13 | self.Kx = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).cuda() 14 | self.Kx = self.Kx.expand(1, num_channels, 3, 3) 15 | self.Kx.requires_grad = False 16 | self.Ky = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).cuda() 17 | self.Ky = self.Ky.expand(1, num_channels, 3, 3) 18 | self.Ky.requires_grad = False 19 | 20 | def forward(self, x): 21 | gradX = F.conv2d(x, self.Kx, stride=1, padding=1) 22 | gradY = F.conv2d(x, self.Ky, stride=1, padding=1) 23 | grad = torch.sqrt(gradX ** 2 + gradY ** 2 + self.eps) 24 | 25 | loss = (grad ** self.alpha).mean() 26 | 27 | return loss 28 | -------------------------------------------------------------------------------- /bkse/models/losses/ssim_loss.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | class SSIM(torch.nn.Module): 9 | @staticmethod 10 | def gaussian(window_size, sigma): 11 | gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)]) 12 | return gauss / gauss.sum() 13 | 14 | @staticmethod 15 | def create_window(window_size, channel): 16 | _1D_window = SSIM.gaussian(window_size, 1.5).unsqueeze(1) 17 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 18 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 19 | return window 20 | 21 | @staticmethod 22 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 23 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 24 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 25 | 26 | mu1_sq = mu1.pow(2) 27 | mu2_sq = mu2.pow(2) 28 | mu1_mu2 = mu1 * mu2 29 | 30 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 31 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 32 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 33 | 34 | C1 = 0.01 ** 2 35 | C2 = 0.03 ** 2 36 | 37 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 38 | 39 | if size_average: 40 | return ssim_map.mean() 41 | else: 42 | return ssim_map.mean(1).mean(1).mean(1) 43 | 44 | def __init__(self, window_size=11, size_average=True): 45 | super(SSIM, self).__init__() 46 | self.window_size = window_size 47 | self.size_average = size_average 48 | self.channel = 1 49 | self.window = self.create_window(window_size, self.channel) 50 | 51 | def forward(self, img1, img2): 52 | (_, channel, _, _) = img1.size() 53 | 54 | if channel == self.channel and self.window.data.type() == img1.data.type(): 55 | window = self.window 56 | else: 57 | window = self.create_window(self.window_size, channel) 58 | 59 | if img1.is_cuda: 60 | window = window.cuda(img1.get_device()) 61 | window = window.type_as(img1) 62 | 63 | self.window = window 64 | self.channel = channel 65 | 66 | return self._ssim(img1, img2, window, self.window_size, channel, self.size_average) 67 | -------------------------------------------------------------------------------- /bkse/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/options/__init__.py -------------------------------------------------------------------------------- /bkse/options/data_augmentation/default.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | gpu_ids: [0] 3 | 4 | #### network structures 5 | KernelWizard: 6 | pretrained: experiments/pretrained/GOPRO_woVAE.pth 7 | input_nc: 3 8 | nf: 64 9 | front_RBs: 10 10 | back_RBs: 20 11 | N_frames: 1 12 | kernel_dim: 512 13 | use_vae: false 14 | KernelExtractor: 15 | norm: none 16 | use_sharp: true 17 | n_blocks: 4 18 | padding_type: reflect 19 | use_dropout: false 20 | Adapter: 21 | norm: none 22 | use_dropout: false 23 | -------------------------------------------------------------------------------- /bkse/options/domain_specific_deblur/stylegan.yml: -------------------------------------------------------------------------------- 1 | stylegan_ver: 1 2 | img_size: &HQ_SIZE [256, 256] 3 | output_size: 1024 4 | verbose: true 5 | num_epochs: 25 6 | num_warmup_iters: 150 7 | num_x_iters: 300 8 | num_k_iters: 200 9 | x_lr: !!float 0.2 10 | k_lr: !!float 1e-4 11 | warmup_k_path: experiments/pretrained/kernel.pth 12 | reg_noise_std: !!float 0.001 13 | duplicates: 1 14 | batch_size: 1 15 | loss_str: '100*L2+0.1*GEOCROSS' 16 | eps: !!float 1e-15 17 | noise_type: trainable 18 | num_trainable_noise_layers: 5 19 | bad_noise_layers: '17' 20 | optimizer_name: adam 21 | lr_schedule: linear1cycledrop 22 | save_intermediate: true 23 | tile_latent: ~ 24 | seed: ~ 25 | 26 | KernelDIP: 27 | nf: 64 28 | n_blocks: 6 29 | padding_type: reflect 30 | use_dropout: false 31 | kernel_dim: 512 32 | norm: none 33 | 34 | KernelWizard: 35 | pretrained: experiments/pretrained/GOPRO_woVAE.pth 36 | input_nc: 3 37 | nf: 64 38 | front_RBs: 10 39 | back_RBs: 20 40 | N_frames: 1 41 | kernel_dim: 512 42 | img_size: *HQ_SIZE 43 | use_vae: false 44 | KernelExtractor: 45 | norm: none 46 | use_sharp: true 47 | n_blocks: 4 48 | padding_type: reflect 49 | use_dropout: false 50 | Adapter: 51 | norm: none 52 | use_dropout: false 53 | -------------------------------------------------------------------------------- /bkse/options/domain_specific_deblur/stylegan2.yml: -------------------------------------------------------------------------------- 1 | stylegan_ver: 2 2 | img_size: &HQ_SIZE [256, 256] 3 | output_size: 256 4 | verbose: true 5 | num_epochs: 25 6 | num_warmup_iters: 150 7 | num_x_iters: 300 8 | num_k_iters: 200 9 | x_lr: !!float 0.2 10 | k_lr: !!float 5e-4 11 | warmup_k_path: experiments/pretrained/kernel.pth 12 | reg_noise_std: !!float 0.001 13 | duplicates: 1 14 | batch_size: 1 15 | loss_str: '100*L2+0.1*GEOCROSS' 16 | eps: !!float 1e-15 17 | noise_type: trainable 18 | num_trainable_noise_layers: 5 19 | bad_noise_layers: '17' 20 | optimizer_name: adam 21 | lr_schedule: linear1cycledrop 22 | save_intermediate: true 23 | tile_latent: ~ 24 | seed: ~ 25 | 26 | ImageDIP: 27 | input_nc: 8 28 | output_nc: 3 29 | nf: 64 30 | norm: none 31 | padding_type: reflect 32 | 33 | KernelDIP: 34 | nf: 64 35 | n_blocks: 6 36 | padding_type: reflect 37 | use_dropout: false 38 | kernel_dim: 512 39 | norm: none 40 | 41 | KernelWizard: 42 | pretrained: experiments/pretrained/GOPRO_woVAE.pth 43 | input_nc: 3 44 | nf: 64 45 | front_RBs: 10 46 | back_RBs: 20 47 | N_frames: 1 48 | kernel_dim: 512 49 | img_size: *HQ_SIZE 50 | use_vae: false 51 | KernelExtractor: 52 | norm: none 53 | use_sharp: true 54 | n_blocks: 4 55 | padding_type: reflect 56 | use_dropout: false 57 | Adapter: 58 | norm: none 59 | use_dropout: false 60 | -------------------------------------------------------------------------------- /bkse/options/generate_blur/default.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | gpu_ids: [0] 3 | 4 | #### network structures 5 | KernelWizard: 6 | #pretrained: experiments/pretrained/GOPRO_wVAE.pth 7 | #pretrained: /home/mmardani/research/stable-diffusion-sampling-gitlab/pgdm/bkse/experiments/pretrained/kernel.pth 8 | #pretrained: /lustre/fsw/nvresearch/mmardani/source/latent-diffusion-sampling/pgdm/bkse/experiments/pretrained/kernel.pth 9 | pretrained: bkse/experiments/pretrained/GOPRO_wVAE.pth 10 | input_nc: 3 11 | nf: 64 12 | front_RBs: 10 13 | back_RBs: 20 14 | N_frames: 1 15 | kernel_dim: 512 16 | use_vae: true 17 | KernelExtractor: 18 | norm: none 19 | use_sharp: true 20 | n_blocks: 4 21 | padding_type: reflect 22 | use_dropout: false 23 | Adapter: 24 | norm: none 25 | use_dropout: false 26 | -------------------------------------------------------------------------------- /bkse/options/generic_deblur/default.yml: -------------------------------------------------------------------------------- 1 | num_iters: 5000 2 | num_warmup_iters: 300 3 | x_lr: !!float 5e-4 4 | k_lr: !!float 5e-4 5 | img_size: &HQ_SIZE [256, 256] 6 | warmup_k_path: bkse/experiments/pretrained/kernel.pth 7 | reg_noise_std: !!float 0.001 8 | 9 | ImageDIP: 10 | input_nc: 8 11 | output_nc: 3 12 | nf: 64 13 | norm: none 14 | padding_type: reflect 15 | 16 | KernelDIP: 17 | nf: 64 18 | n_blocks: 6 19 | padding_type: reflect 20 | use_dropout: false 21 | kernel_dim: 512 22 | norm: none 23 | 24 | KernelWizard: 25 | pretrained: bkse/experiments/pretrained/GOPRO_woVAE.pth 26 | input_nc: 3 27 | nf: 64 28 | front_RBs: 10 29 | back_RBs: 20 30 | N_frames: 1 31 | kernel_dim: 512 32 | img_size: *HQ_SIZE 33 | use_vae: false 34 | KernelExtractor: 35 | norm: none 36 | use_sharp: true 37 | n_blocks: 4 38 | padding_type: reflect 39 | use_dropout: false 40 | Adapter: 41 | norm: none 42 | use_dropout: false 43 | -------------------------------------------------------------------------------- /bkse/options/kernel_encoding/GOPRO/wVAE.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: GOPRO_VAE 3 | use_tb_logger: true 4 | model: image_base 5 | distortion: deblur 6 | scale: 1 7 | gpu_ids: [0] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: GOPRO 13 | mode: GOPRO 14 | interval_list: [1] 15 | dataroot_HQ: datasets/GOPRO/train_sharp.lmdb 16 | dataroot_LQ: datasets/GOPRO/train_blur_linear.lmdb 17 | cache_keys: ~ 18 | 19 | use_shuffle: true 20 | n_workers: 4 # per GPU 21 | batch_size: 8 22 | HQ_size: &HQ_SIZE 256 23 | LQ_size: 256 24 | use_flip: true 25 | use_rot: true 26 | color: RGB 27 | 28 | #### network structures 29 | KernelWizard: 30 | input_nc: 3 31 | nf: 64 32 | front_RBs: 10 33 | back_RBs: 20 34 | N_frames: 1 35 | kernel_dim: 512 36 | img_size: *HQ_SIZE 37 | use_vae: true 38 | KernelExtractor: 39 | norm: none 40 | use_sharp: true 41 | n_blocks: 4 42 | padding_type: reflect 43 | use_dropout: false 44 | Adapter: 45 | norm: none 46 | use_dropout: false 47 | 48 | #### path 49 | path: 50 | pretrain_model_G: experiments/pretrained/GOPRO_wsharp_woVAE.pth 51 | strict_load: false 52 | resume_state: ~ 53 | 54 | #### training settings: learning rate scheme, loss 55 | train: 56 | lr_G: !!float 1e-4 57 | lr_scheme: CosineAnnealingLR_Restart 58 | beta1: 0.9 59 | beta2: 0.99 60 | niter: 600000 61 | warmup_iter: -1 # -1: no warm up 62 | T_period: [50000, 100000, 150000, 150000, 150000] 63 | restarts: [50000, 150000, 300000, 450000] 64 | restart_weights: [1, 1, 1, 1] 65 | eta_min: !!float 1e-8 66 | 67 | pixel_criterion: cb 68 | pixel_weight: !!float 1.0 69 | kl_weight: !!float 10.0 70 | val_freq: !!float 5e3 71 | 72 | manual_seed: 0 73 | 74 | #### logger 75 | logger: 76 | print_freq: 10 77 | save_checkpoint_freq: !!float 5e3 78 | -------------------------------------------------------------------------------- /bkse/options/kernel_encoding/GOPRO/woVAE.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: GOPRO_woVAE 3 | use_tb_logger: true 4 | model: image_base 5 | distortion: deblur 6 | scale: 1 7 | gpu_ids: [0] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: GOPRO 13 | mode: GOPRO 14 | interval_list: [1] 15 | dataroot_HQ: datasets/GOPRO/train_sharp.lmdb 16 | dataroot_LQ: datasets/GOPRO/train_blur_linear.lmdb 17 | cache_keys: ~ 18 | 19 | use_shuffle: true 20 | n_workers: 4 # per GPU 21 | batch_size: 16 22 | HQ_size: &HQ_SIZE 256 23 | LQ_size: 256 24 | use_flip: true 25 | use_rot: true 26 | color: RGB 27 | 28 | #### network structures 29 | KernelWizard: 30 | input_nc: 3 31 | nf: 64 32 | front_RBs: 10 33 | back_RBs: 20 34 | N_frames: 1 35 | kernel_dim: 512 36 | img_size: *HQ_SIZE 37 | use_vae: false 38 | KernelExtractor: 39 | norm: none 40 | use_sharp: true 41 | n_blocks: 4 42 | padding_type: reflect 43 | use_dropout: false 44 | Adapter: 45 | norm: none 46 | use_dropout: false 47 | 48 | #### path 49 | path: 50 | pretrain_model_G: ~ 51 | strict_load: false 52 | resume_state: ~ 53 | 54 | #### training settings: learning rate scheme, loss 55 | train: 56 | lr_G: !!float 1e-4 57 | lr_scheme: CosineAnnealingLR_Restart 58 | beta1: 0.9 59 | beta2: 0.99 60 | niter: 600000 61 | warmup_iter: -1 # -1: no warm up 62 | T_period: [50000, 100000, 150000, 150000, 150000] 63 | restarts: [50000, 150000, 300000, 450000] 64 | restart_weights: [1, 1, 1, 1] 65 | eta_min: !!float 1e-8 66 | 67 | pixel_criterion: cb 68 | pixel_weight: 1.0 69 | kl_weight: 0.0 70 | val_freq: !!float 5e3 71 | 72 | manual_seed: 0 73 | 74 | #### logger 75 | logger: 76 | print_freq: 10 77 | save_checkpoint_freq: !!float 5e3 78 | -------------------------------------------------------------------------------- /bkse/options/kernel_encoding/REDS/woVAE.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: REDS_woVAE 3 | use_tb_logger: true 4 | model: image_base 5 | distortion: deblur 6 | scale: 1 7 | gpu_ids: [3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: REDS 13 | mode: REDS 14 | interval_list: [1] 15 | dataroot_HQ: datasets/REDS/train_sharp_wval.lmdb 16 | dataroot_LQ: datasets/REDS/train_blur_wval.lmdb 17 | cache_keys: ~ 18 | 19 | use_shuffle: true 20 | n_workers: 4 # per GPU 21 | batch_size: 13 22 | HQ_size: &HQ_SIZE 256 23 | LQ_size: 256 24 | use_flip: true 25 | use_rot: true 26 | color: RGB 27 | 28 | #### network structures 29 | KernelWizard: 30 | input_nc: 3 31 | nf: 64 32 | front_RBs: 10 33 | back_RBs: 20 34 | N_frames: 1 35 | kernel_dim: 512 36 | img_size: *HQ_SIZE 37 | use_vae: false 38 | KernelExtractor: 39 | norm: none 40 | use_sharp: true 41 | n_blocks: 4 42 | padding_type: reflect 43 | use_dropout: false 44 | Adapter: 45 | norm: none 46 | use_dropout: false 47 | 48 | #### path 49 | path: 50 | pretrain_model_G: ~ 51 | strict_load: false 52 | resume_state: ~ 53 | 54 | #### training settings: learning rate scheme, loss 55 | train: 56 | lr_G: !!float 1e-4 57 | lr_scheme: CosineAnnealingLR_Restart 58 | beta1: 0.9 59 | beta2: 0.99 60 | niter: 600000 61 | warmup_iter: -1 # -1: no warm up 62 | T_period: [50000, 100000, 150000, 150000, 150000] 63 | restarts: [50000, 150000, 300000, 450000] 64 | restart_weights: [1, 1, 1, 1] 65 | eta_min: !!float 1e-6 66 | 67 | pixel_criterion: cb 68 | pixel_weight: 1.0 69 | kl_weight: 0.0 70 | val_freq: !!float 5e3 71 | 72 | manual_seed: 0 73 | 74 | #### logger 75 | logger: 76 | print_freq: 10 77 | save_checkpoint_freq: !!float 5e3 78 | -------------------------------------------------------------------------------- /bkse/options/kernel_encoding/mix/woVAE.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: mix_wsharp 3 | use_tb_logger: true 4 | model: image_base 5 | distortion: deblur 6 | scale: 1 7 | gpu_ids: [0] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: mix 13 | mode: mix 14 | interval_list: [1] 15 | dataroots_HQ: ['datasets/REDS/train_sharp_wval.lmdb', 'datasets/GOPRO/train_sharp.lmdb'] 16 | dataroots_LQ: ['datasets/REDS/train_blur_wval.lmdb', 'datasets/GOPRO/train_blur_linear.lmdb'] 17 | dataset_weights: [1, 10] 18 | cache_keys: ~ 19 | 20 | N_frames: 1 21 | use_shuffle: true 22 | n_workers: 3 # per GPU 23 | batch_size: 16 24 | HQ_size: 256 25 | LQ_size: 256 26 | use_flip: true 27 | use_rot: true 28 | color: RGB 29 | 30 | #### network structures 31 | KernelWizard: 32 | input_nc: 3 33 | nf: 64 34 | front_RBs: 10 35 | back_RBs: 20 36 | N_frames: 1 37 | kernel_dim: 512 38 | use_vae: false 39 | KernelExtractor: 40 | norm: none 41 | use_sharp: true 42 | n_blocks: 4 43 | padding_type: reflect 44 | use_dropout: false 45 | Adapter: 46 | norm: none 47 | use_dropout: false 48 | 49 | #### path 50 | path: 51 | pretrain_model_G: ~ 52 | strict_load: false 53 | resume_state: ~ 54 | 55 | #### training settings: learning rate scheme, loss 56 | train: 57 | lr_G: !!float 1e-4 58 | lr_scheme: CosineAnnealingLR_Restart 59 | beta1: 0.9 60 | beta2: 0.99 61 | niter: 600000 62 | warmup_iter: -1 # -1: no warm up 63 | T_period: [50000, 100000, 150000, 150000, 150000] 64 | restarts: [50000, 150000, 300000, 450000] 65 | restart_weights: [1, 1, 1, 1] 66 | eta_min: !!float 1e-8 67 | 68 | pixel_criterion: cb 69 | pixel_weight: 1.0 70 | kernel_weight: 0.1 71 | gradient_loss_weight: 0.3 72 | val_freq: !!float 5e3 73 | 74 | manual_seed: 0 75 | 76 | #### logger 77 | logger: 78 | print_freq: 10 79 | save_checkpoint_freq: !!float 5000 80 | -------------------------------------------------------------------------------- /bkse/options/options.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import os.path as osp 4 | 5 | import yaml 6 | from utils.util import OrderedYaml 7 | 8 | 9 | Loader, Dumper = OrderedYaml() 10 | 11 | 12 | def parse(opt_path, is_train=True): 13 | with open(opt_path, mode="r") as f: 14 | opt = yaml.load(f, Loader=Loader) 15 | # export CUDA_VISIBLE_DEVICES 16 | gpu_list = ",".join(str(x) for x in opt["gpu_ids"]) 17 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list 18 | print("export CUDA_VISIBLE_DEVICES=" + gpu_list) 19 | 20 | opt["is_train"] = is_train 21 | if opt["distortion"] == "sr": 22 | scale = opt["scale"] 23 | 24 | # datasets 25 | for phase, dataset in opt["datasets"].items(): 26 | phase = phase.split("_")[0] 27 | dataset["phase"] = phase 28 | if opt["distortion"] == "sr": 29 | dataset["scale"] = scale 30 | is_lmdb = False 31 | if dataset.get("dataroot_GT", None) is not None: 32 | dataset["dataroot_GT"] = osp.expanduser(dataset["dataroot_GT"]) 33 | if dataset["dataroot_GT"].endswith("lmdb"): 34 | is_lmdb = True 35 | if dataset.get("dataroot_LQ", None) is not None: 36 | dataset["dataroot_LQ"] = osp.expanduser(dataset["dataroot_LQ"]) 37 | if dataset["dataroot_LQ"].endswith("lmdb"): 38 | is_lmdb = True 39 | dataset["data_type"] = "lmdb" if is_lmdb else "img" 40 | if dataset["mode"].endswith("mc"): # for memcached 41 | dataset["data_type"] = "mc" 42 | dataset["mode"] = dataset["mode"].replace("_mc", "") 43 | 44 | # path 45 | for key, path in opt["path"].items(): 46 | if path and key in opt["path"] and key != "strict_load": 47 | opt["path"][key] = osp.expanduser(path) 48 | opt["path"]["root"] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 49 | if is_train: 50 | experiments_root = osp.join(opt["path"]["root"], "experiments", opt["name"]) 51 | opt["path"]["experiments_root"] = experiments_root 52 | opt["path"]["models"] = osp.join(experiments_root, "models") 53 | opt["path"]["training_state"] = osp.join(experiments_root, "training_state") 54 | opt["path"]["log"] = experiments_root 55 | opt["path"]["val_images"] = osp.join(experiments_root, "val_images") 56 | 57 | # change some options for debug mode 58 | if "debug" in opt["name"]: 59 | opt["train"]["val_freq"] = 8 60 | opt["logger"]["print_freq"] = 1 61 | opt["logger"]["save_checkpoint_freq"] = 8 62 | else: # test 63 | results_root = osp.join(opt["path"]["root"], "results", opt["name"]) 64 | opt["path"]["results_root"] = results_root 65 | opt["path"]["log"] = results_root 66 | 67 | # network 68 | if opt["distortion"] == "sr": 69 | opt["network_G"]["scale"] = scale 70 | 71 | return opt 72 | 73 | 74 | def dict2str(opt, indent_l=1): 75 | """dict to string for logger""" 76 | msg = "" 77 | for k, v in opt.items(): 78 | if isinstance(v, dict): 79 | msg += " " * (indent_l * 2) + k + ":[\n" 80 | msg += dict2str(v, indent_l + 1) 81 | msg += " " * (indent_l * 2) + "]\n" 82 | else: 83 | msg += " " * (indent_l * 2) + k + ": " + str(v) + "\n" 84 | return msg 85 | 86 | 87 | class NoneDict(dict): 88 | def __missing__(self, key): 89 | return None 90 | 91 | 92 | # convert to NoneDict, which return None for missing key. 93 | def dict_to_nonedict(opt): 94 | if isinstance(opt, dict): 95 | new_opt = dict() 96 | for key, sub_opt in opt.items(): 97 | new_opt[key] = dict_to_nonedict(sub_opt) 98 | return NoneDict(**new_opt) 99 | elif isinstance(opt, list): 100 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 101 | else: 102 | return opt 103 | 104 | 105 | def check_resume(opt, resume_iter): 106 | """Check resume states and pretrain_model paths""" 107 | logger = logging.getLogger("base") 108 | if opt["path"]["resume_state"]: 109 | if ( 110 | opt["path"].get("pretrain_model_G", None) is not None 111 | or opt["path"].get("pretrain_model_D", None) is not None 112 | ): 113 | logger.warning( 114 | "pretrain_model path will be ignored \ 115 | when resuming training." 116 | ) 117 | 118 | opt["path"]["pretrain_model_G"] = osp.join(opt["path"]["models"], "{}_G.pth".format(resume_iter)) 119 | logger.info("Set [pretrain_model_G] to " + opt["path"]["pretrain_model_G"]) 120 | if "gan" in opt["model"]: 121 | opt["path"]["pretrain_model_D"] = osp.join(opt["path"]["models"], "{}_D.pth".format(resume_iter)) 122 | logger.info("Set [pretrain_model_D] to " + opt["path"]["pretrain_model_D"]) 123 | -------------------------------------------------------------------------------- /bkse/requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.4.0 2 | torchvision >= 0.5.0 3 | pyyaml 4 | opencv-python 5 | numpy 6 | lmdb 7 | tqdm 8 | tensorboard >= 1.15.0 9 | ninja 10 | -------------------------------------------------------------------------------- /bkse/scripts/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import pickle 4 | import sys 5 | from multiprocessing import Pool 6 | 7 | import cv2 8 | import lmdb 9 | 10 | 11 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 12 | import data.util as data_util # noqa: E402 13 | import utils.util as util # noqa: E402 14 | 15 | 16 | def read_image_worker(path, key): 17 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 18 | return (key, img) 19 | 20 | 21 | def create_dataset(name, img_folder, lmdb_save_path, H_dst, W_dst, C_dst): 22 | """Create lmdb for the dataset, each image with a fixed size 23 | key pattern: folder_frameid 24 | """ 25 | # configurations 26 | read_all_imgs = False # whether real all images to memory with multiprocessing 27 | # Set False for use limited memory 28 | BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False 29 | n_thread = 40 30 | ######################################################## 31 | if not lmdb_save_path.endswith(".lmdb"): 32 | raise ValueError("lmdb_save_path must end with 'lmdb'.") 33 | if osp.exists(lmdb_save_path): 34 | print("Folder [{:s}] already exists. Exit...".format(lmdb_save_path)) 35 | sys.exit(1) 36 | 37 | # read all the image paths to a list 38 | print("Reading image path list ...") 39 | all_img_list = data_util._get_paths_from_images(img_folder) 40 | keys = [] 41 | for img_path in all_img_list: 42 | split_rlt = img_path.split("/") 43 | folder = split_rlt[-2] 44 | img_name = split_rlt[-1].split(".png")[0] 45 | keys.append(folder + "_" + img_name) 46 | 47 | if read_all_imgs: 48 | # read all images to memory (multiprocessing) 49 | dataset = {} # store all image data. list cannot keep the order, use dict 50 | print("Read images with multiprocessing, #thread: {} ...".format(n_thread)) 51 | pbar = util.ProgressBar(len(all_img_list)) 52 | 53 | def mycallback(arg): 54 | """get the image data and update pbar""" 55 | key = arg[0] 56 | dataset[key] = arg[1] 57 | pbar.update("Reading {}".format(key)) 58 | 59 | pool = Pool(n_thread) 60 | for path, key in zip(all_img_list, keys): 61 | pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) 62 | pool.close() 63 | pool.join() 64 | print("Finish reading {} images.\nWrite lmdb...".format(len(all_img_list))) 65 | 66 | # create lmdb environment 67 | data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes 68 | print("data size per image is: ", data_size_per_img) 69 | data_size = data_size_per_img * len(all_img_list) 70 | env = lmdb.open(lmdb_save_path, map_size=data_size * 10) 71 | 72 | # write data to lmdb 73 | pbar = util.ProgressBar(len(all_img_list)) 74 | txn = env.begin(write=True) 75 | for idx, (path, key) in enumerate(zip(all_img_list, keys)): 76 | pbar.update("Write {}".format(key)) 77 | key_byte = key.encode("ascii") 78 | data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) 79 | 80 | assert len(data.shape) > 2 or C_dst == 1, "different shape" 81 | 82 | if C_dst == 1: 83 | H, W = data.shape 84 | assert H == H_dst and W == W_dst, "different shape." 85 | else: 86 | H, W, C = data.shape 87 | assert H == H_dst and W == W_dst and C == 3, "different shape." 88 | txn.put(key_byte, data) 89 | if not read_all_imgs and idx % BATCH == 0: 90 | txn.commit() 91 | txn = env.begin(write=True) 92 | txn.commit() 93 | env.close() 94 | print("Finish writing lmdb.") 95 | 96 | # create meta information 97 | meta_info = {} 98 | meta_info["name"] = name 99 | channel = C_dst 100 | meta_info["resolution"] = "{}_{}_{}".format(channel, H_dst, W_dst) 101 | meta_info["keys"] = keys 102 | pickle.dump(meta_info, open(osp.join(lmdb_save_path, "meta_info.pkl"), "wb")) 103 | print("Finish creating lmdb meta info.") 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser(description="Kernel extractor testing") 108 | 109 | parser.add_argument("--H", action="store", help="source image height", type=int, required=True) 110 | parser.add_argument("--W", action="store", help="source image height", type=int, required=True) 111 | parser.add_argument("--C", action="store", help="source image height", type=int, required=True) 112 | parser.add_argument("--img_folder", action="store", help="img folder", type=str, required=True) 113 | parser.add_argument("--save_path", action="store", help="save path", type=str, default=".") 114 | parser.add_argument("--name", action="store", help="dataset name", type=str, required=True) 115 | 116 | args = parser.parse_args() 117 | create_dataset(args.name, args.img_folder, args.save_path, args.H, args.W, args.C) 118 | -------------------------------------------------------------------------------- /bkse/scripts/download_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | import requests 6 | 7 | 8 | def download_file_from_google_drive(file_id, destination): 9 | os.makedirs(osp.dirname(destination), exist_ok=True) 10 | URL = "https://docs.google.com/uc?export=download" 11 | 12 | session = requests.Session() 13 | 14 | response = session.get(URL, params={"id": file_id}, stream=True) 15 | token = get_confirm_token(response) 16 | 17 | if token: 18 | params = {"id": file_id, "confirm": token} 19 | response = session.get(URL, params=params, stream=True) 20 | 21 | save_response_content(response, destination) 22 | 23 | 24 | def get_confirm_token(response): 25 | for key, value in response.cookies.items(): 26 | if key.startswith("download_warning"): 27 | return value 28 | 29 | return None 30 | 31 | 32 | def save_response_content(response, destination): 33 | CHUNK_SIZE = 32768 34 | 35 | with open(destination, "wb") as f: 36 | for chunk in response.iter_content(CHUNK_SIZE): 37 | if chunk: # filter out keep-alive new chunks 38 | f.write(chunk) 39 | 40 | 41 | if __name__ == "__main__": 42 | dataset_ids = { 43 | "GOPRO_Large": "1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2", 44 | "train_sharp": "1YLksKtMhd2mWyVSkvhDaDLWSc1qYNCz-", 45 | "train_blur": "1Be2cgzuuXibcqAuJekDgvHq4MLYkCgR8", 46 | "val_sharp": "1MGeObVQ1-Z29f-myDP7-8c3u0_xECKXq", 47 | "val_blur": "1N8z2yD0GDWmh6U4d4EADERtcUgDzGrHx", 48 | "test_blur": "1dr0--ZBKqr4P1M8lek6JKD1Vd6bhhrZT", 49 | } 50 | 51 | parser = argparse.ArgumentParser( 52 | description="Download REDS dataset from google drive to current folder", allow_abbrev=False 53 | ) 54 | 55 | parser.add_argument("--REDS_train_sharp", action="store_true", help="download REDS train_sharp.zip") 56 | parser.add_argument("--REDS_train_blur", action="store_true", help="download REDS train_blur.zip") 57 | parser.add_argument("--REDS_val_sharp", action="store_true", help="download REDS val_sharp.zip") 58 | parser.add_argument("--REDS_val_blur", action="store_true", help="download REDS val_blur.zip") 59 | parser.add_argument("--GOPRO", action="store_true", help="download GOPRO_Large.zip") 60 | 61 | args = parser.parse_args() 62 | 63 | if args.REDS_train_sharp: 64 | download_file_from_google_drive(dataset_ids["train_sharp"], "REDS/train_sharp.zip") 65 | if args.REDS_train_blur: 66 | download_file_from_google_drive(dataset_ids["train_blur"], "REDS/train_blur.zip") 67 | if args.REDS_val_sharp: 68 | download_file_from_google_drive(dataset_ids["val_sharp"], "REDS/val_sharp.zip") 69 | if args.REDS_val_blur: 70 | download_file_from_google_drive(dataset_ids["val_blur"], "REDS/val_blur.zip") 71 | if args.GOPRO: 72 | download_file_from_google_drive(dataset_ids["GOPRO_Large"], "GOPRO/GOPRO.zip") 73 | -------------------------------------------------------------------------------- /bkse/train_script.sh: -------------------------------------------------------------------------------- 1 | python3.7 train.py -opt options/REDS/wsharp_woVAE.yml 2 | -------------------------------------------------------------------------------- /bkse/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/bkse/utils/__init__.py -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.distributed as dist 3 | from torch.utils.data import Dataset 4 | 5 | #SELENE 6 | from datasets.ffhq import get_ffhq_dataset, get_ffhq_loader 7 | from datasets.imagenet import get_imagenet_dataset, get_imagenet_loader 8 | from utils.distributed import get_logger 9 | 10 | #LOCAL MACHINE 11 | # from pgdm.datasets.ffhq import get_ffhq_dataset, get_ffhq_loader 12 | # from pgdm.datasets.imagenet import get_imagenet_dataset, get_imagenet_loader 13 | # from pgdm.utils.distributed import get_logger 14 | 15 | 16 | 17 | class ZipDataset(Dataset): 18 | def __init__(self, datasets): 19 | self.datasets = datasets 20 | assert all(len(dataset) == len(datasets[0]) for dataset in datasets) 21 | 22 | def __getitem__(self, index): 23 | return [dataset[index] for dataset in self.datasets] 24 | 25 | def __len__(self): 26 | return len(self.datasets[0]) 27 | 28 | 29 | def build_one_dataset(cfg, dataset_attr='dataset'): 30 | logger = get_logger('dataset', cfg) 31 | exp_root = cfg.exp.root 32 | cfg_dataset = getattr(cfg, dataset_attr) 33 | try: 34 | samples_root = cfg.exp.samples_root 35 | exp_name = cfg.exp.name 36 | samples_root = os.path.join(exp_root, samples_root, exp_name) 37 | except Exception: 38 | samples_root = '' 39 | logger.info('Does not attempt to prune existing samples (overwrite=False).') 40 | if "ImageNet" in cfg_dataset.name: 41 | overwrite = getattr(cfg.exp, 'overwrite', True) 42 | dset = get_imagenet_dataset(overwrite=overwrite, samples_root=samples_root, **cfg_dataset) 43 | dist.barrier() 44 | if "FFHQ" in cfg_dataset.name: 45 | dset = get_ffhq_dataset(**cfg_dataset) 46 | 47 | return dset 48 | 49 | 50 | def build_loader(cfg, dataset_attr='dataset'): 51 | 52 | if type(dataset_attr) == list: 53 | dsets = [] 54 | for da in dataset_attr: 55 | cfg_dataset = getattr(cfg, da) 56 | dset = build_one_dataset(cfg, dataset_attr=da) 57 | dsets.append(dset) 58 | dsets = ZipDataset(dsets) 59 | if "ImageNet" in cfg_dataset.name: 60 | loader = get_imagenet_loader(dsets, **cfg.loader) 61 | elif "FFHQ" in cfg_dataset.name: 62 | loader = get_ffhq_loader(dsets, **cfg.loader) 63 | else: 64 | cfg_dataset = getattr(cfg, dataset_attr) 65 | dset = build_one_dataset(cfg, dataset_attr=dataset_attr) 66 | if "ImageNet" in cfg_dataset.name: 67 | loader = get_imagenet_loader(dset, **cfg.loader) 68 | elif "FFHQ" in cfg_dataset.name: 69 | loader = get_ffhq_loader(dset, **cfg.loader) 70 | 71 | 72 | return loader 73 | -------------------------------------------------------------------------------- /datasets/ffhq.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | from torch.utils.data import DataLoader, DistributedSampler, Subset 3 | from .lmdb_dataset import LMDBDataset 4 | 5 | 6 | 7 | 8 | def get_ffhq_dataset(root, split, transform='default', subset=-1, **kwargs): 9 | if transform == 'default': 10 | transform = transforms.ToTensor() 11 | dset = LMDBDataset(root, 'ffhq', split, transform) 12 | if isinstance(subset, int) and subset > 0: 13 | dset = Subset(dset, list(range(subset))) 14 | else: 15 | assert isinstance(subset, list) 16 | dset = Subset(dset, subset) 17 | return dset 18 | 19 | 20 | def get_ffhq_loader(dset, *, batch_size, num_workers, shuffle, drop_last, pin_memory, **kwargs): 21 | sampler = DistributedSampler(dset, shuffle=shuffle, drop_last=drop_last) 22 | loader = DataLoader( 23 | dset, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, sampler=sampler, pin_memory=pin_memory, persistent_workers=True 24 | ) 25 | return loader 26 | 27 | -------------------------------------------------------------------------------- /datasets/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | import lmdb 4 | import os 5 | import io 6 | from PIL import Image 7 | 8 | 9 | def num_samples(dataset, train): 10 | if dataset == 'celeba': 11 | return 27000 if train else 3000 12 | elif dataset == 'celeba64': 13 | return 162770 if train else 19867 14 | elif dataset == 'imagenet-oord': 15 | return 1281147 if train else 50000 16 | elif dataset == 'ffhq': 17 | return 63000 if train else 7000 18 | else: 19 | raise NotImplementedError('dataset %s is unknown' % dataset) 20 | 21 | 22 | class LMDBDataset(data.Dataset): 23 | def __init__(self, root, name='', split='train', transform=None, is_encoded=False): 24 | self.name = name 25 | self.transform = transform 26 | self.split = split 27 | if self.split == 'train': 28 | lmdb_path = os.path.join(root, 'train.lmdb') 29 | elif self.split == 'val': 30 | lmdb_path = os.path.join(root, 'validation.lmdb') 31 | else: 32 | lmdb_path = os.path.join(f'{root}.lmdb') 33 | 34 | self.data_lmdb = lmdb.open(lmdb_path, readonly=True, max_readers=1, 35 | lock=False, readahead=False, meminit=False) 36 | self.is_encoded = is_encoded 37 | 38 | def __getitem__(self, index): 39 | target = 0 40 | with self.data_lmdb.begin(write=False, buffers=True) as txn: 41 | data = txn.get(str(index).encode()) 42 | if self.is_encoded: 43 | img = Image.open(io.BytesIO(data)) 44 | img = img.convert('RGB') 45 | else: 46 | img = np.asarray(data, dtype=np.uint8) 47 | # assume data is RGB 48 | size = int(np.sqrt(len(img) / 3)) 49 | img = np.reshape(img, (size, size, 3)) 50 | img = Image.fromarray(img, mode='RGB') 51 | 52 | if self.transform is not None: 53 | img = self.transform(img) 54 | 55 | return img, target, {'index': index} 56 | 57 | def __len__(self): 58 | if hasattr(self, 'length'): 59 | return self.length 60 | else: 61 | with self.data_lmdb.begin() as txn: 62 | self.length = txn.stat()['entries'] 63 | return self.length 64 | -------------------------------------------------------------------------------- /demo/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/demo/output.gif -------------------------------------------------------------------------------- /eval/ca.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | import torch.nn.functional as F 9 | from cleanfid.features import build_feature_extractor 10 | from hydra.core.hydra_config import HydraConfig 11 | from omegaconf import DictConfig, OmegaConf 12 | from PIL import Image 13 | 14 | from torchvision.transforms import Compose, ToTensor 15 | from tqdm import tqdm 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | from datasets import build_loader 18 | import torchvision.transforms as transforms 19 | from utils.distributed import get_logger, init_processes, get_results_file 20 | 21 | from utils.functions import accuracy 22 | 23 | 24 | import sys 25 | import torchvision 26 | 27 | # print(sys.version) 28 | # print(torch.__version__) 29 | # print(torchvision.__version__) 30 | 31 | 32 | def main(cfg: DictConfig): 33 | torch.hub.set_dir(os.path.join(cfg.exp.root, 'hub')) 34 | torch.cuda.set_device(dist.get_rank()) 35 | logger = get_logger("ca", cfg) 36 | exp_root = os.path.join(cfg.exp.root, "fid_stats") 37 | os.makedirs(exp_root, exist_ok=True) 38 | loader = build_loader(cfg) 39 | model = torch.hub.load("pytorch/vision:v0.13.1", "resnet50", weights='IMAGENET1K_V1').cuda() #, force_reload=True 40 | model = DDP(model, device_ids=[dist.get_rank()], output_device=[dist.get_rank()]) 41 | model.eval() 42 | top1, top5 = 0, 0 43 | count = 0 44 | logger.info(f'A total of {len(loader.dataset)} images are processed.') 45 | for x, y, info in tqdm(loader): 46 | n, c, h, w = x.size() 47 | with torch.no_grad(): 48 | x = x.cuda() 49 | y = y.cuda() 50 | y_ = model(x) 51 | t1, t5 = accuracy(y_, y, topk=(1, 5)) 52 | top1 += t1 * n 53 | top5 += t5 * n 54 | count += n 55 | 56 | features = torch.tensor([top1, top5, count]).cpu() 57 | features_list = [torch.zeros_like(features) for i in range(dist.get_world_size())] 58 | dist.gather(features, features_list, dst=0) 59 | 60 | if dist.get_rank() == 0: 61 | features = torch.stack(features_list, dim=1) 62 | top1_tot = torch.sum(features[0], dim=0).item() 63 | top5_tot = torch.sum(features[1], dim=0).item() 64 | count_tot = torch.sum(features[2], dim=0).item() 65 | 66 | logger.info(f"Top1: {top1_tot / count_tot}, Top5: {top5_tot / count_tot}.") 67 | results_file = get_results_file(cfg, logger) 68 | 69 | with open(results_file, 'a') as f: 70 | f.write(f"Total: {count_tot}\nTop1: {top1_tot / count_tot}\nTop5: {top5_tot / count_tot}\n") 71 | 72 | dist.barrier() 73 | 74 | 75 | @hydra.main(version_base="1.2", config_path="_configs", config_name="ca") 76 | def main_dist(cfg: DictConfig): 77 | cwd = HydraConfig.get().runtime.output_dir 78 | 79 | if cfg.dist.num_processes_per_node < 0: 80 | size = torch.cuda.device_count() 81 | cfg.dist.num_processes_per_node = size 82 | else: 83 | size = cfg.dist.num_processes_per_node 84 | if size > 1: 85 | num_proc_node = cfg.dist.num_proc_node 86 | num_process_per_node = cfg.dist.num_processes_per_node 87 | world_size = num_proc_node * num_process_per_node 88 | mp.spawn( 89 | init_processes, args=(world_size, main, cfg, cwd), nprocs=world_size, join=True, 90 | ) 91 | else: 92 | init_processes(0, size, main, cfg, cwd) 93 | 94 | 95 | if __name__ == "__main__": 96 | main_dist() 97 | -------------------------------------------------------------------------------- /eval/fid.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | import numpy as np 5 | from cleanfid.fid import frechet_distance, kernel_distance 6 | from hydra.core.hydra_config import HydraConfig 7 | from omegaconf import DictConfig 8 | import torch.distributed as dist 9 | from utils.distributed import get_logger, init_processes, get_results_file 10 | 11 | 12 | def main(cfg: DictConfig): 13 | logger = get_logger("main", cfg) 14 | features1 = np.load(cfg.path1) 15 | features2 = np.load(cfg.path2) 16 | 17 | # print('cfg.path1', cfg.path1) 18 | # print('features1', features1) 19 | 20 | if 'npy' in cfg.path1: 21 | mu1 = np.mean(features1, axis=0) 22 | sigma1 = np.cov(features1, rowvar=False) 23 | else: 24 | mu1, sigma1 = features1['mu'], features1['sigma'] 25 | 26 | if 'npy' in cfg.path2: 27 | mu2 = np.mean(features2, axis=0) 28 | sigma2 = np.cov(features2, rowvar=False) 29 | else: 30 | mu2, sigma2 = features2['mu'], features2['sigma'] 31 | 32 | results_file = get_results_file(cfg, logger) 33 | 34 | fid = frechet_distance(mu1, sigma1, mu2, sigma2) 35 | logger.info(f"FID: {fid:.4f}") 36 | with open(results_file, 'a') as f: 37 | f.write(f"FID: {fid}\n") 38 | 39 | if 'npy' in cfg.path1 and 'npy' in cfg.path2: 40 | kid = kernel_distance(features1, features2) 41 | logger.info(f"KIDx10^3: {kid * 1000:.7f}") 42 | with open(results_file, 'a') as f: 43 | f.write(f"KID: {kid}\n") 44 | 45 | 46 | @hydra.main(version_base="1.2", config_path="_configs", config_name="fid") 47 | def main_dist(cfg: DictConfig): 48 | cwd = HydraConfig.get().runtime.output_dir 49 | init_processes(0, 1, main, cfg, cwd) 50 | #print('cfg', cfg) 51 | 52 | 53 | if __name__ == "__main__": 54 | main_dist() 55 | -------------------------------------------------------------------------------- /eval/inception_score.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.distributed as dist 4 | import torch.multiprocessing as mp 5 | from hydra.core.hydra_config import HydraConfig 6 | from omegaconf import DictConfig 7 | from tqdm import tqdm 8 | 9 | import torch_fidelity 10 | 11 | from datasets import build_loader 12 | from utils.distributed import get_logger, get_results_file, init_processes 13 | 14 | 15 | class InceptionDataset(torch.utils.data.Dataset): 16 | def __init__(self, dset): 17 | super(InceptionDataset, self).__init__() 18 | self.dset = dset 19 | 20 | def __len__(self): 21 | return self.dset.__len__() 22 | 23 | def __getitem__(self, index: int): 24 | return self.dset[index][0] 25 | 26 | 27 | def main(cfg: DictConfig): 28 | logger = get_logger("inception_score", cfg) 29 | loader = build_loader(cfg) 30 | 31 | dset = InceptionDataset(loader.dataset) 32 | metrics = torch_fidelity.calculate_metrics(input1=dset, cuda=True, isc=True, verbose=True, samples_find_deep=True) 33 | isc_mean = metrics['inception_score_mean'] 34 | isc_std = metrics['inception_score_std'] 35 | 36 | if dist.get_rank() == 0: 37 | results_file = get_results_file(cfg, logger) 38 | logger.info(f"IS: {isc_mean} +/- {isc_std}") 39 | 40 | with open(results_file, 'a') as f: 41 | f.write(f'IS_mean: {isc_mean}') 42 | f.write(f'IS_std: {isc_std}') 43 | 44 | dist.barrier() 45 | 46 | 47 | @hydra.main(version_base="1.2", config_path="_configs", config_name="inception_score") 48 | def main_dist(cfg: DictConfig): 49 | cwd = HydraConfig.get().runtime.output_dir 50 | 51 | if cfg.dist.num_processes_per_node < 0: 52 | size = torch.cuda.device_count() 53 | cfg.dist.num_processes_per_node = size 54 | else: 55 | size = cfg.dist.num_processes_per_node 56 | if size > 1: 57 | num_proc_node = cfg.dist.num_proc_node 58 | num_process_per_node = cfg.dist.num_processes_per_node 59 | world_size = num_proc_node * num_process_per_node 60 | mp.spawn( 61 | init_processes, args=(world_size, main, cfg, cwd), nprocs=world_size, join=True, 62 | ) 63 | else: 64 | init_processes(0, size, main, cfg, cwd) 65 | 66 | 67 | if __name__ == "__main__": 68 | main_dist() 69 | -------------------------------------------------------------------------------- /eval/psnr.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.distributed as dist 4 | import torch.multiprocessing as mp 5 | from hydra.core.hydra_config import HydraConfig 6 | from omegaconf import DictConfig 7 | from tqdm import tqdm 8 | from torchmetrics.functional import structural_similarity_index_measure #StructuralSimilarityIndexMeasure 9 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 10 | 11 | from datasets import build_loader 12 | from utils.distributed import get_logger, get_results_file, init_processes 13 | 14 | 15 | def main(cfg: DictConfig): 16 | logger = get_logger("psnr", cfg) 17 | loader = build_loader(cfg, ["dataset1", "dataset2"]) 18 | psnrs = [] 19 | ssims = [] 20 | lpips = [] 21 | 22 | for b1, b2 in tqdm(loader): 23 | x1, x2 = b1[0], b2[0] 24 | x1 = x1.cuda() 25 | x2 = x2.cuda() 26 | mse = torch.mean((x1 - x2) ** 2, dim=(1, 2, 3)) 27 | psnr = 10 * torch.log10(1 / (mse + 1e-10)) 28 | ssim = structural_similarity_index_measure(x2, x1, reduction=None) 29 | with torch.no_grad(): 30 | lpip = LearnedPerceptualImagePatchSimilarity().cuda()(x2, x1) 31 | psnrs.append(psnr) 32 | ssims.append(ssim) 33 | lpips.append(lpip.item()) 34 | 35 | 36 | psnrs = torch.cat(psnrs, dim=0) 37 | ssims = torch.cat(ssims, dim=0) 38 | lpips = torch.tensor(lpips).cuda() 39 | 40 | psnrs_list = [torch.zeros_like(psnrs) for i in range(dist.get_world_size())] 41 | ssims_list = [torch.zeros_like(ssims) for i in range(dist.get_world_size())] 42 | lpips_list = [torch.zeros_like(lpips) for i in range(dist.get_world_size())] 43 | dist.gather(psnrs, psnrs_list, dst=0) 44 | dist.gather(ssims, ssims_list, dst=0) 45 | dist.gather(lpips, lpips_list, dst=0) 46 | 47 | if dist.get_rank() == 0: 48 | results_file = get_results_file(cfg, logger) 49 | psnrs = torch.cat(psnrs_list, dim=0) 50 | ssims = torch.cat(ssims_list, dim=0) 51 | lpips = torch.cat(lpips_list, dim=0) 52 | logger.info(f"PSNR: {psnrs.mean().item()} +/- {psnrs.std().item()}") 53 | logger.info(f"SSIM: {ssims.mean().item()}") 54 | logger.info(f"LPIPS: {lpips.mean().item()}") 55 | 56 | with open(results_file, 'a') as f: 57 | f.write(f'PSNR: {psnrs.mean().item()}\n') 58 | f.write(f'SSIM: {ssims.mean().item()}\n') 59 | f.write(f"LPIPS: {lpips.mean().item()}") 60 | 61 | dist.barrier() 62 | 63 | 64 | @hydra.main(version_base="1.2", config_path="_configs", config_name="psnr") 65 | def main_dist(cfg: DictConfig): 66 | cwd = HydraConfig.get().runtime.output_dir 67 | 68 | #print('cfg', cfg) 69 | #import pdb; pdb.set_trace() 70 | 71 | if cfg.dist.num_processes_per_node < 0: 72 | size = torch.cuda.device_count() 73 | cfg.dist.num_processes_per_node = size 74 | else: 75 | size = cfg.dist.num_processes_per_node 76 | if size > 1: 77 | num_proc_node = cfg.dist.num_proc_node 78 | num_process_per_node = cfg.dist.num_processes_per_node 79 | world_size = num_proc_node * num_process_per_node 80 | mp.spawn( 81 | init_processes, args=(world_size, main, cfg, cwd), nprocs=world_size, join=True, 82 | ) 83 | else: 84 | init_processes(0, size, main, cfg, cwd) 85 | 86 | 87 | 88 | if __name__ == "__main__": 89 | main_dist() 90 | -------------------------------------------------------------------------------- /misc/dgp_top10.txt: -------------------------------------------------------------------------------- 1 | n01440764/ILSVRC2012_val_00000293.JPEG 0 2 | n01443537/ILSVRC2012_val_00000236.JPEG 1 3 | n01484850/ILSVRC2012_val_00002338.JPEG 2 4 | n01491361/ILSVRC2012_val_00002922.JPEG 3 5 | n01494475/ILSVRC2012_val_00001676.JPEG 4 6 | n01496331/ILSVRC2012_val_00000921.JPEG 5 7 | n01498041/ILSVRC2012_val_00001935.JPEG 6 8 | n01514668/ILSVRC2012_val_00000329.JPEG 7 9 | n01514859/ILSVRC2012_val_00001114.JPEG 8 10 | n01518878/ILSVRC2012_val_00001031.JPEG 9 11 | -------------------------------------------------------------------------------- /misc/dgp_top100.txt: -------------------------------------------------------------------------------- 1 | n01440764/ILSVRC2012_val_00000293.JPEG 0 2 | n01443537/ILSVRC2012_val_00000236.JPEG 1 3 | n01484850/ILSVRC2012_val_00002338.JPEG 2 4 | n01491361/ILSVRC2012_val_00002922.JPEG 3 5 | n01494475/ILSVRC2012_val_00001676.JPEG 4 6 | n01496331/ILSVRC2012_val_00000921.JPEG 5 7 | n01498041/ILSVRC2012_val_00001935.JPEG 6 8 | n01514668/ILSVRC2012_val_00000329.JPEG 7 9 | n01514859/ILSVRC2012_val_00001114.JPEG 8 10 | n01518878/ILSVRC2012_val_00001031.JPEG 9 11 | n01530575/ILSVRC2012_val_00000651.JPEG 10 12 | n01531178/ILSVRC2012_val_00000570.JPEG 11 13 | n01532829/ILSVRC2012_val_00000873.JPEG 12 14 | n01534433/ILSVRC2012_val_00000247.JPEG 13 15 | n01537544/ILSVRC2012_val_00000414.JPEG 14 16 | n01558993/ILSVRC2012_val_00001598.JPEG 15 17 | n01560419/ILSVRC2012_val_00000198.JPEG 16 18 | n01580077/ILSVRC2012_val_00000880.JPEG 17 19 | n01582220/ILSVRC2012_val_00000476.JPEG 18 20 | n01592084/ILSVRC2012_val_00000837.JPEG 19 21 | n01601694/ILSVRC2012_val_00000962.JPEG 20 22 | n01608432/ILSVRC2012_val_00000073.JPEG 21 23 | n01614925/ILSVRC2012_val_00000128.JPEG 22 24 | n01616318/ILSVRC2012_val_00000018.JPEG 23 25 | n01622779/ILSVRC2012_val_00000318.JPEG 24 26 | n01629819/ILSVRC2012_val_00000225.JPEG 25 27 | n01630670/ILSVRC2012_val_00000498.JPEG 26 28 | n01631663/ILSVRC2012_val_00000258.JPEG 27 29 | n01632458/ILSVRC2012_val_00000088.JPEG 28 30 | n01632777/ILSVRC2012_val_00000052.JPEG 29 31 | n01641577/ILSVRC2012_val_00001652.JPEG 30 32 | n01644373/ILSVRC2012_val_00000944.JPEG 31 33 | n01644900/ILSVRC2012_val_00000037.JPEG 32 34 | n01664065/ILSVRC2012_val_00000404.JPEG 33 35 | n01665541/ILSVRC2012_val_00000791.JPEG 34 36 | n01667114/ILSVRC2012_val_00000229.JPEG 35 37 | n01667778/ILSVRC2012_val_00002832.JPEG 36 38 | n01669191/ILSVRC2012_val_00000362.JPEG 37 39 | n01675722/ILSVRC2012_val_00003904.JPEG 38 40 | n01677366/ILSVRC2012_val_00000153.JPEG 39 41 | n01682714/ILSVRC2012_val_00001931.JPEG 40 42 | n01685808/ILSVRC2012_val_00000908.JPEG 41 43 | n01687978/ILSVRC2012_val_00003821.JPEG 42 44 | n01688243/ILSVRC2012_val_00000297.JPEG 43 45 | n01689811/ILSVRC2012_val_00001022.JPEG 44 46 | n01692333/ILSVRC2012_val_00001475.JPEG 45 47 | n01693334/ILSVRC2012_val_00000064.JPEG 46 48 | n01694178/ILSVRC2012_val_00001375.JPEG 47 49 | n01695060/ILSVRC2012_val_00000541.JPEG 48 50 | n01697457/ILSVRC2012_val_00001573.JPEG 49 51 | n01698640/ILSVRC2012_val_00000090.JPEG 50 52 | n01704323/ILSVRC2012_val_00002051.JPEG 51 53 | n01728572/ILSVRC2012_val_00003569.JPEG 52 54 | n01728920/ILSVRC2012_val_00001042.JPEG 53 55 | n01729322/ILSVRC2012_val_00002568.JPEG 54 56 | n01729977/ILSVRC2012_val_00000029.JPEG 55 57 | n01734418/ILSVRC2012_val_00000512.JPEG 56 58 | n01735189/ILSVRC2012_val_00000006.JPEG 57 59 | n01737021/ILSVRC2012_val_00000084.JPEG 58 60 | n01739381/ILSVRC2012_val_00001108.JPEG 59 61 | n01740131/ILSVRC2012_val_00000337.JPEG 60 62 | n01742172/ILSVRC2012_val_00000786.JPEG 61 63 | n01744401/ILSVRC2012_val_00000688.JPEG 62 64 | n01748264/ILSVRC2012_val_00003143.JPEG 63 65 | n01749939/ILSVRC2012_val_00000298.JPEG 64 66 | n01751748/ILSVRC2012_val_00000001.JPEG 65 67 | n01753488/ILSVRC2012_val_00001582.JPEG 66 68 | n01755581/ILSVRC2012_val_00000749.JPEG 67 69 | n01756291/ILSVRC2012_val_00001706.JPEG 68 70 | n01768244/ILSVRC2012_val_00000299.JPEG 69 71 | n01770081/ILSVRC2012_val_00000107.JPEG 70 72 | n01770393/ILSVRC2012_val_00001007.JPEG 71 73 | n01773157/ILSVRC2012_val_00001037.JPEG 72 74 | n01773549/ILSVRC2012_val_00002688.JPEG 73 75 | n01773797/ILSVRC2012_val_00000040.JPEG 74 76 | n01774384/ILSVRC2012_val_00001150.JPEG 75 77 | n01774750/ILSVRC2012_val_00000396.JPEG 76 78 | n01775062/ILSVRC2012_val_00000218.JPEG 77 79 | n01776313/ILSVRC2012_val_00000679.JPEG 78 80 | n01784675/ILSVRC2012_val_00003981.JPEG 79 81 | n01795545/ILSVRC2012_val_00000075.JPEG 80 82 | n01796340/ILSVRC2012_val_00000860.JPEG 81 83 | n01797886/ILSVRC2012_val_00002539.JPEG 82 84 | n01798484/ILSVRC2012_val_00001461.JPEG 83 85 | n01806143/ILSVRC2012_val_00001146.JPEG 84 86 | n01806567/ILSVRC2012_val_00003015.JPEG 85 87 | n01807496/ILSVRC2012_val_00000576.JPEG 86 88 | n01817953/ILSVRC2012_val_00001629.JPEG 87 89 | n01818515/ILSVRC2012_val_00001163.JPEG 88 90 | n01819313/ILSVRC2012_val_00001288.JPEG 89 91 | n01820546/ILSVRC2012_val_00001749.JPEG 90 92 | n01824575/ILSVRC2012_val_00000370.JPEG 91 93 | n01828970/ILSVRC2012_val_00000051.JPEG 92 94 | n01829413/ILSVRC2012_val_00002100.JPEG 93 95 | n01833805/ILSVRC2012_val_00002974.JPEG 94 96 | n01843065/ILSVRC2012_val_00002326.JPEG 95 97 | n01843383/ILSVRC2012_val_00000952.JPEG 96 98 | n01847000/ILSVRC2012_val_00000415.JPEG 97 99 | n01855032/ILSVRC2012_val_00002309.JPEG 98 100 | n01855672/ILSVRC2012_val_00000547.JPEG 99 101 | -------------------------------------------------------------------------------- /misc/palette_jpeg_demo.txt: -------------------------------------------------------------------------------- 1 | n01622779/ILSVRC2012_val_00047963.JPEG 24 2 | n02692877/ILSVRC2012_val_00039339.JPEG 405 3 | n07614500/ILSVRC2012_val_00047647.JPEG 928 4 | n09193705/ILSVRC2012_val_00000002.JPEG 970 5 | n02100877/ILSVRC2012_val_00041364.JPEG 213 6 | n01818515/ILSVRC2012_val_00047547.JPEG 88 7 | n01820546/ILSVRC2012_val_00045836.JPEG 90 8 | n02509815/ILSVRC2012_val_00038134.JPEG 387 -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from matplotlib.pyplot import get 2 | import torch 3 | import torch.distributed as dist 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | from hydra.utils import call 6 | 7 | from utils.checkpoints import ckpt_path_adm 8 | from utils.distributed import get_logger 9 | 10 | 11 | def build_model(cfg): 12 | logger = get_logger("model", cfg) 13 | model = call(cfg.model) 14 | map_location = {"cuda:0": f"cuda:{dist.get_rank()}"} 15 | model_ckpt = ckpt_path_adm(cfg.model.ckpt, cfg) 16 | logger.info(f"Loading model from {model_ckpt}..") 17 | model.load_state_dict(torch.load(model_ckpt, map_location=map_location)) 18 | classifier = call(cfg.classifier) 19 | 20 | if getattr(cfg.classifier, "ckpt", None): 21 | classifier_ckpt = ckpt_path_adm(cfg.classifier.ckpt, cfg) 22 | logger.info(f"Loading classifier from {classifier_ckpt}..") 23 | classifier.load_state_dict(torch.load(classifier_ckpt, map_location=map_location)) 24 | if classifier is not None: 25 | classifier.cuda(dist.get_rank()) 26 | classifier = DDP(classifier, device_ids=[dist.get_rank()], output_device=[dist.get_rank()],) 27 | 28 | model.cuda(dist.get_rank()) 29 | model = DDP(model, device_ids=[dist.get_rank()], output_device=[dist.get_rank()]) 30 | return model, classifier 31 | -------------------------------------------------------------------------------- /models/classifier_guidance_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from omegaconf import DictConfig 6 | from .diffusion import Diffusion 7 | 8 | 9 | class ClassifierGuidanceModel: 10 | def __init__(self, model: nn.Module, classifier: nn.Module, diffusion: Diffusion, cfg: DictConfig): 11 | self.model = model 12 | self.classifier = classifier 13 | self.diffusion = diffusion 14 | self.cfg = cfg 15 | 16 | def __call__(self, xt, y, t, scale=1.0): 17 | # Returns both the noise value (score function scaled) and the predicted x0. 18 | alpha_t = self.diffusion.alpha(t).view(-1, 1, 1, 1) 19 | if self.classifier is None: 20 | et = self.model(xt, t)[:, :3] 21 | else: 22 | et = self.model(xt, t, y)[:, :3] 23 | et = et - (1 - alpha_t).sqrt() * self.cond_fn(xt, y, t, scale=scale) 24 | x0_pred = (xt - et * (1 - alpha_t).sqrt()) / alpha_t.sqrt() 25 | return et, x0_pred 26 | 27 | def cond_fn(self, xt, y, t, scale=1.0): 28 | with torch.enable_grad(): 29 | x_in = xt.detach().requires_grad_(True) 30 | logits = self.classifier(x_in, t) 31 | log_probs = F.log_softmax(logits, dim=-1) 32 | selected = log_probs[range(len(logits)), y.view(-1)] 33 | 34 | scale = scale * self.cfg.classifier.classifier_scale 35 | return torch.autograd.grad(selected.sum(), x_in, create_graph=True)[0] * scale -------------------------------------------------------------------------------- /models/diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Diffusion: 6 | def __init__(self, beta_schedule="linear", beta_start=1e-4, beta_end=2e-2, num_diffusion_timesteps=1000, given_betas=None): 7 | from utils.functions import sigmoid 8 | if given_betas is None: 9 | if beta_schedule == "quad": 10 | betas = ( 11 | np.linspace( 12 | beta_start**0.5, 13 | beta_end**0.5, 14 | num_diffusion_timesteps, 15 | dtype=np.float64, 16 | ) 17 | ** 2 18 | ) 19 | elif beta_schedule == "linear": 20 | betas = np.linspace( 21 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 22 | ) 23 | elif beta_schedule == "const": 24 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 25 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 26 | betas = 1.0 / np.linspace( 27 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 28 | ) 29 | elif beta_schedule == "sigmoid": 30 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 31 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 32 | else: 33 | raise NotImplementedError(beta_schedule) 34 | assert betas.shape == (num_diffusion_timesteps,) 35 | betas = torch.from_numpy(betas) 36 | else: 37 | betas = given_betas 38 | self.betas = torch.cat([torch.zeros(1).to(betas.device), betas], dim=0).cuda().float() 39 | self.alphas = (1 - self.betas).cumprod(dim=0).cuda().float() 40 | self.num_diffusion_timesteps = num_diffusion_timesteps 41 | 42 | def alpha(self, t): 43 | return self.alphas.index_select(0, t+1) 44 | -------------------------------------------------------------------------------- /models/guided_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/models/guided_diffusion/__init__.py -------------------------------------------------------------------------------- /motionblur/.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | .DS_Store 16 | 17 | # Project files 18 | .ropeproject 19 | .project 20 | .pydevproject 21 | .settings 22 | .idea 23 | tags 24 | 25 | # Package files 26 | *.egg 27 | *.eggs/ 28 | .installed.cfg 29 | *.egg-info 30 | 31 | # Unittest and coverage 32 | htmlcov/* 33 | .coverage 34 | .tox 35 | junit.xml 36 | coverage.xml 37 | .pytest_cache/ 38 | 39 | # Build and docs folder/files 40 | build/* 41 | dist/* 42 | sdist/* 43 | docs/api/* 44 | docs/_rst/* 45 | docs/_build/* 46 | cover/* 47 | MANIFEST 48 | 49 | # Per-project virtualenvs 50 | .venv*/ 51 | 52 | # Costom 53 | dataset/* 54 | -------------------------------------------------------------------------------- /motionblur/README.md: -------------------------------------------------------------------------------- 1 | # MotionBlur 2 | 3 | Generate authentic motion blur kernels (point spread functions) and apply them to images en masse. 4 | 5 | Very efficient thanks to numpy's FFT based convolution and the optimised procedural generation of kernels. Intuitive API. 6 | 7 | # Description 8 | 9 | After installation, import the `Kernel` class from `motionblur.py` and use to your liking. 10 | 11 | Here is how: 12 | 13 | Initialise a `Kernel` instance with the parameters `size` (size of kernel matrix in pixels - as a tuple of integers) and `intensity`. 14 | 15 | Intensity determines how non-linear and shaken the motion blur is. It must have a value between 0 and 1. 16 | Zero is a linear motion and 1 a highly non-linear and often self intersecting motion. 17 | 18 | ![Effect of intensity](./intensity.png) 19 | 20 | Once a kernel is initialised, you can utilise a range of properties to make us of it. 21 | 22 | ```python 23 | # Initialise Kernel 24 | kernel = Kernel(size=(100, 100), intensity=0.2) 25 | 26 | # Display kernel 27 | kernel.displayKernel() 28 | 29 | # Get kernel as numpy array 30 | kernel.kernelMatrix 31 | 32 | # Save kernel as image. (Do not show kernel, just save.) 33 | kernel.displayKernel(save_to="./my_file.png", show=False) 34 | 35 | # load image or get image path 36 | image1_path = "./image1.png" 37 | image2 = PIL.Image.open("./image2.png") 38 | 39 | # apply motion blur (returns PIL.Image instance of blurred image) 40 | blurred1 = kernel.applyTo(image1_path) 41 | 42 | blurred2 = kernel.applyTo(image2) 43 | 44 | # if you need the dimension of the blurred image to be the same 45 | # as the original image, pass `keep_image_dim=True` 46 | blurred_same = kernel.applyTo(image2, keep_image_dim=True) 47 | 48 | # show result 49 | blurred1.show() 50 | 51 | # or save to file 52 | blurred2.save("./output2.png", "PNG") 53 | ``` 54 | 55 | 56 | # Installation 57 | 58 | In order to set up the necessary environment: 59 | 60 | 1. create an environment `MotionBlur` with the help of conda, 61 | ``` 62 | conda env create - f environment.yaml 63 | ``` 64 | 2. activate the new environment with 65 | ``` 66 | conda activate MotionBlur 67 | ``` 68 | 69 | Or simply install numpy, pillow and scipy manually. 70 | -------------------------------------------------------------------------------- /motionblur/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/motionblur/__init__.py -------------------------------------------------------------------------------- /motionblur/environment.yaml: -------------------------------------------------------------------------------- 1 | name: MotionBlur 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python>=3.6 7 | - pip 8 | - numpy 9 | - scipy 10 | - Pillow 11 | 12 | # for development only (could also be kept in a separate environment file) 13 | - pytest 14 | - pytest-cov 15 | - tox 16 | - pre_commit 17 | - nbdime 18 | - nbstripout 19 | - sphinx 20 | - recommonmark 21 | -------------------------------------------------------------------------------- /motionblur/example_kernel/kernel0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/motionblur/example_kernel/kernel0.png -------------------------------------------------------------------------------- /motionblur/example_kernel/kernel100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/motionblur/example_kernel/kernel100.png -------------------------------------------------------------------------------- /motionblur/example_kernel/kernel25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/motionblur/example_kernel/kernel25.png -------------------------------------------------------------------------------- /motionblur/example_kernel/kernel50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/motionblur/example_kernel/kernel50.png -------------------------------------------------------------------------------- /motionblur/example_kernel/kernel75.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/motionblur/example_kernel/kernel75.png -------------------------------------------------------------------------------- /motionblur/images/flag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/motionblur/images/flag.png -------------------------------------------------------------------------------- /motionblur/images/flagBLURRED.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/motionblur/images/flagBLURRED.png -------------------------------------------------------------------------------- /motionblur/images/moon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/motionblur/images/moon.png -------------------------------------------------------------------------------- /motionblur/intensity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/motionblur/intensity.png -------------------------------------------------------------------------------- /output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/output.gif -------------------------------------------------------------------------------- /playground/100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/100.png -------------------------------------------------------------------------------- /playground/1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/1000.png -------------------------------------------------------------------------------- /playground/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/20.png -------------------------------------------------------------------------------- /playground/50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/50.png -------------------------------------------------------------------------------- /playground/500.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/500.png -------------------------------------------------------------------------------- /playground/adm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/adm.png -------------------------------------------------------------------------------- /playground/adm0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/adm0.png -------------------------------------------------------------------------------- /playground/adm1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/adm1.png -------------------------------------------------------------------------------- /playground/adm2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/adm2.png -------------------------------------------------------------------------------- /playground/adm3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/adm3.png -------------------------------------------------------------------------------- /playground/adm4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/adm4.png -------------------------------------------------------------------------------- /playground/adm5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/adm5.png -------------------------------------------------------------------------------- /playground/awd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/awd.png -------------------------------------------------------------------------------- /playground/awd_fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/awd_fwd.png -------------------------------------------------------------------------------- /playground/ddrmpp_res/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/ddrmpp_res/0.png -------------------------------------------------------------------------------- /playground/ddrmpp_res/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/ddrmpp_res/10.png -------------------------------------------------------------------------------- /playground/ddrmpp_res/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/ddrmpp_res/2.png -------------------------------------------------------------------------------- /playground/ddrmpp_res/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/ddrmpp_res/20.png -------------------------------------------------------------------------------- /playground/ddrmpp_res/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/ddrmpp_res/3.png -------------------------------------------------------------------------------- /playground/ddrmpp_res/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/ddrmpp_res/7.png -------------------------------------------------------------------------------- /playground/figures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "18f8a659-4110-49ec-b766-bddd36ff6732", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%matplotlib inline\n", 11 | "import os\n", 12 | "import sys\n", 13 | "sys.path.append('../')\n", 14 | "\n", 15 | "import numpy as np\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import torch\n", 18 | "from hydra import compose, initialize\n", 19 | "from omegaconf import OmegaConf\n", 20 | "\n", 21 | "from models.diffusion import Diffusion" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "477b7292-35cf-4b7c-a275-50f4298b0ab8", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [] 31 | } 32 | ], 33 | "metadata": { 34 | "kernelspec": { 35 | "display_name": "Python 3 (ipykernel)", 36 | "language": "python", 37 | "name": "python3" 38 | }, 39 | "language_info": { 40 | "codemirror_mode": { 41 | "name": "ipython", 42 | "version": 3 43 | }, 44 | "file_extension": ".py", 45 | "mimetype": "text/x-python", 46 | "name": "python", 47 | "nbconvert_exporter": "python", 48 | "pygments_lexer": "ipython3", 49 | "version": "3.9.7" 50 | } 51 | }, 52 | "nbformat": 4, 53 | "nbformat_minor": 5 54 | } 55 | -------------------------------------------------------------------------------- /playground/jpeg5_deg/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_deg/0.png -------------------------------------------------------------------------------- /playground/jpeg5_deg/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_deg/10.png -------------------------------------------------------------------------------- /playground/jpeg5_deg/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_deg/2.png -------------------------------------------------------------------------------- /playground/jpeg5_deg/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_deg/20.png -------------------------------------------------------------------------------- /playground/jpeg5_deg/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_deg/3.png -------------------------------------------------------------------------------- /playground/jpeg5_deg/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_deg/7.png -------------------------------------------------------------------------------- /playground/jpeg5_ori/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_ori/0.png -------------------------------------------------------------------------------- /playground/jpeg5_ori/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_ori/10.png -------------------------------------------------------------------------------- /playground/jpeg5_ori/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_ori/2.png -------------------------------------------------------------------------------- /playground/jpeg5_ori/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_ori/20.png -------------------------------------------------------------------------------- /playground/jpeg5_ori/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_ori/3.png -------------------------------------------------------------------------------- /playground/jpeg5_ori/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/jpeg5_ori/7.png -------------------------------------------------------------------------------- /playground/palette_img/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_img/0.jpg -------------------------------------------------------------------------------- /playground/palette_img/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_img/10.jpg -------------------------------------------------------------------------------- /playground/palette_img/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_img/11.jpg -------------------------------------------------------------------------------- /playground/palette_img/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_img/2.jpg -------------------------------------------------------------------------------- /playground/palette_img/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_img/3.jpg -------------------------------------------------------------------------------- /playground/palette_img/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_img/7.jpg -------------------------------------------------------------------------------- /playground/palette_result/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_result/0.png -------------------------------------------------------------------------------- /playground/palette_result/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_result/10.png -------------------------------------------------------------------------------- /playground/palette_result/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_result/11.png -------------------------------------------------------------------------------- /playground/palette_result/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_result/2.png -------------------------------------------------------------------------------- /playground/palette_result/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_result/20.png -------------------------------------------------------------------------------- /playground/palette_result/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_result/3.png -------------------------------------------------------------------------------- /playground/palette_result/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RED-diff/66482f23e242bb31166c3662002d0a6f9f065030/playground/palette_result/7.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | pillow 4 | cython 5 | matplotlib 6 | tensorboard 7 | tensorboardX 8 | tfrecord 9 | gpustat 10 | tqdm 11 | ipdb 12 | jupyterlab 13 | gdown 14 | hydra-core 15 | clean-fid 16 | wandb 17 | pytorch-lightning 18 | sklearn 19 | hydra-joblib-launcher 20 | hydra_colorlog 21 | python-dotenv 22 | jammy 23 | torchmetrics 24 | einops 25 | jax[cpu] 26 | invisible-watermark 27 | omegaconf 28 | opencv-python==4.5.5.64 29 | opencv-python-headless==4.5.5.64 30 | diffusers 31 | transformers 32 | kornia -------------------------------------------------------------------------------- /run_eval.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | root=/lustre/fsw/nvresearch/mmardani/output/latent-diffusion-sampling/pgdm 4 | code=$root 5 | dataset=$SHARE_DATA/imagenet-root 6 | subset_txt=./misc/dgp_top1k.txt 7 | #subset_txt=./misc/dgp_top10.txt 8 | #subset_txt=./misc/dgp_top100.txt 9 | 10 | GPUS=1 11 | 12 | EXPDIR=$root/_exp 13 | SOURCE_FID=$EXPDIR/fid_stats/imagenet256_train_mean_std.npz 14 | SOURCE_KID=$EXPDIR/fid_stats/imagenet256_val_dgp_top1k.npy 15 | 16 | MEAN_STD_STATS=False 17 | SPLIT=custom 18 | 19 | 20 | declare -a models=("imagenet256_uncond") #("imagenet256_uncond" "imagenet256_cond") 21 | 22 | # declare -a degs=("sr4" "in2_20ff") # ("deblur_gauss" "deblur_uni") ("deblur_nl" "hdr" "phase_retrieval") 23 | # declare -a algs=("ddrmppvarinf" "ddrm" "ddrmpp" "dps") #"ddrmppvarinf_parallel" 24 | 25 | # declare -a degs=("deblur_gauss" "deblur_uni") 26 | # declare -a algs=("ddrmppvarinf" "ddrm" "ddrmpp" "dps") #"ddrmppvarinf_parallel" 27 | 28 | declare -a degs=("deblur_nl" "hdr" "phase_retrieval") 29 | declare -a algs=("reddiff" "dps") 30 | 31 | 32 | for MODEL in ${models[@]}; do 33 | for DEG in ${degs[@]}; do 34 | for ALGO in ${algs[@]}; do #mcg 35 | 36 | TARGET_DIR=$MODEL/$ALGO/$DEG 37 | arr=($EXPDIR/samples/$TARGET_DIR/*) # This creates an array of the full paths to all subdirs 38 | arr=("${arr[@]##*/}") 39 | 40 | 41 | for DIR in ${arr[@]}; do 42 | 43 | DIR_PARSE=(${DIR//_/ }) 44 | 45 | ETA=${DIR_PARSE[0]} 46 | STEPS=${DIR_PARSE[1]} 47 | GRAD_WEIGHT=${DIR_PARSE[2]} 48 | GRAD_TYPE=${DIR_PARSE[3]} 49 | 50 | 51 | IDX=$DIR 52 | TARGET=$TARGET_DIR/$IDX 53 | 54 | echo "--------------------------------" 55 | echo $GRAD_WEIGHT 56 | echo $GRAD_TYPE 57 | echo $ETA 58 | echo $ALGO 59 | echo $DEG 60 | echo $IDX 61 | echo $TARGET 62 | echo $STEPS 63 | 64 | 65 | cd $code/eval 66 | 67 | #FID 68 | #FID STATS 69 | python fid_stats.py mean_std_stats=True dist.num_processes_per_node=$GPUS save_path=$TARGET/fid exp.root=$EXPDIR dataset.root=$EXPDIR/samples/$TARGET dataset.meta_root=$dataset dataset.split=$SPLIT dataset.subset_txt=$subset_txt 70 | #FID - MEAN_STD_STATS=True 71 | python fid.py path1=$EXPDIR/fid_stats/$TARGET/fid_mean_std.npz path2=$SOURCE_FID results=$TARGET exp.root=$EXPDIR 72 | 73 | #KID 74 | #KID STATS 75 | python fid_stats.py mean_std_stats=False dist.num_processes_per_node=$GPUS save_path=$TARGET/kid_gen_dgp exp.root=$EXPDIR dataset.root=$EXPDIR/samples/$TARGET dataset.meta_root=$dataset dataset.split=$SPLIT dataset.subset_txt=$subset_txt 76 | #python fid_stats.py mean_std_stats=False dist.num_processes_per_node=$GPUS save_path=$SOURCE_KID exp.root=$EXPDIR dataset.root=$dataset/imagenet/val dataset.meta_root=$dataset dataset.split=$SPLIT dataset.subset_txt=$subset_txt 77 | #KID - MEAN_STD_STATS=False 78 | python fid.py path1=$EXPDIR/fid_stats/$TARGET/kid_gen_dgp.npy path2=$SOURCE_KID results=$TARGET exp.root=$EXPDIR 79 | 80 | #PSNR 81 | python psnr.py dist.num_processes_per_node=$GPUS exp.root=$EXPDIR dataset1.root=$dataset dataset2.root=$EXPDIR/samples/$TARGET dataset2.split=custom dataset1.meta_root=$dataset dataset2.meta_root=$dataset results=$TARGET dataset1.subset_txt=$subset_txt dataset2.subset_txt=$subset_txt 82 | 83 | #top1 accuracy 84 | python ca.py dataset.transform=ca_cropped dataset.root=$EXPDIR/samples/$TARGET dataset.meta_root=$dataset dataset.split=${SPLIT} exp.root=$EXPDIR results=$TARGET dataset.subset_txt=$subset_txt 85 | 86 | 87 | done 88 | done 89 | done 90 | done -------------------------------------------------------------------------------- /sample_batch.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved 2 | 3 | root= 4 | 5 | ETA=1.0 6 | STEPS=1000 7 | MODEL=imagenet256 8 | GRAD_WEIGHT=1.0 9 | 10 | 11 | samples_root=$root/_exp/samples #where to save data 12 | save_deg=False 13 | save_ori=False 14 | overwrite=True 15 | smoke_test=1e5 16 | batch_size=20 #50 17 | num_steps=1000 #$STEPS 18 | 19 | 20 | for DEG in sr4 in2_20ff; do 21 | for ALGO in ddrm pgdm reddiff dps; do 22 | 23 | IDX=${ETA}_${STEPS}_${GRAD_WEIGHT} 24 | TARGET=$MODEL/$ALGO/$DEG/$IDX #debug 25 | 26 | # val=`expr $gpu_idx + 1` 27 | # echo "$gpu_idx + 1 : $val" 28 | 29 | echo $ETA 30 | echo $DEG 31 | echo $ALGO 32 | echo $GRAD_WEIGHT 33 | 34 | #sample 35 | python main.py exp.overwrite=$overwrite algo=$ALGO algo.deg=$DEG algo.eta=$ETA exp.num_steps=$num_steps algo.sigma_y=0.0 loader.batch_size=$batch_size exp.seed=3 loader=imagenet256_ddrmpp dist.num_processes_per_node=1 exp.name=$TARGET exp.save_ori=$save_ori exp.save_deg=$save_deg exp.smoke_test=$smoke_test exp.samples_root=$samples_root algo.grad_term_weight=$GRAD_WEIGHT # algo.awd=True 36 | 37 | done 38 | done 39 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | def return_none(): 2 | return None -------------------------------------------------------------------------------- /utils/checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import requests 4 | from tqdm import tqdm 5 | import gdown 6 | 7 | import torch.distributed as dist 8 | 9 | from .distributed import get_logger 10 | 11 | 12 | URL_MAP = { 13 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1", 14 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1", 15 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1", 16 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1", 17 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1", 18 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1", 19 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1", 20 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1", 21 | "imagenet_256_uncond": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt", 22 | "imagenet_256_cond": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion.pt", 23 | "imagenet_256_classifier": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_classifier.pt", 24 | "imagenet_512_cond": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/512x512_diffusion.pt", 25 | "imagenet_512_classifier": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/512x512_classifier.pt", 26 | "ffhq_256": "https://drive.google.com/uc\?id=117Y6Z6-Hg6TMZVIXMmgYbpZy7QvTXign" 27 | } 28 | CKPT_MAP = { 29 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt", 30 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt", 31 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt", 32 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt", 33 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt", 34 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt", 35 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt", 36 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt", 37 | "imagenet_256_uncond": "imagenet/256x256_diffusion_uncond.pt", 38 | "imagenet_256_cond": "imagenet/256x256_diffusion.pt", 39 | "imagenet_256_classifier": "imagenet/256x256_classifier.pt", 40 | "imagenet_512_classifier": "imagenet/512x512_classifier.pt", 41 | "imagenet_512_cond": "imagenet/512x512_diffusion.pt", 42 | "ffhq_256": "ffhq/ffhq_10m.pt" 43 | } 44 | MD5_MAP = { 45 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669", 46 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3", 47 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c", 48 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f", 49 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b", 50 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558", 51 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3", 52 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f", 53 | } 54 | 55 | 56 | def download(url, local_path, chunk_size=1024): 57 | if dist.get_rank() == 0: 58 | if 'drive.google.com' in url: 59 | gdown.download(url, local_path, quiet=False) 60 | else: 61 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 62 | with requests.get(url, stream=True) as r: 63 | total_size = int(r.headers.get("content-length", 0)) 64 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 65 | with open(local_path, "wb") as f: 66 | for data in r.iter_content(chunk_size=chunk_size): 67 | if data: 68 | f.write(data) 69 | pbar.update(chunk_size) 70 | dist.barrier() 71 | 72 | 73 | def md5_hash(path): 74 | with open(path, "rb") as f: 75 | content = f.read() 76 | return hashlib.md5(content).hexdigest() 77 | 78 | 79 | def get_ckpt_path(name, root=None, check=False, prefix='exp'): 80 | if 'church_outdoor' in name: 81 | name = name.replace('church_outdoor', 'church') 82 | assert name in URL_MAP 83 | # Modify the path when necessary 84 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.join(prefix, "logs/")) 85 | root = ( 86 | root 87 | if root is not None 88 | else os.path.join(cachedir, "diffusion_models_converted") 89 | ) 90 | path = os.path.join(root, CKPT_MAP[name]) 91 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 92 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 93 | download(URL_MAP[name], path) 94 | md5 = md5_hash(path) 95 | assert md5 == MD5_MAP[name], md5 96 | return path 97 | 98 | 99 | def ckpt_path_adm(name, cfg): 100 | logger = get_logger('ckpt', cfg) 101 | 102 | ckpt_root = os.path.join(cfg.exp.root, cfg.exp.ckpt_root) 103 | ckpt = os.path.join(ckpt_root, CKPT_MAP[name]) 104 | if not os.path.exists(ckpt): 105 | logger.info(URL_MAP[name]) 106 | download(URL_MAP[name], ckpt) 107 | return ckpt -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import traceback 4 | 5 | import numpy as np 6 | import torch 7 | import torch.multiprocessing as mp 8 | import torch.distributed as dist 9 | from omegaconf import OmegaConf 10 | from hydra.core.hydra_config import HydraConfig 11 | 12 | 13 | def init_processes(rank, size, fn, cfg, cwd): 14 | """ Initialize the distributed environment. """ 15 | try: 16 | cfg = OmegaConf.create(cfg) 17 | OmegaConf.set_struct(cfg, False) 18 | cfg.cwd = cwd 19 | 20 | os.environ["MASTER_ADDR"] = cfg.dist.master_address 21 | os.environ["MASTER_PORT"] = str(cfg.dist.port) 22 | dist.init_process_group(backend=cfg.dist.backend, init_method="env://", rank=rank, world_size=size) 23 | fn(cfg) 24 | dist.barrier() 25 | dist.destroy_process_group() 26 | except Exception: 27 | logging.error(traceback.format_exc()) 28 | dist.destroy_process_group() 29 | 30 | 31 | def common_init(rank, seed): 32 | # we use different seeds per gpu. But we sync the weights after model initialization. 33 | torch.manual_seed(rank + seed) 34 | np.random.seed(rank + seed) 35 | torch.cuda.manual_seed(rank + seed) 36 | torch.cuda.manual_seed_all(rank + seed) 37 | torch.backends.cudnn.benchmark = True 38 | 39 | 40 | def broadcast_params(params, is_distributed): 41 | if is_distributed: 42 | for param in params: 43 | dist.broadcast(param.data, src=0) 44 | 45 | 46 | def get_logger(name=None, cfg=None): 47 | if dist.get_rank() == 0 or not dist.is_available(): 48 | load_path = os.path.join(cfg.cwd, ".hydra/hydra.yaml") 49 | hydra_conf = OmegaConf.load(load_path) 50 | logging.config.dictConfig(OmegaConf.to_container(hydra_conf.hydra.job_logging, resolve=True)) 51 | return logging.getLogger(name) 52 | 53 | 54 | def get_results_file(cfg, logger): 55 | if dist.get_rank() == 0 or not dist.is_available(): 56 | results_root = os.path.join(cfg.exp.root, 'results') 57 | os.makedirs(results_root, exist_ok=True) 58 | if '/' in cfg.results: 59 | results_dir = '/'.join(cfg.results.split('/')[:-1]) 60 | results_dir = os.path.join(results_root, results_dir) 61 | logger.info(f'Creating directory {results_dir}') 62 | os.makedirs(results_dir, exist_ok=True) 63 | results_file = f'{results_root}/{cfg.results}.yaml' 64 | return results_file 65 | 66 | 67 | def distributed(func): 68 | def wrapper(cfg): 69 | cwd = HydraConfig.get().runtime.output_dir 70 | if cfg.dist.num_processes_per_node < 0: 71 | size = torch.cuda.device_count() 72 | cfg.dist.num_processes_per_node = size 73 | else: 74 | size = cfg.dist.num_processes_per_node 75 | if size > 1: 76 | num_proc_node = cfg.dist.num_proc_node 77 | num_process_per_node = cfg.dist.num_processes_per_node 78 | world_size = num_proc_node * num_process_per_node 79 | mp.spawn( 80 | init_processes, args=(world_size, func, cfg, cwd), nprocs=world_size, join=True, 81 | ) 82 | else: 83 | init_processes(0, size, func, cfg, cwd) 84 | 85 | return wrapper -------------------------------------------------------------------------------- /utils/fft_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.fft import fftshift, ifftshift 3 | 4 | def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 5 | """ 6 | Apply centered 2 dimensional Fast Fourier Transform. 7 | Args: 8 | data: Complex valued input data containing at least 3 dimensions: 9 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 10 | 2. All other dimensions are assumed to be batch dimensions. 11 | norm: Normalization mode. See ``torch.fft.fft``. 12 | Returns: 13 | The FFT of the input. 14 | """ 15 | if not data.shape[-1] == 2: 16 | raise ValueError("Tensor does not have separate complex dim.") 17 | 18 | data = ifftshift(data, dim=[-3, -2]) 19 | data = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(data), dim=(-2, -1), norm=norm)) 20 | data = fftshift(data, dim=[-3, -2]) 21 | 22 | return data 23 | 24 | 25 | def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 26 | """ 27 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 28 | Args: 29 | data: Complex valued input data containing at least 3 dimensions: 30 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 31 | 2. All other dimensions are assumed to be batch dimensions. 32 | norm: Normalization mode. See ``torch.fft.ifft``. 33 | Returns: 34 | The IFFT of the input. 35 | """ 36 | if not data.shape[-1] == 2: 37 | raise ValueError("Tensor does not have separate complex dim.") 38 | 39 | data = ifftshift(data, dim=[-3, -2]) 40 | data = torch.view_as_real( 41 | torch.fft.ifftn( # type: ignore 42 | torch.view_as_complex(data), dim=(-2, -1), norm=norm 43 | ) 44 | ) 45 | data = fftshift(data, dim=[-3, -2]) 46 | 47 | return data 48 | 49 | 50 | def fft2_m(x): 51 | """ FFT for multi-coil """ 52 | if not torch.is_complex(x): 53 | x = x.type(torch.complex64) 54 | return torch.view_as_complex(fft2c_new(torch.view_as_real(x))) 55 | 56 | 57 | def ifft2_m(x): 58 | """ IFFT for multi-coil """ 59 | if not torch.is_complex(x): 60 | x = x.type(torch.complex64) 61 | return torch.view_as_complex(ifft2c_new(torch.view_as_real(x))) -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def sigmoid(x): 6 | return 1 / (np.exp(-x) + 1) 7 | 8 | 9 | def postprocess(x): 10 | if type(x) == list: 11 | return [(v + 1) / 2 for v in x] 12 | else: 13 | return (x + 1) / 2 14 | 15 | 16 | def preprocess(x): 17 | return x * 2 - 1 18 | 19 | 20 | def get_timesteps(cfg): 21 | skip = (cfg.exp.start_step - cfg.exp.end_step) // cfg.exp.num_steps 22 | ts = list(range(cfg.exp.end_step, cfg.exp.start_step, skip)) 23 | 24 | return ts 25 | 26 | 27 | def strfdt(dt): 28 | days = dt.days 29 | hours, rem = divmod(dt.seconds, 3600) 30 | minutes, seconds = divmod(rem, 60) 31 | milliseconds, _ = divmod(dt.microseconds, 1000) 32 | 33 | if days > 0: 34 | s = f"{days:3d}-" 35 | else: 36 | s = " " 37 | s += f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}" 38 | return s 39 | 40 | 41 | def accuracy(output, target, topk=(1,)): 42 | """Computes the accuracy over the k top predictions for the specified values of k""" 43 | with torch.no_grad(): 44 | maxk = max(topk) 45 | batch_size = target.size(0) 46 | 47 | _, pred = output.topk(maxk, 1, True, True) 48 | pred = pred.t() 49 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 50 | 51 | res = [] 52 | for k in topk: 53 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 54 | res.append(correct_k.mul_(100.0 / batch_size)) 55 | return res -------------------------------------------------------------------------------- /utils/save.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lmdb 3 | import numpy as np 4 | import torch 5 | import torchvision.utils as tvu 6 | import torch.distributed as dist 7 | 8 | 9 | def save_imagenet_result(x, y, info, samples_root, suffix=""): 10 | 11 | if len(x.shape) == 3: 12 | n=1 13 | else: 14 | n = x.size(0) 15 | 16 | 17 | for i in range(n): 18 | #print('info["class_id"][i]', info["class_id"][i]) 19 | class_dir = os.path.join(samples_root, info["class_id"][i]) 20 | #print('class_dir', class_dir) 21 | os.makedirs(class_dir, exist_ok=True) 22 | for i in range(n): 23 | if len(suffix) > 0: 24 | tvu.save_image(x[i], os.path.join(samples_root, info["class_id"][i], f'{info["name"][i]}_{suffix}.png')) 25 | else: 26 | tvu.save_image(x[i], os.path.join(samples_root, info["class_id"][i], f'{info["name"][i]}.png')) 27 | 28 | dist.barrier() 29 | 30 | 31 | def save_ffhq_result(x, y, info, samples_root, suffix=""): 32 | x_list = [torch.zeros_like(x) for i in range(dist.get_world_size())] 33 | idx = info['index'] 34 | idx_list = [torch.zeros_like(idx) for i in range(dist.get_world_size())] 35 | dist.gather(x, x_list, dst=0) 36 | dist.gather(idx, idx_list, dst=0) 37 | 38 | if len(suffix) == 0: 39 | lmdb_path = f'{samples_root}.lmdb' 40 | else: 41 | lmdb_path = f'{samples_root}_{suffix}.lmdb' 42 | 43 | lmdb_dir = lmdb_path.split('/')[:-1] 44 | if len(lmdb_dir) > 0: 45 | lmdb_dir = '/'.join(lmdb_dir) 46 | os.makedirs(lmdb_dir, exist_ok=True) 47 | 48 | if dist.get_rank() == 0: 49 | x = torch.cat(x_list, dim=0).permute(0, 2, 3, 1).detach().cpu().numpy() 50 | idx = torch.cat(idx_list, dim=0).detach().cpu().numpy() 51 | x = (x * 255.).astype(np.uint8) 52 | n = x.shape[0] 53 | env = lmdb.open(lmdb_path, map_size=int(1e12), readonly=False) 54 | with env.begin(write=True) as txn: 55 | for i in range(n): 56 | xi = x[i].copy() 57 | txn.put(str(int(idx[i])).encode(), xi) 58 | 59 | dist.barrier() 60 | 61 | 62 | def save_result(name, x, y, info, samples_root, suffix=""): 63 | if 'ImageNet' in name: 64 | save_imagenet_result(x, y, info, samples_root, suffix) 65 | elif 'FFHQ' in name: 66 | save_ffhq_result(x, y, info, samples_root, suffix) 67 | --------------------------------------------------------------------------------