├── snowification ├── diffusion │ ├── model │ │ ├── __init__.py │ │ ├── get_model.py │ │ └── unet_convnext.py │ ├── __init__.py │ ├── get_dataset.py │ └── utils.py ├── Fid │ └── __init__.py ├── training_script.sh ├── testing_script.sh ├── setup.py ├── train.py └── test.py ├── decolor-diffusion ├── diffusion │ ├── model │ │ ├── __init__.py │ │ ├── get_model.py │ │ └── unet_convnext.py │ ├── __init__.py │ └── get_dataset.py ├── Fid │ └── __init__.py ├── setup.py ├── train.py └── test.py ├── deblurring-diffusion-pytorch ├── deblurring_diffusion_pytorch │ ├── cal_Fid.py │ └── __init__.py ├── Fid │ └── __init__.py ├── .github │ └── workflows │ │ └── python-publish.yml ├── .gitignore ├── AFHQ_128.py ├── mnist_train.py ├── dispatch.py ├── cifar10_train.py ├── celebA_128.py ├── mnist_test.py ├── cifar10_test.py ├── AFHQ_128_test.py └── celebA_128_test.py ├── demixing-diffusion-pytorch ├── demixing_diffusion_pytorch │ ├── cal_Fid.py │ └── __init__.py ├── Fid │ └── __init__.py ├── .github │ └── workflows │ │ └── python-publish.yml ├── .gitignore ├── AFHQ_128_to_celebA_128.py ├── AFHQ_128_to_celebA_128_test.py └── dispatch.py ├── denoising-diffusion-pytorch ├── denoising_diffusion_pytorch │ ├── cal_Fid.py │ └── __init__.py ├── Fid │ └── __init__.py ├── AFHQ_noise_128.py ├── celebA_noise_128.py ├── AFHQ_noise_128_test.py └── celebA_noise_128_test.py ├── defading-generation-diffusion-pytorch ├── defading_diffusion_pytorch │ ├── cal_Fid.py │ └── __init__.py ├── Fid │ └── __init__.py ├── celebA_128.py └── celebA_128_test.py ├── all_transform_cover.png ├── defading-diffusion-pytorch ├── Fid │ └── __init__.py ├── defading_diffusion_pytorch │ └── __init__.py ├── bigger_mask_train.sh ├── inpainting_test_paper.sh ├── .github │ └── workflows │ │ └── python-publish.yml ├── LICENSE ├── .gitignore ├── mnist_train.py ├── cifar10_train_wd.py ├── cifar10_train.py ├── celebA_train.py ├── mnist_test.py ├── celebA_test.py ├── cifar10_test.py └── inpainting_paper_train.sh ├── resolution-diffusion-pytorch ├── Fid │ └── __init__.py ├── resolution_diffusion_pytorch │ └── __init__.py ├── .github │ └── workflows │ │ └── python-publish.yml ├── runs.sh ├── celebA_128.py ├── .gitignore ├── celebA.py ├── mnist_train.py ├── cifar10_train.py ├── celebA_128_test.py ├── cifar10_test.py └── mnist_test.py ├── .idea ├── misc.xml ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── Cold-Diffusion-Models.iml └── modules.xml └── create_data.py /snowification/diffusion/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /decolor-diffusion/diffusion/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/deblurring_diffusion_pytorch/cal_Fid.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demixing-diffusion-pytorch/demixing_diffusion_pytorch/cal_Fid.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /denoising-diffusion-pytorch/denoising_diffusion_pytorch/cal_Fid.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /defading-generation-diffusion-pytorch/defading_diffusion_pytorch/cal_Fid.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /snowification/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from diffusion.diffusion import GaussianDiffusion, Trainer 2 | -------------------------------------------------------------------------------- /decolor-diffusion/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from diffusion.diffusion import GaussianDiffusion, Trainer 2 | -------------------------------------------------------------------------------- /all_transform_cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arpitbansal297/Cold-Diffusion-Models/HEAD/all_transform_cover.png -------------------------------------------------------------------------------- /decolor-diffusion/Fid/__init__.py: -------------------------------------------------------------------------------- 1 | from Fid.inception import InceptionV3 2 | from Fid.fid_score import calculate_fid_given_samples -------------------------------------------------------------------------------- /snowification/Fid/__init__.py: -------------------------------------------------------------------------------- 1 | from Fid.inception import InceptionV3 2 | from Fid.fid_score import calculate_fid_given_samples -------------------------------------------------------------------------------- /defading-diffusion-pytorch/Fid/__init__.py: -------------------------------------------------------------------------------- 1 | from Fid.inception import InceptionV3 2 | from Fid.fid_score import calculate_fid_given_samples -------------------------------------------------------------------------------- /demixing-diffusion-pytorch/Fid/__init__.py: -------------------------------------------------------------------------------- 1 | from Fid.inception import InceptionV3 2 | from Fid.fid_score import calculate_fid_given_samples -------------------------------------------------------------------------------- /denoising-diffusion-pytorch/Fid/__init__.py: -------------------------------------------------------------------------------- 1 | from Fid.inception import InceptionV3 2 | from Fid.fid_score import calculate_fid_given_samples -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/Fid/__init__.py: -------------------------------------------------------------------------------- 1 | from Fid.inception import InceptionV3 2 | from Fid.fid_score import calculate_fid_given_samples -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/Fid/__init__.py: -------------------------------------------------------------------------------- 1 | from Fid.inception import InceptionV3 2 | from Fid.fid_score import calculate_fid_given_samples -------------------------------------------------------------------------------- /defading-generation-diffusion-pytorch/Fid/__init__.py: -------------------------------------------------------------------------------- 1 | from Fid.inception import InceptionV3 2 | from Fid.fid_score import calculate_fid_given_samples -------------------------------------------------------------------------------- /demixing-diffusion-pytorch/demixing_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from demixing_diffusion_pytorch.demixing_diffusion_pytorch import GaussianDiffusion, Unet, Trainer 2 | 3 | -------------------------------------------------------------------------------- /denoising-diffusion-pytorch/denoising_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer 2 | 3 | -------------------------------------------------------------------------------- /defading-generation-diffusion-pytorch/defading_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from defading_diffusion_pytorch.defading_diffusion_pytorch import GaussianDiffusion, Unet, Trainer 2 | 3 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /defading-diffusion-pytorch/defading_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from defading_diffusion_pytorch.defading_diffusion_gaussian import GaussianDiffusion, Unet, Trainer 2 | from defading_diffusion_pytorch.new_model import Model 3 | -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/deblurring_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from deblurring_diffusion_pytorch.deblurring_diffusion_pytorch import GaussianDiffusion, Unet, Trainer 2 | from deblurring_diffusion_pytorch.Model2 import Model 3 | -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/resolution_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from resolution_diffusion_pytorch.resolution_diffusion_pytorch import GaussianDiffusion, Unet, Trainer 2 | from resolution_diffusion_pytorch.Model2 import Model 3 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/Cold-Diffusion-Models.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /snowification/training_script.sh: -------------------------------------------------------------------------------- 1 | # Longer step experiments 2 | python train.py --time_steps 20 --forward_process_type 'Decolorization' --dataset_folder --exp_name 'cifar_exp' --decolor_total_remove --decolor_routine 'Linear' 3 | python train.py --time_steps 20 --forward_process_type 'Decolorization' --exp_name 'celeba_exp' --decolor_ema_factor 0.9 --decolor_total_remove --decolor_routine 'Linear' --dataset celebA --dataset_folder --resolution 64 --resume_training 4 | -------------------------------------------------------------------------------- /snowification/testing_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/sh 2 | python test.py --time_steps 20 --forward_process_type 'Decolorization' --dataset_folder --exp_name 'cifar_exp' --decolor_total_remove --decolor_routine 'Linear' --sampling_routine x0_step_down --test_type test_paper --order_seed 1 --test_fid 3 | python test.py --time_steps 20 --forward_process_type 'Decolorization' --exp_name 'celeba_exp' --decolor_total_remove --decolor_routine 'Linear' --dataset celebA --dataset_folder --resolution 64 --resume_training --sampling_routine x0_step_down --test_type test_paper --order_seed 1 --test_fid 4 | -------------------------------------------------------------------------------- /defading-diffusion-pytorch/bigger_mask_train.sh: -------------------------------------------------------------------------------- 1 | #python defading-diffusion-pytorch/celebA_train.py --image_size 64 --time_steps 100 --initial_mask 11 --save_folder ./celebA_64x64_100_step_11_init --discrete --sampling_routine x0_step_down --train_steps 700000 --kernel_std 0.1 --fade_routine Random_Incremental --load_path /cmlscratch/eborgnia/cold_diffusion/test_speed_celebA/model.pt 2 | python defading-diffusion-pytorch/celebA_train.py --image_size 128 --time_steps 200 --initial_mask 21 --save_folder ./celebA_128x128_200_step_21_init --discrete --sampling_routine x0_step_down --train_steps 700000 --kernel_std 0.1 --fade_routine Random_Incremental 3 | -------------------------------------------------------------------------------- /defading-diffusion-pytorch/inpainting_test_paper.sh: -------------------------------------------------------------------------------- 1 | #python defading-diffusion-pytorch/cifar10_test.py --time_steps 50 --save_folder ./cifar_inpainting_test --discrete --sampling_routine x0_step_down --blur_std 0.1 --fade_routine Random_Incremental --test_type test_paper_invert_section_images 2 | 3 | #python defading-diffusion-pytorch/celebA_128_test.py --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_64x64_100_step_11_init/model.pt --time_steps 100 --save_folder ./celebA_inpainting_test --discrete --sampling_routine x0_step_down --blur_std 0.1 --initial_mask 11 --fade_routine Random_Incremental --test_type test_data 4 | 5 | #python defading-diffusion-pytorch/mnist_test.py --time_steps 50 --save_folder ./mnist_inpainting_test --discrete --sampling_routine x0_step_down --blur_std 0.1 --fade_routine Random_Incremental --test_type test_paper_invert_section_images 6 | 7 | python defading-diffusion-pytorch/celebA_test.py -------------------------------------------------------------------------------- /snowification/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'denoising-diffusion-pytorch', 5 | packages = find_packages(), 6 | version = '0.7.1', 7 | license='MIT', 8 | description = 'Denoising Diffusion Probabilistic Models - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/denoising-diffusion-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'generative models' 15 | ], 16 | install_requires=[ 17 | 'einops', 18 | 'numpy', 19 | 'pillow', 20 | 'torch', 21 | 'torchvision', 22 | 'tqdm' 23 | ], 24 | classifiers=[ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | 'License :: OSI Approved :: MIT License', 29 | 'Programming Language :: Python :: 3.6', 30 | ], 31 | ) -------------------------------------------------------------------------------- /decolor-diffusion/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'denoising-diffusion-pytorch', 5 | packages = find_packages(), 6 | version = '0.7.1', 7 | license='MIT', 8 | description = 'Denoising Diffusion Probabilistic Models - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/denoising-diffusion-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'generative models' 15 | ], 16 | install_requires=[ 17 | 'einops', 18 | 'numpy', 19 | 'pillow', 20 | 'torch', 21 | 'torchvision', 22 | 'tqdm' 23 | ], 24 | classifiers=[ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | 'License :: OSI Approved :: MIT License', 29 | 'Programming Language :: Python :: 3.6', 30 | ], 31 | ) -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /defading-diffusion-pytorch/.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /demixing-diffusion-pytorch/.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /defading-diffusion-pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/runs.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | python celebA_128.py --time_steps 4 --resolution_routine 'Incremental_factor_2' --save_folder './celebA_4_steps_fac2_train' 4 | python celebA_test.py --time_steps 4 --train_routine 'Final' --sampling_routine 'x0_step_down' --resolution_routine 'Incremental_factor_2' --save_folder './celebA_test' --load_path './celebA_final_ckpt/model.pt' --test_type 'test_fid_distance_decrease_from_manifold' 5 | 6 | 7 | python mnist_train.py --time_steps 3 --resolution_routine 'Incremental_factor_2' --save_folder './mnist_3_steps_fac2_train' 8 | python mnist_test.py --time_steps 3 --train_routine 'Final' --sampling_routine 'x0_step_down' --resolution_routine 'Incremental_factor_2' --save_folder './mnist_test' --load_path './mnist_final_ckpt/model.pt' --test_type 'test_fid_distance_decrease_from_manifold' 9 | 10 | 11 | python cifar10_train.py --time_steps 3 --resolution_routine 'Incremental_factor_2' --save_folder './cifar10Aug_3_steps_fac2_train' 12 | python cifar10_test.py --time_steps 3 --train_routine 'Final' --sampling_routine 'x0_step_down' --resolution_routine 'Incremental_factor_2' --save_folder './cifar10_test' --load_path './cifar10_final_ckpt/model.pt' --test_type 'test_fid_distance_decrease_from_manifold' 13 | 14 | 15 | -------------------------------------------------------------------------------- /decolor-diffusion/diffusion/model/get_model.py: -------------------------------------------------------------------------------- 1 | from .unet_convnext import UnetConvNextBlock 2 | from .unet_resnet import UnetResNetBlock 3 | 4 | def get_model(args, with_time_emb=True): 5 | if args.model == 'UnetConvNext': 6 | return UnetConvNextBlock( 7 | dim = 64, 8 | dim_mults = (1, 2, 4, 8), 9 | channels=3, 10 | with_time_emb=with_time_emb, 11 | residual=False, 12 | ) 13 | if args.model == 'UnetResNet': 14 | if 'cifar10' in args.dataset: 15 | return UnetResNetBlock(resolution=32, 16 | in_channels=3, 17 | out_ch=3, 18 | ch=128, 19 | ch_mult=(1,2,2,2), 20 | num_res_blocks=2, 21 | attn_resolutions=(16,), 22 | with_time_emb=with_time_emb, 23 | dropout=0.1 24 | ) 25 | if 'celebA' in args.dataset: 26 | return UnetResNetBlock(resolution=128, 27 | in_channels=3, 28 | out_ch=3, 29 | ch=128, 30 | ch_mult=(1,2,2,2), 31 | num_res_blocks=2, 32 | attn_resolutions=(16,), 33 | with_time_emb=with_time_emb, 34 | dropout=0.1 35 | ) 36 | 37 | -------------------------------------------------------------------------------- /snowification/diffusion/model/get_model.py: -------------------------------------------------------------------------------- 1 | from .unet_convnext import UnetConvNextBlock 2 | from .unet_resnet import UnetResNetBlock 3 | 4 | def get_model(args, with_time_emb=True): 5 | if args.model == 'UnetConvNext': 6 | return UnetConvNextBlock( 7 | dim = 64, 8 | dim_mults = (1, 2, 4, 8), 9 | channels=3, 10 | with_time_emb=with_time_emb, 11 | residual=False, 12 | ) 13 | if args.model == 'UnetResNet': 14 | if 'cifar10' in args.dataset: 15 | return UnetResNetBlock(resolution=32, 16 | in_channels=3, 17 | out_ch=3, 18 | ch=128, 19 | ch_mult=(1,2,2,2), 20 | num_res_blocks=2, 21 | attn_resolutions=(16,), 22 | with_time_emb=with_time_emb, 23 | dropout=0.1 24 | ) 25 | if 'celebA' in args.dataset: 26 | return UnetResNetBlock(resolution=128, 27 | in_channels=3, 28 | out_ch=3, 29 | ch=128, 30 | ch_mult=(1,2,2,2), 31 | num_res_blocks=2, 32 | attn_resolutions=(16,), 33 | with_time_emb=with_time_emb, 34 | dropout=0.1 35 | ) 36 | 37 | -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/celebA_128.py: -------------------------------------------------------------------------------- 1 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 2 | import torchvision 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--time_steps', default=50, type=int) 7 | parser.add_argument('--train_steps', default=700000, type=int) 8 | parser.add_argument('--save_folder', default='./results_celebA', type=str) 9 | parser.add_argument('--load_path', default=None, type=str) 10 | parser.add_argument('--data_path', default='./root_celebA_128_train_new/', type=str) 11 | parser.add_argument('--resolution_routine', default='Incremental', type=str) 12 | parser.add_argument('--train_routine', default='Final', type=str) 13 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 14 | parser.add_argument('--remove_time_embed', action="store_true") 15 | parser.add_argument('--residual', action="store_true") 16 | 17 | args = parser.parse_args() 18 | print(args) 19 | 20 | model = Unet( 21 | dim = 64, 22 | dim_mults = (1, 2, 4, 8), 23 | channels=3, 24 | with_time_emb=not(args.remove_time_embed), 25 | residual=args.residual 26 | ).cuda() 27 | 28 | diffusion = GaussianDiffusion( 29 | model, 30 | image_size = 128, 31 | device_of_kernel = 'cuda', 32 | channels = 3, 33 | timesteps = args.time_steps, # number of steps 34 | loss_type = 'l1', # L1 or L2 35 | resolution_routine=args.resolution_routine, 36 | train_routine = args.train_routine, 37 | sampling_routine = args.sampling_routine 38 | ).cuda() 39 | 40 | import torch 41 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 42 | 43 | trainer = Trainer( 44 | diffusion, 45 | args.data_path, 46 | image_size = 128, 47 | train_batch_size = 32, 48 | train_lr = 2e-5, 49 | train_num_steps = args.train_steps, # total training steps 50 | gradient_accumulate_every = 2, # gradient accumulation steps 51 | ema_decay = 0.995, # exponential moving average decay 52 | fp16 = False, # turn on mixed precision training with apex 53 | results_folder = args.save_folder, 54 | load_path = args.load_path, 55 | dataset = 'celebA' 56 | ) 57 | 58 | trainer.train() 59 | -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | # Generation results 2 | results/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /defading-diffusion-pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | # Generation results 2 | results/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /demixing-diffusion-pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | # Generation results 2 | results/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | # Generation results 2 | results/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/celebA.py: -------------------------------------------------------------------------------- 1 | #from comet_ml import Experiment 2 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--time_steps', default=50, type=int) 26 | parser.add_argument('--train_steps', default=700000, type=int) 27 | parser.add_argument('--save_folder', default='./results_celebA', type=str) 28 | parser.add_argument('--load_path', default=None, type=str) 29 | parser.add_argument('--resolution_routine', default='Incremental', type=str) 30 | parser.add_argument('--train_routine', default='Final', type=str) 31 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 32 | parser.add_argument('--remove_time_embed', action="store_true") 33 | parser.add_argument('--residual', action="store_true") 34 | 35 | 36 | 37 | args = parser.parse_args() 38 | print(args) 39 | 40 | 41 | model = Unet( 42 | dim = 64, 43 | dim_mults = (1, 2, 4, 8), 44 | channels=3, 45 | with_time_emb=not(args.remove_time_embed), 46 | residual=args.residual 47 | ).cuda() 48 | 49 | import torch 50 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 51 | 52 | 53 | diffusion = GaussianDiffusion( 54 | model, 55 | image_size = 64, 56 | device_of_kernel = 'cuda', 57 | channels = 3, 58 | timesteps = args.time_steps, # number of steps 59 | loss_type = 'l1', # L1 or L2 60 | resolution_routine=args.resolution_routine, 61 | train_routine = args.train_routine, 62 | sampling_routine = args.sampling_routine 63 | ).cuda() 64 | 65 | trainer = Trainer( 66 | diffusion, 67 | '/cmlscratch/bansal01/CelebA_train_new/train/', 68 | image_size = 64, 69 | train_batch_size = 32, 70 | train_lr = 2e-5, 71 | train_num_steps = args.train_steps, # total training steps 72 | gradient_accumulate_every = 2, # gradient accumulation steps 73 | ema_decay = 0.995, # exponential moving average decay 74 | fp16 = False, # turn on mixed precision training with apex 75 | results_folder = args.save_folder, 76 | load_path = args.load_path 77 | ) 78 | 79 | trainer.train() 80 | -------------------------------------------------------------------------------- /snowification/diffusion/get_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import transforms, utils 3 | from torchvision import datasets 4 | 5 | def get_transform(image_size, random_aug=False, resize=False): 6 | if image_size[0] == 64: 7 | transform_list = [ 8 | transforms.CenterCrop((128,128)), 9 | transforms.Resize(image_size), 10 | transforms.ToTensor(), 11 | transforms.Lambda(lambda t: (t * 2) - 1) 12 | ] 13 | elif not random_aug: 14 | transform_list = [ 15 | transforms.CenterCrop(image_size), 16 | transforms.ToTensor(), 17 | transforms.Lambda(lambda t: (t * 2) - 1) 18 | ] 19 | if resize: 20 | transform_list = [transforms.Resize(image_size)] + transform_list 21 | T = transforms.Compose(transform_list) 22 | else: 23 | s = 1.0 24 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 25 | T = transforms.Compose([ 26 | transforms.RandomResizedCrop(size=image_size), 27 | transforms.RandomHorizontalFlip(), 28 | transforms.RandomApply([color_jitter], p=0.8), 29 | transforms.ToTensor(), 30 | transforms.Lambda(lambda t: (t * 2) - 1) 31 | ]) 32 | 33 | return T 34 | 35 | def get_image_size(name): 36 | if 'cifar10' in name: 37 | return (32, 32) 38 | if 'celebA' in name: 39 | return (128, 128) 40 | if 'flower' in name: 41 | return (128, 128) 42 | 43 | def get_dataset(name, folder, image_size, random_aug=False): 44 | print(folder) 45 | if name == 'cifar10_train': 46 | return datasets.CIFAR10(folder, train=True, transform=get_transform(image_size, random_aug=random_aug)) 47 | if name == 'cifar10_test': 48 | return datasets.CIFAR10(folder, train=False, transform=get_transform(image_size, random_aug=random_aug)) 49 | if name == 'CelebA_train': 50 | return datasets.CelebA(folder, split='train', transform=get_transform(image_size, random_aug=random_aug), download=True) 51 | if name == 'CelebA_test': 52 | return datasets.CelebA(folder, split='test', transform=get_transform(image_size, random_aug=random_aug)) 53 | if name == 'flower_train': 54 | return datasets.Flowers102(folder, split='train', transform=get_transform(image_size, random_aug=random_aug, resize=True), download=True) 55 | if name == 'flower_test': 56 | return datasets.Flowers102(folder, split='test', transform=get_transform(image_size, random_aug=random_aug, resize=True), download=True) 57 | 58 | -------------------------------------------------------------------------------- /decolor-diffusion/diffusion/get_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import transforms, utils 3 | from torchvision import datasets 4 | 5 | def get_transform(image_size, random_aug=False, resize=False): 6 | if image_size[0] == 64: 7 | transform_list = [ 8 | transforms.CenterCrop((128,128)), 9 | transforms.Resize(image_size), 10 | transforms.ToTensor(), 11 | transforms.Lambda(lambda t: (t * 2) - 1) 12 | ] 13 | elif not random_aug: 14 | transform_list = [ 15 | transforms.CenterCrop(image_size), 16 | transforms.ToTensor(), 17 | transforms.Lambda(lambda t: (t * 2) - 1) 18 | ] 19 | if resize: 20 | transform_list = [transforms.Resize(image_size)] + transform_list 21 | T = transforms.Compose(transform_list) 22 | else: 23 | s = 1.0 24 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 25 | T = transforms.Compose([ 26 | transforms.RandomResizedCrop(size=image_size), 27 | transforms.RandomHorizontalFlip(), 28 | transforms.RandomApply([color_jitter], p=0.8), 29 | transforms.ToTensor(), 30 | transforms.Lambda(lambda t: (t * 2) - 1) 31 | ]) 32 | 33 | return T 34 | 35 | def get_image_size(name): 36 | if 'cifar10' in name: 37 | return (32, 32) 38 | if 'celebA' in name: 39 | return (128, 128) 40 | if 'flower' in name: 41 | return (128, 128) 42 | 43 | def get_dataset(name, folder, image_size, random_aug=False): 44 | print(folder) 45 | if name == 'cifar10_train': 46 | return datasets.CIFAR10(folder, train=True, transform=get_transform(image_size, random_aug=random_aug)) 47 | if name == 'cifar10_test': 48 | return datasets.CIFAR10(folder, train=False, transform=get_transform(image_size, random_aug=random_aug)) 49 | if name == 'CelebA_train': 50 | return datasets.CelebA(folder, split='train', transform=get_transform(image_size, random_aug=random_aug), download=True) 51 | if name == 'CelebA_test': 52 | return datasets.CelebA(folder, split='test', transform=get_transform(image_size, random_aug=random_aug)) 53 | if name == 'flower_train': 54 | return datasets.Flowers102(folder, split='train', transform=get_transform(image_size, random_aug=random_aug, resize=True), download=True) 55 | if name == 'flower_test': 56 | return datasets.Flowers102(folder, split='test', transform=get_transform(image_size, random_aug=random_aug, resize=True), download=True) 57 | 58 | -------------------------------------------------------------------------------- /demixing-diffusion-pytorch/AFHQ_128_to_celebA_128.py: -------------------------------------------------------------------------------- 1 | #from comet_ml import Experiment 2 | from demixing_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--time_steps', default=50, type=int) 26 | parser.add_argument('--train_steps', default=700000, type=int) 27 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 28 | parser.add_argument('--data_path_start', default='../deblurring-diffusion-pytorch/AFHQ/afhq/train/', type=str) 29 | parser.add_argument('--data_path_end', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str) 30 | parser.add_argument('--load_path', default=None, type=str) 31 | parser.add_argument('--train_routine', default='Final', type=str) 32 | parser.add_argument('--sampling_routine', default='default', type=str) 33 | parser.add_argument('--remove_time_embed', action="store_true") 34 | parser.add_argument('--residual', action="store_true") 35 | parser.add_argument('--loss_type', default='l1', type=str) 36 | 37 | 38 | args = parser.parse_args() 39 | print(args) 40 | 41 | 42 | model = Unet( 43 | dim = 64, 44 | dim_mults = (1, 2, 4, 8), 45 | channels=3, 46 | with_time_emb=not(args.remove_time_embed), 47 | residual=args.residual 48 | ).cuda() 49 | 50 | diffusion = GaussianDiffusion( 51 | model, 52 | image_size = 128, 53 | channels = 3, 54 | timesteps = args.time_steps, # number of steps 55 | loss_type = args.loss_type, # L1 or L2 56 | train_routine = args.train_routine, 57 | sampling_routine = args.sampling_routine 58 | ).cuda() 59 | 60 | import torch 61 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 62 | 63 | 64 | trainer = Trainer( 65 | diffusion, 66 | args.data_path_end, 67 | args.data_path_start, 68 | image_size = 128, 69 | train_batch_size = 32, 70 | train_lr = 2e-5, 71 | train_num_steps = args.train_steps, # total training steps 72 | gradient_accumulate_every = 2, # gradient accumulation steps 73 | ema_decay = 0.995, # exponential moving average decay 74 | fp16 = False, # turn on mixed precision training with apex 75 | results_folder = args.save_folder, 76 | load_path = args.load_path, 77 | dataset = 'train' 78 | ) 79 | 80 | trainer.train() -------------------------------------------------------------------------------- /defading-diffusion-pytorch/mnist_train.py: -------------------------------------------------------------------------------- 1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | 18 | def del_folder(path): 19 | try: 20 | shutil.rmtree(path) 21 | except OSError as exc: 22 | pass 23 | 24 | 25 | create = 0 26 | 27 | if create: 28 | trainset = torchvision.datasets.MNIST( 29 | root='./data', train=True, download=True) 30 | root = './root_mnist/' 31 | del_folder(root) 32 | create_folder(root) 33 | 34 | for i in range(10): 35 | lable_root = root + str(i) + '/' 36 | create_folder(lable_root) 37 | 38 | for idx in range(len(trainset)): 39 | img, label = trainset[idx] 40 | print(idx) 41 | img.save(root + str(label) + '/' + str(idx) + '.png') 42 | 43 | 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--time_steps', default=50, type=int) 46 | parser.add_argument('--train_steps', default=700000, type=int) 47 | parser.add_argument('--save_folder', default='./_step_50_gaussian', type=str) 48 | parser.add_argument('--kernel_std', default=0.1, type=float) 49 | parser.add_argument('--load_path', default=None, type=str) 50 | parser.add_argument('--data_path', default='./root_mnist/', type=str) 51 | parser.add_argument('--fade_routine', default='Random_Incremental', type=str) 52 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 53 | parser.add_argument('--discrete', action="store_true") 54 | parser.add_argument('--remove_time_embed', action="store_true") 55 | parser.add_argument('--residual', action="store_true") 56 | 57 | 58 | args = parser.parse_args() 59 | print(args) 60 | 61 | model = Unet( 62 | dim=64, 63 | dim_mults=(1, 2, 4, 8), 64 | channels=1, 65 | with_time_emb=not args.remove_time_embed, 66 | residual=args.residual 67 | ).cuda() 68 | 69 | diffusion = GaussianDiffusion( 70 | model, 71 | image_size=32, 72 | device_of_kernel='cuda', 73 | channels=1, 74 | timesteps=args.time_steps, 75 | loss_type='l1', 76 | fade_routine=args.fade_routine, 77 | kernel_std=args.kernel_std, 78 | sampling_routine=args.sampling_routine, 79 | discrete=args.discrete 80 | ).cuda() 81 | 82 | trainer = Trainer( 83 | diffusion, 84 | args.data_path, 85 | image_size=32, 86 | train_batch_size=32, 87 | train_lr=2e-5, 88 | train_num_steps=args.train_steps, 89 | gradient_accumulate_every=2, 90 | ema_decay=0.995, 91 | fp16=False, 92 | results_folder=args.save_folder, 93 | load_path=args.load_path, 94 | ) 95 | 96 | trainer.train() -------------------------------------------------------------------------------- /defading-generation-diffusion-pytorch/celebA_128.py: -------------------------------------------------------------------------------- 1 | #from comet_ml import Experiment 2 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--time_steps', default=50, type=int) 26 | parser.add_argument('--train_steps', default=700000, type=int) 27 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 28 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str) 29 | parser.add_argument('--load_path', default=None, type=str) 30 | parser.add_argument('--train_routine', default='Final', type=str) 31 | parser.add_argument('--sampling_routine', default='default', type=str) 32 | parser.add_argument('--remove_time_embed', action="store_true") 33 | parser.add_argument('--residual', action="store_true") 34 | parser.add_argument('--loss_type', default='l1', type=str) 35 | parser.add_argument('--initial_mask', default=11, type=int) 36 | parser.add_argument('--kernel_std', default=0.15, type=float) 37 | parser.add_argument('--reverse', action="store_true") 38 | 39 | args = parser.parse_args() 40 | print(args) 41 | 42 | 43 | model = Unet( 44 | dim = 64, 45 | dim_mults = (1, 2, 4, 8), 46 | channels=3, 47 | with_time_emb=not(args.remove_time_embed), 48 | residual=args.residual 49 | ).cuda() 50 | 51 | diffusion = GaussianDiffusion( 52 | model, 53 | image_size = 128, 54 | channels = 3, 55 | timesteps = args.time_steps, # number of steps 56 | loss_type = args.loss_type, # L1 or L2 57 | train_routine = args.train_routine, 58 | sampling_routine = args.sampling_routine, 59 | reverse = args.reverse, 60 | kernel_std = args.kernel_std, 61 | initial_mask=args.initial_mask 62 | ).cuda() 63 | 64 | import torch 65 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 66 | 67 | 68 | trainer = Trainer( 69 | diffusion, 70 | args.data_path, 71 | image_size = 128, 72 | train_batch_size = 32, 73 | train_lr = 2e-5, 74 | train_num_steps = args.train_steps, # total training steps 75 | gradient_accumulate_every = 2, # gradient accumulation steps 76 | ema_decay = 0.995, # exponential moving average decay 77 | fp16 = False, # turn on mixed precision training with apex 78 | results_folder = args.save_folder, 79 | load_path = args.load_path, 80 | dataset = 'train' 81 | ) 82 | 83 | trainer.train() -------------------------------------------------------------------------------- /defading-diffusion-pytorch/cifar10_train_wd.py: -------------------------------------------------------------------------------- 1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | 8 | def create_folder(path): 9 | try: 10 | os.mkdir(path) 11 | except OSError as exc: 12 | if exc.errno != errno.EEXIST: 13 | raise 14 | pass 15 | 16 | def del_folder(path): 17 | try: 18 | shutil.rmtree(path) 19 | except OSError as exc: 20 | pass 21 | 22 | create = 0 23 | 24 | if create: 25 | trainset = torchvision.datasets.CIFAR10( 26 | root='./data', train=True, download=True) 27 | root = './root_cifar10/' 28 | del_folder(root) 29 | create_folder(root) 30 | 31 | for i in range(10): 32 | lable_root = root + str(i) + '/' 33 | create_folder(lable_root) 34 | 35 | for idx in range(len(trainset)): 36 | img, label = trainset[idx] 37 | print(idx) 38 | img.save(root + str(label) + '/' + str(idx) + '.png') 39 | 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--time_steps', default=50, type=int) 43 | parser.add_argument('--train_steps', default=700000, type=int) 44 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 45 | parser.add_argument('--load_path', default=None, type=str) 46 | parser.add_argument('--fade_routine', default='Incremental', type=str) 47 | parser.add_argument('--train_routine', default='Final', type=str) 48 | parser.add_argument('--sampling_routine', default='default', type=str) 49 | parser.add_argument('--remove_time_embed', action="store_true") 50 | parser.add_argument('--residual', action="store_true") 51 | 52 | 53 | 54 | args = parser.parse_args() 55 | print(args) 56 | 57 | 58 | model = Unet( 59 | dim = 64, 60 | dim_mults = (1, 2, 4, 8), 61 | channels=3, 62 | with_time_emb=not(args.remove_time_embed), 63 | residual=args.residual 64 | ).cuda() 65 | 66 | 67 | diffusion = GaussianDiffusion( 68 | model, 69 | image_size = 32, 70 | device_of_kernel = 'cuda', 71 | channels = 3, 72 | timesteps = args.time_steps, # number of steps 73 | loss_type = 'l1', # L1 or L2 74 | fade_routine=args.fade_routine, 75 | train_routine = args.train_routine, 76 | sampling_routine = args.sampling_routine 77 | ).cuda() 78 | 79 | trainer = Trainer( 80 | diffusion, 81 | './root_cifar10/', 82 | image_size = 32, 83 | train_batch_size = 32, 84 | train_lr = 2e-5, 85 | train_num_steps = args.train_steps, # total training steps 86 | gradient_accumulate_every = 2, # gradient accumulation steps 87 | ema_decay = 0.995, # exponential moving average decay 88 | fp16 = False, # turn on mixed precision training with apex 89 | results_folder = args.save_folder, 90 | load_path = args.load_path 91 | ) 92 | 93 | trainer.train() -------------------------------------------------------------------------------- /defading-diffusion-pytorch/cifar10_train.py: -------------------------------------------------------------------------------- 1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | 18 | def del_folder(path): 19 | try: 20 | shutil.rmtree(path) 21 | except OSError as exc: 22 | pass 23 | 24 | 25 | create = 0 26 | 27 | if create: 28 | trainset = torchvision.datasets.CIFAR10( 29 | root='./data', train=True, download=True) 30 | root = './root_cifar10/' 31 | del_folder(root) 32 | create_folder(root) 33 | 34 | for i in range(10): 35 | lable_root = root + str(i) + '/' 36 | create_folder(lable_root) 37 | 38 | for idx in range(len(trainset)): 39 | img, label = trainset[idx] 40 | print(idx) 41 | img.save(root + str(label) + '/' + str(idx) + '.png') 42 | 43 | 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--time_steps', default=50, type=int) 46 | parser.add_argument('--train_steps', default=700000, type=int) 47 | parser.add_argument('--save_folder', default=None, type=str) 48 | parser.add_argument('--kernel_std', default=0.1, type=float) 49 | parser.add_argument('--load_path', default=None, type=str) 50 | parser.add_argument('--data_path', default='./root_cifar10/', type=str) 51 | parser.add_argument('--fade_routine', default='Random_Incremental', type=str) 52 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 53 | parser.add_argument('--discrete', action="store_true") 54 | parser.add_argument('--remove_time_embed', action="store_true") 55 | parser.add_argument('--residual', action="store_true") 56 | 57 | args = parser.parse_args() 58 | print(args) 59 | 60 | model = Model(resolution=32, 61 | in_channels=3, 62 | out_ch=3, 63 | ch=128, 64 | ch_mult=(1, 2, 2, 2), 65 | num_res_blocks=2, 66 | attn_resolutions=(16,), 67 | dropout=0.1).cuda() 68 | 69 | diffusion = GaussianDiffusion( 70 | model, 71 | image_size=32, 72 | device_of_kernel='cuda', 73 | channels=3, 74 | timesteps=args.time_steps, 75 | loss_type='l1', 76 | kernel_std=args.kernel_std, 77 | fade_routine=args.fade_routine, 78 | sampling_routine=args.sampling_routine, 79 | discrete=args.discrete 80 | ).cuda() 81 | 82 | trainer = Trainer( 83 | diffusion, 84 | args.data_path, 85 | image_size=32, 86 | train_batch_size=32, 87 | train_lr=2e-5, 88 | train_num_steps=args.train_steps, 89 | gradient_accumulate_every=2, 90 | ema_decay=0.995, 91 | fp16=False, 92 | results_folder=args.save_folder, 93 | load_path=args.load_path, 94 | dataset='cifar10' 95 | ) 96 | 97 | trainer.train() 98 | -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/mnist_train.py: -------------------------------------------------------------------------------- 1 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | 8 | def create_folder(path): 9 | try: 10 | os.mkdir(path) 11 | except OSError as exc: 12 | if exc.errno != errno.EEXIST: 13 | raise 14 | pass 15 | 16 | def del_folder(path): 17 | try: 18 | shutil.rmtree(path) 19 | except OSError as exc: 20 | pass 21 | 22 | create = 0 23 | 24 | if create: 25 | trainset = torchvision.datasets.MNIST( 26 | root='./data', train=True, download=True) 27 | root = './root_mnist/' 28 | del_folder(root) 29 | create_folder(root) 30 | 31 | for i in range(10): 32 | lable_root = root + str(i) + '/' 33 | create_folder(lable_root) 34 | 35 | for idx in range(len(trainset)): 36 | img, label = trainset[idx] 37 | print(idx) 38 | img.save(root + str(label) + '/' + str(idx) + '.png') 39 | 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--time_steps', default=50, type=int) 43 | parser.add_argument('--train_steps', default=700000, type=int) 44 | parser.add_argument('--save_folder', default='./results_mnist', type=str) 45 | parser.add_argument('--load_path', default=None, type=str) 46 | parser.add_argument('--data_path', default='./root_mnist/', type=str) 47 | parser.add_argument('--resolution_routine', default='Incremental', type=str) 48 | parser.add_argument('--train_routine', default='Final', type=str) 49 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 50 | parser.add_argument('--remove_time_embed', action="store_true") 51 | parser.add_argument('--residual', action="store_true") 52 | 53 | args = parser.parse_args() 54 | print(args) 55 | 56 | 57 | model = Unet( 58 | dim = 64, 59 | dim_mults = (1, 2, 4, 8), 60 | channels=1, 61 | with_time_emb=not(args.remove_time_embed), 62 | residual=args.residual 63 | ).cuda() 64 | 65 | 66 | diffusion = GaussianDiffusion( 67 | model, 68 | image_size = 32, 69 | device_of_kernel = 'cuda', 70 | channels = 1, 71 | timesteps = args.time_steps, # number of steps 72 | loss_type = 'l1', # L1 or L2 73 | resolution_routine=args.resolution_routine, 74 | train_routine = args.train_routine 75 | ).cuda() 76 | 77 | trainer = Trainer( 78 | diffusion, 79 | args.data_path, 80 | image_size = 32, 81 | train_batch_size = 32, 82 | train_lr = 2e-5, 83 | train_num_steps = args.train_steps, # total training steps 84 | gradient_accumulate_every = 2, # gradient accumulation steps 85 | ema_decay = 0.995, # exponential moving average decay 86 | fp16 = False, # turn on mixed precision training with apex 87 | results_folder = args.save_folder, 88 | load_path = args.load_path 89 | ) 90 | 91 | trainer.train() 92 | -------------------------------------------------------------------------------- /denoising-diffusion-pytorch/AFHQ_noise_128.py: -------------------------------------------------------------------------------- 1 | #from comet_ml import Experiment 2 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--time_steps', default=50, type=int, 26 | help="The number of steps the scheduler takes to go from clean image to an isotropic gaussian. This is also the number of steps of diffusion.") 27 | parser.add_argument('--train_steps', default=700000, type=int, 28 | help='The number of iterations for training.') 29 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 30 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/AFHQ/afhq/train/', type=str) 31 | parser.add_argument('--load_path', default=None, type=str) 32 | parser.add_argument('--train_routine', default='Final', type=str) 33 | parser.add_argument('--sampling_routine', default='default', type=str, 34 | help='The choice of sampling routine for reversing the diffusion process.') 35 | parser.add_argument('--remove_time_embed', action="store_true") 36 | parser.add_argument('--residual', action="store_true") 37 | parser.add_argument('--loss_type', default='l1', type=str) 38 | 39 | 40 | args = parser.parse_args() 41 | print(args) 42 | 43 | 44 | model = Unet( 45 | dim = 64, 46 | dim_mults = (1, 2, 4, 8), 47 | channels=3, 48 | with_time_emb=not(args.remove_time_embed), 49 | residual=args.residual 50 | ).cuda() 51 | 52 | diffusion = GaussianDiffusion( 53 | model, 54 | image_size = 128, 55 | channels = 3, 56 | timesteps = args.time_steps, # number of steps 57 | loss_type = args.loss_type, # L1 or L2 58 | train_routine = args.train_routine, 59 | sampling_routine = args.sampling_routine 60 | ).cuda() 61 | 62 | import torch 63 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 64 | 65 | 66 | trainer = Trainer( 67 | diffusion, 68 | '../deblurring-diffusion-pytorch/AFHQ/afhq/train/', 69 | image_size = 128, 70 | train_batch_size = 32, 71 | train_lr = 2e-5, 72 | train_num_steps = args.train_steps, # total training steps 73 | gradient_accumulate_every = 2, # gradient accumulation steps 74 | ema_decay = 0.995, # exponential moving average decay 75 | fp16 = False, # turn on mixed precision training with apex 76 | results_folder = args.save_folder, 77 | load_path = args.load_path, 78 | dataset = 'train' 79 | ) 80 | 81 | trainer.train() -------------------------------------------------------------------------------- /denoising-diffusion-pytorch/celebA_noise_128.py: -------------------------------------------------------------------------------- 1 | #from comet_ml import Experiment 2 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--time_steps', default=50, type=int, 26 | help="The number of steps the scheduler takes to go from clean image to an isotropic gaussian. This is also the number of steps of diffusion.") 27 | parser.add_argument('--train_steps', default=700000, type=int, 28 | help='The number of iterations for training.') 29 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 30 | parser.add_argument('--load_path', default=None, type=str) 31 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str) 32 | parser.add_argument('--train_routine', default='Final', type=str) 33 | parser.add_argument('--sampling_routine', default='default', type=str, 34 | help='The choice of sampling routine for reversing the diffusion process.') 35 | parser.add_argument('--remove_time_embed', action="store_true") 36 | parser.add_argument('--residual', action="store_true") 37 | parser.add_argument('--loss_type', default='l1', type=str) 38 | 39 | 40 | args = parser.parse_args() 41 | print(args) 42 | 43 | 44 | model = Unet( 45 | dim = 64, 46 | dim_mults = (1, 2, 4, 8), 47 | channels=3, 48 | with_time_emb=not(args.remove_time_embed), 49 | residual=args.residual 50 | ).cuda() 51 | 52 | diffusion = GaussianDiffusion( 53 | model, 54 | image_size = 128, 55 | channels = 3, 56 | timesteps = args.time_steps, # number of steps 57 | loss_type = args.loss_type, # L1 or L2 58 | train_routine = args.train_routine, 59 | sampling_routine = args.sampling_routine 60 | ).cuda() 61 | 62 | import torch 63 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 64 | 65 | 66 | trainer = Trainer( 67 | diffusion, 68 | '../deblurring-diffusion-pytorch/root_celebA_128_train_new/', 69 | image_size = 128, 70 | train_batch_size = 32, 71 | train_lr = 2e-5, 72 | train_num_steps = args.train_steps, # total training steps 73 | gradient_accumulate_every = 2, # gradient accumulation steps 74 | ema_decay = 0.995, # exponential moving average decay 75 | fp16 = False, # turn on mixed precision training with apex 76 | results_folder = args.save_folder, 77 | load_path = args.load_path, 78 | dataset = 'train' 79 | ) 80 | 81 | trainer.train() -------------------------------------------------------------------------------- /defading-diffusion-pytorch/celebA_train.py: -------------------------------------------------------------------------------- 1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 2 | import torch 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | 18 | def del_folder(path): 19 | try: 20 | shutil.rmtree(path) 21 | except OSError as exc: 22 | if exc.errno != errno.EEXIST: 23 | raise 24 | pass 25 | 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--time_steps', default=100, type=int) 29 | parser.add_argument('--train_steps', default=700000, type=int) 30 | parser.add_argument('--save_folder', default='test_speed_celebA', type=str) 31 | parser.add_argument('--kernel_std', default=0.1, type=float) 32 | parser.add_argument('--initial_mask', default=11, type=int) 33 | parser.add_argument('--load_path', default=None, type=str) 34 | parser.add_argument('--data_path', default='/cmlscratch/bansal01/spring_2022/Cold-Diffusion/deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str) 35 | parser.add_argument('--fade_routine', default="Incremental", type=str) 36 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 37 | parser.add_argument('--remove_time_embed', action="store_true") 38 | parser.add_argument('--residual', action="store_true") 39 | parser.add_argument('--discrete', action="store_true") 40 | parser.add_argument('--image_size', default=64, type=int) 41 | parser.add_argument('--dataset', default=None, type=str) 42 | args = parser.parse_args() 43 | print(args) 44 | 45 | model = Unet( 46 | dim=64, 47 | dim_mults=(1, 2, 4, 8), 48 | channels=3, 49 | with_time_emb=not args.remove_time_embed, 50 | residual=args.residual 51 | ).cuda() 52 | 53 | # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 54 | 55 | diffusion = GaussianDiffusion( 56 | model, 57 | image_size=args.image_size, 58 | device_of_kernel='cuda', 59 | channels=3, 60 | timesteps=args.time_steps, 61 | loss_type='l1', 62 | fade_routine=args.fade_routine, 63 | kernel_std=args.kernel_std, 64 | initial_mask=args.initial_mask, 65 | sampling_routine=args.sampling_routine, 66 | discrete=args.discrete 67 | ).cuda() 68 | 69 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 70 | 71 | trainer = Trainer( 72 | diffusion, 73 | # '/cmlscratch/bansal01/CelebA_train_new/train/', 74 | # './root_celebA_128_train/', 75 | args.data_path, 76 | image_size=args.image_size, 77 | train_batch_size=100, 78 | train_lr=2e-5, 79 | train_num_steps=args.train_steps, 80 | gradient_accumulate_every=2, 81 | ema_decay=0.995, 82 | fp16=False, 83 | results_folder=args.save_folder, 84 | load_path=args.load_path, 85 | dataset=args.dataset 86 | ) 87 | trainer.train() 88 | -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/cifar10_train.py: -------------------------------------------------------------------------------- 1 | #from comet_ml import Experiment 2 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | create = 0 24 | 25 | if create: 26 | trainset = torchvision.datasets.CIFAR10( 27 | root='./data', train=True, download=True) 28 | root = './root_cifar10/' 29 | del_folder(root) 30 | create_folder(root) 31 | 32 | for i in range(10): 33 | lable_root = root + str(i) + '/' 34 | create_folder(lable_root) 35 | 36 | for idx in range(len(trainset)): 37 | img, label = trainset[idx] 38 | print(idx) 39 | img.save(root + str(label) + '/' + str(idx) + '.png') 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--time_steps', default=30, type=int) 43 | parser.add_argument('--train_steps', default=700000, type=int) 44 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 45 | parser.add_argument('--load_path', default=None, type=str) 46 | parser.add_argument('--data_path', default='./root_cifar10/', type=str) 47 | parser.add_argument('--resolution_routine', default='Incremental', type=str) 48 | parser.add_argument('--train_routine', default='Final', type=str) 49 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 50 | parser.add_argument('--remove_time_embed', action="store_true") 51 | parser.add_argument('--residual', action="store_true") 52 | 53 | args = parser.parse_args() 54 | print(args) 55 | 56 | 57 | model = Model(resolution=32, 58 | in_channels=3, 59 | out_ch=3, 60 | ch=128, 61 | ch_mult=(1,2,2,2), 62 | num_res_blocks=2, 63 | attn_resolutions=(16,), 64 | dropout=0.1).cuda() 65 | 66 | 67 | diffusion = GaussianDiffusion( 68 | model, 69 | image_size = 32, 70 | device_of_kernel = 'cuda', 71 | channels = 3, 72 | timesteps = args.time_steps, # number of steps 73 | loss_type = 'l1', # L1 or L2 74 | resolution_routine=args.resolution_routine, 75 | train_routine = args.train_routine, 76 | sampling_routine = args.sampling_routine 77 | ).cuda() 78 | 79 | trainer = Trainer( 80 | diffusion, 81 | args.data_path, 82 | image_size = 32, 83 | train_batch_size = 32, 84 | train_lr = 2e-5, 85 | train_num_steps = args.train_steps, # total training steps 86 | gradient_accumulate_every = 2, # gradient accumulation steps 87 | ema_decay = 0.995, # exponential moving average decay 88 | fp16 = False, # turn on mixed precision training with apex 89 | results_folder = args.save_folder, 90 | load_path = args.load_path, 91 | dataset = 'cifar10' 92 | ) 93 | 94 | trainer.train() 95 | -------------------------------------------------------------------------------- /create_data.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import os 3 | import errno 4 | import shutil 5 | from pathlib import Path 6 | from PIL import Image 7 | 8 | def create_folder(path): 9 | try: 10 | os.mkdir(path) 11 | except OSError as exc: 12 | if exc.errno != errno.EEXIST: 13 | raise 14 | pass 15 | 16 | def del_folder(path): 17 | try: 18 | shutil.rmtree(path) 19 | except OSError as exc: 20 | pass 21 | 22 | 23 | CelebA_folder = '/fs/cml-datasets/CelebA-HQ/images-128/' # change this to folder which has CelebA data 24 | 25 | ############################################# MNIST ############################################### 26 | trainset = torchvision.datasets.MNIST( 27 | root='./data', train=True, download=True) 28 | root = './root_mnist/' 29 | del_folder(root) 30 | create_folder(root) 31 | 32 | for i in range(10): 33 | lable_root = root + str(i) + '/' 34 | create_folder(lable_root) 35 | 36 | for idx in range(len(trainset)): 37 | img, label = trainset[idx] 38 | print(idx) 39 | img.save(root + str(label) + '/' + str(idx) + '.png') 40 | 41 | 42 | trainset = torchvision.datasets.MNIST( 43 | root='./data', train=False, download=True) 44 | root = './root_mnist_test/' 45 | del_folder(root) 46 | create_folder(root) 47 | 48 | for i in range(10): 49 | lable_root = root + str(i) + '/' 50 | create_folder(lable_root) 51 | 52 | for idx in range(len(trainset)): 53 | img, label = trainset[idx] 54 | print(idx) 55 | img.save(root + str(label) + '/' + str(idx) + '.png') 56 | 57 | 58 | ############################################# Cifar10 ############################################### 59 | trainset = torchvision.datasets.CIFAR10( 60 | root='./data', train=True, download=True) 61 | root = './root_cifar10/' 62 | del_folder(root) 63 | create_folder(root) 64 | 65 | for i in range(10): 66 | lable_root = root + str(i) + '/' 67 | create_folder(lable_root) 68 | 69 | for idx in range(len(trainset)): 70 | img, label = trainset[idx] 71 | print(idx) 72 | img.save(root + str(label) + '/' + str(idx) + '.png') 73 | 74 | 75 | trainset = torchvision.datasets.CIFAR10( 76 | root='./data', train=False, download=True) 77 | root = './root_cifar10_test/' 78 | del_folder(root) 79 | create_folder(root) 80 | 81 | for i in range(10): 82 | lable_root = root + str(i) + '/' 83 | create_folder(lable_root) 84 | 85 | for idx in range(len(trainset)): 86 | img, label = trainset[idx] 87 | print(idx) 88 | img.save(root + str(label) + '/' + str(idx) + '.png') 89 | 90 | 91 | ############################################# CelebA ############################################### 92 | root_train = './root_celebA_128_train_new/' 93 | root_test = './root_celebA_128_test_new/' 94 | del_folder(root_train) 95 | create_folder(root_train) 96 | 97 | del_folder(root_test) 98 | create_folder(root_test) 99 | 100 | exts = ['jpg', 'jpeg', 'png'] 101 | folder = CelebA_folder 102 | paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] 103 | 104 | for idx in range(len(paths)): 105 | img = Image.open(paths[idx]) 106 | print(idx) 107 | if idx < 0.9*len(paths): 108 | img.save(root_train + str(idx) + '.png') 109 | else: 110 | img.save(root_test + str(idx) + '.png') -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/AFHQ_128.py: -------------------------------------------------------------------------------- 1 | from comet_ml import Experiment 2 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | 24 | 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--time_steps', default=50, type=int, 28 | help="This is the number of steps in which a clean image looses information.") 29 | parser.add_argument('--train_steps', default=700000, type=int, 30 | help='The number of iterations for training.') 31 | parser.add_argument('--blur_std', default=0.1, type=float, 32 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.') 33 | parser.add_argument('--blur_size', default=3, type=int, 34 | help='It sets the size of gaussian blur used in blur routines for each step t') 35 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 36 | parser.add_argument('--data_path', default='./AFHQ/afhq/train/', type=str) 37 | parser.add_argument('--load_path', default=None, type=str) 38 | parser.add_argument('--blur_routine', default='Incremental', type=str, 39 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail') 40 | parser.add_argument('--train_routine', default='Final', type=str) 41 | parser.add_argument('--sampling_routine', default='default', type=str, 42 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2') 43 | parser.add_argument('--remove_time_embed', action="store_true") 44 | parser.add_argument('--residual', action="store_true") 45 | parser.add_argument('--loss_type', default='l1', type=str) 46 | parser.add_argument('--discrete', action="store_true") 47 | 48 | 49 | args = parser.parse_args() 50 | print(args) 51 | 52 | 53 | model = Unet( 54 | dim = 64, 55 | dim_mults = (1, 2, 4, 8), 56 | channels=3, 57 | with_time_emb=not(args.remove_time_embed), 58 | residual=args.residual 59 | ).cuda() 60 | 61 | diffusion = GaussianDiffusion( 62 | model, 63 | image_size = 128, 64 | device_of_kernel = 'cuda', 65 | channels = 3, 66 | timesteps = args.time_steps, # number of steps 67 | loss_type = args.loss_type, # L1 or L2 68 | kernel_std=args.blur_std, 69 | kernel_size=args.blur_size, 70 | blur_routine=args.blur_routine, 71 | train_routine = args.train_routine, 72 | sampling_routine = args.sampling_routine, 73 | discrete=args.discrete 74 | ).cuda() 75 | 76 | import torch 77 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 78 | 79 | trainer = Trainer( 80 | diffusion, 81 | args.data_path, 82 | image_size = 128, 83 | train_batch_size = 32, 84 | train_lr = 2e-5, 85 | train_num_steps = args.train_steps, # total training steps 86 | gradient_accumulate_every = 2, # gradient accumulation steps 87 | ema_decay = 0.995, # exponential moving average decay 88 | fp16 = False, # turn on mixed precision training with apex 89 | results_folder = args.save_folder, 90 | load_path = args.load_path, 91 | dataset = 'AFHQ' 92 | ) 93 | 94 | trainer.train() -------------------------------------------------------------------------------- /denoising-diffusion-pytorch/AFHQ_noise_128_test.py: -------------------------------------------------------------------------------- 1 | #from comet_ml import Experiment 2 | from demixing_noise_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--time_steps', default=50, type=int, 26 | help="The number of steps the scheduler takes to go from clean image to an isotropic gaussian. This is also the number of steps of diffusion.") 27 | parser.add_argument('--train_steps', default=700000, type=int, 28 | help='The number of iterations for training.') 29 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 30 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/AFHQ/afhq/train/', type=str) 31 | parser.add_argument('--load_path', default=None, type=str) 32 | parser.add_argument('--train_routine', default='Final', type=str) 33 | parser.add_argument('--sampling_routine', default='default', type=str, 34 | help='The choice of sampling routine for reversing the diffusion process.') 35 | parser.add_argument('--remove_time_embed', action="store_true") 36 | parser.add_argument('--residual', action="store_true") 37 | parser.add_argument('--loss_type', default='l1', type=str) 38 | parser.add_argument('--test_type', default='train_data', type=str) 39 | 40 | 41 | args = parser.parse_args() 42 | print(args) 43 | 44 | img_path=None 45 | if 'train' in args.test_type: 46 | img_path = args.data_path 47 | elif 'test' in args.test_type: 48 | img_path = args.data_path 49 | 50 | model = Unet( 51 | dim = 64, 52 | dim_mults = (1, 2, 4, 8), 53 | channels=3, 54 | with_time_emb=not(args.remove_time_embed), 55 | residual=args.residual 56 | ).cuda() 57 | 58 | diffusion = GaussianDiffusion( 59 | model, 60 | image_size = 128, 61 | channels = 3, 62 | timesteps = args.time_steps, # number of steps 63 | loss_type = args.loss_type, # L1 or L2 64 | train_routine = args.train_routine, 65 | sampling_routine = args.sampling_routine 66 | ).cuda() 67 | 68 | import torch 69 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 70 | 71 | 72 | trainer = Trainer( 73 | diffusion, 74 | img_path, 75 | image_size = 128, 76 | train_batch_size = 32, 77 | train_lr = 2e-5, 78 | train_num_steps = args.train_steps, # total training steps 79 | gradient_accumulate_every = 2, # gradient accumulation steps 80 | ema_decay = 0.995, # exponential moving average decay 81 | fp16 = False, # turn on mixed precision training with apex 82 | results_folder = args.save_folder, 83 | load_path = args.load_path, 84 | dataset = 'train' 85 | ) 86 | 87 | if args.test_type == 'train_data': 88 | trainer.test_from_data('train', s_times=args.sample_steps) 89 | 90 | elif args.test_type == 'test_data': 91 | trainer.test_from_data('test', s_times=args.sample_steps) 92 | 93 | #### for FID and noise ablation ## 94 | elif args.test_type == 'test_sample_and_save_for_fid': 95 | trainer.sample_and_save_for_fid() 96 | 97 | ########## for paper ########## 98 | 99 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page': 100 | trainer.paper_showing_diffusion_images_cover_page() -------------------------------------------------------------------------------- /denoising-diffusion-pytorch/celebA_noise_128_test.py: -------------------------------------------------------------------------------- 1 | #from comet_ml import Experiment 2 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--time_steps', default=50, type=int, 26 | help="The number of steps the scheduler takes to go from clean image to an isotropic gaussian. This is also the number of steps of diffusion.") 27 | parser.add_argument('--train_steps', default=700000, type=int, 28 | help='The number of iterations for training.') 29 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 30 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str) 31 | parser.add_argument('--load_path', default=None, type=str) 32 | parser.add_argument('--train_routine', default='Final', type=str) 33 | parser.add_argument('--sampling_routine', default='default', type=str, 34 | help='The choice of sampling routine for reversing the diffusion process.') 35 | parser.add_argument('--remove_time_embed', action="store_true") 36 | parser.add_argument('--residual', action="store_true") 37 | parser.add_argument('--loss_type', default='l1', type=str) 38 | parser.add_argument('--test_type', default='train_data', type=str) 39 | 40 | 41 | args = parser.parse_args() 42 | print(args) 43 | 44 | img_path=None 45 | if 'train' in args.test_type: 46 | img_path = args.data_path 47 | elif 'test' in args.test_type: 48 | img_path = args.data_path 49 | 50 | model = Unet( 51 | dim = 64, 52 | dim_mults = (1, 2, 4, 8), 53 | channels=3, 54 | with_time_emb=not(args.remove_time_embed), 55 | residual=args.residual 56 | ).cuda() 57 | 58 | diffusion = GaussianDiffusion( 59 | model, 60 | image_size = 128, 61 | channels = 3, 62 | timesteps = args.time_steps, # number of steps 63 | loss_type = args.loss_type, # L1 or L2 64 | train_routine = args.train_routine, 65 | sampling_routine = args.sampling_routine 66 | ).cuda() 67 | 68 | import torch 69 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 70 | 71 | 72 | trainer = Trainer( 73 | diffusion, 74 | img_path, 75 | image_size = 128, 76 | train_batch_size = 32, 77 | train_lr = 2e-5, 78 | train_num_steps = args.train_steps, # total training steps 79 | gradient_accumulate_every = 2, # gradient accumulation steps 80 | ema_decay = 0.995, # exponential moving average decay 81 | fp16 = False, # turn on mixed precision training with apex 82 | results_folder = args.save_folder, 83 | load_path = args.load_path, 84 | dataset = 'train' 85 | ) 86 | 87 | if args.test_type == 'train_data': 88 | trainer.test_from_data('train', s_times=args.sample_steps) 89 | 90 | elif args.test_type == 'test_data': 91 | trainer.test_from_data('test', s_times=args.sample_steps) 92 | 93 | #### for FID and noise ablation ## 94 | elif args.test_type == 'test_sample_and_save_for_fid': 95 | trainer.sample_and_save_for_fid() 96 | 97 | ########## for paper ########## 98 | 99 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page': 100 | trainer.paper_showing_diffusion_images_cover_page() -------------------------------------------------------------------------------- /defading-generation-diffusion-pytorch/celebA_128_test.py: -------------------------------------------------------------------------------- 1 | #from comet_ml import Experiment 2 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--time_steps', default=50, type=int) 26 | parser.add_argument('--train_steps', default=700000, type=int) 27 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 28 | parser.add_argument('--load_path', default=None, type=str) 29 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str) 30 | parser.add_argument('--train_routine', default='Final', type=str) 31 | parser.add_argument('--sampling_routine', default='default', type=str) 32 | parser.add_argument('--remove_time_embed', action="store_true") 33 | parser.add_argument('--residual', action="store_true") 34 | parser.add_argument('--loss_type', default='l1', type=str) 35 | parser.add_argument('--initial_mask', default=11, type=int) 36 | parser.add_argument('--kernel_std', default=0.15, type=float) 37 | parser.add_argument('--reverse', action="store_true") 38 | parser.add_argument('--test_type', default='train_data', type=str) 39 | parser.add_argument('--noise', default=0, type=float) 40 | 41 | args = parser.parse_args() 42 | print(args) 43 | 44 | img_path=None 45 | if 'train' in args.test_type: 46 | img_path = args.data_path 47 | elif 'test' in args.test_type: 48 | img_path = args.data_path 49 | 50 | model = Unet( 51 | dim = 64, 52 | dim_mults = (1, 2, 4, 8), 53 | channels=3, 54 | with_time_emb=not(args.remove_time_embed), 55 | residual=args.residual 56 | ).cuda() 57 | 58 | diffusion = GaussianDiffusion( 59 | model, 60 | image_size = 128, 61 | channels = 3, 62 | timesteps = args.time_steps, # number of steps 63 | loss_type = args.loss_type, # L1 or L2 64 | train_routine = args.train_routine, 65 | sampling_routine = args.sampling_routine, 66 | reverse = args.reverse, 67 | kernel_std = args.kernel_std, 68 | initial_mask=args.initial_mask 69 | ).cuda() 70 | 71 | import torch 72 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 73 | 74 | 75 | trainer = Trainer( 76 | diffusion, 77 | img_path, 78 | image_size = 128, 79 | train_batch_size = 32, 80 | train_lr = 2e-5, 81 | train_num_steps = args.train_steps, # total training steps 82 | gradient_accumulate_every = 2, # gradient accumulation steps 83 | ema_decay = 0.995, # exponential moving average decay 84 | fp16 = False, # turn on mixed precision training with apex 85 | results_folder = args.save_folder, 86 | load_path = args.load_path, 87 | dataset = 'train' 88 | ) 89 | 90 | if args.test_type == 'train_data': 91 | trainer.test_from_data('train', s_times=None) 92 | 93 | elif args.test_type == 'test_data': 94 | trainer.test_from_data('test', s_times=None) 95 | 96 | #### for FID and noise ablation ## 97 | elif args.test_type == 'test_sample_and_save_for_fid': 98 | trainer.sample_and_save_for_fid(args.noise) 99 | 100 | ########## for paper ########## 101 | 102 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page': 103 | trainer.paper_showing_diffusion_images_cover_page() -------------------------------------------------------------------------------- /demixing-diffusion-pytorch/AFHQ_128_to_celebA_128_test.py: -------------------------------------------------------------------------------- 1 | #from comet_ml import Experiment 2 | from demixing_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--time_steps', default=50, type=int) 26 | parser.add_argument('--train_steps', default=700000, type=int) 27 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 28 | parser.add_argument('--data_path_start', default='../deblurring-diffusion-pytorch/AFHQ/afhq/train/', type=str) 29 | parser.add_argument('--data_path_end', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str) 30 | parser.add_argument('--load_path', default=None, type=str) 31 | parser.add_argument('--train_routine', default='Final', type=str) 32 | parser.add_argument('--sampling_routine', default='default', type=str) 33 | parser.add_argument('--remove_time_embed', action="store_true") 34 | parser.add_argument('--residual', action="store_true") 35 | parser.add_argument('--loss_type', default='l1', type=str) 36 | parser.add_argument('--test_type', default='train_data', type=str) 37 | parser.add_argument('--noise', default=0, type=float) 38 | 39 | 40 | args = parser.parse_args() 41 | print(args) 42 | 43 | img_path=None 44 | if 'train' in args.test_type: 45 | img_path = args.data_path_start 46 | elif 'test' in args.test_type: 47 | img_path = args.data_path_start 48 | 49 | 50 | model = Unet( 51 | dim = 64, 52 | dim_mults = (1, 2, 4, 8), 53 | channels=3, 54 | with_time_emb=not(args.remove_time_embed), 55 | residual=args.residual 56 | ).cuda() 57 | 58 | diffusion = GaussianDiffusion( 59 | model, 60 | image_size = 128, 61 | channels = 3, 62 | timesteps = args.time_steps, # number of steps 63 | loss_type = args.loss_type, # L1 or L2 64 | train_routine = args.train_routine, 65 | sampling_routine = args.sampling_routine 66 | ).cuda() 67 | 68 | import torch 69 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 70 | 71 | 72 | trainer = Trainer( 73 | diffusion, 74 | args.data_path_end, 75 | img_path, 76 | image_size = 128, 77 | train_batch_size = 32, 78 | train_lr = 2e-5, 79 | train_num_steps = args.train_steps, # total training steps 80 | gradient_accumulate_every = 2, # gradient accumulation steps 81 | ema_decay = 0.995, # exponential moving average decay 82 | fp16 = False, # turn on mixed precision training with apex 83 | results_folder = args.save_folder, 84 | load_path = args.load_path, 85 | dataset = 'train' 86 | ) 87 | 88 | if args.test_type == 'train_data': 89 | trainer.test_from_data('train', s_times=None) 90 | 91 | elif args.test_type == 'test_data': 92 | trainer.test_from_data('test', s_times=None) 93 | 94 | 95 | #### for FID and noise ablation ## 96 | elif args.test_type == 'test_sample_and_save_for_fid': 97 | trainer.sample_and_save_for_fid(args.noise) 98 | 99 | ########## for paper ########## 100 | 101 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page': 102 | trainer.paper_showing_diffusion_images_cover_page() 103 | 104 | elif args.test_type == 'test_paper_showing_diffusion_images_cover_page': 105 | trainer.paper_showing_diffusion_images_cover_page() -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/mnist_train.py: -------------------------------------------------------------------------------- 1 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | 8 | def create_folder(path): 9 | try: 10 | os.mkdir(path) 11 | except OSError as exc: 12 | if exc.errno != errno.EEXIST: 13 | raise 14 | pass 15 | 16 | def del_folder(path): 17 | try: 18 | shutil.rmtree(path) 19 | except OSError as exc: 20 | pass 21 | 22 | create = 0 23 | 24 | if create: 25 | trainset = torchvision.datasets.MNIST( 26 | root='./data', train=True, download=True) 27 | root = './root_mnist/' 28 | del_folder(root) 29 | create_folder(root) 30 | 31 | for i in range(10): 32 | lable_root = root + str(i) + '/' 33 | create_folder(lable_root) 34 | 35 | for idx in range(len(trainset)): 36 | img, label = trainset[idx] 37 | print(idx) 38 | img.save(root + str(label) + '/' + str(idx) + '.png') 39 | 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--time_steps', default=50, type=int, 43 | help="This is the number of steps in which a clean image looses information.") 44 | parser.add_argument('--train_steps', default=700000, type=int, 45 | help='The number of iterations for training.') 46 | parser.add_argument('--blur_std', default=0.1, type=float, 47 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.') 48 | parser.add_argument('--blur_size', default=3, type=int, 49 | help='It sets the size of gaussian blur used in blur routines for each step t') 50 | parser.add_argument('--save_folder', default='./results_mnist', type=str) 51 | parser.add_argument('--data_path', default='./root_mnist/', type=str) 52 | parser.add_argument('--load_path', default=None, type=str) 53 | parser.add_argument('--blur_routine', default='Incremental', type=str, 54 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail') 55 | parser.add_argument('--train_routine', default='Final', type=str) 56 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str, 57 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2') 58 | parser.add_argument('--discrete', action="store_true") 59 | 60 | args = parser.parse_args() 61 | print(args) 62 | 63 | 64 | model = Unet( 65 | dim = 64, 66 | dim_mults = (1, 2, 4, 8), 67 | channels=1 68 | ).cuda() 69 | 70 | 71 | diffusion = GaussianDiffusion( 72 | model, 73 | image_size = 32, 74 | device_of_kernel = 'cuda', 75 | channels = 1, 76 | timesteps = args.time_steps, # number of steps 77 | loss_type = 'l1', # L1 or L2 78 | kernel_std=args.blur_std, 79 | kernel_size=args.blur_size, 80 | blur_routine=args.blur_routine, 81 | train_routine = args.train_routine, 82 | sampling_routine = args.sampling_routine, 83 | discrete=args.discrete 84 | ).cuda() 85 | 86 | import torch 87 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 88 | 89 | trainer = Trainer( 90 | diffusion, 91 | args.data_path, 92 | image_size = 32, 93 | train_batch_size = 32, 94 | train_lr = 2e-5, 95 | train_num_steps = args.train_steps, # total training steps 96 | gradient_accumulate_every = 2, # gradient accumulation steps 97 | ema_decay = 0.995, # exponential moving average decay 98 | fp16 = False, # turn on mixed precision training with apex 99 | results_folder = args.save_folder, 100 | load_path = args.load_path, 101 | dataset = 'mnist' 102 | ) 103 | 104 | trainer.train() -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/celebA_128_test.py: -------------------------------------------------------------------------------- 1 | from pycave.bayes import GaussianMixture 2 | import torchvision 3 | import argparse 4 | from Fid import calculate_fid_given_samples 5 | import torch 6 | import random 7 | 8 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 9 | 10 | seed_value=123457 11 | torch.manual_seed(seed_value) # cpu vars 12 | random.seed(seed_value) # Python 13 | torch.cuda.manual_seed(seed_value) 14 | torch.cuda.manual_seed_all(seed_value) 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--time_steps', default=50, type=int) 20 | parser.add_argument('--sample_steps', default=None, type=int) 21 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 22 | parser.add_argument('--load_path', default=None, type=str) 23 | parser.add_argument('--data_path', default='./root_celebA_128_train_new/', type=str) 24 | parser.add_argument('--resolution_routine', default='Incremental', type=str) 25 | parser.add_argument('--test_type', default='train_data', type=str) 26 | parser.add_argument('--train_routine', default='Final', type=str) 27 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 28 | parser.add_argument('--remove_time_embed', action="store_true") 29 | parser.add_argument('--residual', action="store_true") 30 | parser.add_argument('--gmm_size', default=8, type=int) 31 | parser.add_argument('--gmm_cluster', default=10, type=int) 32 | parser.add_argument('--gmm_sample_at', default=1, type=int) 33 | parser.add_argument('--bs', default=32, type=int) 34 | parser.add_argument('--discrete', action="store_true") 35 | parser.add_argument('--noise', default=0, type=float) 36 | 37 | args = parser.parse_args() 38 | print(args) 39 | 40 | 41 | img_path=None 42 | if 'train' in args.test_type: 43 | img_path = args.data_path 44 | elif 'test' in args.test_type: 45 | img_path = args.data_path 46 | print("Img Path is ", img_path) 47 | 48 | 49 | model = Unet( 50 | dim = 64, 51 | dim_mults = (1, 2, 4, 8), 52 | channels=3, 53 | with_time_emb=not(args.remove_time_embed), 54 | residual=args.residual 55 | ).cuda() 56 | 57 | 58 | diffusion = GaussianDiffusion( 59 | model, 60 | image_size = 128, 61 | device_of_kernel = 'cuda', 62 | channels = 3, 63 | timesteps = args.time_steps, # number of steps 64 | loss_type = 'l1', # L1 or L2 65 | resolution_routine=args.resolution_routine, 66 | train_routine=args.train_routine, 67 | sampling_routine = args.sampling_routine 68 | ).cuda() 69 | 70 | import torch 71 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 72 | 73 | trainer = Trainer( 74 | diffusion, 75 | img_path, 76 | image_size = 128, 77 | train_batch_size = 32, 78 | train_lr = 2e-5, 79 | train_num_steps = 700000, # total training steps 80 | gradient_accumulate_every = 2, # gradient accumulation steps 81 | ema_decay = 0.995, # exponential moving average decay 82 | fp16 = False, # turn on mixed precision training with apex 83 | results_folder = args.save_folder, 84 | load_path = args.load_path, 85 | shuffle=False, 86 | dataset = 'celebA' 87 | ) 88 | 89 | if args.test_type == 'train_data': 90 | trainer.test_from_data('train', s_times=args.sample_steps) 91 | 92 | elif args.test_type == 'test_fid_distance_decrease_from_manifold': 93 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 94 | 95 | elif args.test_type == 'train_fid_distance_decrease_from_manifold': 96 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 97 | 98 | elif args.test_type == 'train_distribution_mean_blur_torch_gmm_ablation': 99 | trainer.sample_as_a_mean_blur_torch_gmm_ablation(GaussianMixture, siz=args.gmm_size, ch=3, clusters=args.gmm_cluster, sample_at=args.gmm_sample_at, noise=args.noise) 100 | 101 | -------------------------------------------------------------------------------- /defading-diffusion-pytorch/mnist_test.py: -------------------------------------------------------------------------------- 1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 2 | from Fid import calculate_fid_given_samples 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | 10 | def create_folder(path): 11 | try: 12 | os.mkdir(path) 13 | except OSError as exc: 14 | if exc.errno != errno.EEXIST: 15 | raise 16 | pass 17 | 18 | 19 | def del_folder(path): 20 | try: 21 | shutil.rmtree(path) 22 | except OSError as exc: 23 | pass 24 | 25 | 26 | create = 0 27 | 28 | if create: 29 | trainset = torchvision.datasets.MNIST( 30 | root='./data', train=False, download=True) 31 | root = './root_mnist_test/' 32 | del_folder(root) 33 | create_folder(root) 34 | 35 | for i in range(10): 36 | lable_root = root + str(i) + '/' 37 | create_folder(lable_root) 38 | 39 | for idx in range(len(trainset)): 40 | img, label = trainset[idx] 41 | print(idx) 42 | img.save(root + str(label) + '/' + str(idx) + '.png') 43 | 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--time_steps', default=50, type=int) 47 | parser.add_argument('--sample_steps', default=None, type=int) 48 | parser.add_argument('--kernel_std', default=0.1, type=float) 49 | parser.add_argument('--save_folder', default='progression_mnist', type=str) 50 | parser.add_argument('--load_path', default='/cmlscratch/eborgnia/cold_diffusion/paper_defading_random_mnist_1/model.pt', type=str) 51 | parser.add_argument('--data_path', default='./root_mnist_test/', type=str) 52 | parser.add_argument('--test_type', default='test_fid_distance_decrease_from_manifold', type=str) 53 | parser.add_argument('--fade_routine', default='Random_Incremental', type=str) 54 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 55 | parser.add_argument('--remove_time_embed', action="store_true") 56 | parser.add_argument('--discrete', action="store_true") 57 | parser.add_argument('--residual', action="store_true") 58 | args = parser.parse_args() 59 | print(args) 60 | 61 | img_path = None 62 | if 'train' in args.test_type: 63 | img_path = args.data_path 64 | elif 'test' in args.test_type: 65 | img_path = args.data_path 66 | 67 | print("Img Path is ", img_path) 68 | 69 | 70 | model = Unet( 71 | dim=64, 72 | dim_mults=(1, 2, 4, 8), 73 | channels=1 74 | ).cuda() 75 | 76 | 77 | diffusion = GaussianDiffusion( 78 | model, 79 | image_size=32, 80 | device_of_kernel='cuda', 81 | channels=1, 82 | timesteps=args.time_steps, # number of steps 83 | loss_type='l1', # L1 or L2 84 | kernel_std=args.kernel_std, 85 | fade_routine=args.fade_routine, 86 | sampling_routine=args.sampling_routine, 87 | discrete=args.discrete 88 | ).cuda() 89 | 90 | trainer = Trainer( 91 | diffusion, 92 | img_path, 93 | image_size=32, 94 | train_batch_size=32, 95 | train_lr=2e-5, 96 | train_num_steps=700000, # total training steps 97 | gradient_accumulate_every=2, # gradient accumulation steps 98 | ema_decay=0.995, # exponential moving average decay 99 | fp16=False, # turn on mixed precision training with apex 100 | results_folder=args.save_folder, 101 | load_path=args.load_path 102 | ) 103 | 104 | if args.test_type == 'train_data': 105 | trainer.test_from_data('train', s_times=args.sample_steps) 106 | 107 | elif args.test_type == 'test_data': 108 | trainer.test_from_data('test', s_times=args.sample_steps) 109 | 110 | elif args.test_type == 'mixup_train_data': 111 | trainer.test_with_mixup('train') 112 | 113 | elif args.test_type == 'mixup_test_data': 114 | trainer.test_with_mixup('test') 115 | 116 | elif args.test_type == 'test_random': 117 | trainer.test_from_random('random') 118 | 119 | elif args.test_type == 'test_fid_distance_decrease_from_manifold': 120 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 121 | 122 | elif args.test_type == 'test_paper_invert_section_images': 123 | trainer.paper_invert_section_images() 124 | 125 | elif args.test_type == 'test_paper_showing_diffusion_images_diff': 126 | trainer.paper_showing_diffusion_images() 127 | -------------------------------------------------------------------------------- /demixing-diffusion-pytorch/dispatch.py: -------------------------------------------------------------------------------- 1 | """dispatch.py 2 | for dispatching jobs on CML 3 | July 2021 4 | """ 5 | import argparse 6 | import getpass 7 | import os 8 | import random 9 | import subprocess 10 | import time 11 | import warnings 12 | 13 | parser = argparse.ArgumentParser(description="Dispatch python jobs on the CML cluster") 14 | parser.add_argument("file", type=argparse.FileType()) 15 | parser.add_argument("--qos", default="scav", type=str, 16 | help="QOS, choose default, medium, high, scav") 17 | parser.add_argument("--name", default=None, type=str, 18 | help="Name that will be displayed in squeue. Default: file name") 19 | parser.add_argument("--gpus", default="1", type=int, help="Requested GPUs per job") 20 | parser.add_argument("--mem", default="64", type=int, help="Requested memory per job") 21 | parser.add_argument("--time", default=10, type=int, help="Requested hour limit per job") 22 | args = parser.parse_args() 23 | 24 | # get username for use in email and squeue below (do crazy things for arjun) 25 | cmluser_str = getpass.getuser() 26 | username = cmluser_str 27 | 28 | # Parse and validate input: 29 | if args.name is None: 30 | dispatch_name = args.file.name 31 | else: 32 | dispatch_name = args.name 33 | 34 | # Usage warnings: 35 | if args.mem > 385: 36 | raise ValueError("Maximal node memory exceeded.") 37 | if args.gpus > 8: 38 | raise ValueError("Maximal node GPU number exceeded.") 39 | if args.qos == "high" and args.gpus > 4: 40 | warnings.warn("QOS only allows for 4 GPUs, GPU request has been reduced to 4.") 41 | args.gpus = 4 42 | if args.qos == "medium" and args.gpus > 2: 43 | warnings.warn("QOS only allows for 2 GPUs, GPU request has been reduced to 2.") 44 | args.gpus = 2 45 | if args.qos == "default" and args.gpus > 1: 46 | warnings.warn("QOS only allows for 1 GPU, GPU request has been reduced to 1.") 47 | args.gpus = 1 48 | if args.mem / args.gpus > 48: 49 | warnings.warn("You are oversubscribing to memory. " 50 | "This might leave some GPUs idle as total node memory is consumed.") 51 | 52 | # 1) Stripping file of comments and blank lines 53 | content = args.file.readlines() 54 | jobs = [c.strip().split("#", 1)[0] for c in content if "python" in c and c[0] != "#"] 55 | 56 | print(f"Detected {len(jobs)} jobs.") 57 | 58 | # Write file list 59 | authkey = random.randint(10**5, 10**6 - 1) 60 | with open(f".cml_job_list_{authkey}.temp.sh", "w") as file: 61 | file.writelines(chr(10).join(job for job in jobs)) 62 | file.write("\n") 63 | 64 | # 2) Prepare environment 65 | if not os.path.exists("cmllogs"): 66 | os.makedirs("cmllogs") 67 | 68 | # 3) Construct launch file 69 | SBATCH_PROTOTYPE = \ 70 | f"""#!/bin/bash 71 | # Lines that begin with #SBATCH specify commands to be used by SLURM for scheduling 72 | #SBATCH --job-name={"".join(e for e in dispatch_name if e.isalnum())} 73 | #SBATCH --array={f"1-{len(jobs)}"} 74 | #SBATCH --output=cmllogs/%x_%A_%a.log 75 | #SBATCH --error=cmllogs/%x_%A_%a.log 76 | #SBATCH --time={args.time}:00:00 77 | #SBATCH --account={"tomg" if args.qos != "scav" else "scavenger"} 78 | #SBATCH --qos={args.qos if args.qos != "scav" else "scavenger"} 79 | #SBATCH --gres=gpu:{args.gpus} 80 | #SBATCH --cpus-per-task=4 81 | #SBATCH --partition={"dpart" if args.qos != "scav" else "scavenger"} 82 | #SBATCH --mem={args.mem}gb 83 | #SBATCH --mail-user={username}@umd.edu 84 | #SBATCH --mail-type=END,TIME_LIMIT,FAIL,ARRAY_TASKS 85 | #SBATCH --exclude=cmlgrad05,cmlgrad02,cml12,cml17,cml18,cml19,cml20,cml21,cml22,cml23,cml24 86 | srun $(head -n $((${{SLURM_ARRAY_TASK_ID}} + 0)) .cml_job_list_{authkey}.temp.sh | tail -n 1) 87 | """ 88 | 89 | # Write launch commands to file 90 | with open(f".cml_launch_{authkey}.temp.sh", "w") as file: 91 | file.write(SBATCH_PROTOTYPE) 92 | print("Launch prototype is ...") 93 | print("---------------") 94 | print(SBATCH_PROTOTYPE) 95 | print("---------------") 96 | print(chr(10).join("srun " + job for job in jobs)) 97 | print(f"Preparing {len(jobs)} jobs ") 98 | print("Terminate if necessary ...") 99 | for _ in range(10): 100 | time.sleep(1) 101 | 102 | # Execute file with sbatch 103 | subprocess.run(["/usr/bin/sbatch", f".cml_launch_{authkey}.temp.sh"]) 104 | print("Subprocess launched ...") 105 | time.sleep(1) 106 | os.system(f"watch squeue -u {cmluser_str}") -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/dispatch.py: -------------------------------------------------------------------------------- 1 | """dispatch.py 2 | for dispatching jobs on CML 3 | July 2021 4 | """ 5 | import argparse 6 | import getpass 7 | import os 8 | import random 9 | import subprocess 10 | import time 11 | import warnings 12 | 13 | parser = argparse.ArgumentParser(description="Dispatch python jobs on the CML cluster") 14 | parser.add_argument("file", type=argparse.FileType()) 15 | parser.add_argument("--qos", default="scav", type=str, 16 | help="QOS, choose default, medium, high, scav") 17 | parser.add_argument("--name", default=None, type=str, 18 | help="Name that will be displayed in squeue. Default: file name") 19 | parser.add_argument("--gpus", default="1", type=int, help="Requested GPUs per job") 20 | parser.add_argument("--mem", default="64", type=int, help="Requested memory per job") 21 | parser.add_argument("--time", default=10, type=int, help="Requested hour limit per job") 22 | args = parser.parse_args() 23 | 24 | # get username for use in email and squeue below (do crazy things for arjun) 25 | cmluser_str = getpass.getuser() 26 | username = cmluser_str 27 | 28 | # Parse and validate input: 29 | if args.name is None: 30 | dispatch_name = args.file.name 31 | else: 32 | dispatch_name = args.name 33 | 34 | # Usage warnings: 35 | if args.mem > 385: 36 | raise ValueError("Maximal node memory exceeded.") 37 | if args.gpus > 8: 38 | raise ValueError("Maximal node GPU number exceeded.") 39 | if args.qos == "high" and args.gpus > 4: 40 | warnings.warn("QOS only allows for 4 GPUs, GPU request has been reduced to 4.") 41 | args.gpus = 4 42 | if args.qos == "medium" and args.gpus > 2: 43 | warnings.warn("QOS only allows for 2 GPUs, GPU request has been reduced to 2.") 44 | args.gpus = 2 45 | if args.qos == "default" and args.gpus > 1: 46 | warnings.warn("QOS only allows for 1 GPU, GPU request has been reduced to 1.") 47 | args.gpus = 1 48 | if args.mem / args.gpus > 48: 49 | warnings.warn("You are oversubscribing to memory. " 50 | "This might leave some GPUs idle as total node memory is consumed.") 51 | 52 | # 1) Stripping file of comments and blank lines 53 | content = args.file.readlines() 54 | jobs = [c.strip().split("#", 1)[0] for c in content if "python" in c and c[0] != "#"] 55 | 56 | print(f"Detected {len(jobs)} jobs.") 57 | 58 | # Write file list 59 | authkey = random.randint(10**5, 10**6 - 1) 60 | with open(f".cml_job_list_{authkey}.temp.sh", "w") as file: 61 | file.writelines(chr(10).join(job for job in jobs)) 62 | file.write("\n") 63 | 64 | # 2) Prepare environment 65 | if not os.path.exists("cmllogs"): 66 | os.makedirs("cmllogs") 67 | 68 | # 3) Construct launch file 69 | SBATCH_PROTOTYPE = \ 70 | f"""#!/bin/bash 71 | # Lines that begin with #SBATCH specify commands to be used by SLURM for scheduling 72 | #SBATCH --job-name={"".join(e for e in dispatch_name if e.isalnum())} 73 | #SBATCH --array={f"1-{len(jobs)}"} 74 | #SBATCH --output=cmllogs/%x_%A_%a.log 75 | #SBATCH --error=cmllogs/%x_%A_%a.log 76 | #SBATCH --time={args.time}:00:00 77 | #SBATCH --account={"tomg" if args.qos != "scav" else "scavenger"} 78 | #SBATCH --qos={args.qos if args.qos != "scav" else "scavenger"} 79 | #SBATCH --gres=gpu:{args.gpus} 80 | #SBATCH --cpus-per-task=4 81 | #SBATCH --partition={"dpart" if args.qos != "scav" else "scavenger"} 82 | #SBATCH --mem={args.mem}gb 83 | #SBATCH --mail-user={username}@umd.edu 84 | #SBATCH --mail-type=END,TIME_LIMIT,FAIL,ARRAY_TASKS 85 | #SBATCH --exclude=cmlgrad05,cmlgrad02,cml12,cml17,cml18,cml19,cml20,cml21,cml22,cml23,cml24 86 | srun $(head -n $((${{SLURM_ARRAY_TASK_ID}} + 0)) .cml_job_list_{authkey}.temp.sh | tail -n 1) 87 | """ 88 | 89 | # Write launch commands to file 90 | with open(f".cml_launch_{authkey}.temp.sh", "w") as file: 91 | file.write(SBATCH_PROTOTYPE) 92 | print("Launch prototype is ...") 93 | print("---------------") 94 | print(SBATCH_PROTOTYPE) 95 | print("---------------") 96 | print(chr(10).join("srun " + job for job in jobs)) 97 | print(f"Preparing {len(jobs)} jobs ") 98 | print("Terminate if necessary ...") 99 | for _ in range(10): 100 | time.sleep(1) 101 | 102 | # Execute file with sbatch 103 | subprocess.run(["/usr/bin/sbatch", f".cml_launch_{authkey}.temp.sh"]) 104 | print("Subprocess launched ...") 105 | time.sleep(1) 106 | os.system(f"watch squeue -u {cmluser_str}") -------------------------------------------------------------------------------- /defading-diffusion-pytorch/celebA_test.py: -------------------------------------------------------------------------------- 1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 2 | from Fid import calculate_fid_given_samples 3 | import torch 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | 10 | def create_folder(path): 11 | try: 12 | os.mkdir(path) 13 | except OSError as exc: 14 | if exc.errno != errno.EEXIST: 15 | raise 16 | pass 17 | 18 | 19 | def del_folder(path): 20 | try: 21 | shutil.rmtree(path) 22 | except OSError as exc: 23 | pass 24 | 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--time_steps', default=100, type=int) 28 | parser.add_argument('--sample_steps', default=None, type=int) 29 | parser.add_argument('--kernel_std', default=0.2, type=float) 30 | parser.add_argument('--initial_mask', default=1, type=int) 31 | parser.add_argument('--save_folder', default='inpainting_with_algorithm_1', type=str) 32 | parser.add_argument('--load_path', default='centered_mask_celebA_128_100_step_ker_0.2_data_augmentation_float/model.pt', type=str) 33 | parser.add_argument('--data_path', default='/cmlscratch/bansal01/spring_2022/Cold-Diffusion/deblurring-diffusion-pytorch/root_celebA_128_test_new/', type=str) 34 | parser.add_argument('--test_type', default='test_data', type=str) 35 | parser.add_argument('--fade_routine', default='Incremental', type=str) 36 | parser.add_argument('--sampling_routine', default='default', type=str) 37 | parser.add_argument('--remove_time_embed', action="store_true") 38 | parser.add_argument('--discrete', action="store_true") 39 | parser.add_argument('--residual', action="store_true") 40 | parser.add_argument('--image_size', default=128, type=int) 41 | parser.add_argument('--dataset', default=None, type=str) 42 | args = parser.parse_args() 43 | print(args) 44 | 45 | img_path=None 46 | if 'train' in args.test_type: 47 | img_path = args.data_path 48 | elif 'test' in args.test_type: 49 | img_path = args.data_path 50 | 51 | model = Unet( 52 | dim=64, 53 | dim_mults=(1, 2, 4, 8), 54 | channels=3, 55 | with_time_emb=not args.remove_time_embed, 56 | residual=args.residual 57 | ).cuda() 58 | 59 | # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 60 | 61 | diffusion = GaussianDiffusion( 62 | model, 63 | image_size=args.image_size, 64 | device_of_kernel='cuda', 65 | channels=3, 66 | timesteps=args.time_steps, # number of steps 67 | loss_type='l1', 68 | kernel_std=args.kernel_std, # L1 or L2 69 | fade_routine=args.fade_routine, 70 | sampling_routine=args.sampling_routine, 71 | discrete=args.discrete, 72 | initial_mask=args.initial_mask 73 | ).cuda() 74 | 75 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 76 | 77 | trainer = Trainer( 78 | diffusion, 79 | img_path, 80 | image_size=args.image_size, 81 | train_batch_size=32, 82 | train_lr=2e-5, 83 | train_num_steps=700000, # total training steps 84 | gradient_accumulate_every=2, # gradient accumulation steps 85 | ema_decay=0.995, # exponential moving average decay 86 | fp16=False, # turn on mixed precision training with apex 87 | results_folder=args.save_folder, 88 | load_path=args.load_path, 89 | dataset=args.dataset 90 | ) 91 | 92 | if args.test_type == 'train_data': 93 | trainer.test_from_data('train', s_times=args.sample_steps) 94 | 95 | elif args.test_type == 'test_data': 96 | trainer.test_from_data('test', s_times=args.sample_steps) 97 | 98 | elif args.test_type == 'mixup_train_data': 99 | trainer.test_with_mixup('train') 100 | 101 | elif args.test_type == 'mixup_test_data': 102 | trainer.test_with_mixup('test') 103 | 104 | elif args.test_type == 'test_random': 105 | trainer.test_from_random('random') 106 | 107 | elif args.test_type == 'test_fid_distance_decrease_from_manifold': 108 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 109 | 110 | elif args.test_type == 'test_paper_invert_section_images': 111 | trainer.paper_invert_section_images() 112 | 113 | elif args.test_type == 'test_paper_showing_diffusion_images_diff': 114 | trainer.paper_showing_diffusion_images() 115 | 116 | elif args.test_type == 'test_save_images': 117 | trainer.test_from_data_save_results() 118 | -------------------------------------------------------------------------------- /defading-diffusion-pytorch/cifar10_test.py: -------------------------------------------------------------------------------- 1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model 2 | from Fid import calculate_fid_given_samples 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | 10 | def create_folder(path): 11 | try: 12 | os.mkdir(path) 13 | except OSError as exc: 14 | if exc.errno != errno.EEXIST: 15 | raise 16 | pass 17 | 18 | 19 | def del_folder(path): 20 | try: 21 | shutil.rmtree(path) 22 | except OSError as exc: 23 | pass 24 | 25 | 26 | create = 0 27 | 28 | if create: 29 | trainset = torchvision.datasets.CIFAR10( 30 | root='./data', train=False, download=True) 31 | root = './root_cifar10_test/' 32 | del_folder(root) 33 | create_folder(root) 34 | 35 | for i in range(10): 36 | lable_root = root + str(i) + '/' 37 | create_folder(lable_root) 38 | 39 | for idx in range(len(trainset)): 40 | img, label = trainset[idx] 41 | print(idx) 42 | img.save(root + str(label) + '/' + str(idx) + '.png') 43 | 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--time_steps', default=50, type=int) 47 | parser.add_argument('--sample_steps', default=None, type=int) 48 | parser.add_argument('--kernel_std', default=0.1, type=float) 49 | parser.add_argument('--save_folder', default='progression_cifar', type=str) 50 | parser.add_argument('--load_path', default='/cmlscratch/eborgnia/cold_diffusion/paper_defading_random_1/model.pt', type=str) 51 | parser.add_argument('--data_path', default='./root_cifar10_test/', type=str) 52 | parser.add_argument('--test_type', default='test_paper_showing_diffusion_images_diff', type=str) 53 | parser.add_argument('--fade_routine', default='Random_Incremental', type=str) 54 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 55 | parser.add_argument('--remove_time_embed', action="store_true") 56 | parser.add_argument('--discrete', action="store_true") 57 | parser.add_argument('--residual', action="store_true") 58 | 59 | args = parser.parse_args() 60 | print(args) 61 | 62 | img_path=None 63 | if 'train' in args.test_type: 64 | img_path = args.data_path 65 | elif 'test' in args.test_type: 66 | img_path = args.data_path 67 | 68 | print("Img Path is ", img_path) 69 | 70 | model = Model(resolution=32, 71 | in_channels=3, 72 | out_ch=3, 73 | ch=128, 74 | ch_mult=(1, 2, 2, 2), 75 | num_res_blocks=2, 76 | attn_resolutions=(16,), 77 | dropout=0.1).cuda() 78 | 79 | diffusion = GaussianDiffusion( 80 | model, 81 | image_size = 32, 82 | device_of_kernel = 'cuda', 83 | channels = 3, 84 | timesteps = args.time_steps, # number of steps 85 | loss_type = 'l1', # L1 or L2 86 | kernel_std=args.kernel_std, 87 | fade_routine=args.fade_routine, 88 | sampling_routine = args.sampling_routine, 89 | discrete=args.discrete 90 | ).cuda() 91 | 92 | trainer = Trainer( 93 | diffusion, 94 | img_path, 95 | image_size = 32, 96 | train_batch_size = 32, 97 | train_lr = 2e-5, 98 | train_num_steps = 700000, # total training steps 99 | gradient_accumulate_every = 2, # gradient accumulation steps 100 | ema_decay = 0.995, # exponential moving average decay 101 | fp16 = False, # turn on mixed precision training with apex 102 | results_folder = args.save_folder, 103 | load_path = args.load_path 104 | ) 105 | 106 | if args.test_type == 'train_data': 107 | trainer.test_from_data('train', s_times=args.sample_steps) 108 | 109 | elif args.test_type == 'test_data': 110 | trainer.test_from_data('test', s_times=args.sample_steps) 111 | 112 | elif args.test_type == 'mixup_train_data': 113 | trainer.test_with_mixup('train') 114 | 115 | elif args.test_type == 'mixup_test_data': 116 | trainer.test_with_mixup('test') 117 | 118 | elif args.test_type == 'test_random': 119 | trainer.test_from_random('random') 120 | 121 | elif args.test_type == 'test_fid_distance_decrease_from_manifold': 122 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 123 | 124 | elif args.test_type == 'test_paper_invert_section_images': 125 | trainer.paper_invert_section_images() 126 | 127 | elif args.test_type == 'test_paper_showing_diffusion_images_diff': 128 | trainer.paper_showing_diffusion_images() 129 | -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/cifar10_train.py: -------------------------------------------------------------------------------- 1 | from comet_ml import Experiment 2 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | create = 0 24 | 25 | if create: 26 | trainset = torchvision.datasets.CIFAR10( 27 | root='./data', train=True, download=True) 28 | root = './root_cifar10/' 29 | del_folder(root) 30 | create_folder(root) 31 | 32 | for i in range(10): 33 | lable_root = root + str(i) + '/' 34 | create_folder(lable_root) 35 | 36 | for idx in range(len(trainset)): 37 | img, label = trainset[idx] 38 | print(idx) 39 | img.save(root + str(label) + '/' + str(idx) + '.png') 40 | 41 | 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--time_steps', default=50, type=int, 44 | help="This is the number of steps in which a clean image looses information.") 45 | parser.add_argument('--train_steps', default=700000, type=int, 46 | help='The number of iterations for training.') 47 | parser.add_argument('--blur_std', default=0.1, type=float, 48 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.') 49 | parser.add_argument('--blur_size', default=3, type=int, 50 | help='It sets the size of gaussian blur used in blur routines for each step t') 51 | parser.add_argument('--image_size', default=32, type=int) 52 | parser.add_argument('--batch_size', default=32, type=int) 53 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 54 | parser.add_argument('--data_path', default='./root_cifar10/', type=str) 55 | parser.add_argument('--load_path', default=None, type=str) 56 | parser.add_argument('--blur_routine', default='Incremental', type=str, 57 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail') 58 | parser.add_argument('--train_routine', default='Final', type=str) 59 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str, 60 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2') 61 | parser.add_argument('--remove_time_embed', action="store_true") 62 | parser.add_argument('--residual', action="store_true") 63 | parser.add_argument('--discrete', action="store_true") 64 | parser.add_argument('--da', default='cifar10', type=str) 65 | 66 | 67 | args = parser.parse_args() 68 | print(args) 69 | 70 | 71 | model = Model(resolution=args.image_size, 72 | in_channels=3, 73 | out_ch=3, 74 | ch=128, 75 | ch_mult=(1,2,2,2), 76 | num_res_blocks=2, 77 | attn_resolutions=(16,), 78 | dropout=0.1).cuda() 79 | 80 | diffusion = GaussianDiffusion( 81 | model, 82 | image_size = args.image_size, 83 | device_of_kernel = 'cuda', 84 | channels = 3, 85 | timesteps = args.time_steps, # number of steps 86 | loss_type = 'l1', # L1 or L2 87 | kernel_std=args.blur_std, 88 | kernel_size=args.blur_size, 89 | blur_routine=args.blur_routine, 90 | train_routine = args.train_routine, 91 | sampling_routine = args.sampling_routine, 92 | discrete=args.discrete 93 | ).cuda() 94 | 95 | import torch 96 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 97 | 98 | 99 | trainer = Trainer( 100 | diffusion, 101 | args.data_path, 102 | image_size = args.image_size, 103 | train_batch_size = args.batch_size, 104 | train_lr = 2e-5, 105 | train_num_steps = args.train_steps, # total training steps 106 | gradient_accumulate_every = 2, # gradient accumulation steps 107 | ema_decay = 0.995, # exponential moving average decay 108 | fp16 = False, # turn on mixed precision training with apex 109 | results_folder = args.save_folder, 110 | load_path = args.load_path, 111 | dataset = args.da 112 | ) 113 | 114 | trainer.train() -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/celebA_128.py: -------------------------------------------------------------------------------- 1 | #from comet_ml import Experiment 2 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | 9 | def create_folder(path): 10 | try: 11 | os.mkdir(path) 12 | except OSError as exc: 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | pass 16 | 17 | def del_folder(path): 18 | try: 19 | shutil.rmtree(path) 20 | except OSError as exc: 21 | pass 22 | 23 | create = 0 24 | from pathlib import Path 25 | from PIL import Image 26 | 27 | if create: 28 | root_train = './root_celebA_128_train_new/' 29 | root_test = './root_celebA_128_test_new/' 30 | del_folder(root_train) 31 | create_folder(root_train) 32 | 33 | del_folder(root_test) 34 | create_folder(root_test) 35 | 36 | exts = ['jpg', 'jpeg', 'png'] 37 | folder = '/fs/cml-datasets/CelebA-HQ/images-128/' 38 | paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] 39 | 40 | for idx in range(len(paths)): 41 | img = Image.open(paths[idx]) 42 | print(idx) 43 | if idx < 0.9*len(paths): 44 | img.save(root_train + str(idx) + '.png') 45 | else: 46 | img.save(root_test + str(idx) + '.png') 47 | 48 | 49 | 50 | 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument('--time_steps', default=50, type=int, 53 | help="This is the number of steps in which a clean image looses information.") 54 | parser.add_argument('--train_steps', default=700000, type=int, 55 | help='The number of iterations for training.') 56 | parser.add_argument('--blur_std', default=0.1, type=float, 57 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.') 58 | parser.add_argument('--blur_size', default=3, type=int, 59 | help='It sets the size of gaussian blur used in blur routines for each step t') 60 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 61 | parser.add_argument('--data_path', default='./root_celebA_128_train_new/', type=str) 62 | parser.add_argument('--load_path', default=None, type=str) 63 | parser.add_argument('--blur_routine', default='Incremental', type=str, 64 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail') 65 | parser.add_argument('--train_routine', default='Final', type=str) 66 | parser.add_argument('--sampling_routine', default='default', type=str, 67 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2') 68 | parser.add_argument('--remove_time_embed', action="store_true") 69 | parser.add_argument('--residual', action="store_true") 70 | parser.add_argument('--loss_type', default='l1', type=str) 71 | parser.add_argument('--discrete', action="store_true") 72 | 73 | 74 | args = parser.parse_args() 75 | print(args) 76 | 77 | 78 | model = Unet( 79 | dim = 64, 80 | dim_mults = (1, 2, 4, 8), 81 | channels=3, 82 | with_time_emb=not(args.remove_time_embed), 83 | residual=args.residual 84 | ).cuda() 85 | 86 | diffusion = GaussianDiffusion( 87 | model, 88 | image_size = 128, 89 | device_of_kernel = 'cuda', 90 | channels = 3, 91 | timesteps = args.time_steps, # number of steps 92 | loss_type = args.loss_type, # L1 or L2 93 | kernel_std=args.blur_std, 94 | kernel_size=args.blur_size, 95 | blur_routine=args.blur_routine, 96 | train_routine = args.train_routine, 97 | sampling_routine = args.sampling_routine, 98 | discrete=args.discrete 99 | ).cuda() 100 | 101 | import torch 102 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 103 | 104 | trainer = Trainer( 105 | diffusion, 106 | args.data_path, 107 | image_size = 128, 108 | train_batch_size = 32, 109 | train_lr = 2e-5, 110 | train_num_steps = args.train_steps, # total training steps 111 | gradient_accumulate_every = 2, # gradient accumulation steps 112 | ema_decay = 0.995, # exponential moving average decay 113 | fp16 = False, # turn on mixed precision training with apex 114 | results_folder = args.save_folder, 115 | load_path = args.load_path, 116 | dataset = 'celebA' 117 | ) 118 | 119 | trainer.train() -------------------------------------------------------------------------------- /snowification/train.py: -------------------------------------------------------------------------------- 1 | from diffusion import GaussianDiffusion, Trainer, get_dataset 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | from diffusion.model.get_model import get_model 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--time_steps', default=50, type=int) 11 | parser.add_argument('--train_steps', default=700000, type=int) 12 | parser.add_argument('--save_folder', default='/cmlscratch/hmchu/cold_diff/', type=str) 13 | parser.add_argument('--load_path', default=None, type=str) 14 | parser.add_argument('--train_routine', default='Final', type=str) 15 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 16 | parser.add_argument('--remove_time_embed', action="store_true") 17 | parser.add_argument('--dataset_folder', default='./root_cifar10', type=str) 18 | parser.add_argument('--random_aug', action='store_true') 19 | parser.add_argument('--output_mean_scale', action='store_true') 20 | parser.add_argument('--exp_name', default='', type=str) 21 | parser.add_argument('--dataset', default='cifar10') 22 | parser.add_argument('--model', default='UnetConvNext', type=str) 23 | 24 | parser.add_argument('--forward_process_type', default='Snow') 25 | 26 | # Decolor args 27 | parser.add_argument('--decolor_routine', default='Constant') 28 | parser.add_argument('--decolor_ema_factor', default=0.9, type=float) 29 | parser.add_argument('--decolor_total_remove', action='store_true') 30 | parser.add_argument('--to_lab', action='store_true') 31 | 32 | parser.add_argument('--loss_type', type=str, default='l1') 33 | parser.add_argument('--resume_training', action='store_true') 34 | parser.add_argument('--load_model_steps', default=-1, type=int) 35 | 36 | # Snow arg 37 | parser.add_argument('--snow_level', default=1, type=int) 38 | parser.add_argument('--random_snow', action='store_true') 39 | parser.add_argument('--single_snow', action='store_true') 40 | parser.add_argument('--fix_brightness', action='store_true') 41 | 42 | parser.add_argument('--resolution', default=-1, type=int) 43 | 44 | args = parser.parse_args() 45 | 46 | assert len(args.exp_name) > 0 47 | args.save_folder = os.path.join(args.save_folder, args.exp_name) 48 | print(args) 49 | 50 | if args.resume_training: 51 | if args.load_model_steps == -1: 52 | args.load_path = os.path.join(args.save_folder, 'model.pt') 53 | else: 54 | args.load_path = os.path.join(args.save_folder, f'model_{args.load_model_steps}.pt') 55 | print(f'resume from checkpoint stored at {args.load_path}') 56 | 57 | with_time_emb = not args.remove_time_embed 58 | 59 | model = get_model(args, with_time_emb=with_time_emb).cuda() 60 | model_one_shot = get_model(args, with_time_emb=False).cuda() 61 | 62 | 63 | 64 | image_size = get_dataset.get_image_size(args.dataset) 65 | if args.resolution != -1: 66 | image_size = (args.resolution, args.resolution) 67 | 68 | use_torchvison_dataset = False 69 | if 'cifar10' in args.dataset: 70 | use_torchvison_dataset = True 71 | args.dataset = 'cifar10_train' 72 | 73 | if image_size[0] <= 64: 74 | train_batch_size = 32 75 | elif image_size[0] > 64: 76 | train_batch_size = 16 77 | 78 | 79 | diffusion = GaussianDiffusion( 80 | model, 81 | image_size = image_size, 82 | device_of_kernel = 'cuda', 83 | channels = 3, 84 | one_shot_denoise_fn=model_one_shot, 85 | timesteps = args.time_steps, # number of steps 86 | loss_type = args.loss_type, # L1 or L2 87 | train_routine = args.train_routine, 88 | sampling_routine = args.sampling_routine, 89 | forward_process_type=args.forward_process_type, 90 | decolor_routine=args.decolor_routine, 91 | decolor_ema_factor=args.decolor_ema_factor, 92 | decolor_total_remove=args.decolor_total_remove, 93 | snow_level=args.snow_level, 94 | single_snow=args.single_snow, 95 | batch_size=train_batch_size, 96 | random_snow=args.random_snow, 97 | to_lab=args.to_lab, 98 | load_path=args.load_path, 99 | results_folder=args.save_folder, 100 | fix_brightness=args.fix_brightness, 101 | ).cuda() 102 | 103 | trainer = Trainer( 104 | diffusion, 105 | args.dataset_folder, 106 | image_size = image_size, 107 | train_batch_size = train_batch_size, 108 | train_lr = 2e-5, 109 | train_num_steps = args.train_steps, # total training steps 110 | gradient_accumulate_every = 2, # gradient accumulation steps 111 | ema_decay = 0.995, # exponential moving average decay 112 | fp16 = False, # turn on mixed precision training with apex 113 | results_folder = args.save_folder, 114 | load_path = args.load_path, 115 | random_aug=args.random_aug, 116 | torchvision_dataset=use_torchvison_dataset, 117 | dataset = f'{args.dataset}', 118 | to_lab=args.to_lab, 119 | ) 120 | 121 | trainer.train() 122 | -------------------------------------------------------------------------------- /decolor-diffusion/train.py: -------------------------------------------------------------------------------- 1 | from diffusion import GaussianDiffusion, Trainer, get_dataset 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | from diffusion.model.get_model import get_model 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--time_steps', default=50, type=int) 11 | parser.add_argument('--train_steps', default=700000, type=int) 12 | parser.add_argument('--save_folder', default='/cmlscratch/hmchu/cold_diff/', type=str) 13 | parser.add_argument('--load_path', default=None, type=str) 14 | parser.add_argument('--train_routine', default='Final', type=str) 15 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 16 | parser.add_argument('--remove_time_embed', action="store_true") 17 | parser.add_argument('--dataset_folder', default='./root_cifar10', type=str) 18 | parser.add_argument('--random_aug', action='store_true') 19 | parser.add_argument('--output_mean_scale', action='store_true') 20 | parser.add_argument('--exp_name', default='', type=str) 21 | parser.add_argument('--dataset', default='cifar10') 22 | parser.add_argument('--model', default='UnetConvNext', type=str) 23 | 24 | parser.add_argument('--forward_process_type', default='Snow') 25 | 26 | # Decolor args 27 | parser.add_argument('--decolor_routine', default='Constant') 28 | parser.add_argument('--decolor_ema_factor', default=0.9, type=float) 29 | parser.add_argument('--decolor_total_remove', action='store_true') 30 | parser.add_argument('--to_lab', action='store_true') 31 | 32 | parser.add_argument('--loss_type', type=str, default='l1') 33 | parser.add_argument('--resume_training', action='store_true') 34 | parser.add_argument('--load_model_steps', default=-1, type=int) 35 | 36 | # Snow arg 37 | parser.add_argument('--snow_level', default=1, type=int) 38 | parser.add_argument('--random_snow', action='store_true') 39 | parser.add_argument('--single_snow', action='store_true') 40 | parser.add_argument('--fix_brightness', action='store_true') 41 | 42 | parser.add_argument('--resolution', default=-1, type=int) 43 | 44 | args = parser.parse_args() 45 | 46 | assert len(args.exp_name) > 0 47 | args.save_folder = os.path.join(args.save_folder, args.exp_name) 48 | print(args) 49 | 50 | if args.resume_training: 51 | if args.load_model_steps == -1: 52 | args.load_path = os.path.join(args.save_folder, 'model.pt') 53 | else: 54 | args.load_path = os.path.join(args.save_folder, f'model_{args.load_model_steps}.pt') 55 | print(f'resume from checkpoint stored at {args.load_path}') 56 | 57 | with_time_emb = not args.remove_time_embed 58 | 59 | model = get_model(args, with_time_emb=with_time_emb).cuda() 60 | model_one_shot = get_model(args, with_time_emb=False).cuda() 61 | 62 | 63 | 64 | image_size = get_dataset.get_image_size(args.dataset) 65 | if args.resolution != -1: 66 | image_size = (args.resolution, args.resolution) 67 | 68 | use_torchvison_dataset = False 69 | if 'cifar10' in args.dataset: 70 | use_torchvison_dataset = True 71 | args.dataset = 'cifar10_train' 72 | 73 | if image_size[0] <= 64: 74 | train_batch_size = 32 75 | elif image_size[0] > 64: 76 | train_batch_size = 16 77 | 78 | 79 | diffusion = GaussianDiffusion( 80 | model, 81 | image_size = image_size, 82 | device_of_kernel = 'cuda', 83 | channels = 3, 84 | one_shot_denoise_fn=model_one_shot, 85 | timesteps = args.time_steps, # number of steps 86 | loss_type = args.loss_type, # L1 or L2 87 | train_routine = args.train_routine, 88 | sampling_routine = args.sampling_routine, 89 | forward_process_type=args.forward_process_type, 90 | decolor_routine=args.decolor_routine, 91 | decolor_ema_factor=args.decolor_ema_factor, 92 | decolor_total_remove=args.decolor_total_remove, 93 | snow_level=args.snow_level, 94 | single_snow=args.single_snow, 95 | batch_size=train_batch_size, 96 | random_snow=args.random_snow, 97 | to_lab=args.to_lab, 98 | load_path=args.load_path, 99 | results_folder=args.save_folder, 100 | fix_brightness=args.fix_brightness, 101 | ).cuda() 102 | 103 | trainer = Trainer( 104 | diffusion, 105 | args.dataset_folder, 106 | image_size = image_size, 107 | train_batch_size = train_batch_size, 108 | train_lr = 2e-5, 109 | train_num_steps = args.train_steps, # total training steps 110 | gradient_accumulate_every = 2, # gradient accumulation steps 111 | ema_decay = 0.995, # exponential moving average decay 112 | fp16 = False, # turn on mixed precision training with apex 113 | results_folder = args.save_folder, 114 | load_path = args.load_path, 115 | random_aug=args.random_aug, 116 | torchvision_dataset=use_torchvison_dataset, 117 | dataset = f'{args.dataset}', 118 | to_lab=args.to_lab, 119 | ) 120 | 121 | trainer.train() 122 | -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/mnist_test.py: -------------------------------------------------------------------------------- 1 | from pycave.bayes import GaussianMixture 2 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | from Fid import calculate_fid_given_samples 9 | 10 | def create_folder(path): 11 | try: 12 | os.mkdir(path) 13 | except OSError as exc: 14 | if exc.errno != errno.EEXIST: 15 | raise 16 | pass 17 | 18 | def del_folder(path): 19 | try: 20 | shutil.rmtree(path) 21 | except OSError as exc: 22 | pass 23 | 24 | create = 0 25 | 26 | if create: 27 | trainset = torchvision.datasets.MNIST( 28 | root='./data', train=False, download=True) 29 | root = './root_mnist_test/' 30 | del_folder(root) 31 | create_folder(root) 32 | 33 | for i in range(10): 34 | lable_root = root + str(i) + '/' 35 | create_folder(lable_root) 36 | 37 | for idx in range(len(trainset)): 38 | img, label = trainset[idx] 39 | print(idx) 40 | img.save(root + str(label) + '/' + str(idx) + '.png') 41 | 42 | 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--time_steps', default=50, type=int, 45 | help="This is the number of steps in which a clean image looses information.") 46 | parser.add_argument('--sample_steps', default=None, type=int) 47 | parser.add_argument('--blur_std', default=0.1, type=float, 48 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.') 49 | parser.add_argument('--blur_size', default=3, type=int, 50 | help='It sets the size of gaussian blur used in blur routines for each step t') 51 | parser.add_argument('--save_folder', default='./results_mnist', type=str) 52 | parser.add_argument('--load_path', default=None, type=str) 53 | parser.add_argument('--data_path', default='./root_mnist/', type=str) 54 | parser.add_argument('--test_type', default='train_data', type=str) 55 | parser.add_argument('--blur_routine', default='Incremental', type=str, 56 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail') 57 | parser.add_argument('--train_routine', default='Final', type=str) 58 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str, 59 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2') 60 | parser.add_argument('--gmm_size', default=8, type=int) 61 | parser.add_argument('--gmm_cluster', default=10, type=int) 62 | 63 | args = parser.parse_args() 64 | print(args) 65 | 66 | img_path=None 67 | if 'train' in args.test_type: 68 | img_path = args.data_path 69 | elif 'test' in args.test_type: 70 | img_path = args.data_path 71 | 72 | print("Img Path is ", img_path) 73 | 74 | 75 | model = Unet( 76 | dim = 64, 77 | dim_mults = (1, 2, 4, 8), 78 | channels=1 79 | ).cuda() 80 | 81 | 82 | diffusion = GaussianDiffusion( 83 | model, 84 | image_size = 32, 85 | device_of_kernel = 'cuda', 86 | channels = 1, 87 | timesteps = args.time_steps, # number of steps 88 | loss_type = 'l1', # L1 or L2 89 | kernel_std=args.blur_std, 90 | kernel_size=args.blur_size, 91 | blur_routine=args.blur_routine, 92 | train_routine=args.train_routine, 93 | sampling_routine = args.sampling_routine 94 | ).cuda() 95 | 96 | import torch 97 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 98 | 99 | trainer = Trainer( 100 | diffusion, 101 | img_path, 102 | image_size = 32, 103 | train_batch_size = 32, 104 | train_lr = 2e-5, 105 | train_num_steps = 700000, # total training steps 106 | gradient_accumulate_every = 2, # gradient accumulation steps 107 | ema_decay = 0.995, # exponential moving average decay 108 | fp16 = False, # turn on mixed precision training with apex 109 | results_folder = args.save_folder, 110 | load_path = args.load_path 111 | ) 112 | 113 | if args.test_type == 'train_data': 114 | trainer.test_from_data('train', s_times=args.sample_steps) 115 | 116 | elif args.test_type == 'test_data': 117 | trainer.test_from_data('test', s_times=args.sample_steps) 118 | 119 | elif args.test_type == 'train_save_orig_data_same_as_trained': 120 | trainer.save_training_data() 121 | 122 | elif args.test_type == 'test_save_orig_data_same_as_tested': 123 | trainer.save_training_data() 124 | 125 | elif args.test_type == 'test_fid_distance_decrease_from_manifold': 126 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 127 | 128 | elif args.test_type == 'train_fid_distance_decrease_from_manifold': 129 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 130 | -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/cifar10_test.py: -------------------------------------------------------------------------------- 1 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | import torch 8 | import random 9 | from Fid import calculate_fid_given_samples 10 | 11 | seed_value=123457 12 | torch.manual_seed(seed_value) # cpu vars 13 | random.seed(seed_value) # Python 14 | torch.cuda.manual_seed(seed_value) 15 | torch.cuda.manual_seed_all(seed_value) 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = False 18 | 19 | def create_folder(path): 20 | try: 21 | os.mkdir(path) 22 | except OSError as exc: 23 | if exc.errno != errno.EEXIST: 24 | raise 25 | pass 26 | 27 | def del_folder(path): 28 | try: 29 | shutil.rmtree(path) 30 | except OSError as exc: 31 | pass 32 | 33 | create = 0 34 | 35 | if create: 36 | trainset = torchvision.datasets.CIFAR10( 37 | root='./data', train=False, download=True) 38 | root = './root_cifar10_test/' 39 | del_folder(root) 40 | create_folder(root) 41 | 42 | for i in range(10): 43 | lable_root = root + str(i) + '/' 44 | create_folder(lable_root) 45 | 46 | for idx in range(len(trainset)): 47 | img, label = trainset[idx] 48 | print(idx) 49 | img.save(root + str(label) + '/' + str(idx) + '.png') 50 | 51 | 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('--time_steps', default=50, type=int, 54 | help="This is the number of steps in which a clean image looses information.") 55 | parser.add_argument('--sample_steps', default=None, type=int) 56 | parser.add_argument('--blur_std', default=0.1, type=float, 57 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.') 58 | parser.add_argument('--blur_size', default=3, type=int, 59 | help='It sets the size of gaussian blur used in blur routines for each step t') 60 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 61 | parser.add_argument('--data_path', default='./root_cifar10/', type=str) 62 | parser.add_argument('--load_path', default=None, type=str) 63 | parser.add_argument('--test_type', default='train_data', type=str) 64 | parser.add_argument('--blur_routine', default='Incremental', type=str, 65 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail') 66 | parser.add_argument('--train_routine', default='Final', type=str) 67 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str, 68 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2') 69 | parser.add_argument('--remove_time_embed', action="store_true") 70 | parser.add_argument('--residual', action="store_true") 71 | 72 | 73 | args = parser.parse_args() 74 | print(args) 75 | 76 | img_path=None 77 | if args.test_type == 'train_distribution_cov_vector': 78 | img_path = args.data_path 79 | elif 'train' in args.test_type: 80 | img_path = args.data_path 81 | elif 'test' in args.test_type: 82 | img_path = args.data_path 83 | 84 | print("Img Path is ", img_path) 85 | 86 | 87 | model = Model(resolution=32, 88 | in_channels=3, 89 | out_ch=3, 90 | ch=128, 91 | ch_mult=(1,2,2,2), 92 | num_res_blocks=2, 93 | attn_resolutions=(16,), 94 | dropout=0.01).cuda() 95 | 96 | 97 | diffusion = GaussianDiffusion( 98 | model, 99 | image_size = 32, 100 | device_of_kernel = 'cuda', 101 | channels = 3, 102 | timesteps = args.time_steps, # number of steps 103 | loss_type = 'l1', # L1 or L2 104 | kernel_std=args.blur_std, 105 | kernel_size=args.blur_size, 106 | blur_routine=args.blur_routine, 107 | train_routine=args.train_routine, 108 | sampling_routine = args.sampling_routine 109 | ).cuda() 110 | 111 | trainer = Trainer( 112 | diffusion, 113 | img_path, 114 | image_size = 32, 115 | train_batch_size = 32, 116 | train_lr = 2e-5, 117 | train_num_steps = 700000, # total training steps 118 | gradient_accumulate_every = 2, # gradient accumulation steps 119 | ema_decay = 0.995, # exponential moving average decay 120 | fp16 = False, # turn on mixed precision training with apex 121 | results_folder = args.save_folder, 122 | load_path = args.load_path, 123 | shuffle=True 124 | ) 125 | 126 | import torch 127 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 128 | 129 | 130 | if args.test_type == 'train_data': 131 | trainer.test_from_data('train', s_times=args.sample_steps) 132 | 133 | elif args.test_type == 'test_data': 134 | trainer.test_from_data('test', s_times=args.sample_steps) 135 | 136 | elif args.test_type == 'test_data_save_results': 137 | trainer.test_from_data_save_results('test') 138 | 139 | elif args.test_type == 'test_fid_distance_decrease_from_manifold': 140 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 141 | 142 | elif args.test_type == 'train_fid_distance_decrease_from_manifold': 143 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 144 | -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/cifar10_test.py: -------------------------------------------------------------------------------- 1 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | import torch 8 | import random 9 | from Fid import calculate_fid_given_samples 10 | 11 | seed_value=123457 12 | torch.manual_seed(seed_value) # cpu vars 13 | random.seed(seed_value) # Python 14 | torch.cuda.manual_seed(seed_value) 15 | torch.cuda.manual_seed_all(seed_value) 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = False 18 | 19 | def create_folder(path): 20 | try: 21 | os.mkdir(path) 22 | except OSError as exc: 23 | if exc.errno != errno.EEXIST: 24 | raise 25 | pass 26 | 27 | def del_folder(path): 28 | try: 29 | shutil.rmtree(path) 30 | except OSError as exc: 31 | pass 32 | 33 | create = 0 34 | 35 | if create: 36 | trainset = torchvision.datasets.CIFAR10( 37 | root='./data', train=False, download=True) 38 | root = './root_cifar10_test/' 39 | del_folder(root) 40 | create_folder(root) 41 | 42 | for i in range(10): 43 | lable_root = root + str(i) + '/' 44 | create_folder(lable_root) 45 | 46 | for idx in range(len(trainset)): 47 | img, label = trainset[idx] 48 | print(idx) 49 | img.save(root + str(label) + '/' + str(idx) + '.png') 50 | 51 | 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('--time_steps', default=50, type=int) 54 | parser.add_argument('--sample_steps', default=None, type=int) 55 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 56 | parser.add_argument('--load_path', default=None, type=str) 57 | parser.add_argument('--data_path', default='./root_cifar10_test/', type=str) 58 | parser.add_argument('--test_type', default='train_data', type=str) 59 | parser.add_argument('--resolution_routine', default='Incremental', type=str) 60 | parser.add_argument('--train_routine', default='Final', type=str) 61 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 62 | parser.add_argument('--remove_time_embed', action="store_true") 63 | parser.add_argument('--residual', action="store_true") 64 | 65 | 66 | args = parser.parse_args() 67 | print(args) 68 | 69 | img_path=None 70 | if args.test_type == 'train_distribution_cov_vector': 71 | img_path = args.data_path 72 | elif 'train' in args.test_type: 73 | img_path = args.data_path 74 | elif 'test' in args.test_type: 75 | img_path = args.data_path 76 | 77 | print("Img Path is ", img_path) 78 | 79 | 80 | model = Model(resolution=32, 81 | in_channels=3, 82 | out_ch=3, 83 | ch=128, 84 | ch_mult=(1,2,2,2), 85 | num_res_blocks=2, 86 | attn_resolutions=(16,), 87 | dropout=0.01).cuda() 88 | 89 | 90 | diffusion = GaussianDiffusion( 91 | model, 92 | image_size = 32, 93 | device_of_kernel = 'cuda', 94 | channels = 3, 95 | timesteps = args.time_steps, # number of steps 96 | loss_type = 'l1', # L1 or L2 97 | resolution_routine=args.resolution_routine, 98 | train_routine=args.train_routine, 99 | sampling_routine = args.sampling_routine 100 | ).cuda() 101 | 102 | trainer = Trainer( 103 | diffusion, 104 | img_path, 105 | image_size = 32, 106 | train_batch_size = 32, 107 | train_lr = 2e-5, 108 | train_num_steps = 700000, # total training steps 109 | gradient_accumulate_every = 2, # gradient accumulation steps 110 | ema_decay = 0.995, # exponential moving average decay 111 | fp16 = False, # turn on mixed precision training with apex 112 | results_folder = args.save_folder, 113 | load_path = args.load_path, 114 | shuffle=True 115 | ) 116 | 117 | if args.test_type == 'train_data': 118 | trainer.test_from_data('train', s_times=args.sample_steps) 119 | 120 | elif args.test_type == 'train_data_dropout': 121 | trainer.test_from_data_dropout('train_drop', s_times=args.sample_steps) 122 | 123 | elif args.test_type == 'test_data_dropout': 124 | trainer.test_from_data_dropout('test_drop', s_times=args.sample_steps) 125 | 126 | elif args.test_type == 'test_data': 127 | trainer.test_from_data('test', s_times=args.sample_steps) 128 | 129 | elif args.test_type == 'mixup_train_data': 130 | trainer.test_with_mixup('train') 131 | 132 | elif args.test_type == 'mixup_test_data': 133 | trainer.test_with_mixup('test') 134 | 135 | elif args.test_type == 'test_random': 136 | trainer.test_from_random('test_random') 137 | 138 | elif args.test_type == 'train_random': 139 | trainer.test_from_random('train_random') 140 | 141 | elif args.test_type == 'train_distribution_cov_vector': 142 | trainer.sample_as_a_vector_cov(start=0, end=None, siz=4, ch=3) 143 | 144 | elif args.test_type == 'test_fid_distance_decrease_from_manifold': 145 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 146 | 147 | elif args.test_type == 'train_fid_distance_decrease_from_manifold': 148 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 149 | 150 | elif args.test_type == 'test_paper_invert_section_images': 151 | trainer.paper_invert_section_images() 152 | 153 | elif args.test_type == 'test_paper_showing_diffusion_images': 154 | trainer.paper_showing_diffusion_images() 155 | 156 | elif args.test_type == 'test_paper_showing_diffusion_imgs_og': 157 | trainer.paper_showing_diffusion_imgs_og() 158 | -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/AFHQ_128_test.py: -------------------------------------------------------------------------------- 1 | from pycave.bayes import GaussianMixture 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | from Fid import calculate_fid_given_samples 8 | import torch 9 | import random 10 | 11 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 12 | from gmm_pycave import GaussianMixture as GaussianMixture 13 | 14 | seed_value=123457 15 | torch.manual_seed(seed_value) # cpu vars 16 | random.seed(seed_value) # Python 17 | torch.cuda.manual_seed(seed_value) 18 | torch.cuda.manual_seed_all(seed_value) 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | 22 | def create_folder(path): 23 | try: 24 | os.mkdir(path) 25 | except OSError as exc: 26 | if exc.errno != errno.EEXIST: 27 | raise 28 | pass 29 | 30 | def del_folder(path): 31 | try: 32 | shutil.rmtree(path) 33 | except OSError as exc: 34 | pass 35 | 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--time_steps', default=50, type=int, 39 | help="This is the number of steps in which a clean image looses information.") 40 | parser.add_argument('--sample_steps', default=None, type=int) 41 | parser.add_argument('--blur_std', default=0.1, type=float, 42 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.') 43 | parser.add_argument('--blur_size', default=3, type=int, 44 | help='It sets the size of gaussian blur used in blur routines for each step t') 45 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 46 | parser.add_argument('--data_path', default='./AFHQ/afhq/train/', type=str) 47 | parser.add_argument('--load_path', default=None, type=str) 48 | parser.add_argument('--test_type', default='train_data', type=str) 49 | parser.add_argument('--blur_routine', default='Incremental', type=str, 50 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail') 51 | parser.add_argument('--train_routine', default='Final', type=str) 52 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str, 53 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2') 54 | parser.add_argument('--remove_time_embed', action="store_true") 55 | parser.add_argument('--residual', action="store_true") 56 | parser.add_argument('--gmm_size', default=8, type=int) 57 | parser.add_argument('--gmm_cluster', default=10, type=int) 58 | parser.add_argument('--gmm_sample_at', default=1, type=int) 59 | parser.add_argument('--bs', default=32, type=int) 60 | parser.add_argument('--discrete', action="store_true") 61 | parser.add_argument('--noise', default=0, type=float) 62 | 63 | args = parser.parse_args() 64 | print(args) 65 | 66 | img_path=None 67 | if 'train' in args.test_type: 68 | img_path = args.data_path 69 | elif 'test' in args.test_type: 70 | img_path = args.data_path 71 | 72 | print("Img Path is ", img_path) 73 | 74 | 75 | model = Unet( 76 | dim = 64, 77 | dim_mults = (1, 2, 4, 8), 78 | channels=3, 79 | with_time_emb=not(args.remove_time_embed), 80 | residual=args.residual 81 | ).cuda() 82 | 83 | 84 | diffusion = GaussianDiffusion( 85 | model, 86 | image_size = 128, 87 | device_of_kernel = 'cuda', 88 | channels = 3, 89 | timesteps = args.time_steps, # number of steps 90 | loss_type = 'l1', # L1 or L2 91 | kernel_std=args.blur_std, 92 | kernel_size=args.blur_size, 93 | blur_routine=args.blur_routine, 94 | train_routine=args.train_routine, 95 | sampling_routine = args.sampling_routine, 96 | discrete=args.discrete 97 | ).cuda() 98 | 99 | import torch 100 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 101 | 102 | 103 | trainer = Trainer( 104 | diffusion, 105 | img_path, 106 | image_size = 128, 107 | train_batch_size = args.bs, 108 | train_lr = 2e-5, 109 | train_num_steps = 700000, # total training steps 110 | gradient_accumulate_every = 2, # gradient accumulation steps 111 | ema_decay = 0.995, # exponential moving average decay 112 | fp16 = False, # turn on mixed precision training with apex 113 | results_folder = args.save_folder, 114 | load_path = args.load_path, 115 | dataset = None, 116 | shuffle=False 117 | ) 118 | 119 | if args.test_type == 'train_data': 120 | trainer.test_from_data('train', s_times=args.sample_steps) 121 | 122 | elif args.test_type == 'test_data': 123 | trainer.test_from_data('test', s_times=args.sample_steps) 124 | 125 | 126 | 127 | # this for extreme case and we perform gmm at the meaned vectors and test with what happens with noise 128 | elif args.test_type == 'train_distribution_mean_blur_torch_gmm': 129 | trainer.sample_as_a_mean_blur_torch_gmm(GaussianMixture, start=0, end=26900, ch=3, clusters=args.gmm_cluster) 130 | 131 | # to save 132 | elif args.test_type == 'train_distribution_mean_blur_torch_gmm_ablation': 133 | trainer.sample_as_a_mean_blur_torch_gmm_ablation(GaussianMixture, ch=3, clusters=args.gmm_cluster, noise=args.noise) 134 | 135 | 136 | # this for the non extreme case and you can create gmm at any point in the blur sequence and then blur from there 137 | elif args.test_type == 'train_distribution_blur_torch_gmm': 138 | trainer.sample_as_a_blur_torch_gmm(GaussianMixture, start=0, end=26900, ch=3, clusters=args.gmm_cluster) 139 | 140 | 141 | ######### Quant ############## 142 | elif args.test_type == 'test_fid_distance_decrease_from_manifold': 143 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 144 | 145 | elif args.test_type == 'train_fid_distance_decrease_from_manifold': 146 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 147 | 148 | ############################### 149 | -------------------------------------------------------------------------------- /defading-diffusion-pytorch/inpainting_paper_train.sh: -------------------------------------------------------------------------------- 1 | #python defading-diffusion-pytorch/cifar10_train.py --time_steps 50 --save_folder ./cifar_inpainting --discrete --sampling_routine x0_step_down --train_steps 700000 --blur_std 0.1 --fade_routine Random_Incremental 2 | 3 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 50 --save_folder ./celebA_inpainting --discrete --sampling_routine x0_step_down --train_steps 700000 --blur_std 0.1 --fade_routine Random_Incremental 4 | 5 | #python defading-diffusion-pytorch/mnist_train.py --time_steps 50 --save_folder ./mnist_inpainting --discrete --sampling_routine x0_step_down --train_steps 700000 --blur_std 0.1 --fade_routine Random_Incremental 6 | 7 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 200 --save_folder ./celebA_128_200_step --discrete --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.1 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_200_step/model.pt 8 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --save_folder ./celebA_128_100_step --discrete --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_100_step/model.pt 9 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 50 --save_folder ./celebA_128_50_step --discrete --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.4 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_50_step/model.pt 10 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --save_folder ./celebA_128_100_step_data_augmentation --discrete --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --fade_routine Random_Incremental --initial_mask 11 --image_size 128 --dataset celebA --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_100_step_data_augmentation/model.pt 11 | 12 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 200 --save_folder ./celebA_128_200_step_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.1 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_200_step_float/model.pt 13 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --save_folder ./celebA_128_100_step_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_100_step_float/model.pt 14 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 50 --save_folder ./celebA_128_50_step_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.4 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_50_step_float/model.pt 15 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --save_folder ./celebA_128_100_step_data_augmentation_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --fade_routine Random_Incremental --initial_mask 11 --image_size 128 --dataset celebA --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_100_step_data_augmentation_float/model.pt 16 | 17 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 200 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_200_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 18 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_100_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 19 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 50 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_50_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 20 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_100_step_ker_0.2_data_augmentation_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 --dataset celebA 21 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 125 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_125_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 22 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 150 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_150_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 23 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_100_step_ker_0.2_data_augmentation --discrete --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 --dataset celebA 24 | 25 | python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_100_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/centered_mask_celebA_128_100_step_ker_0.2_float/model.pt 26 | python defading-diffusion-pytorch/celebA_train.py --time_steps 50 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_50_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/centered_mask_celebA_128_50_step_ker_0.2_float/model.pt 27 | python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_100_step_ker_0.2_data_augmentation_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 --dataset celebA --load_path /cmlscratch/eborgnia/cold_diffusion/centered_mask_celebA_128_100_step_ker_0.2_data_augmentation_float/model.pt -------------------------------------------------------------------------------- /deblurring-diffusion-pytorch/celebA_128_test.py: -------------------------------------------------------------------------------- 1 | from pycave.bayes import GaussianMixture 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | from Fid import calculate_fid_given_samples 8 | import torch 9 | import random 10 | 11 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 12 | from gmm_pycave import GaussianMixture as GaussianMixture 13 | 14 | seed_value=123457 15 | torch.manual_seed(seed_value) # cpu vars 16 | random.seed(seed_value) # Python 17 | torch.cuda.manual_seed(seed_value) 18 | torch.cuda.manual_seed_all(seed_value) 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | 22 | def create_folder(path): 23 | try: 24 | os.mkdir(path) 25 | except OSError as exc: 26 | if exc.errno != errno.EEXIST: 27 | raise 28 | pass 29 | 30 | def del_folder(path): 31 | try: 32 | shutil.rmtree(path) 33 | except OSError as exc: 34 | pass 35 | 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--time_steps', default=50, type=int, 39 | help="This is the number of steps in which a clean image looses information.") 40 | parser.add_argument('--sample_steps', default=None, type=int) 41 | parser.add_argument('--blur_std', default=0.1, type=float, 42 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.') 43 | parser.add_argument('--blur_size', default=3, type=int, 44 | help='It sets the size of gaussian blur used in blur routines for each step t') 45 | parser.add_argument('--save_folder', default='./results_cifar10', type=str) 46 | parser.add_argument('--data_path', default='./root_celebA_128_train_new/', type=str) 47 | parser.add_argument('--load_path', default=None, type=str) 48 | parser.add_argument('--test_type', default='train_data', type=str) 49 | parser.add_argument('--blur_routine', default='Incremental', type=str, 50 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail') 51 | parser.add_argument('--train_routine', default='Final', type=str) 52 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str, 53 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2') 54 | parser.add_argument('--remove_time_embed', action="store_true") 55 | parser.add_argument('--residual', action="store_true") 56 | parser.add_argument('--gmm_size', default=8, type=int) 57 | parser.add_argument('--gmm_cluster', default=10, type=int) 58 | parser.add_argument('--gmm_sample_at', default=1, type=int) 59 | parser.add_argument('--bs', default=32, type=int) 60 | parser.add_argument('--discrete', action="store_true") 61 | parser.add_argument('--noise', default=0, type=float) 62 | 63 | args = parser.parse_args() 64 | print(args) 65 | 66 | img_path=None 67 | if 'train' in args.test_type: 68 | img_path = args.data_path 69 | elif 'test' in args.test_type: 70 | img_path = args.data_path 71 | 72 | print("Img Path is ", img_path) 73 | 74 | 75 | model = Unet( 76 | dim = 64, 77 | dim_mults = (1, 2, 4, 8), 78 | channels=3, 79 | with_time_emb=not(args.remove_time_embed), 80 | residual=args.residual 81 | ).cuda() 82 | 83 | 84 | diffusion = GaussianDiffusion( 85 | model, 86 | image_size = 128, 87 | device_of_kernel = 'cuda', 88 | channels = 3, 89 | timesteps = args.time_steps, # number of steps 90 | loss_type = 'l1', # L1 or L2 91 | kernel_std=args.blur_std, 92 | kernel_size=args.blur_size, 93 | blur_routine=args.blur_routine, 94 | train_routine=args.train_routine, 95 | sampling_routine = args.sampling_routine, 96 | discrete=args.discrete 97 | ).cuda() 98 | 99 | import torch 100 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) 101 | 102 | 103 | trainer = Trainer( 104 | diffusion, 105 | img_path, 106 | image_size = 128, 107 | train_batch_size = args.bs, 108 | train_lr = 2e-5, 109 | train_num_steps = 700000, # total training steps 110 | gradient_accumulate_every = 2, # gradient accumulation steps 111 | ema_decay = 0.995, # exponential moving average decay 112 | fp16 = False, # turn on mixed precision training with apex 113 | results_folder = args.save_folder, 114 | load_path = args.load_path, 115 | dataset = None, 116 | shuffle=False 117 | ) 118 | 119 | if args.test_type == 'train_data': 120 | trainer.test_from_data('train', s_times=args.sample_steps) 121 | 122 | elif args.test_type == 'test_data': 123 | trainer.test_from_data('test', s_times=args.sample_steps) 124 | 125 | 126 | 127 | # this for extreme case and we perform gmm at the meaned vectors and test with what happens with noise 128 | elif args.test_type == 'train_distribution_mean_blur_torch_gmm': 129 | trainer.sample_as_a_mean_blur_torch_gmm(GaussianMixture, start=0, end=26900, ch=3, clusters=args.gmm_cluster) 130 | 131 | # to save 132 | elif args.test_type == 'train_distribution_mean_blur_torch_gmm_ablation': 133 | trainer.sample_as_a_mean_blur_torch_gmm_ablation(GaussianMixture, ch=3, clusters=args.gmm_cluster, noise=args.noise) 134 | 135 | 136 | # this for the non extreme case and you can create gmm at any point in the blur sequence and then blur from there 137 | elif args.test_type == 'train_distribution_blur_torch_gmm': 138 | trainer.sample_as_a_blur_torch_gmm(GaussianMixture, siz=args.gmm_size, ch=3, clusters=args.gmm_cluster, sample_at=args.gmm_sample_at) 139 | 140 | 141 | ######### Quant ############## 142 | elif args.test_type == 'test_fid_distance_decrease_from_manifold': 143 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 144 | 145 | elif args.test_type == 'train_fid_distance_decrease_from_manifold': 146 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 147 | 148 | 149 | ########## for paper ########## 150 | 151 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page': 152 | trainer.paper_showing_diffusion_images_cover_page() 153 | 154 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page_both_sampling': 155 | trainer.paper_showing_diffusion_images_cover_page_both_sampling() 156 | 157 | 158 | -------------------------------------------------------------------------------- /snowification/test.py: -------------------------------------------------------------------------------- 1 | from diffusion import GaussianDiffusion, Trainer, get_dataset 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | from diffusion.model.get_model import get_model 8 | from Fid.fid_score import calculate_fid_given_samples 9 | import torch 10 | import random 11 | 12 | 13 | 14 | def create_folder(path): 15 | try: 16 | os.mkdir(path) 17 | except OSError as exc: 18 | if exc.errno != errno.EEXIST: 19 | raise 20 | pass 21 | 22 | def del_folder(path): 23 | try: 24 | shutil.rmtree(path) 25 | except OSError as exc: 26 | pass 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--time_steps', default=50, type=int) 30 | parser.add_argument('--sample_steps', default=None, type=int) 31 | parser.add_argument('--save_folder_train', default='/cmlscratch/hmchu/cold_diff/', type=str) 32 | parser.add_argument('--save_folder_test', default='/cmlscratch/hmchu/cold_diff_paper_test/', type=str) 33 | parser.add_argument('--load_path', default=None, type=str) 34 | parser.add_argument('--test_type', default='train_data', type=str) 35 | parser.add_argument('--train_routine', default='Final', type=str) 36 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 37 | parser.add_argument('--remove_time_embed', action="store_true") 38 | 39 | parser.add_argument('--exp_name', default='', type=str) 40 | parser.add_argument('--test_postfix', default='', type=str) 41 | parser.add_argument('--dataset_folder', default='./root_cifar10', type=str) 42 | parser.add_argument('--output_mean_scale', action='store_true') 43 | parser.add_argument('--random_aug', action='store_true') 44 | parser.add_argument('--model', default='UnetConvNext', type=str) 45 | parser.add_argument('--dataset', default='cifar10') 46 | 47 | parser.add_argument('--forward_process_type', default='GaussianBlur') 48 | # GaussianBlur args 49 | 50 | # Decolor args 51 | parser.add_argument('--decolor_routine', default='Constant') 52 | parser.add_argument('--decolor_ema_factor', default=0.9, type=float) 53 | parser.add_argument('--decolor_total_remove', action='store_true') 54 | 55 | parser.add_argument('--load_model_steps', default=-1, type=int) 56 | parser.add_argument('--resume_training', action='store_true') 57 | 58 | parser.add_argument('--order_seed', default=-1.0, type=float) 59 | parser.add_argument('--resolution', default=-1, type=int) 60 | 61 | parser.add_argument('--to_lab', action='store_true') 62 | parser.add_argument('--snow_level', default=1, type=int) 63 | parser.add_argument('--random_snow', action='store_true') 64 | parser.add_argument('--single_snow', action='store_true') 65 | parser.add_argument('--fix_brightness', action='store_true') 66 | 67 | parser.add_argument('--test_fid', action='store_true') 68 | 69 | 70 | 71 | 72 | args = parser.parse_args() 73 | assert len(args.exp_name) > 0 74 | 75 | #if args.test_type == 'test_paper': 76 | # args.save_folder_test = '/cmlscratch/hmchu/cold_diff_paper_test/' 77 | 78 | if args.load_model_steps != -1: 79 | args.load_path = os.path.join(args.save_folder_train, args.exp_name, f'model_{args.load_model_steps}.pt') 80 | else: 81 | args.load_path = os.path.join(args.save_folder_train, args.exp_name, 'model.pt') 82 | 83 | if args.test_postfix != '': 84 | save_folder_name = f'{args.exp_name}_{args.test_postfix}' 85 | else: 86 | save_folder_name = args.exp_name 87 | args.save_folder_test = os.path.join(args.save_folder_test, save_folder_name, args.test_type) 88 | print(args.save_folder_test) 89 | print(args) 90 | 91 | 92 | img_path = args.dataset_folder 93 | 94 | with_time_emb = not args.remove_time_embed 95 | 96 | model = get_model(args, with_time_emb=with_time_emb).cuda() 97 | model_one_shot = get_model(args, with_time_emb=False).cuda() 98 | 99 | image_size = get_dataset.get_image_size(args.dataset) 100 | if args.resolution != -1: 101 | image_size = (args.resolution, args.resolution) 102 | print(f'image_size: {image_size}') 103 | 104 | use_torchvison_dataset = False 105 | if image_size[0] <= 64: 106 | train_batch_size = 32 107 | elif image_size[0] > 64: 108 | train_batch_size = 16 109 | 110 | print(args.dataset) 111 | 112 | seed_value=args.order_seed 113 | torch.manual_seed(seed_value) # cpu vars 114 | random.seed(seed_value) # Python 115 | torch.cuda.manual_seed(seed_value) 116 | torch.cuda.manual_seed_all(seed_value) 117 | torch.backends.cudnn.deterministic = True 118 | torch.backends.cudnn.benchmark = False 119 | 120 | 121 | 122 | 123 | diffusion = GaussianDiffusion( 124 | model, 125 | image_size = image_size, 126 | device_of_kernel = 'cuda', 127 | channels = 3, 128 | one_shot_denoise_fn=model_one_shot, 129 | timesteps = args.time_steps, # number of steps 130 | loss_type = 'l1', # L1 or L2 131 | train_routine=args.train_routine, 132 | sampling_routine = args.sampling_routine, 133 | forward_process_type=args.forward_process_type, 134 | decolor_routine=args.decolor_routine, 135 | decolor_ema_factor=args.decolor_ema_factor, 136 | decolor_total_remove=args.decolor_total_remove, 137 | snow_level=args.snow_level, 138 | random_snow=args.random_snow, 139 | single_snow=args.single_snow, 140 | batch_size=train_batch_size, 141 | to_lab=args.to_lab, 142 | load_snow_base=False, 143 | fix_brightness=args.fix_brightness, 144 | load_path = args.load_path, 145 | results_folder = args.save_folder_test, 146 | ).cuda() 147 | 148 | trainer = Trainer( 149 | diffusion, 150 | img_path, 151 | image_size = image_size, 152 | train_batch_size = train_batch_size, 153 | train_lr = 2e-5, 154 | train_num_steps = 700000, # total training steps 155 | gradient_accumulate_every = 2, # gradient accumulation steps 156 | ema_decay = 0.995, # exponential moving average decay 157 | fp16 = False, # turn on mixed precision training with apex 158 | results_folder = args.save_folder_test, 159 | load_path = args.load_path, 160 | random_aug=args.random_aug, 161 | torchvision_dataset=use_torchvison_dataset, 162 | dataset=f'{args.dataset}', 163 | order_seed=args.order_seed, 164 | to_lab=args.to_lab, 165 | ) 166 | 167 | if args.test_type == 'train_data': 168 | trainer.test_from_data('train', s_times=args.sample_steps) 169 | if args.test_fid: 170 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=1000) 171 | 172 | elif args.test_type == 'test_data': 173 | trainer.test_from_data('test', s_times=args.sample_steps) 174 | if args.test_fid: 175 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=2000) 176 | 177 | elif args.test_type == 'test_paper': 178 | trainer.paper_invert_section_images() 179 | if args.test_fid: 180 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=1000) 181 | 182 | elif args.test_type == 'test_paper_series': 183 | trainer.paper_showing_diffusion_images() 184 | 185 | elif args.test_type == 'test_rebuttal': 186 | trainer.paper_showing_diffusion_images_cover_page() 187 | 188 | -------------------------------------------------------------------------------- /decolor-diffusion/test.py: -------------------------------------------------------------------------------- 1 | from diffusion import GaussianDiffusion, Trainer, get_dataset 2 | import torchvision 3 | import os 4 | import errno 5 | import shutil 6 | import argparse 7 | from diffusion.model.get_model import get_model 8 | from Fid.fid_score import calculate_fid_given_samples 9 | import torch 10 | import random 11 | 12 | 13 | 14 | def create_folder(path): 15 | try: 16 | os.mkdir(path) 17 | except OSError as exc: 18 | if exc.errno != errno.EEXIST: 19 | raise 20 | pass 21 | 22 | def del_folder(path): 23 | try: 24 | shutil.rmtree(path) 25 | except OSError as exc: 26 | pass 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--time_steps', default=50, type=int) 30 | parser.add_argument('--sample_steps', default=None, type=int) 31 | parser.add_argument('--save_folder_train', default='/cmlscratch/hmchu/cold_diff/', type=str) 32 | parser.add_argument('--save_folder_test', default='/cmlscratch/hmchu/cold_diff_paper_test/', type=str) 33 | parser.add_argument('--load_path', default=None, type=str) 34 | parser.add_argument('--test_type', default='train_data', type=str) 35 | parser.add_argument('--train_routine', default='Final', type=str) 36 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 37 | parser.add_argument('--remove_time_embed', action="store_true") 38 | 39 | parser.add_argument('--exp_name', default='', type=str) 40 | parser.add_argument('--test_postfix', default='', type=str) 41 | parser.add_argument('--dataset_folder', default='./root_cifar10', type=str) 42 | parser.add_argument('--output_mean_scale', action='store_true') 43 | parser.add_argument('--random_aug', action='store_true') 44 | parser.add_argument('--model', default='UnetConvNext', type=str) 45 | parser.add_argument('--dataset', default='cifar10') 46 | 47 | parser.add_argument('--forward_process_type', default='GaussianBlur') 48 | # GaussianBlur args 49 | 50 | # Decolor args 51 | parser.add_argument('--decolor_routine', default='Constant') 52 | parser.add_argument('--decolor_ema_factor', default=0.9, type=float) 53 | parser.add_argument('--decolor_total_remove', action='store_true') 54 | 55 | parser.add_argument('--load_model_steps', default=-1, type=int) 56 | parser.add_argument('--resume_training', action='store_true') 57 | 58 | parser.add_argument('--order_seed', default=-1.0, type=float) 59 | parser.add_argument('--resolution', default=-1, type=int) 60 | 61 | parser.add_argument('--to_lab', action='store_true') 62 | parser.add_argument('--snow_level', default=1, type=int) 63 | parser.add_argument('--random_snow', action='store_true') 64 | parser.add_argument('--single_snow', action='store_true') 65 | parser.add_argument('--fix_brightness', action='store_true') 66 | 67 | parser.add_argument('--test_fid', action='store_true') 68 | 69 | 70 | 71 | 72 | args = parser.parse_args() 73 | assert len(args.exp_name) > 0 74 | 75 | #if args.test_type == 'test_paper': 76 | # args.save_folder_test = '/cmlscratch/hmchu/cold_diff_paper_test/' 77 | 78 | if args.load_model_steps != -1: 79 | args.load_path = os.path.join(args.save_folder_train, args.exp_name, f'model_{args.load_model_steps}.pt') 80 | else: 81 | args.load_path = os.path.join(args.save_folder_train, args.exp_name, 'model.pt') 82 | 83 | if args.test_postfix != '': 84 | save_folder_name = f'{args.exp_name}_{args.test_postfix}' 85 | else: 86 | save_folder_name = args.exp_name 87 | args.save_folder_test = os.path.join(args.save_folder_test, save_folder_name, args.test_type) 88 | print(args.save_folder_test) 89 | print(args) 90 | 91 | 92 | img_path = args.dataset_folder 93 | 94 | with_time_emb = not args.remove_time_embed 95 | 96 | model = get_model(args, with_time_emb=with_time_emb).cuda() 97 | model_one_shot = get_model(args, with_time_emb=False).cuda() 98 | 99 | image_size = get_dataset.get_image_size(args.dataset) 100 | if args.resolution != -1: 101 | image_size = (args.resolution, args.resolution) 102 | print(f'image_size: {image_size}') 103 | 104 | use_torchvison_dataset = False 105 | if image_size[0] <= 64: 106 | train_batch_size = 32 107 | elif image_size[0] > 64: 108 | train_batch_size = 16 109 | 110 | print(args.dataset) 111 | 112 | seed_value=args.order_seed 113 | torch.manual_seed(seed_value) # cpu vars 114 | random.seed(seed_value) # Python 115 | torch.cuda.manual_seed(seed_value) 116 | torch.cuda.manual_seed_all(seed_value) 117 | torch.backends.cudnn.deterministic = True 118 | torch.backends.cudnn.benchmark = False 119 | 120 | 121 | 122 | 123 | diffusion = GaussianDiffusion( 124 | model, 125 | image_size = image_size, 126 | device_of_kernel = 'cuda', 127 | channels = 3, 128 | one_shot_denoise_fn=model_one_shot, 129 | timesteps = args.time_steps, # number of steps 130 | loss_type = 'l1', # L1 or L2 131 | train_routine=args.train_routine, 132 | sampling_routine = args.sampling_routine, 133 | forward_process_type=args.forward_process_type, 134 | decolor_routine=args.decolor_routine, 135 | decolor_ema_factor=args.decolor_ema_factor, 136 | decolor_total_remove=args.decolor_total_remove, 137 | snow_level=args.snow_level, 138 | random_snow=args.random_snow, 139 | single_snow=args.single_snow, 140 | batch_size=train_batch_size, 141 | to_lab=args.to_lab, 142 | load_snow_base=False, 143 | fix_brightness=args.fix_brightness, 144 | load_path = args.load_path, 145 | results_folder = args.save_folder_test, 146 | ).cuda() 147 | 148 | trainer = Trainer( 149 | diffusion, 150 | img_path, 151 | image_size = image_size, 152 | train_batch_size = train_batch_size, 153 | train_lr = 2e-5, 154 | train_num_steps = 700000, # total training steps 155 | gradient_accumulate_every = 2, # gradient accumulation steps 156 | ema_decay = 0.995, # exponential moving average decay 157 | fp16 = False, # turn on mixed precision training with apex 158 | results_folder = args.save_folder_test, 159 | load_path = args.load_path, 160 | random_aug=args.random_aug, 161 | torchvision_dataset=use_torchvison_dataset, 162 | dataset=f'{args.dataset}', 163 | order_seed=args.order_seed, 164 | to_lab=args.to_lab, 165 | ) 166 | 167 | if args.test_type == 'train_data': 168 | trainer.test_from_data('train', s_times=args.sample_steps) 169 | if args.test_fid: 170 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=1000) 171 | 172 | elif args.test_type == 'test_data': 173 | trainer.test_from_data('test', s_times=args.sample_steps) 174 | if args.test_fid: 175 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=2000) 176 | 177 | elif args.test_type == 'test_paper': 178 | trainer.paper_invert_section_images() 179 | if args.test_fid: 180 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=1000) 181 | 182 | elif args.test_type == 'test_paper_series': 183 | trainer.paper_showing_diffusion_images() 184 | 185 | elif args.test_type == 'test_rebuttal': 186 | trainer.paper_showing_diffusion_images_cover_page() 187 | 188 | -------------------------------------------------------------------------------- /resolution-diffusion-pytorch/mnist_test.py: -------------------------------------------------------------------------------- 1 | from pycave.bayes import GaussianMixture 2 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 3 | import torchvision 4 | import os 5 | import errno 6 | import shutil 7 | import argparse 8 | from Fid import calculate_fid_given_samples 9 | 10 | def create_folder(path): 11 | try: 12 | os.mkdir(path) 13 | except OSError as exc: 14 | if exc.errno != errno.EEXIST: 15 | raise 16 | pass 17 | 18 | def del_folder(path): 19 | try: 20 | shutil.rmtree(path) 21 | except OSError as exc: 22 | pass 23 | 24 | create = 0 25 | 26 | if create: 27 | trainset = torchvision.datasets.MNIST( 28 | root='./data', train=False, download=True) 29 | root = './root_mnist_test/' 30 | del_folder(root) 31 | create_folder(root) 32 | 33 | for i in range(10): 34 | lable_root = root + str(i) + '/' 35 | create_folder(lable_root) 36 | 37 | for idx in range(len(trainset)): 38 | img, label = trainset[idx] 39 | print(idx) 40 | img.save(root + str(label) + '/' + str(idx) + '.png') 41 | 42 | 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--time_steps', default=50, type=int) 45 | parser.add_argument('--sample_steps', default=None, type=int) 46 | parser.add_argument('--save_folder', default='./results_mnist', type=str) 47 | parser.add_argument('--load_path', default=None, type=str) 48 | parser.add_argument('--data_path', default='./root_mnist_test/', type=str) 49 | parser.add_argument('--test_type', default='train_data', type=str) 50 | parser.add_argument('--resolution_routine', default='Incremental', type=str) 51 | parser.add_argument('--train_routine', default='Final', type=str) 52 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str) 53 | parser.add_argument('--gmm_size', default=8, type=int) 54 | parser.add_argument('--gmm_cluster', default=10, type=int) 55 | 56 | args = parser.parse_args() 57 | print(args) 58 | 59 | img_path=None 60 | if 'train' in args.test_type: 61 | img_path = args.data_path 62 | elif 'test' in args.test_type: 63 | img_path = args.data_path 64 | 65 | print("Img Path is ", img_path) 66 | 67 | 68 | model = Unet( 69 | dim = 64, 70 | dim_mults = (1, 2, 4, 8), 71 | channels=1 72 | ).cuda() 73 | 74 | 75 | diffusion = GaussianDiffusion( 76 | model, 77 | image_size = 32, 78 | device_of_kernel = 'cuda', 79 | channels = 1, 80 | timesteps = args.time_steps, # number of steps 81 | loss_type = 'l1', # L1 or L2 82 | resolution_routine=args.resolution_routine, 83 | train_routine=args.train_routine, 84 | sampling_routine = args.sampling_routine 85 | ).cuda() 86 | 87 | trainer = Trainer( 88 | diffusion, 89 | img_path, 90 | image_size = 32, 91 | train_batch_size = 32, 92 | train_lr = 2e-5, 93 | train_num_steps = 700000, # total training steps 94 | gradient_accumulate_every = 2, # gradient accumulation steps 95 | ema_decay = 0.995, # exponential moving average decay 96 | fp16 = False, # turn on mixed precision training with apex 97 | results_folder = args.save_folder, 98 | load_path = args.load_path 99 | ) 100 | 101 | if args.test_type == 'train_data': 102 | trainer.test_from_data('train', s_times=args.sample_steps) 103 | 104 | elif args.test_type == 'test_data': 105 | trainer.test_from_data('test', s_times=args.sample_steps) 106 | 107 | elif args.test_type == 'mixup_train_data': 108 | trainer.test_with_mixup('train') 109 | 110 | elif args.test_type == 'mixup_test_data': 111 | trainer.test_with_mixup('test') 112 | 113 | elif args.test_type == 'test_random': 114 | trainer.test_from_random('random') 115 | 116 | elif args.test_type == 'train_distribution_cov_vector': 117 | trainer.sample_as_a_vector_cov(start=0, end=None, siz=4, ch=1) 118 | trainer.sample_as_a_vector_cov(start=0, end=None, siz=8, ch=1) 119 | trainer.sample_as_a_vector_cov(start=0, end=None, siz=16, ch=1) 120 | # trainer.sample_as_a_vector_cov(start=0, end=None, siz=16, ch=1) 121 | # trainer.sample_as_a_vector_cov(start=0, end=None, siz=8, ch=1) 122 | # trainer.sample_as_a_vector_cov(start=0, end=None, siz=4, ch=1) 123 | 124 | elif args.test_type == 'train_distribution_gmm': 125 | trainer.sample_as_a_vector_gmm(start=0, end=None, siz=args.gmm_size, ch=1, clusters=args.gmm_cluster) 126 | # trainer.sample_as_a_vector_gmm(start=0, end=None, siz=8, ch=1, clusters=10) 127 | # trainer.sample_as_a_vector_gmm(start=0, end=None, siz=4, ch=1, clusters=10) 128 | # trainer.sample_as_a_vector_gmm(start=0, end=None, siz=16, ch=1, clusters=10) 129 | 130 | # elif args.test_type == 'train_distribution_save_gmm': 131 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=8, ch=1, clusters=10) 132 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=4, ch=1, clusters=10) 133 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=16, ch=1, clusters=10) 134 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=32, ch=1, clusters=10) 135 | # 136 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=8, ch=1, clusters=20) 137 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=4, ch=1, clusters=20) 138 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=16, ch=1, clusters=20) 139 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=32, ch=1, clusters=20) 140 | # 141 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=8, ch=1, clusters=50) 142 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=4, ch=1, clusters=50) 143 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=16, ch=1, clusters=50) 144 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=32, ch=1, clusters=50) 145 | # 146 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=8, ch=1, clusters=100) 147 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=4, ch=1, clusters=100) 148 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=16, ch=1, clusters=100) 149 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=32, ch=1, clusters=100) 150 | 151 | elif args.test_type == 'train_distribution_save_gmm': 152 | trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=args.gmm_size, ch=1, clusters=args.gmm_cluster) 153 | 154 | elif args.test_type == 'train_distribution_save_gmm_slowly': 155 | trainer.sample_as_a_vector_gmm_and_save_slowly(start=0, end=None, siz=args.gmm_size, ch=1, clusters=args.gmm_cluster) 156 | 157 | elif args.test_type == 'train_save_orig_data_same_as_trained': 158 | trainer.save_training_data() 159 | 160 | elif args.test_type == 'test_save_orig_data_same_as_tested': 161 | trainer.save_training_data() 162 | 163 | elif args.test_type == 'test_fid_distance_decrease_from_manifold': 164 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 165 | 166 | elif args.test_type == 'train_fid_distance_decrease_from_manifold': 167 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) 168 | 169 | elif args.test_type == 'sample_from_train_data': 170 | trainer.sample_from_data_save(start=0, end=None) 171 | 172 | elif args.test_type == 'sample_from_test_data': 173 | trainer.sample_from_data_save(start=0, end=None) 174 | 175 | elif args.test_type == 'train_distribution_save_pytorch_gmm': 176 | trainer.sample_as_a_vector_pytorch_gmm_and_save(GaussianMixture, start=0, end=59000, siz=args.gmm_size, ch=1, clusters=args.gmm_cluster) 177 | 178 | elif args.test_type == 'test_paper_invert_section_images': 179 | trainer.paper_invert_section_images() 180 | 181 | elif args.test_type == 'test_paper_showing_diffusion_images': 182 | trainer.paper_showing_diffusion_images() 183 | 184 | elif args.test_type == 'test_paper_showing_diffusion_imgs_og': 185 | trainer.paper_showing_diffusion_imgs_og() 186 | -------------------------------------------------------------------------------- /decolor-diffusion/diffusion/model/unet_convnext.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from inspect import isfunction 5 | from einops import rearrange 6 | 7 | def exists(x): 8 | return x is not None 9 | 10 | def default(val, d): 11 | if exists(val): 12 | return val 13 | return d() if isfunction(d) else d 14 | 15 | 16 | 17 | class Residual(nn.Module): 18 | def __init__(self, fn): 19 | super().__init__() 20 | self.fn = fn 21 | 22 | def forward(self, x, *args, **kwargs): 23 | return self.fn(x, *args, **kwargs) + x 24 | 25 | class SinusoidalPosEmb(nn.Module): 26 | def __init__(self, dim): 27 | super().__init__() 28 | self.dim = dim 29 | 30 | def forward(self, x): 31 | device = x.device 32 | half_dim = self.dim // 2 33 | emb = math.log(10000) / (half_dim - 1) 34 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 35 | emb = x[:, None] * emb[None, :] 36 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 37 | return emb 38 | 39 | def Upsample(dim): 40 | return nn.ConvTranspose2d(dim, dim, 4, 2, 1) 41 | 42 | def Downsample(dim): 43 | return nn.Conv2d(dim, dim, 4, 2, 1) 44 | 45 | class LayerNorm(nn.Module): 46 | def __init__(self, dim, eps = 1e-5): 47 | super().__init__() 48 | self.eps = eps 49 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 50 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) 51 | 52 | def forward(self, x): 53 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 54 | mean = torch.mean(x, dim = 1, keepdim = True) 55 | return (x - mean) / (var + self.eps).sqrt() * self.g + self.b 56 | 57 | class PreNorm(nn.Module): 58 | def __init__(self, dim, fn): 59 | super().__init__() 60 | self.fn = fn 61 | self.norm = LayerNorm(dim) 62 | 63 | def forward(self, x): 64 | x = self.norm(x) 65 | return self.fn(x) 66 | 67 | # building block modules 68 | 69 | class ConvNextBlock(nn.Module): 70 | """ https://arxiv.org/abs/2201.03545 """ 71 | 72 | def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True): 73 | super().__init__() 74 | self.mlp = nn.Sequential( 75 | nn.GELU(), 76 | nn.Linear(time_emb_dim, dim) 77 | ) if exists(time_emb_dim) else None 78 | 79 | self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim) 80 | 81 | self.net = nn.Sequential( 82 | LayerNorm(dim) if norm else nn.Identity(), 83 | nn.Conv2d(dim, dim_out * mult, 3, padding = 1), 84 | nn.GELU(), 85 | nn.Conv2d(dim_out * mult, dim_out, 3, padding = 1) 86 | ) 87 | 88 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 89 | 90 | def forward(self, x, time_emb = None): 91 | h = self.ds_conv(x) 92 | 93 | if exists(self.mlp): 94 | assert exists(time_emb), 'time emb must be passed in' 95 | condition = self.mlp(time_emb) 96 | h = h + rearrange(condition, 'b c -> b c 1 1') 97 | 98 | h = self.net(h) 99 | return h + self.res_conv(x) 100 | 101 | class LinearAttention(nn.Module): 102 | def __init__(self, dim, heads = 4, dim_head = 32): 103 | super().__init__() 104 | self.scale = dim_head ** -0.5 105 | self.heads = heads 106 | hidden_dim = dim_head * heads 107 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 108 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 109 | 110 | def forward(self, x): 111 | b, c, h, w = x.shape 112 | qkv = self.to_qkv(x).chunk(3, dim = 1) 113 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) 114 | q = q * self.scale 115 | 116 | k = k.softmax(dim = -1) 117 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 118 | 119 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 120 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) 121 | return self.to_out(out) 122 | 123 | # model 124 | 125 | class UnetConvNextBlock(nn.Module): 126 | def __init__( 127 | self, 128 | dim, 129 | out_dim = None, 130 | dim_mults=(1, 2, 4, 8), 131 | channels = 3, 132 | with_time_emb = True, 133 | output_mean_scale = False, 134 | residual = False, 135 | ): 136 | super().__init__() 137 | self.channels = channels 138 | self.residual = residual 139 | print("Is Time embed used ? ", with_time_emb) 140 | self.output_mean_scale = output_mean_scale 141 | 142 | dims = [channels, *map(lambda m: dim * m, dim_mults)] 143 | in_out = list(zip(dims[:-1], dims[1:])) 144 | 145 | if with_time_emb: 146 | time_dim = dim 147 | self.time_mlp = nn.Sequential( 148 | SinusoidalPosEmb(dim), 149 | nn.Linear(dim, dim * 4), 150 | nn.GELU(), 151 | nn.Linear(dim * 4, dim) 152 | ) 153 | else: 154 | time_dim = None 155 | self.time_mlp = None 156 | 157 | self.downs = nn.ModuleList([]) 158 | self.ups = nn.ModuleList([]) 159 | num_resolutions = len(in_out) 160 | 161 | for ind, (dim_in, dim_out) in enumerate(in_out): 162 | is_last = ind >= (num_resolutions - 1) 163 | 164 | self.downs.append(nn.ModuleList([ 165 | ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0), 166 | ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim), 167 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 168 | Downsample(dim_out) if not is_last else nn.Identity() 169 | ])) 170 | 171 | mid_dim = dims[-1] 172 | self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim) 173 | self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) 174 | self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim) 175 | 176 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 177 | is_last = ind >= (num_resolutions - 1) 178 | 179 | self.ups.append(nn.ModuleList([ 180 | ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim), 181 | ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim), 182 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 183 | Upsample(dim_in) if not is_last else nn.Identity() 184 | ])) 185 | 186 | out_dim = default(out_dim, channels) 187 | self.final_conv = nn.Sequential( 188 | ConvNextBlock(dim, dim), 189 | nn.Conv2d(dim, out_dim, 1) 190 | ) 191 | 192 | def forward(self, x, time=None): 193 | orig_x = x 194 | t = None 195 | if time is not None and exists(self.time_mlp): 196 | t = self.time_mlp(time) 197 | 198 | original_mean = torch.mean(x, [1, 2, 3], keepdim=True) 199 | h = [] 200 | 201 | for convnext, convnext2, attn, downsample in self.downs: 202 | x = convnext(x, t) 203 | x = convnext2(x, t) 204 | x = attn(x) 205 | h.append(x) 206 | x = downsample(x) 207 | 208 | x = self.mid_block1(x, t) 209 | x = self.mid_attn(x) 210 | x = self.mid_block2(x, t) 211 | 212 | for convnext, convnext2, attn, upsample in self.ups: 213 | x = torch.cat((x, h.pop()), dim=1) 214 | x = convnext(x, t) 215 | x = convnext2(x, t) 216 | x = attn(x) 217 | x = upsample(x) 218 | if self.residual: 219 | return self.final_conv(x) + orig_x 220 | 221 | out = self.final_conv(x) 222 | if self.output_mean_scale: 223 | out_mean = torch.mean(out, [1,2,3], keepdim=True) 224 | out = out - original_mean + out_mean 225 | 226 | return out 227 | 228 | 229 | -------------------------------------------------------------------------------- /snowification/diffusion/model/unet_convnext.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from inspect import isfunction 5 | from einops import rearrange 6 | 7 | def exists(x): 8 | return x is not None 9 | 10 | def default(val, d): 11 | if exists(val): 12 | return val 13 | return d() if isfunction(d) else d 14 | 15 | 16 | 17 | class Residual(nn.Module): 18 | def __init__(self, fn): 19 | super().__init__() 20 | self.fn = fn 21 | 22 | def forward(self, x, *args, **kwargs): 23 | return self.fn(x, *args, **kwargs) + x 24 | 25 | class SinusoidalPosEmb(nn.Module): 26 | def __init__(self, dim): 27 | super().__init__() 28 | self.dim = dim 29 | 30 | def forward(self, x): 31 | device = x.device 32 | half_dim = self.dim // 2 33 | emb = math.log(10000) / (half_dim - 1) 34 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 35 | emb = x[:, None] * emb[None, :] 36 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 37 | return emb 38 | 39 | def Upsample(dim): 40 | return nn.ConvTranspose2d(dim, dim, 4, 2, 1) 41 | 42 | def Downsample(dim): 43 | return nn.Conv2d(dim, dim, 4, 2, 1) 44 | 45 | class LayerNorm(nn.Module): 46 | def __init__(self, dim, eps = 1e-5): 47 | super().__init__() 48 | self.eps = eps 49 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 50 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) 51 | 52 | def forward(self, x): 53 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 54 | mean = torch.mean(x, dim = 1, keepdim = True) 55 | return (x - mean) / (var + self.eps).sqrt() * self.g + self.b 56 | 57 | class PreNorm(nn.Module): 58 | def __init__(self, dim, fn): 59 | super().__init__() 60 | self.fn = fn 61 | self.norm = LayerNorm(dim) 62 | 63 | def forward(self, x): 64 | x = self.norm(x) 65 | return self.fn(x) 66 | 67 | # building block modules 68 | 69 | class ConvNextBlock(nn.Module): 70 | """ https://arxiv.org/abs/2201.03545 """ 71 | 72 | def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True): 73 | super().__init__() 74 | self.mlp = nn.Sequential( 75 | nn.GELU(), 76 | nn.Linear(time_emb_dim, dim) 77 | ) if exists(time_emb_dim) else None 78 | 79 | self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim) 80 | 81 | self.net = nn.Sequential( 82 | LayerNorm(dim) if norm else nn.Identity(), 83 | nn.Conv2d(dim, dim_out * mult, 3, padding = 1), 84 | nn.GELU(), 85 | nn.Conv2d(dim_out * mult, dim_out, 3, padding = 1) 86 | ) 87 | 88 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 89 | 90 | def forward(self, x, time_emb = None): 91 | h = self.ds_conv(x) 92 | 93 | if exists(self.mlp): 94 | assert exists(time_emb), 'time emb must be passed in' 95 | condition = self.mlp(time_emb) 96 | h = h + rearrange(condition, 'b c -> b c 1 1') 97 | 98 | h = self.net(h) 99 | return h + self.res_conv(x) 100 | 101 | class LinearAttention(nn.Module): 102 | def __init__(self, dim, heads = 4, dim_head = 32): 103 | super().__init__() 104 | self.scale = dim_head ** -0.5 105 | self.heads = heads 106 | hidden_dim = dim_head * heads 107 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 108 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 109 | 110 | def forward(self, x): 111 | b, c, h, w = x.shape 112 | qkv = self.to_qkv(x).chunk(3, dim = 1) 113 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) 114 | q = q * self.scale 115 | 116 | k = k.softmax(dim = -1) 117 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 118 | 119 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 120 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) 121 | return self.to_out(out) 122 | 123 | # model 124 | 125 | class UnetConvNextBlock(nn.Module): 126 | def __init__( 127 | self, 128 | dim, 129 | out_dim = None, 130 | dim_mults=(1, 2, 4, 8), 131 | channels = 3, 132 | with_time_emb = True, 133 | output_mean_scale = False, 134 | residual = False, 135 | ): 136 | super().__init__() 137 | self.channels = channels 138 | self.residual = residual 139 | print("Is Time embed used ? ", with_time_emb) 140 | self.output_mean_scale = output_mean_scale 141 | 142 | dims = [channels, *map(lambda m: dim * m, dim_mults)] 143 | in_out = list(zip(dims[:-1], dims[1:])) 144 | 145 | if with_time_emb: 146 | time_dim = dim 147 | self.time_mlp = nn.Sequential( 148 | SinusoidalPosEmb(dim), 149 | nn.Linear(dim, dim * 4), 150 | nn.GELU(), 151 | nn.Linear(dim * 4, dim) 152 | ) 153 | else: 154 | time_dim = None 155 | self.time_mlp = None 156 | 157 | self.downs = nn.ModuleList([]) 158 | self.ups = nn.ModuleList([]) 159 | num_resolutions = len(in_out) 160 | 161 | for ind, (dim_in, dim_out) in enumerate(in_out): 162 | is_last = ind >= (num_resolutions - 1) 163 | 164 | self.downs.append(nn.ModuleList([ 165 | ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0), 166 | ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim), 167 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 168 | Downsample(dim_out) if not is_last else nn.Identity() 169 | ])) 170 | 171 | mid_dim = dims[-1] 172 | self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim) 173 | self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) 174 | self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim) 175 | 176 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 177 | is_last = ind >= (num_resolutions - 1) 178 | 179 | self.ups.append(nn.ModuleList([ 180 | ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim), 181 | ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim), 182 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 183 | Upsample(dim_in) if not is_last else nn.Identity() 184 | ])) 185 | 186 | out_dim = default(out_dim, channels) 187 | self.final_conv = nn.Sequential( 188 | ConvNextBlock(dim, dim), 189 | nn.Conv2d(dim, out_dim, 1) 190 | ) 191 | 192 | def forward(self, x, time=None): 193 | orig_x = x 194 | t = None 195 | if time is not None and exists(self.time_mlp): 196 | t = self.time_mlp(time) 197 | 198 | original_mean = torch.mean(x, [1, 2, 3], keepdim=True) 199 | h = [] 200 | 201 | for convnext, convnext2, attn, downsample in self.downs: 202 | x = convnext(x, t) 203 | x = convnext2(x, t) 204 | x = attn(x) 205 | h.append(x) 206 | x = downsample(x) 207 | 208 | x = self.mid_block1(x, t) 209 | x = self.mid_attn(x) 210 | x = self.mid_block2(x, t) 211 | 212 | for convnext, convnext2, attn, upsample in self.ups: 213 | x = torch.cat((x, h.pop()), dim=1) 214 | x = convnext(x, t) 215 | x = convnext2(x, t) 216 | x = attn(x) 217 | x = upsample(x) 218 | if self.residual: 219 | return self.final_conv(x) + orig_x 220 | 221 | out = self.final_conv(x) 222 | if self.output_mean_scale: 223 | out_mean = torch.mean(out, [1,2,3], keepdim=True) 224 | out = out - original_mean + out_mean 225 | 226 | return out 227 | 228 | 229 | -------------------------------------------------------------------------------- /snowification/diffusion/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from kornia.color.rgb import linear_rgb_to_rgb, rgb_to_linear_rgb 7 | from kornia.color.xyz import rgb_to_xyz, xyz_to_rgb 8 | 9 | def rgb2hsv(image_old: torch.Tensor, eps: float = 1e-8, rescale=True) -> torch.Tensor: 10 | r"""Convert an image from RGB to HSV. 11 | 12 | .. image:: _static/img/rgb_to_hsv.png 13 | 14 | The image data is assumed to be in the range of (0, 1). 15 | 16 | Args: 17 | image: RGB Image to be converted to HSV with shape of :math:`(*, 3, H, W)`. 18 | eps: scalar to enforce numarical stability. 19 | 20 | Returns: 21 | HSV version of the image with shape of :math:`(*, 3, H, W)`. 22 | The H channel values are in the range 0..2pi. S and V are in the range 0..1. 23 | 24 | .. note:: 25 | See a working example `here `__. 27 | 28 | Example: 29 | >>> input = torch.rand(2, 3, 4, 5) 30 | >>> output = rgb_to_hsv(input) # 2x3x4x5 31 | """ 32 | if rescale: 33 | image = (image_old + 1) * 0.5 34 | if not isinstance(image, torch.Tensor): 35 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") 36 | 37 | if len(image.shape) < 3 or image.shape[-3] != 3: 38 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") 39 | 40 | max_rgb, argmax_rgb = image.max(-3) 41 | min_rgb, argmin_rgb = image.min(-3) 42 | deltac = max_rgb - min_rgb 43 | 44 | v = max_rgb 45 | s = deltac / (max_rgb + eps) 46 | 47 | deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac) 48 | rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3) 49 | 50 | h1 = (bc - gc) 51 | h2 = (rc - bc) + 2.0 * deltac 52 | h3 = (gc - rc) + 4.0 * deltac 53 | 54 | h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3) 55 | h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3) 56 | h = (h / 6.0) % 1.0 57 | h = 2. * math.pi * h # we return 0/2pi output 58 | 59 | return torch.stack((h, s, v), dim=-3) 60 | 61 | 62 | 63 | def hsv2rgb(image: torch.Tensor, rescale=True) -> torch.Tensor: 64 | r"""Convert an image from HSV to RGB. 65 | 66 | The H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1. 67 | 68 | Args: 69 | image: HSV Image to be converted to HSV with shape of :math:`(*, 3, H, W)`. 70 | 71 | Returns: 72 | RGB version of the image with shape of :math:`(*, 3, H, W)`. 73 | 74 | Example: 75 | >>> input = torch.rand(2, 3, 4, 5) 76 | >>> output = hsv_to_rgb(input) # 2x3x4x5 77 | """ 78 | if not isinstance(image, torch.Tensor): 79 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") 80 | 81 | if len(image.shape) < 3 or image.shape[-3] != 3: 82 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") 83 | 84 | h: torch.Tensor = image[..., 0, :, :] / (2 * math.pi) 85 | s: torch.Tensor = image[..., 1, :, :] 86 | v: torch.Tensor = image[..., 2, :, :] 87 | 88 | hi: torch.Tensor = torch.floor(h * 6) % 6 89 | f: torch.Tensor = ((h * 6) % 6) - hi 90 | one: torch.Tensor = torch.tensor(1.0, device=image.device, dtype=image.dtype) 91 | p: torch.Tensor = v * (one - s) 92 | q: torch.Tensor = v * (one - f * s) 93 | t: torch.Tensor = v * (one - (one - f) * s) 94 | 95 | hi = hi.long() 96 | indices: torch.Tensor = torch.stack([hi, hi + 6, hi + 12], dim=-3) 97 | out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3) 98 | out = torch.gather(out, -3, indices) 99 | 100 | if rescale: 101 | out = 2.0 * out - 1 102 | 103 | return out 104 | 105 | 106 | """ 107 | The RGB to Lab color transformations were translated from scikit image's rgb2lab and lab2rgb 108 | 109 | https://github.com/scikit-image/scikit-image/blob/a48bf6774718c64dade4548153ae16065b595ca9/skimage/color/colorconv.py 110 | 111 | """ 112 | 113 | def rgb2lab(image_old: torch.Tensor) -> torch.Tensor: 114 | r"""Convert a RGB image to Lab. 115 | 116 | .. image:: _static/img/rgb_to_lab.png 117 | 118 | The image data is assumed to be in the range of :math:`[0, 1]`. Lab 119 | color is computed using the D65 illuminant and Observer 2. 120 | 121 | Args: 122 | image: RGB Image to be converted to Lab with shape :math:`(*, 3, H, W)`. 123 | 124 | Returns: 125 | Lab version of the image with shape :math:`(*, 3, H, W)`. 126 | The L channel values are in the range 0..100. a and b are in the range -127..127. 127 | 128 | Example: 129 | >>> input = torch.rand(2, 3, 4, 5) 130 | >>> output = rgb_to_lab(input) # 2x3x4x5 131 | """ 132 | image = (image_old + 1) * 0.5 133 | if not isinstance(image, torch.Tensor): 134 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") 135 | 136 | if len(image.shape) < 3 or image.shape[-3] != 3: 137 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") 138 | 139 | # Convert from sRGB to Linear RGB 140 | lin_rgb = rgb_to_linear_rgb(image) 141 | 142 | xyz_im: torch.Tensor = rgb_to_xyz(lin_rgb) 143 | 144 | # normalize for D65 white point 145 | xyz_ref_white = torch.tensor([0.95047, 1.0, 1.08883], device=xyz_im.device, dtype=xyz_im.dtype)[..., :, None, None] 146 | xyz_normalized = torch.div(xyz_im, xyz_ref_white) 147 | 148 | threshold = 0.008856 149 | power = torch.pow(xyz_normalized.clamp(min=threshold), 1 / 3.0) 150 | scale = 7.787 * xyz_normalized + 4.0 / 29.0 151 | xyz_int = torch.where(xyz_normalized > threshold, power, scale) 152 | 153 | x: torch.Tensor = xyz_int[..., 0, :, :] 154 | y: torch.Tensor = xyz_int[..., 1, :, :] 155 | z: torch.Tensor = xyz_int[..., 2, :, :] 156 | 157 | L: torch.Tensor = (116.0 * y) - 16.0 158 | a: torch.Tensor = 500.0 * (x - y) 159 | _b: torch.Tensor = 200.0 * (y - z) 160 | 161 | out: torch.Tensor = torch.stack([L, a, _b], dim=-3) 162 | 163 | return out 164 | 165 | 166 | def lab2rgb(image: torch.Tensor, clip: bool = True) -> torch.Tensor: 167 | r"""Convert a Lab image to RGB. 168 | 169 | Args: 170 | image: Lab image to be converted to RGB with shape :math:`(*, 3, H, W)`. 171 | clip: Whether to apply clipping to insure output RGB values in range :math:`[0, 1]`. 172 | 173 | Returns: 174 | Lab version of the image with shape :math:`(*, 3, H, W)`. 175 | 176 | Example: 177 | >>> input = torch.rand(2, 3, 4, 5) 178 | >>> output = lab_to_rgb(input) # 2x3x4x5 179 | """ 180 | if not isinstance(image, torch.Tensor): 181 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") 182 | 183 | if len(image.shape) < 3 or image.shape[-3] != 3: 184 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") 185 | 186 | L: torch.Tensor = image[..., 0, :, :] 187 | a: torch.Tensor = image[..., 1, :, :] 188 | _b: torch.Tensor = image[..., 2, :, :] 189 | 190 | fy = (L + 16.0) / 116.0 191 | fx = (a / 500.0) + fy 192 | fz = fy - (_b / 200.0) 193 | 194 | # if color data out of range: Z < 0 195 | fz = fz.clamp(min=0.0) 196 | 197 | fxyz = torch.stack([fx, fy, fz], dim=-3) 198 | 199 | # Convert from Lab to XYZ 200 | power = torch.pow(fxyz, 3.0) 201 | scale = (fxyz - 4.0 / 29.0) / 7.787 202 | xyz = torch.where(fxyz > 0.2068966, power, scale) 203 | 204 | # For D65 white point 205 | xyz_ref_white = torch.tensor([0.95047, 1.0, 1.08883], device=xyz.device, dtype=xyz.dtype)[..., :, None, None] 206 | xyz_im = xyz * xyz_ref_white 207 | 208 | rgbs_im: torch.Tensor = xyz_to_rgb(xyz_im) 209 | 210 | # https://github.com/richzhang/colorization-pytorch/blob/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/util/util.py#L107 211 | # rgbs_im = torch.where(rgbs_im < 0, torch.zeros_like(rgbs_im), rgbs_im) 212 | 213 | # Convert from RGB Linear to sRGB 214 | rgb_im = linear_rgb_to_rgb(rgbs_im) 215 | 216 | # Clip to 0,1 https://www.w3.org/Graphics/Color/srgb 217 | if clip: 218 | rgb_im = torch.clamp(rgb_im, min=0.0, max=1.0) 219 | 220 | rgb_im = 2.0 * rgb_im - 1 221 | 222 | return rgb_im 223 | 224 | 225 | class RgbToLab(nn.Module): 226 | r"""Convert an image from RGB to Lab. 227 | 228 | The image data is assumed to be in the range of :math:`[0, 1]`. Lab 229 | color is computed using the D65 illuminant and Observer 2. 230 | 231 | Returns: 232 | Lab version of the image. 233 | 234 | Shape: 235 | - image: :math:`(*, 3, H, W)` 236 | - output: :math:`(*, 3, H, W)` 237 | 238 | Examples: 239 | >>> input = torch.rand(2, 3, 4, 5) 240 | >>> lab = RgbToLab() 241 | >>> output = lab(input) # 2x3x4x5 242 | 243 | Reference: 244 | [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html 245 | 246 | [2] https://www.easyrgb.com/en/math.php 247 | 248 | [3] https://github.com/torch/image/blob/dc061b98fb7e946e00034a5fc73e883a299edc7f/generic/image.c#L1467 249 | """ 250 | 251 | def forward(self, image: torch.Tensor) -> torch.Tensor: 252 | return rgb_to_lab(image) 253 | 254 | 255 | 256 | 257 | class LabToRgb(nn.Module): 258 | r"""Convert an image from Lab to RGB. 259 | 260 | Returns: 261 | RGB version of the image. Range may not be in :math:`[0, 1]`. 262 | 263 | Shape: 264 | - image: :math:`(*, 3, H, W)` 265 | - output: :math:`(*, 3, H, W)` 266 | 267 | Examples: 268 | >>> input = torch.rand(2, 3, 4, 5) 269 | >>> rgb = LabToRgb() 270 | >>> output = rgb(input) # 2x3x4x5 271 | 272 | References: 273 | [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html 274 | 275 | [2] https://www.easyrgb.com/en/math.php 276 | 277 | [3] https://github.com/torch/image/blob/dc061b98fb7e946e00034a5fc73e883a299edc7f/generic/image.c#L1518 278 | """ 279 | 280 | def forward(self, image: torch.Tensor, clip: bool = True) -> torch.Tensor: 281 | return lab_to_rgb(image, clip) 282 | 283 | --------------------------------------------------------------------------------