├── snowification
├── diffusion
│ ├── model
│ │ ├── __init__.py
│ │ ├── get_model.py
│ │ └── unet_convnext.py
│ ├── __init__.py
│ ├── get_dataset.py
│ └── utils.py
├── Fid
│ └── __init__.py
├── training_script.sh
├── testing_script.sh
├── setup.py
├── train.py
└── test.py
├── decolor-diffusion
├── diffusion
│ ├── model
│ │ ├── __init__.py
│ │ ├── get_model.py
│ │ └── unet_convnext.py
│ ├── __init__.py
│ └── get_dataset.py
├── Fid
│ └── __init__.py
├── setup.py
├── train.py
└── test.py
├── deblurring-diffusion-pytorch
├── deblurring_diffusion_pytorch
│ ├── cal_Fid.py
│ └── __init__.py
├── Fid
│ └── __init__.py
├── .github
│ └── workflows
│ │ └── python-publish.yml
├── .gitignore
├── AFHQ_128.py
├── mnist_train.py
├── dispatch.py
├── cifar10_train.py
├── celebA_128.py
├── mnist_test.py
├── cifar10_test.py
├── AFHQ_128_test.py
└── celebA_128_test.py
├── demixing-diffusion-pytorch
├── demixing_diffusion_pytorch
│ ├── cal_Fid.py
│ └── __init__.py
├── Fid
│ └── __init__.py
├── .github
│ └── workflows
│ │ └── python-publish.yml
├── .gitignore
├── AFHQ_128_to_celebA_128.py
├── AFHQ_128_to_celebA_128_test.py
└── dispatch.py
├── denoising-diffusion-pytorch
├── denoising_diffusion_pytorch
│ ├── cal_Fid.py
│ └── __init__.py
├── Fid
│ └── __init__.py
├── AFHQ_noise_128.py
├── celebA_noise_128.py
├── AFHQ_noise_128_test.py
└── celebA_noise_128_test.py
├── defading-generation-diffusion-pytorch
├── defading_diffusion_pytorch
│ ├── cal_Fid.py
│ └── __init__.py
├── Fid
│ └── __init__.py
├── celebA_128.py
└── celebA_128_test.py
├── all_transform_cover.png
├── defading-diffusion-pytorch
├── Fid
│ └── __init__.py
├── defading_diffusion_pytorch
│ └── __init__.py
├── bigger_mask_train.sh
├── inpainting_test_paper.sh
├── .github
│ └── workflows
│ │ └── python-publish.yml
├── LICENSE
├── .gitignore
├── mnist_train.py
├── cifar10_train_wd.py
├── cifar10_train.py
├── celebA_train.py
├── mnist_test.py
├── celebA_test.py
├── cifar10_test.py
└── inpainting_paper_train.sh
├── resolution-diffusion-pytorch
├── Fid
│ └── __init__.py
├── resolution_diffusion_pytorch
│ └── __init__.py
├── .github
│ └── workflows
│ │ └── python-publish.yml
├── runs.sh
├── celebA_128.py
├── .gitignore
├── celebA.py
├── mnist_train.py
├── cifar10_train.py
├── celebA_128_test.py
├── cifar10_test.py
└── mnist_test.py
├── .idea
├── misc.xml
├── vcs.xml
├── .gitignore
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── Cold-Diffusion-Models.iml
└── modules.xml
└── create_data.py
/snowification/diffusion/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/decolor-diffusion/diffusion/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/deblurring_diffusion_pytorch/cal_Fid.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/demixing-diffusion-pytorch/demixing_diffusion_pytorch/cal_Fid.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/denoising-diffusion-pytorch/denoising_diffusion_pytorch/cal_Fid.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/defading-generation-diffusion-pytorch/defading_diffusion_pytorch/cal_Fid.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/snowification/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from diffusion.diffusion import GaussianDiffusion, Trainer
2 |
--------------------------------------------------------------------------------
/decolor-diffusion/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from diffusion.diffusion import GaussianDiffusion, Trainer
2 |
--------------------------------------------------------------------------------
/all_transform_cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arpitbansal297/Cold-Diffusion-Models/HEAD/all_transform_cover.png
--------------------------------------------------------------------------------
/decolor-diffusion/Fid/__init__.py:
--------------------------------------------------------------------------------
1 | from Fid.inception import InceptionV3
2 | from Fid.fid_score import calculate_fid_given_samples
--------------------------------------------------------------------------------
/snowification/Fid/__init__.py:
--------------------------------------------------------------------------------
1 | from Fid.inception import InceptionV3
2 | from Fid.fid_score import calculate_fid_given_samples
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/Fid/__init__.py:
--------------------------------------------------------------------------------
1 | from Fid.inception import InceptionV3
2 | from Fid.fid_score import calculate_fid_given_samples
--------------------------------------------------------------------------------
/demixing-diffusion-pytorch/Fid/__init__.py:
--------------------------------------------------------------------------------
1 | from Fid.inception import InceptionV3
2 | from Fid.fid_score import calculate_fid_given_samples
--------------------------------------------------------------------------------
/denoising-diffusion-pytorch/Fid/__init__.py:
--------------------------------------------------------------------------------
1 | from Fid.inception import InceptionV3
2 | from Fid.fid_score import calculate_fid_given_samples
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/Fid/__init__.py:
--------------------------------------------------------------------------------
1 | from Fid.inception import InceptionV3
2 | from Fid.fid_score import calculate_fid_given_samples
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/Fid/__init__.py:
--------------------------------------------------------------------------------
1 | from Fid.inception import InceptionV3
2 | from Fid.fid_score import calculate_fid_given_samples
--------------------------------------------------------------------------------
/defading-generation-diffusion-pytorch/Fid/__init__.py:
--------------------------------------------------------------------------------
1 | from Fid.inception import InceptionV3
2 | from Fid.fid_score import calculate_fid_given_samples
--------------------------------------------------------------------------------
/demixing-diffusion-pytorch/demixing_diffusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from demixing_diffusion_pytorch.demixing_diffusion_pytorch import GaussianDiffusion, Unet, Trainer
2 |
3 |
--------------------------------------------------------------------------------
/denoising-diffusion-pytorch/denoising_diffusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer
2 |
3 |
--------------------------------------------------------------------------------
/defading-generation-diffusion-pytorch/defading_diffusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from defading_diffusion_pytorch.defading_diffusion_pytorch import GaussianDiffusion, Unet, Trainer
2 |
3 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/defading_diffusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from defading_diffusion_pytorch.defading_diffusion_gaussian import GaussianDiffusion, Unet, Trainer
2 | from defading_diffusion_pytorch.new_model import Model
3 |
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/deblurring_diffusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from deblurring_diffusion_pytorch.deblurring_diffusion_pytorch import GaussianDiffusion, Unet, Trainer
2 | from deblurring_diffusion_pytorch.Model2 import Model
3 |
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/resolution_diffusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from resolution_diffusion_pytorch.resolution_diffusion_pytorch import GaussianDiffusion, Unet, Trainer
2 | from resolution_diffusion_pytorch.Model2 import Model
3 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/Cold-Diffusion-Models.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/snowification/training_script.sh:
--------------------------------------------------------------------------------
1 | # Longer step experiments
2 | python train.py --time_steps 20 --forward_process_type 'Decolorization' --dataset_folder --exp_name 'cifar_exp' --decolor_total_remove --decolor_routine 'Linear'
3 | python train.py --time_steps 20 --forward_process_type 'Decolorization' --exp_name 'celeba_exp' --decolor_ema_factor 0.9 --decolor_total_remove --decolor_routine 'Linear' --dataset celebA --dataset_folder --resolution 64 --resume_training
4 |
--------------------------------------------------------------------------------
/snowification/testing_script.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/sh
2 | python test.py --time_steps 20 --forward_process_type 'Decolorization' --dataset_folder --exp_name 'cifar_exp' --decolor_total_remove --decolor_routine 'Linear' --sampling_routine x0_step_down --test_type test_paper --order_seed 1 --test_fid
3 | python test.py --time_steps 20 --forward_process_type 'Decolorization' --exp_name 'celeba_exp' --decolor_total_remove --decolor_routine 'Linear' --dataset celebA --dataset_folder --resolution 64 --resume_training --sampling_routine x0_step_down --test_type test_paper --order_seed 1 --test_fid
4 |
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/bigger_mask_train.sh:
--------------------------------------------------------------------------------
1 | #python defading-diffusion-pytorch/celebA_train.py --image_size 64 --time_steps 100 --initial_mask 11 --save_folder ./celebA_64x64_100_step_11_init --discrete --sampling_routine x0_step_down --train_steps 700000 --kernel_std 0.1 --fade_routine Random_Incremental --load_path /cmlscratch/eborgnia/cold_diffusion/test_speed_celebA/model.pt
2 | python defading-diffusion-pytorch/celebA_train.py --image_size 128 --time_steps 200 --initial_mask 21 --save_folder ./celebA_128x128_200_step_21_init --discrete --sampling_routine x0_step_down --train_steps 700000 --kernel_std 0.1 --fade_routine Random_Incremental
3 |
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/inpainting_test_paper.sh:
--------------------------------------------------------------------------------
1 | #python defading-diffusion-pytorch/cifar10_test.py --time_steps 50 --save_folder ./cifar_inpainting_test --discrete --sampling_routine x0_step_down --blur_std 0.1 --fade_routine Random_Incremental --test_type test_paper_invert_section_images
2 |
3 | #python defading-diffusion-pytorch/celebA_128_test.py --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_64x64_100_step_11_init/model.pt --time_steps 100 --save_folder ./celebA_inpainting_test --discrete --sampling_routine x0_step_down --blur_std 0.1 --initial_mask 11 --fade_routine Random_Incremental --test_type test_data
4 |
5 | #python defading-diffusion-pytorch/mnist_test.py --time_steps 50 --save_folder ./mnist_inpainting_test --discrete --sampling_routine x0_step_down --blur_std 0.1 --fade_routine Random_Incremental --test_type test_paper_invert_section_images
6 |
7 | python defading-diffusion-pytorch/celebA_test.py
--------------------------------------------------------------------------------
/snowification/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'denoising-diffusion-pytorch',
5 | packages = find_packages(),
6 | version = '0.7.1',
7 | license='MIT',
8 | description = 'Denoising Diffusion Probabilistic Models - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/denoising-diffusion-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'generative models'
15 | ],
16 | install_requires=[
17 | 'einops',
18 | 'numpy',
19 | 'pillow',
20 | 'torch',
21 | 'torchvision',
22 | 'tqdm'
23 | ],
24 | classifiers=[
25 | 'Development Status :: 4 - Beta',
26 | 'Intended Audience :: Developers',
27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
28 | 'License :: OSI Approved :: MIT License',
29 | 'Programming Language :: Python :: 3.6',
30 | ],
31 | )
--------------------------------------------------------------------------------
/decolor-diffusion/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'denoising-diffusion-pytorch',
5 | packages = find_packages(),
6 | version = '0.7.1',
7 | license='MIT',
8 | description = 'Denoising Diffusion Probabilistic Models - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/denoising-diffusion-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'generative models'
15 | ],
16 | install_requires=[
17 | 'einops',
18 | 'numpy',
19 | 'pillow',
20 | 'torch',
21 | 'torchvision',
22 | 'tqdm'
23 | ],
24 | classifiers=[
25 | 'Development Status :: 4 - Beta',
26 | 'Intended Audience :: Developers',
27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
28 | 'License :: OSI Approved :: MIT License',
29 | 'Programming Language :: Python :: 3.6',
30 | ],
31 | )
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/demixing-diffusion-pytorch/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/runs.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | python celebA_128.py --time_steps 4 --resolution_routine 'Incremental_factor_2' --save_folder './celebA_4_steps_fac2_train'
4 | python celebA_test.py --time_steps 4 --train_routine 'Final' --sampling_routine 'x0_step_down' --resolution_routine 'Incremental_factor_2' --save_folder './celebA_test' --load_path './celebA_final_ckpt/model.pt' --test_type 'test_fid_distance_decrease_from_manifold'
5 |
6 |
7 | python mnist_train.py --time_steps 3 --resolution_routine 'Incremental_factor_2' --save_folder './mnist_3_steps_fac2_train'
8 | python mnist_test.py --time_steps 3 --train_routine 'Final' --sampling_routine 'x0_step_down' --resolution_routine 'Incremental_factor_2' --save_folder './mnist_test' --load_path './mnist_final_ckpt/model.pt' --test_type 'test_fid_distance_decrease_from_manifold'
9 |
10 |
11 | python cifar10_train.py --time_steps 3 --resolution_routine 'Incremental_factor_2' --save_folder './cifar10Aug_3_steps_fac2_train'
12 | python cifar10_test.py --time_steps 3 --train_routine 'Final' --sampling_routine 'x0_step_down' --resolution_routine 'Incremental_factor_2' --save_folder './cifar10_test' --load_path './cifar10_final_ckpt/model.pt' --test_type 'test_fid_distance_decrease_from_manifold'
13 |
14 |
15 |
--------------------------------------------------------------------------------
/decolor-diffusion/diffusion/model/get_model.py:
--------------------------------------------------------------------------------
1 | from .unet_convnext import UnetConvNextBlock
2 | from .unet_resnet import UnetResNetBlock
3 |
4 | def get_model(args, with_time_emb=True):
5 | if args.model == 'UnetConvNext':
6 | return UnetConvNextBlock(
7 | dim = 64,
8 | dim_mults = (1, 2, 4, 8),
9 | channels=3,
10 | with_time_emb=with_time_emb,
11 | residual=False,
12 | )
13 | if args.model == 'UnetResNet':
14 | if 'cifar10' in args.dataset:
15 | return UnetResNetBlock(resolution=32,
16 | in_channels=3,
17 | out_ch=3,
18 | ch=128,
19 | ch_mult=(1,2,2,2),
20 | num_res_blocks=2,
21 | attn_resolutions=(16,),
22 | with_time_emb=with_time_emb,
23 | dropout=0.1
24 | )
25 | if 'celebA' in args.dataset:
26 | return UnetResNetBlock(resolution=128,
27 | in_channels=3,
28 | out_ch=3,
29 | ch=128,
30 | ch_mult=(1,2,2,2),
31 | num_res_blocks=2,
32 | attn_resolutions=(16,),
33 | with_time_emb=with_time_emb,
34 | dropout=0.1
35 | )
36 |
37 |
--------------------------------------------------------------------------------
/snowification/diffusion/model/get_model.py:
--------------------------------------------------------------------------------
1 | from .unet_convnext import UnetConvNextBlock
2 | from .unet_resnet import UnetResNetBlock
3 |
4 | def get_model(args, with_time_emb=True):
5 | if args.model == 'UnetConvNext':
6 | return UnetConvNextBlock(
7 | dim = 64,
8 | dim_mults = (1, 2, 4, 8),
9 | channels=3,
10 | with_time_emb=with_time_emb,
11 | residual=False,
12 | )
13 | if args.model == 'UnetResNet':
14 | if 'cifar10' in args.dataset:
15 | return UnetResNetBlock(resolution=32,
16 | in_channels=3,
17 | out_ch=3,
18 | ch=128,
19 | ch_mult=(1,2,2,2),
20 | num_res_blocks=2,
21 | attn_resolutions=(16,),
22 | with_time_emb=with_time_emb,
23 | dropout=0.1
24 | )
25 | if 'celebA' in args.dataset:
26 | return UnetResNetBlock(resolution=128,
27 | in_channels=3,
28 | out_ch=3,
29 | ch=128,
30 | ch_mult=(1,2,2,2),
31 | num_res_blocks=2,
32 | attn_resolutions=(16,),
33 | with_time_emb=with_time_emb,
34 | dropout=0.1
35 | )
36 |
37 |
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/celebA_128.py:
--------------------------------------------------------------------------------
1 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
2 | import torchvision
3 | import argparse
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--time_steps', default=50, type=int)
7 | parser.add_argument('--train_steps', default=700000, type=int)
8 | parser.add_argument('--save_folder', default='./results_celebA', type=str)
9 | parser.add_argument('--load_path', default=None, type=str)
10 | parser.add_argument('--data_path', default='./root_celebA_128_train_new/', type=str)
11 | parser.add_argument('--resolution_routine', default='Incremental', type=str)
12 | parser.add_argument('--train_routine', default='Final', type=str)
13 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
14 | parser.add_argument('--remove_time_embed', action="store_true")
15 | parser.add_argument('--residual', action="store_true")
16 |
17 | args = parser.parse_args()
18 | print(args)
19 |
20 | model = Unet(
21 | dim = 64,
22 | dim_mults = (1, 2, 4, 8),
23 | channels=3,
24 | with_time_emb=not(args.remove_time_embed),
25 | residual=args.residual
26 | ).cuda()
27 |
28 | diffusion = GaussianDiffusion(
29 | model,
30 | image_size = 128,
31 | device_of_kernel = 'cuda',
32 | channels = 3,
33 | timesteps = args.time_steps, # number of steps
34 | loss_type = 'l1', # L1 or L2
35 | resolution_routine=args.resolution_routine,
36 | train_routine = args.train_routine,
37 | sampling_routine = args.sampling_routine
38 | ).cuda()
39 |
40 | import torch
41 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
42 |
43 | trainer = Trainer(
44 | diffusion,
45 | args.data_path,
46 | image_size = 128,
47 | train_batch_size = 32,
48 | train_lr = 2e-5,
49 | train_num_steps = args.train_steps, # total training steps
50 | gradient_accumulate_every = 2, # gradient accumulation steps
51 | ema_decay = 0.995, # exponential moving average decay
52 | fp16 = False, # turn on mixed precision training with apex
53 | results_folder = args.save_folder,
54 | load_path = args.load_path,
55 | dataset = 'celebA'
56 | )
57 |
58 | trainer.train()
59 |
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/.gitignore:
--------------------------------------------------------------------------------
1 | # Generation results
2 | results/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/.gitignore:
--------------------------------------------------------------------------------
1 | # Generation results
2 | results/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/demixing-diffusion-pytorch/.gitignore:
--------------------------------------------------------------------------------
1 | # Generation results
2 | results/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/.gitignore:
--------------------------------------------------------------------------------
1 | # Generation results
2 | results/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/celebA.py:
--------------------------------------------------------------------------------
1 | #from comet_ml import Experiment
2 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--time_steps', default=50, type=int)
26 | parser.add_argument('--train_steps', default=700000, type=int)
27 | parser.add_argument('--save_folder', default='./results_celebA', type=str)
28 | parser.add_argument('--load_path', default=None, type=str)
29 | parser.add_argument('--resolution_routine', default='Incremental', type=str)
30 | parser.add_argument('--train_routine', default='Final', type=str)
31 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
32 | parser.add_argument('--remove_time_embed', action="store_true")
33 | parser.add_argument('--residual', action="store_true")
34 |
35 |
36 |
37 | args = parser.parse_args()
38 | print(args)
39 |
40 |
41 | model = Unet(
42 | dim = 64,
43 | dim_mults = (1, 2, 4, 8),
44 | channels=3,
45 | with_time_emb=not(args.remove_time_embed),
46 | residual=args.residual
47 | ).cuda()
48 |
49 | import torch
50 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
51 |
52 |
53 | diffusion = GaussianDiffusion(
54 | model,
55 | image_size = 64,
56 | device_of_kernel = 'cuda',
57 | channels = 3,
58 | timesteps = args.time_steps, # number of steps
59 | loss_type = 'l1', # L1 or L2
60 | resolution_routine=args.resolution_routine,
61 | train_routine = args.train_routine,
62 | sampling_routine = args.sampling_routine
63 | ).cuda()
64 |
65 | trainer = Trainer(
66 | diffusion,
67 | '/cmlscratch/bansal01/CelebA_train_new/train/',
68 | image_size = 64,
69 | train_batch_size = 32,
70 | train_lr = 2e-5,
71 | train_num_steps = args.train_steps, # total training steps
72 | gradient_accumulate_every = 2, # gradient accumulation steps
73 | ema_decay = 0.995, # exponential moving average decay
74 | fp16 = False, # turn on mixed precision training with apex
75 | results_folder = args.save_folder,
76 | load_path = args.load_path
77 | )
78 |
79 | trainer.train()
80 |
--------------------------------------------------------------------------------
/snowification/diffusion/get_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision import transforms, utils
3 | from torchvision import datasets
4 |
5 | def get_transform(image_size, random_aug=False, resize=False):
6 | if image_size[0] == 64:
7 | transform_list = [
8 | transforms.CenterCrop((128,128)),
9 | transforms.Resize(image_size),
10 | transforms.ToTensor(),
11 | transforms.Lambda(lambda t: (t * 2) - 1)
12 | ]
13 | elif not random_aug:
14 | transform_list = [
15 | transforms.CenterCrop(image_size),
16 | transforms.ToTensor(),
17 | transforms.Lambda(lambda t: (t * 2) - 1)
18 | ]
19 | if resize:
20 | transform_list = [transforms.Resize(image_size)] + transform_list
21 | T = transforms.Compose(transform_list)
22 | else:
23 | s = 1.0
24 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
25 | T = transforms.Compose([
26 | transforms.RandomResizedCrop(size=image_size),
27 | transforms.RandomHorizontalFlip(),
28 | transforms.RandomApply([color_jitter], p=0.8),
29 | transforms.ToTensor(),
30 | transforms.Lambda(lambda t: (t * 2) - 1)
31 | ])
32 |
33 | return T
34 |
35 | def get_image_size(name):
36 | if 'cifar10' in name:
37 | return (32, 32)
38 | if 'celebA' in name:
39 | return (128, 128)
40 | if 'flower' in name:
41 | return (128, 128)
42 |
43 | def get_dataset(name, folder, image_size, random_aug=False):
44 | print(folder)
45 | if name == 'cifar10_train':
46 | return datasets.CIFAR10(folder, train=True, transform=get_transform(image_size, random_aug=random_aug))
47 | if name == 'cifar10_test':
48 | return datasets.CIFAR10(folder, train=False, transform=get_transform(image_size, random_aug=random_aug))
49 | if name == 'CelebA_train':
50 | return datasets.CelebA(folder, split='train', transform=get_transform(image_size, random_aug=random_aug), download=True)
51 | if name == 'CelebA_test':
52 | return datasets.CelebA(folder, split='test', transform=get_transform(image_size, random_aug=random_aug))
53 | if name == 'flower_train':
54 | return datasets.Flowers102(folder, split='train', transform=get_transform(image_size, random_aug=random_aug, resize=True), download=True)
55 | if name == 'flower_test':
56 | return datasets.Flowers102(folder, split='test', transform=get_transform(image_size, random_aug=random_aug, resize=True), download=True)
57 |
58 |
--------------------------------------------------------------------------------
/decolor-diffusion/diffusion/get_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision import transforms, utils
3 | from torchvision import datasets
4 |
5 | def get_transform(image_size, random_aug=False, resize=False):
6 | if image_size[0] == 64:
7 | transform_list = [
8 | transforms.CenterCrop((128,128)),
9 | transforms.Resize(image_size),
10 | transforms.ToTensor(),
11 | transforms.Lambda(lambda t: (t * 2) - 1)
12 | ]
13 | elif not random_aug:
14 | transform_list = [
15 | transforms.CenterCrop(image_size),
16 | transforms.ToTensor(),
17 | transforms.Lambda(lambda t: (t * 2) - 1)
18 | ]
19 | if resize:
20 | transform_list = [transforms.Resize(image_size)] + transform_list
21 | T = transforms.Compose(transform_list)
22 | else:
23 | s = 1.0
24 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
25 | T = transforms.Compose([
26 | transforms.RandomResizedCrop(size=image_size),
27 | transforms.RandomHorizontalFlip(),
28 | transforms.RandomApply([color_jitter], p=0.8),
29 | transforms.ToTensor(),
30 | transforms.Lambda(lambda t: (t * 2) - 1)
31 | ])
32 |
33 | return T
34 |
35 | def get_image_size(name):
36 | if 'cifar10' in name:
37 | return (32, 32)
38 | if 'celebA' in name:
39 | return (128, 128)
40 | if 'flower' in name:
41 | return (128, 128)
42 |
43 | def get_dataset(name, folder, image_size, random_aug=False):
44 | print(folder)
45 | if name == 'cifar10_train':
46 | return datasets.CIFAR10(folder, train=True, transform=get_transform(image_size, random_aug=random_aug))
47 | if name == 'cifar10_test':
48 | return datasets.CIFAR10(folder, train=False, transform=get_transform(image_size, random_aug=random_aug))
49 | if name == 'CelebA_train':
50 | return datasets.CelebA(folder, split='train', transform=get_transform(image_size, random_aug=random_aug), download=True)
51 | if name == 'CelebA_test':
52 | return datasets.CelebA(folder, split='test', transform=get_transform(image_size, random_aug=random_aug))
53 | if name == 'flower_train':
54 | return datasets.Flowers102(folder, split='train', transform=get_transform(image_size, random_aug=random_aug, resize=True), download=True)
55 | if name == 'flower_test':
56 | return datasets.Flowers102(folder, split='test', transform=get_transform(image_size, random_aug=random_aug, resize=True), download=True)
57 |
58 |
--------------------------------------------------------------------------------
/demixing-diffusion-pytorch/AFHQ_128_to_celebA_128.py:
--------------------------------------------------------------------------------
1 | #from comet_ml import Experiment
2 | from demixing_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--time_steps', default=50, type=int)
26 | parser.add_argument('--train_steps', default=700000, type=int)
27 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
28 | parser.add_argument('--data_path_start', default='../deblurring-diffusion-pytorch/AFHQ/afhq/train/', type=str)
29 | parser.add_argument('--data_path_end', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str)
30 | parser.add_argument('--load_path', default=None, type=str)
31 | parser.add_argument('--train_routine', default='Final', type=str)
32 | parser.add_argument('--sampling_routine', default='default', type=str)
33 | parser.add_argument('--remove_time_embed', action="store_true")
34 | parser.add_argument('--residual', action="store_true")
35 | parser.add_argument('--loss_type', default='l1', type=str)
36 |
37 |
38 | args = parser.parse_args()
39 | print(args)
40 |
41 |
42 | model = Unet(
43 | dim = 64,
44 | dim_mults = (1, 2, 4, 8),
45 | channels=3,
46 | with_time_emb=not(args.remove_time_embed),
47 | residual=args.residual
48 | ).cuda()
49 |
50 | diffusion = GaussianDiffusion(
51 | model,
52 | image_size = 128,
53 | channels = 3,
54 | timesteps = args.time_steps, # number of steps
55 | loss_type = args.loss_type, # L1 or L2
56 | train_routine = args.train_routine,
57 | sampling_routine = args.sampling_routine
58 | ).cuda()
59 |
60 | import torch
61 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
62 |
63 |
64 | trainer = Trainer(
65 | diffusion,
66 | args.data_path_end,
67 | args.data_path_start,
68 | image_size = 128,
69 | train_batch_size = 32,
70 | train_lr = 2e-5,
71 | train_num_steps = args.train_steps, # total training steps
72 | gradient_accumulate_every = 2, # gradient accumulation steps
73 | ema_decay = 0.995, # exponential moving average decay
74 | fp16 = False, # turn on mixed precision training with apex
75 | results_folder = args.save_folder,
76 | load_path = args.load_path,
77 | dataset = 'train'
78 | )
79 |
80 | trainer.train()
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/mnist_train.py:
--------------------------------------------------------------------------------
1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 |
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 |
18 | def del_folder(path):
19 | try:
20 | shutil.rmtree(path)
21 | except OSError as exc:
22 | pass
23 |
24 |
25 | create = 0
26 |
27 | if create:
28 | trainset = torchvision.datasets.MNIST(
29 | root='./data', train=True, download=True)
30 | root = './root_mnist/'
31 | del_folder(root)
32 | create_folder(root)
33 |
34 | for i in range(10):
35 | lable_root = root + str(i) + '/'
36 | create_folder(lable_root)
37 |
38 | for idx in range(len(trainset)):
39 | img, label = trainset[idx]
40 | print(idx)
41 | img.save(root + str(label) + '/' + str(idx) + '.png')
42 |
43 |
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument('--time_steps', default=50, type=int)
46 | parser.add_argument('--train_steps', default=700000, type=int)
47 | parser.add_argument('--save_folder', default='./_step_50_gaussian', type=str)
48 | parser.add_argument('--kernel_std', default=0.1, type=float)
49 | parser.add_argument('--load_path', default=None, type=str)
50 | parser.add_argument('--data_path', default='./root_mnist/', type=str)
51 | parser.add_argument('--fade_routine', default='Random_Incremental', type=str)
52 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
53 | parser.add_argument('--discrete', action="store_true")
54 | parser.add_argument('--remove_time_embed', action="store_true")
55 | parser.add_argument('--residual', action="store_true")
56 |
57 |
58 | args = parser.parse_args()
59 | print(args)
60 |
61 | model = Unet(
62 | dim=64,
63 | dim_mults=(1, 2, 4, 8),
64 | channels=1,
65 | with_time_emb=not args.remove_time_embed,
66 | residual=args.residual
67 | ).cuda()
68 |
69 | diffusion = GaussianDiffusion(
70 | model,
71 | image_size=32,
72 | device_of_kernel='cuda',
73 | channels=1,
74 | timesteps=args.time_steps,
75 | loss_type='l1',
76 | fade_routine=args.fade_routine,
77 | kernel_std=args.kernel_std,
78 | sampling_routine=args.sampling_routine,
79 | discrete=args.discrete
80 | ).cuda()
81 |
82 | trainer = Trainer(
83 | diffusion,
84 | args.data_path,
85 | image_size=32,
86 | train_batch_size=32,
87 | train_lr=2e-5,
88 | train_num_steps=args.train_steps,
89 | gradient_accumulate_every=2,
90 | ema_decay=0.995,
91 | fp16=False,
92 | results_folder=args.save_folder,
93 | load_path=args.load_path,
94 | )
95 |
96 | trainer.train()
--------------------------------------------------------------------------------
/defading-generation-diffusion-pytorch/celebA_128.py:
--------------------------------------------------------------------------------
1 | #from comet_ml import Experiment
2 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--time_steps', default=50, type=int)
26 | parser.add_argument('--train_steps', default=700000, type=int)
27 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
28 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str)
29 | parser.add_argument('--load_path', default=None, type=str)
30 | parser.add_argument('--train_routine', default='Final', type=str)
31 | parser.add_argument('--sampling_routine', default='default', type=str)
32 | parser.add_argument('--remove_time_embed', action="store_true")
33 | parser.add_argument('--residual', action="store_true")
34 | parser.add_argument('--loss_type', default='l1', type=str)
35 | parser.add_argument('--initial_mask', default=11, type=int)
36 | parser.add_argument('--kernel_std', default=0.15, type=float)
37 | parser.add_argument('--reverse', action="store_true")
38 |
39 | args = parser.parse_args()
40 | print(args)
41 |
42 |
43 | model = Unet(
44 | dim = 64,
45 | dim_mults = (1, 2, 4, 8),
46 | channels=3,
47 | with_time_emb=not(args.remove_time_embed),
48 | residual=args.residual
49 | ).cuda()
50 |
51 | diffusion = GaussianDiffusion(
52 | model,
53 | image_size = 128,
54 | channels = 3,
55 | timesteps = args.time_steps, # number of steps
56 | loss_type = args.loss_type, # L1 or L2
57 | train_routine = args.train_routine,
58 | sampling_routine = args.sampling_routine,
59 | reverse = args.reverse,
60 | kernel_std = args.kernel_std,
61 | initial_mask=args.initial_mask
62 | ).cuda()
63 |
64 | import torch
65 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
66 |
67 |
68 | trainer = Trainer(
69 | diffusion,
70 | args.data_path,
71 | image_size = 128,
72 | train_batch_size = 32,
73 | train_lr = 2e-5,
74 | train_num_steps = args.train_steps, # total training steps
75 | gradient_accumulate_every = 2, # gradient accumulation steps
76 | ema_decay = 0.995, # exponential moving average decay
77 | fp16 = False, # turn on mixed precision training with apex
78 | results_folder = args.save_folder,
79 | load_path = args.load_path,
80 | dataset = 'train'
81 | )
82 |
83 | trainer.train()
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/cifar10_train_wd.py:
--------------------------------------------------------------------------------
1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 |
8 | def create_folder(path):
9 | try:
10 | os.mkdir(path)
11 | except OSError as exc:
12 | if exc.errno != errno.EEXIST:
13 | raise
14 | pass
15 |
16 | def del_folder(path):
17 | try:
18 | shutil.rmtree(path)
19 | except OSError as exc:
20 | pass
21 |
22 | create = 0
23 |
24 | if create:
25 | trainset = torchvision.datasets.CIFAR10(
26 | root='./data', train=True, download=True)
27 | root = './root_cifar10/'
28 | del_folder(root)
29 | create_folder(root)
30 |
31 | for i in range(10):
32 | lable_root = root + str(i) + '/'
33 | create_folder(lable_root)
34 |
35 | for idx in range(len(trainset)):
36 | img, label = trainset[idx]
37 | print(idx)
38 | img.save(root + str(label) + '/' + str(idx) + '.png')
39 |
40 |
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('--time_steps', default=50, type=int)
43 | parser.add_argument('--train_steps', default=700000, type=int)
44 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
45 | parser.add_argument('--load_path', default=None, type=str)
46 | parser.add_argument('--fade_routine', default='Incremental', type=str)
47 | parser.add_argument('--train_routine', default='Final', type=str)
48 | parser.add_argument('--sampling_routine', default='default', type=str)
49 | parser.add_argument('--remove_time_embed', action="store_true")
50 | parser.add_argument('--residual', action="store_true")
51 |
52 |
53 |
54 | args = parser.parse_args()
55 | print(args)
56 |
57 |
58 | model = Unet(
59 | dim = 64,
60 | dim_mults = (1, 2, 4, 8),
61 | channels=3,
62 | with_time_emb=not(args.remove_time_embed),
63 | residual=args.residual
64 | ).cuda()
65 |
66 |
67 | diffusion = GaussianDiffusion(
68 | model,
69 | image_size = 32,
70 | device_of_kernel = 'cuda',
71 | channels = 3,
72 | timesteps = args.time_steps, # number of steps
73 | loss_type = 'l1', # L1 or L2
74 | fade_routine=args.fade_routine,
75 | train_routine = args.train_routine,
76 | sampling_routine = args.sampling_routine
77 | ).cuda()
78 |
79 | trainer = Trainer(
80 | diffusion,
81 | './root_cifar10/',
82 | image_size = 32,
83 | train_batch_size = 32,
84 | train_lr = 2e-5,
85 | train_num_steps = args.train_steps, # total training steps
86 | gradient_accumulate_every = 2, # gradient accumulation steps
87 | ema_decay = 0.995, # exponential moving average decay
88 | fp16 = False, # turn on mixed precision training with apex
89 | results_folder = args.save_folder,
90 | load_path = args.load_path
91 | )
92 |
93 | trainer.train()
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/cifar10_train.py:
--------------------------------------------------------------------------------
1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 |
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 |
18 | def del_folder(path):
19 | try:
20 | shutil.rmtree(path)
21 | except OSError as exc:
22 | pass
23 |
24 |
25 | create = 0
26 |
27 | if create:
28 | trainset = torchvision.datasets.CIFAR10(
29 | root='./data', train=True, download=True)
30 | root = './root_cifar10/'
31 | del_folder(root)
32 | create_folder(root)
33 |
34 | for i in range(10):
35 | lable_root = root + str(i) + '/'
36 | create_folder(lable_root)
37 |
38 | for idx in range(len(trainset)):
39 | img, label = trainset[idx]
40 | print(idx)
41 | img.save(root + str(label) + '/' + str(idx) + '.png')
42 |
43 |
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument('--time_steps', default=50, type=int)
46 | parser.add_argument('--train_steps', default=700000, type=int)
47 | parser.add_argument('--save_folder', default=None, type=str)
48 | parser.add_argument('--kernel_std', default=0.1, type=float)
49 | parser.add_argument('--load_path', default=None, type=str)
50 | parser.add_argument('--data_path', default='./root_cifar10/', type=str)
51 | parser.add_argument('--fade_routine', default='Random_Incremental', type=str)
52 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
53 | parser.add_argument('--discrete', action="store_true")
54 | parser.add_argument('--remove_time_embed', action="store_true")
55 | parser.add_argument('--residual', action="store_true")
56 |
57 | args = parser.parse_args()
58 | print(args)
59 |
60 | model = Model(resolution=32,
61 | in_channels=3,
62 | out_ch=3,
63 | ch=128,
64 | ch_mult=(1, 2, 2, 2),
65 | num_res_blocks=2,
66 | attn_resolutions=(16,),
67 | dropout=0.1).cuda()
68 |
69 | diffusion = GaussianDiffusion(
70 | model,
71 | image_size=32,
72 | device_of_kernel='cuda',
73 | channels=3,
74 | timesteps=args.time_steps,
75 | loss_type='l1',
76 | kernel_std=args.kernel_std,
77 | fade_routine=args.fade_routine,
78 | sampling_routine=args.sampling_routine,
79 | discrete=args.discrete
80 | ).cuda()
81 |
82 | trainer = Trainer(
83 | diffusion,
84 | args.data_path,
85 | image_size=32,
86 | train_batch_size=32,
87 | train_lr=2e-5,
88 | train_num_steps=args.train_steps,
89 | gradient_accumulate_every=2,
90 | ema_decay=0.995,
91 | fp16=False,
92 | results_folder=args.save_folder,
93 | load_path=args.load_path,
94 | dataset='cifar10'
95 | )
96 |
97 | trainer.train()
98 |
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/mnist_train.py:
--------------------------------------------------------------------------------
1 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 |
8 | def create_folder(path):
9 | try:
10 | os.mkdir(path)
11 | except OSError as exc:
12 | if exc.errno != errno.EEXIST:
13 | raise
14 | pass
15 |
16 | def del_folder(path):
17 | try:
18 | shutil.rmtree(path)
19 | except OSError as exc:
20 | pass
21 |
22 | create = 0
23 |
24 | if create:
25 | trainset = torchvision.datasets.MNIST(
26 | root='./data', train=True, download=True)
27 | root = './root_mnist/'
28 | del_folder(root)
29 | create_folder(root)
30 |
31 | for i in range(10):
32 | lable_root = root + str(i) + '/'
33 | create_folder(lable_root)
34 |
35 | for idx in range(len(trainset)):
36 | img, label = trainset[idx]
37 | print(idx)
38 | img.save(root + str(label) + '/' + str(idx) + '.png')
39 |
40 |
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('--time_steps', default=50, type=int)
43 | parser.add_argument('--train_steps', default=700000, type=int)
44 | parser.add_argument('--save_folder', default='./results_mnist', type=str)
45 | parser.add_argument('--load_path', default=None, type=str)
46 | parser.add_argument('--data_path', default='./root_mnist/', type=str)
47 | parser.add_argument('--resolution_routine', default='Incremental', type=str)
48 | parser.add_argument('--train_routine', default='Final', type=str)
49 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
50 | parser.add_argument('--remove_time_embed', action="store_true")
51 | parser.add_argument('--residual', action="store_true")
52 |
53 | args = parser.parse_args()
54 | print(args)
55 |
56 |
57 | model = Unet(
58 | dim = 64,
59 | dim_mults = (1, 2, 4, 8),
60 | channels=1,
61 | with_time_emb=not(args.remove_time_embed),
62 | residual=args.residual
63 | ).cuda()
64 |
65 |
66 | diffusion = GaussianDiffusion(
67 | model,
68 | image_size = 32,
69 | device_of_kernel = 'cuda',
70 | channels = 1,
71 | timesteps = args.time_steps, # number of steps
72 | loss_type = 'l1', # L1 or L2
73 | resolution_routine=args.resolution_routine,
74 | train_routine = args.train_routine
75 | ).cuda()
76 |
77 | trainer = Trainer(
78 | diffusion,
79 | args.data_path,
80 | image_size = 32,
81 | train_batch_size = 32,
82 | train_lr = 2e-5,
83 | train_num_steps = args.train_steps, # total training steps
84 | gradient_accumulate_every = 2, # gradient accumulation steps
85 | ema_decay = 0.995, # exponential moving average decay
86 | fp16 = False, # turn on mixed precision training with apex
87 | results_folder = args.save_folder,
88 | load_path = args.load_path
89 | )
90 |
91 | trainer.train()
92 |
--------------------------------------------------------------------------------
/denoising-diffusion-pytorch/AFHQ_noise_128.py:
--------------------------------------------------------------------------------
1 | #from comet_ml import Experiment
2 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--time_steps', default=50, type=int,
26 | help="The number of steps the scheduler takes to go from clean image to an isotropic gaussian. This is also the number of steps of diffusion.")
27 | parser.add_argument('--train_steps', default=700000, type=int,
28 | help='The number of iterations for training.')
29 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
30 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/AFHQ/afhq/train/', type=str)
31 | parser.add_argument('--load_path', default=None, type=str)
32 | parser.add_argument('--train_routine', default='Final', type=str)
33 | parser.add_argument('--sampling_routine', default='default', type=str,
34 | help='The choice of sampling routine for reversing the diffusion process.')
35 | parser.add_argument('--remove_time_embed', action="store_true")
36 | parser.add_argument('--residual', action="store_true")
37 | parser.add_argument('--loss_type', default='l1', type=str)
38 |
39 |
40 | args = parser.parse_args()
41 | print(args)
42 |
43 |
44 | model = Unet(
45 | dim = 64,
46 | dim_mults = (1, 2, 4, 8),
47 | channels=3,
48 | with_time_emb=not(args.remove_time_embed),
49 | residual=args.residual
50 | ).cuda()
51 |
52 | diffusion = GaussianDiffusion(
53 | model,
54 | image_size = 128,
55 | channels = 3,
56 | timesteps = args.time_steps, # number of steps
57 | loss_type = args.loss_type, # L1 or L2
58 | train_routine = args.train_routine,
59 | sampling_routine = args.sampling_routine
60 | ).cuda()
61 |
62 | import torch
63 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
64 |
65 |
66 | trainer = Trainer(
67 | diffusion,
68 | '../deblurring-diffusion-pytorch/AFHQ/afhq/train/',
69 | image_size = 128,
70 | train_batch_size = 32,
71 | train_lr = 2e-5,
72 | train_num_steps = args.train_steps, # total training steps
73 | gradient_accumulate_every = 2, # gradient accumulation steps
74 | ema_decay = 0.995, # exponential moving average decay
75 | fp16 = False, # turn on mixed precision training with apex
76 | results_folder = args.save_folder,
77 | load_path = args.load_path,
78 | dataset = 'train'
79 | )
80 |
81 | trainer.train()
--------------------------------------------------------------------------------
/denoising-diffusion-pytorch/celebA_noise_128.py:
--------------------------------------------------------------------------------
1 | #from comet_ml import Experiment
2 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--time_steps', default=50, type=int,
26 | help="The number of steps the scheduler takes to go from clean image to an isotropic gaussian. This is also the number of steps of diffusion.")
27 | parser.add_argument('--train_steps', default=700000, type=int,
28 | help='The number of iterations for training.')
29 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
30 | parser.add_argument('--load_path', default=None, type=str)
31 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str)
32 | parser.add_argument('--train_routine', default='Final', type=str)
33 | parser.add_argument('--sampling_routine', default='default', type=str,
34 | help='The choice of sampling routine for reversing the diffusion process.')
35 | parser.add_argument('--remove_time_embed', action="store_true")
36 | parser.add_argument('--residual', action="store_true")
37 | parser.add_argument('--loss_type', default='l1', type=str)
38 |
39 |
40 | args = parser.parse_args()
41 | print(args)
42 |
43 |
44 | model = Unet(
45 | dim = 64,
46 | dim_mults = (1, 2, 4, 8),
47 | channels=3,
48 | with_time_emb=not(args.remove_time_embed),
49 | residual=args.residual
50 | ).cuda()
51 |
52 | diffusion = GaussianDiffusion(
53 | model,
54 | image_size = 128,
55 | channels = 3,
56 | timesteps = args.time_steps, # number of steps
57 | loss_type = args.loss_type, # L1 or L2
58 | train_routine = args.train_routine,
59 | sampling_routine = args.sampling_routine
60 | ).cuda()
61 |
62 | import torch
63 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
64 |
65 |
66 | trainer = Trainer(
67 | diffusion,
68 | '../deblurring-diffusion-pytorch/root_celebA_128_train_new/',
69 | image_size = 128,
70 | train_batch_size = 32,
71 | train_lr = 2e-5,
72 | train_num_steps = args.train_steps, # total training steps
73 | gradient_accumulate_every = 2, # gradient accumulation steps
74 | ema_decay = 0.995, # exponential moving average decay
75 | fp16 = False, # turn on mixed precision training with apex
76 | results_folder = args.save_folder,
77 | load_path = args.load_path,
78 | dataset = 'train'
79 | )
80 |
81 | trainer.train()
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/celebA_train.py:
--------------------------------------------------------------------------------
1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
2 | import torch
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 |
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 |
18 | def del_folder(path):
19 | try:
20 | shutil.rmtree(path)
21 | except OSError as exc:
22 | if exc.errno != errno.EEXIST:
23 | raise
24 | pass
25 |
26 |
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--time_steps', default=100, type=int)
29 | parser.add_argument('--train_steps', default=700000, type=int)
30 | parser.add_argument('--save_folder', default='test_speed_celebA', type=str)
31 | parser.add_argument('--kernel_std', default=0.1, type=float)
32 | parser.add_argument('--initial_mask', default=11, type=int)
33 | parser.add_argument('--load_path', default=None, type=str)
34 | parser.add_argument('--data_path', default='/cmlscratch/bansal01/spring_2022/Cold-Diffusion/deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str)
35 | parser.add_argument('--fade_routine', default="Incremental", type=str)
36 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
37 | parser.add_argument('--remove_time_embed', action="store_true")
38 | parser.add_argument('--residual', action="store_true")
39 | parser.add_argument('--discrete', action="store_true")
40 | parser.add_argument('--image_size', default=64, type=int)
41 | parser.add_argument('--dataset', default=None, type=str)
42 | args = parser.parse_args()
43 | print(args)
44 |
45 | model = Unet(
46 | dim=64,
47 | dim_mults=(1, 2, 4, 8),
48 | channels=3,
49 | with_time_emb=not args.remove_time_embed,
50 | residual=args.residual
51 | ).cuda()
52 |
53 | # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
54 |
55 | diffusion = GaussianDiffusion(
56 | model,
57 | image_size=args.image_size,
58 | device_of_kernel='cuda',
59 | channels=3,
60 | timesteps=args.time_steps,
61 | loss_type='l1',
62 | fade_routine=args.fade_routine,
63 | kernel_std=args.kernel_std,
64 | initial_mask=args.initial_mask,
65 | sampling_routine=args.sampling_routine,
66 | discrete=args.discrete
67 | ).cuda()
68 |
69 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
70 |
71 | trainer = Trainer(
72 | diffusion,
73 | # '/cmlscratch/bansal01/CelebA_train_new/train/',
74 | # './root_celebA_128_train/',
75 | args.data_path,
76 | image_size=args.image_size,
77 | train_batch_size=100,
78 | train_lr=2e-5,
79 | train_num_steps=args.train_steps,
80 | gradient_accumulate_every=2,
81 | ema_decay=0.995,
82 | fp16=False,
83 | results_folder=args.save_folder,
84 | load_path=args.load_path,
85 | dataset=args.dataset
86 | )
87 | trainer.train()
88 |
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/cifar10_train.py:
--------------------------------------------------------------------------------
1 | #from comet_ml import Experiment
2 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 | create = 0
24 |
25 | if create:
26 | trainset = torchvision.datasets.CIFAR10(
27 | root='./data', train=True, download=True)
28 | root = './root_cifar10/'
29 | del_folder(root)
30 | create_folder(root)
31 |
32 | for i in range(10):
33 | lable_root = root + str(i) + '/'
34 | create_folder(lable_root)
35 |
36 | for idx in range(len(trainset)):
37 | img, label = trainset[idx]
38 | print(idx)
39 | img.save(root + str(label) + '/' + str(idx) + '.png')
40 |
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('--time_steps', default=30, type=int)
43 | parser.add_argument('--train_steps', default=700000, type=int)
44 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
45 | parser.add_argument('--load_path', default=None, type=str)
46 | parser.add_argument('--data_path', default='./root_cifar10/', type=str)
47 | parser.add_argument('--resolution_routine', default='Incremental', type=str)
48 | parser.add_argument('--train_routine', default='Final', type=str)
49 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
50 | parser.add_argument('--remove_time_embed', action="store_true")
51 | parser.add_argument('--residual', action="store_true")
52 |
53 | args = parser.parse_args()
54 | print(args)
55 |
56 |
57 | model = Model(resolution=32,
58 | in_channels=3,
59 | out_ch=3,
60 | ch=128,
61 | ch_mult=(1,2,2,2),
62 | num_res_blocks=2,
63 | attn_resolutions=(16,),
64 | dropout=0.1).cuda()
65 |
66 |
67 | diffusion = GaussianDiffusion(
68 | model,
69 | image_size = 32,
70 | device_of_kernel = 'cuda',
71 | channels = 3,
72 | timesteps = args.time_steps, # number of steps
73 | loss_type = 'l1', # L1 or L2
74 | resolution_routine=args.resolution_routine,
75 | train_routine = args.train_routine,
76 | sampling_routine = args.sampling_routine
77 | ).cuda()
78 |
79 | trainer = Trainer(
80 | diffusion,
81 | args.data_path,
82 | image_size = 32,
83 | train_batch_size = 32,
84 | train_lr = 2e-5,
85 | train_num_steps = args.train_steps, # total training steps
86 | gradient_accumulate_every = 2, # gradient accumulation steps
87 | ema_decay = 0.995, # exponential moving average decay
88 | fp16 = False, # turn on mixed precision training with apex
89 | results_folder = args.save_folder,
90 | load_path = args.load_path,
91 | dataset = 'cifar10'
92 | )
93 |
94 | trainer.train()
95 |
--------------------------------------------------------------------------------
/create_data.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import os
3 | import errno
4 | import shutil
5 | from pathlib import Path
6 | from PIL import Image
7 |
8 | def create_folder(path):
9 | try:
10 | os.mkdir(path)
11 | except OSError as exc:
12 | if exc.errno != errno.EEXIST:
13 | raise
14 | pass
15 |
16 | def del_folder(path):
17 | try:
18 | shutil.rmtree(path)
19 | except OSError as exc:
20 | pass
21 |
22 |
23 | CelebA_folder = '/fs/cml-datasets/CelebA-HQ/images-128/' # change this to folder which has CelebA data
24 |
25 | ############################################# MNIST ###############################################
26 | trainset = torchvision.datasets.MNIST(
27 | root='./data', train=True, download=True)
28 | root = './root_mnist/'
29 | del_folder(root)
30 | create_folder(root)
31 |
32 | for i in range(10):
33 | lable_root = root + str(i) + '/'
34 | create_folder(lable_root)
35 |
36 | for idx in range(len(trainset)):
37 | img, label = trainset[idx]
38 | print(idx)
39 | img.save(root + str(label) + '/' + str(idx) + '.png')
40 |
41 |
42 | trainset = torchvision.datasets.MNIST(
43 | root='./data', train=False, download=True)
44 | root = './root_mnist_test/'
45 | del_folder(root)
46 | create_folder(root)
47 |
48 | for i in range(10):
49 | lable_root = root + str(i) + '/'
50 | create_folder(lable_root)
51 |
52 | for idx in range(len(trainset)):
53 | img, label = trainset[idx]
54 | print(idx)
55 | img.save(root + str(label) + '/' + str(idx) + '.png')
56 |
57 |
58 | ############################################# Cifar10 ###############################################
59 | trainset = torchvision.datasets.CIFAR10(
60 | root='./data', train=True, download=True)
61 | root = './root_cifar10/'
62 | del_folder(root)
63 | create_folder(root)
64 |
65 | for i in range(10):
66 | lable_root = root + str(i) + '/'
67 | create_folder(lable_root)
68 |
69 | for idx in range(len(trainset)):
70 | img, label = trainset[idx]
71 | print(idx)
72 | img.save(root + str(label) + '/' + str(idx) + '.png')
73 |
74 |
75 | trainset = torchvision.datasets.CIFAR10(
76 | root='./data', train=False, download=True)
77 | root = './root_cifar10_test/'
78 | del_folder(root)
79 | create_folder(root)
80 |
81 | for i in range(10):
82 | lable_root = root + str(i) + '/'
83 | create_folder(lable_root)
84 |
85 | for idx in range(len(trainset)):
86 | img, label = trainset[idx]
87 | print(idx)
88 | img.save(root + str(label) + '/' + str(idx) + '.png')
89 |
90 |
91 | ############################################# CelebA ###############################################
92 | root_train = './root_celebA_128_train_new/'
93 | root_test = './root_celebA_128_test_new/'
94 | del_folder(root_train)
95 | create_folder(root_train)
96 |
97 | del_folder(root_test)
98 | create_folder(root_test)
99 |
100 | exts = ['jpg', 'jpeg', 'png']
101 | folder = CelebA_folder
102 | paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
103 |
104 | for idx in range(len(paths)):
105 | img = Image.open(paths[idx])
106 | print(idx)
107 | if idx < 0.9*len(paths):
108 | img.save(root_train + str(idx) + '.png')
109 | else:
110 | img.save(root_test + str(idx) + '.png')
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/AFHQ_128.py:
--------------------------------------------------------------------------------
1 | from comet_ml import Experiment
2 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 |
24 |
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--time_steps', default=50, type=int,
28 | help="This is the number of steps in which a clean image looses information.")
29 | parser.add_argument('--train_steps', default=700000, type=int,
30 | help='The number of iterations for training.')
31 | parser.add_argument('--blur_std', default=0.1, type=float,
32 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.')
33 | parser.add_argument('--blur_size', default=3, type=int,
34 | help='It sets the size of gaussian blur used in blur routines for each step t')
35 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
36 | parser.add_argument('--data_path', default='./AFHQ/afhq/train/', type=str)
37 | parser.add_argument('--load_path', default=None, type=str)
38 | parser.add_argument('--blur_routine', default='Incremental', type=str,
39 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail')
40 | parser.add_argument('--train_routine', default='Final', type=str)
41 | parser.add_argument('--sampling_routine', default='default', type=str,
42 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2')
43 | parser.add_argument('--remove_time_embed', action="store_true")
44 | parser.add_argument('--residual', action="store_true")
45 | parser.add_argument('--loss_type', default='l1', type=str)
46 | parser.add_argument('--discrete', action="store_true")
47 |
48 |
49 | args = parser.parse_args()
50 | print(args)
51 |
52 |
53 | model = Unet(
54 | dim = 64,
55 | dim_mults = (1, 2, 4, 8),
56 | channels=3,
57 | with_time_emb=not(args.remove_time_embed),
58 | residual=args.residual
59 | ).cuda()
60 |
61 | diffusion = GaussianDiffusion(
62 | model,
63 | image_size = 128,
64 | device_of_kernel = 'cuda',
65 | channels = 3,
66 | timesteps = args.time_steps, # number of steps
67 | loss_type = args.loss_type, # L1 or L2
68 | kernel_std=args.blur_std,
69 | kernel_size=args.blur_size,
70 | blur_routine=args.blur_routine,
71 | train_routine = args.train_routine,
72 | sampling_routine = args.sampling_routine,
73 | discrete=args.discrete
74 | ).cuda()
75 |
76 | import torch
77 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
78 |
79 | trainer = Trainer(
80 | diffusion,
81 | args.data_path,
82 | image_size = 128,
83 | train_batch_size = 32,
84 | train_lr = 2e-5,
85 | train_num_steps = args.train_steps, # total training steps
86 | gradient_accumulate_every = 2, # gradient accumulation steps
87 | ema_decay = 0.995, # exponential moving average decay
88 | fp16 = False, # turn on mixed precision training with apex
89 | results_folder = args.save_folder,
90 | load_path = args.load_path,
91 | dataset = 'AFHQ'
92 | )
93 |
94 | trainer.train()
--------------------------------------------------------------------------------
/denoising-diffusion-pytorch/AFHQ_noise_128_test.py:
--------------------------------------------------------------------------------
1 | #from comet_ml import Experiment
2 | from demixing_noise_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--time_steps', default=50, type=int,
26 | help="The number of steps the scheduler takes to go from clean image to an isotropic gaussian. This is also the number of steps of diffusion.")
27 | parser.add_argument('--train_steps', default=700000, type=int,
28 | help='The number of iterations for training.')
29 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
30 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/AFHQ/afhq/train/', type=str)
31 | parser.add_argument('--load_path', default=None, type=str)
32 | parser.add_argument('--train_routine', default='Final', type=str)
33 | parser.add_argument('--sampling_routine', default='default', type=str,
34 | help='The choice of sampling routine for reversing the diffusion process.')
35 | parser.add_argument('--remove_time_embed', action="store_true")
36 | parser.add_argument('--residual', action="store_true")
37 | parser.add_argument('--loss_type', default='l1', type=str)
38 | parser.add_argument('--test_type', default='train_data', type=str)
39 |
40 |
41 | args = parser.parse_args()
42 | print(args)
43 |
44 | img_path=None
45 | if 'train' in args.test_type:
46 | img_path = args.data_path
47 | elif 'test' in args.test_type:
48 | img_path = args.data_path
49 |
50 | model = Unet(
51 | dim = 64,
52 | dim_mults = (1, 2, 4, 8),
53 | channels=3,
54 | with_time_emb=not(args.remove_time_embed),
55 | residual=args.residual
56 | ).cuda()
57 |
58 | diffusion = GaussianDiffusion(
59 | model,
60 | image_size = 128,
61 | channels = 3,
62 | timesteps = args.time_steps, # number of steps
63 | loss_type = args.loss_type, # L1 or L2
64 | train_routine = args.train_routine,
65 | sampling_routine = args.sampling_routine
66 | ).cuda()
67 |
68 | import torch
69 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
70 |
71 |
72 | trainer = Trainer(
73 | diffusion,
74 | img_path,
75 | image_size = 128,
76 | train_batch_size = 32,
77 | train_lr = 2e-5,
78 | train_num_steps = args.train_steps, # total training steps
79 | gradient_accumulate_every = 2, # gradient accumulation steps
80 | ema_decay = 0.995, # exponential moving average decay
81 | fp16 = False, # turn on mixed precision training with apex
82 | results_folder = args.save_folder,
83 | load_path = args.load_path,
84 | dataset = 'train'
85 | )
86 |
87 | if args.test_type == 'train_data':
88 | trainer.test_from_data('train', s_times=args.sample_steps)
89 |
90 | elif args.test_type == 'test_data':
91 | trainer.test_from_data('test', s_times=args.sample_steps)
92 |
93 | #### for FID and noise ablation ##
94 | elif args.test_type == 'test_sample_and_save_for_fid':
95 | trainer.sample_and_save_for_fid()
96 |
97 | ########## for paper ##########
98 |
99 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page':
100 | trainer.paper_showing_diffusion_images_cover_page()
--------------------------------------------------------------------------------
/denoising-diffusion-pytorch/celebA_noise_128_test.py:
--------------------------------------------------------------------------------
1 | #from comet_ml import Experiment
2 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--time_steps', default=50, type=int,
26 | help="The number of steps the scheduler takes to go from clean image to an isotropic gaussian. This is also the number of steps of diffusion.")
27 | parser.add_argument('--train_steps', default=700000, type=int,
28 | help='The number of iterations for training.')
29 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
30 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str)
31 | parser.add_argument('--load_path', default=None, type=str)
32 | parser.add_argument('--train_routine', default='Final', type=str)
33 | parser.add_argument('--sampling_routine', default='default', type=str,
34 | help='The choice of sampling routine for reversing the diffusion process.')
35 | parser.add_argument('--remove_time_embed', action="store_true")
36 | parser.add_argument('--residual', action="store_true")
37 | parser.add_argument('--loss_type', default='l1', type=str)
38 | parser.add_argument('--test_type', default='train_data', type=str)
39 |
40 |
41 | args = parser.parse_args()
42 | print(args)
43 |
44 | img_path=None
45 | if 'train' in args.test_type:
46 | img_path = args.data_path
47 | elif 'test' in args.test_type:
48 | img_path = args.data_path
49 |
50 | model = Unet(
51 | dim = 64,
52 | dim_mults = (1, 2, 4, 8),
53 | channels=3,
54 | with_time_emb=not(args.remove_time_embed),
55 | residual=args.residual
56 | ).cuda()
57 |
58 | diffusion = GaussianDiffusion(
59 | model,
60 | image_size = 128,
61 | channels = 3,
62 | timesteps = args.time_steps, # number of steps
63 | loss_type = args.loss_type, # L1 or L2
64 | train_routine = args.train_routine,
65 | sampling_routine = args.sampling_routine
66 | ).cuda()
67 |
68 | import torch
69 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
70 |
71 |
72 | trainer = Trainer(
73 | diffusion,
74 | img_path,
75 | image_size = 128,
76 | train_batch_size = 32,
77 | train_lr = 2e-5,
78 | train_num_steps = args.train_steps, # total training steps
79 | gradient_accumulate_every = 2, # gradient accumulation steps
80 | ema_decay = 0.995, # exponential moving average decay
81 | fp16 = False, # turn on mixed precision training with apex
82 | results_folder = args.save_folder,
83 | load_path = args.load_path,
84 | dataset = 'train'
85 | )
86 |
87 | if args.test_type == 'train_data':
88 | trainer.test_from_data('train', s_times=args.sample_steps)
89 |
90 | elif args.test_type == 'test_data':
91 | trainer.test_from_data('test', s_times=args.sample_steps)
92 |
93 | #### for FID and noise ablation ##
94 | elif args.test_type == 'test_sample_and_save_for_fid':
95 | trainer.sample_and_save_for_fid()
96 |
97 | ########## for paper ##########
98 |
99 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page':
100 | trainer.paper_showing_diffusion_images_cover_page()
--------------------------------------------------------------------------------
/defading-generation-diffusion-pytorch/celebA_128_test.py:
--------------------------------------------------------------------------------
1 | #from comet_ml import Experiment
2 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--time_steps', default=50, type=int)
26 | parser.add_argument('--train_steps', default=700000, type=int)
27 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
28 | parser.add_argument('--load_path', default=None, type=str)
29 | parser.add_argument('--data_path', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str)
30 | parser.add_argument('--train_routine', default='Final', type=str)
31 | parser.add_argument('--sampling_routine', default='default', type=str)
32 | parser.add_argument('--remove_time_embed', action="store_true")
33 | parser.add_argument('--residual', action="store_true")
34 | parser.add_argument('--loss_type', default='l1', type=str)
35 | parser.add_argument('--initial_mask', default=11, type=int)
36 | parser.add_argument('--kernel_std', default=0.15, type=float)
37 | parser.add_argument('--reverse', action="store_true")
38 | parser.add_argument('--test_type', default='train_data', type=str)
39 | parser.add_argument('--noise', default=0, type=float)
40 |
41 | args = parser.parse_args()
42 | print(args)
43 |
44 | img_path=None
45 | if 'train' in args.test_type:
46 | img_path = args.data_path
47 | elif 'test' in args.test_type:
48 | img_path = args.data_path
49 |
50 | model = Unet(
51 | dim = 64,
52 | dim_mults = (1, 2, 4, 8),
53 | channels=3,
54 | with_time_emb=not(args.remove_time_embed),
55 | residual=args.residual
56 | ).cuda()
57 |
58 | diffusion = GaussianDiffusion(
59 | model,
60 | image_size = 128,
61 | channels = 3,
62 | timesteps = args.time_steps, # number of steps
63 | loss_type = args.loss_type, # L1 or L2
64 | train_routine = args.train_routine,
65 | sampling_routine = args.sampling_routine,
66 | reverse = args.reverse,
67 | kernel_std = args.kernel_std,
68 | initial_mask=args.initial_mask
69 | ).cuda()
70 |
71 | import torch
72 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
73 |
74 |
75 | trainer = Trainer(
76 | diffusion,
77 | img_path,
78 | image_size = 128,
79 | train_batch_size = 32,
80 | train_lr = 2e-5,
81 | train_num_steps = args.train_steps, # total training steps
82 | gradient_accumulate_every = 2, # gradient accumulation steps
83 | ema_decay = 0.995, # exponential moving average decay
84 | fp16 = False, # turn on mixed precision training with apex
85 | results_folder = args.save_folder,
86 | load_path = args.load_path,
87 | dataset = 'train'
88 | )
89 |
90 | if args.test_type == 'train_data':
91 | trainer.test_from_data('train', s_times=None)
92 |
93 | elif args.test_type == 'test_data':
94 | trainer.test_from_data('test', s_times=None)
95 |
96 | #### for FID and noise ablation ##
97 | elif args.test_type == 'test_sample_and_save_for_fid':
98 | trainer.sample_and_save_for_fid(args.noise)
99 |
100 | ########## for paper ##########
101 |
102 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page':
103 | trainer.paper_showing_diffusion_images_cover_page()
--------------------------------------------------------------------------------
/demixing-diffusion-pytorch/AFHQ_128_to_celebA_128_test.py:
--------------------------------------------------------------------------------
1 | #from comet_ml import Experiment
2 | from demixing_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--time_steps', default=50, type=int)
26 | parser.add_argument('--train_steps', default=700000, type=int)
27 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
28 | parser.add_argument('--data_path_start', default='../deblurring-diffusion-pytorch/AFHQ/afhq/train/', type=str)
29 | parser.add_argument('--data_path_end', default='../deblurring-diffusion-pytorch/root_celebA_128_train_new/', type=str)
30 | parser.add_argument('--load_path', default=None, type=str)
31 | parser.add_argument('--train_routine', default='Final', type=str)
32 | parser.add_argument('--sampling_routine', default='default', type=str)
33 | parser.add_argument('--remove_time_embed', action="store_true")
34 | parser.add_argument('--residual', action="store_true")
35 | parser.add_argument('--loss_type', default='l1', type=str)
36 | parser.add_argument('--test_type', default='train_data', type=str)
37 | parser.add_argument('--noise', default=0, type=float)
38 |
39 |
40 | args = parser.parse_args()
41 | print(args)
42 |
43 | img_path=None
44 | if 'train' in args.test_type:
45 | img_path = args.data_path_start
46 | elif 'test' in args.test_type:
47 | img_path = args.data_path_start
48 |
49 |
50 | model = Unet(
51 | dim = 64,
52 | dim_mults = (1, 2, 4, 8),
53 | channels=3,
54 | with_time_emb=not(args.remove_time_embed),
55 | residual=args.residual
56 | ).cuda()
57 |
58 | diffusion = GaussianDiffusion(
59 | model,
60 | image_size = 128,
61 | channels = 3,
62 | timesteps = args.time_steps, # number of steps
63 | loss_type = args.loss_type, # L1 or L2
64 | train_routine = args.train_routine,
65 | sampling_routine = args.sampling_routine
66 | ).cuda()
67 |
68 | import torch
69 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
70 |
71 |
72 | trainer = Trainer(
73 | diffusion,
74 | args.data_path_end,
75 | img_path,
76 | image_size = 128,
77 | train_batch_size = 32,
78 | train_lr = 2e-5,
79 | train_num_steps = args.train_steps, # total training steps
80 | gradient_accumulate_every = 2, # gradient accumulation steps
81 | ema_decay = 0.995, # exponential moving average decay
82 | fp16 = False, # turn on mixed precision training with apex
83 | results_folder = args.save_folder,
84 | load_path = args.load_path,
85 | dataset = 'train'
86 | )
87 |
88 | if args.test_type == 'train_data':
89 | trainer.test_from_data('train', s_times=None)
90 |
91 | elif args.test_type == 'test_data':
92 | trainer.test_from_data('test', s_times=None)
93 |
94 |
95 | #### for FID and noise ablation ##
96 | elif args.test_type == 'test_sample_and_save_for_fid':
97 | trainer.sample_and_save_for_fid(args.noise)
98 |
99 | ########## for paper ##########
100 |
101 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page':
102 | trainer.paper_showing_diffusion_images_cover_page()
103 |
104 | elif args.test_type == 'test_paper_showing_diffusion_images_cover_page':
105 | trainer.paper_showing_diffusion_images_cover_page()
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/mnist_train.py:
--------------------------------------------------------------------------------
1 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 |
8 | def create_folder(path):
9 | try:
10 | os.mkdir(path)
11 | except OSError as exc:
12 | if exc.errno != errno.EEXIST:
13 | raise
14 | pass
15 |
16 | def del_folder(path):
17 | try:
18 | shutil.rmtree(path)
19 | except OSError as exc:
20 | pass
21 |
22 | create = 0
23 |
24 | if create:
25 | trainset = torchvision.datasets.MNIST(
26 | root='./data', train=True, download=True)
27 | root = './root_mnist/'
28 | del_folder(root)
29 | create_folder(root)
30 |
31 | for i in range(10):
32 | lable_root = root + str(i) + '/'
33 | create_folder(lable_root)
34 |
35 | for idx in range(len(trainset)):
36 | img, label = trainset[idx]
37 | print(idx)
38 | img.save(root + str(label) + '/' + str(idx) + '.png')
39 |
40 |
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('--time_steps', default=50, type=int,
43 | help="This is the number of steps in which a clean image looses information.")
44 | parser.add_argument('--train_steps', default=700000, type=int,
45 | help='The number of iterations for training.')
46 | parser.add_argument('--blur_std', default=0.1, type=float,
47 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.')
48 | parser.add_argument('--blur_size', default=3, type=int,
49 | help='It sets the size of gaussian blur used in blur routines for each step t')
50 | parser.add_argument('--save_folder', default='./results_mnist', type=str)
51 | parser.add_argument('--data_path', default='./root_mnist/', type=str)
52 | parser.add_argument('--load_path', default=None, type=str)
53 | parser.add_argument('--blur_routine', default='Incremental', type=str,
54 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail')
55 | parser.add_argument('--train_routine', default='Final', type=str)
56 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str,
57 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2')
58 | parser.add_argument('--discrete', action="store_true")
59 |
60 | args = parser.parse_args()
61 | print(args)
62 |
63 |
64 | model = Unet(
65 | dim = 64,
66 | dim_mults = (1, 2, 4, 8),
67 | channels=1
68 | ).cuda()
69 |
70 |
71 | diffusion = GaussianDiffusion(
72 | model,
73 | image_size = 32,
74 | device_of_kernel = 'cuda',
75 | channels = 1,
76 | timesteps = args.time_steps, # number of steps
77 | loss_type = 'l1', # L1 or L2
78 | kernel_std=args.blur_std,
79 | kernel_size=args.blur_size,
80 | blur_routine=args.blur_routine,
81 | train_routine = args.train_routine,
82 | sampling_routine = args.sampling_routine,
83 | discrete=args.discrete
84 | ).cuda()
85 |
86 | import torch
87 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
88 |
89 | trainer = Trainer(
90 | diffusion,
91 | args.data_path,
92 | image_size = 32,
93 | train_batch_size = 32,
94 | train_lr = 2e-5,
95 | train_num_steps = args.train_steps, # total training steps
96 | gradient_accumulate_every = 2, # gradient accumulation steps
97 | ema_decay = 0.995, # exponential moving average decay
98 | fp16 = False, # turn on mixed precision training with apex
99 | results_folder = args.save_folder,
100 | load_path = args.load_path,
101 | dataset = 'mnist'
102 | )
103 |
104 | trainer.train()
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/celebA_128_test.py:
--------------------------------------------------------------------------------
1 | from pycave.bayes import GaussianMixture
2 | import torchvision
3 | import argparse
4 | from Fid import calculate_fid_given_samples
5 | import torch
6 | import random
7 |
8 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
9 |
10 | seed_value=123457
11 | torch.manual_seed(seed_value) # cpu vars
12 | random.seed(seed_value) # Python
13 | torch.cuda.manual_seed(seed_value)
14 | torch.cuda.manual_seed_all(seed_value)
15 | torch.backends.cudnn.deterministic = True
16 | torch.backends.cudnn.benchmark = False
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--time_steps', default=50, type=int)
20 | parser.add_argument('--sample_steps', default=None, type=int)
21 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
22 | parser.add_argument('--load_path', default=None, type=str)
23 | parser.add_argument('--data_path', default='./root_celebA_128_train_new/', type=str)
24 | parser.add_argument('--resolution_routine', default='Incremental', type=str)
25 | parser.add_argument('--test_type', default='train_data', type=str)
26 | parser.add_argument('--train_routine', default='Final', type=str)
27 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
28 | parser.add_argument('--remove_time_embed', action="store_true")
29 | parser.add_argument('--residual', action="store_true")
30 | parser.add_argument('--gmm_size', default=8, type=int)
31 | parser.add_argument('--gmm_cluster', default=10, type=int)
32 | parser.add_argument('--gmm_sample_at', default=1, type=int)
33 | parser.add_argument('--bs', default=32, type=int)
34 | parser.add_argument('--discrete', action="store_true")
35 | parser.add_argument('--noise', default=0, type=float)
36 |
37 | args = parser.parse_args()
38 | print(args)
39 |
40 |
41 | img_path=None
42 | if 'train' in args.test_type:
43 | img_path = args.data_path
44 | elif 'test' in args.test_type:
45 | img_path = args.data_path
46 | print("Img Path is ", img_path)
47 |
48 |
49 | model = Unet(
50 | dim = 64,
51 | dim_mults = (1, 2, 4, 8),
52 | channels=3,
53 | with_time_emb=not(args.remove_time_embed),
54 | residual=args.residual
55 | ).cuda()
56 |
57 |
58 | diffusion = GaussianDiffusion(
59 | model,
60 | image_size = 128,
61 | device_of_kernel = 'cuda',
62 | channels = 3,
63 | timesteps = args.time_steps, # number of steps
64 | loss_type = 'l1', # L1 or L2
65 | resolution_routine=args.resolution_routine,
66 | train_routine=args.train_routine,
67 | sampling_routine = args.sampling_routine
68 | ).cuda()
69 |
70 | import torch
71 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
72 |
73 | trainer = Trainer(
74 | diffusion,
75 | img_path,
76 | image_size = 128,
77 | train_batch_size = 32,
78 | train_lr = 2e-5,
79 | train_num_steps = 700000, # total training steps
80 | gradient_accumulate_every = 2, # gradient accumulation steps
81 | ema_decay = 0.995, # exponential moving average decay
82 | fp16 = False, # turn on mixed precision training with apex
83 | results_folder = args.save_folder,
84 | load_path = args.load_path,
85 | shuffle=False,
86 | dataset = 'celebA'
87 | )
88 |
89 | if args.test_type == 'train_data':
90 | trainer.test_from_data('train', s_times=args.sample_steps)
91 |
92 | elif args.test_type == 'test_fid_distance_decrease_from_manifold':
93 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
94 |
95 | elif args.test_type == 'train_fid_distance_decrease_from_manifold':
96 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
97 |
98 | elif args.test_type == 'train_distribution_mean_blur_torch_gmm_ablation':
99 | trainer.sample_as_a_mean_blur_torch_gmm_ablation(GaussianMixture, siz=args.gmm_size, ch=3, clusters=args.gmm_cluster, sample_at=args.gmm_sample_at, noise=args.noise)
100 |
101 |
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/mnist_test.py:
--------------------------------------------------------------------------------
1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
2 | from Fid import calculate_fid_given_samples
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 |
10 | def create_folder(path):
11 | try:
12 | os.mkdir(path)
13 | except OSError as exc:
14 | if exc.errno != errno.EEXIST:
15 | raise
16 | pass
17 |
18 |
19 | def del_folder(path):
20 | try:
21 | shutil.rmtree(path)
22 | except OSError as exc:
23 | pass
24 |
25 |
26 | create = 0
27 |
28 | if create:
29 | trainset = torchvision.datasets.MNIST(
30 | root='./data', train=False, download=True)
31 | root = './root_mnist_test/'
32 | del_folder(root)
33 | create_folder(root)
34 |
35 | for i in range(10):
36 | lable_root = root + str(i) + '/'
37 | create_folder(lable_root)
38 |
39 | for idx in range(len(trainset)):
40 | img, label = trainset[idx]
41 | print(idx)
42 | img.save(root + str(label) + '/' + str(idx) + '.png')
43 |
44 |
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('--time_steps', default=50, type=int)
47 | parser.add_argument('--sample_steps', default=None, type=int)
48 | parser.add_argument('--kernel_std', default=0.1, type=float)
49 | parser.add_argument('--save_folder', default='progression_mnist', type=str)
50 | parser.add_argument('--load_path', default='/cmlscratch/eborgnia/cold_diffusion/paper_defading_random_mnist_1/model.pt', type=str)
51 | parser.add_argument('--data_path', default='./root_mnist_test/', type=str)
52 | parser.add_argument('--test_type', default='test_fid_distance_decrease_from_manifold', type=str)
53 | parser.add_argument('--fade_routine', default='Random_Incremental', type=str)
54 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
55 | parser.add_argument('--remove_time_embed', action="store_true")
56 | parser.add_argument('--discrete', action="store_true")
57 | parser.add_argument('--residual', action="store_true")
58 | args = parser.parse_args()
59 | print(args)
60 |
61 | img_path = None
62 | if 'train' in args.test_type:
63 | img_path = args.data_path
64 | elif 'test' in args.test_type:
65 | img_path = args.data_path
66 |
67 | print("Img Path is ", img_path)
68 |
69 |
70 | model = Unet(
71 | dim=64,
72 | dim_mults=(1, 2, 4, 8),
73 | channels=1
74 | ).cuda()
75 |
76 |
77 | diffusion = GaussianDiffusion(
78 | model,
79 | image_size=32,
80 | device_of_kernel='cuda',
81 | channels=1,
82 | timesteps=args.time_steps, # number of steps
83 | loss_type='l1', # L1 or L2
84 | kernel_std=args.kernel_std,
85 | fade_routine=args.fade_routine,
86 | sampling_routine=args.sampling_routine,
87 | discrete=args.discrete
88 | ).cuda()
89 |
90 | trainer = Trainer(
91 | diffusion,
92 | img_path,
93 | image_size=32,
94 | train_batch_size=32,
95 | train_lr=2e-5,
96 | train_num_steps=700000, # total training steps
97 | gradient_accumulate_every=2, # gradient accumulation steps
98 | ema_decay=0.995, # exponential moving average decay
99 | fp16=False, # turn on mixed precision training with apex
100 | results_folder=args.save_folder,
101 | load_path=args.load_path
102 | )
103 |
104 | if args.test_type == 'train_data':
105 | trainer.test_from_data('train', s_times=args.sample_steps)
106 |
107 | elif args.test_type == 'test_data':
108 | trainer.test_from_data('test', s_times=args.sample_steps)
109 |
110 | elif args.test_type == 'mixup_train_data':
111 | trainer.test_with_mixup('train')
112 |
113 | elif args.test_type == 'mixup_test_data':
114 | trainer.test_with_mixup('test')
115 |
116 | elif args.test_type == 'test_random':
117 | trainer.test_from_random('random')
118 |
119 | elif args.test_type == 'test_fid_distance_decrease_from_manifold':
120 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
121 |
122 | elif args.test_type == 'test_paper_invert_section_images':
123 | trainer.paper_invert_section_images()
124 |
125 | elif args.test_type == 'test_paper_showing_diffusion_images_diff':
126 | trainer.paper_showing_diffusion_images()
127 |
--------------------------------------------------------------------------------
/demixing-diffusion-pytorch/dispatch.py:
--------------------------------------------------------------------------------
1 | """dispatch.py
2 | for dispatching jobs on CML
3 | July 2021
4 | """
5 | import argparse
6 | import getpass
7 | import os
8 | import random
9 | import subprocess
10 | import time
11 | import warnings
12 |
13 | parser = argparse.ArgumentParser(description="Dispatch python jobs on the CML cluster")
14 | parser.add_argument("file", type=argparse.FileType())
15 | parser.add_argument("--qos", default="scav", type=str,
16 | help="QOS, choose default, medium, high, scav")
17 | parser.add_argument("--name", default=None, type=str,
18 | help="Name that will be displayed in squeue. Default: file name")
19 | parser.add_argument("--gpus", default="1", type=int, help="Requested GPUs per job")
20 | parser.add_argument("--mem", default="64", type=int, help="Requested memory per job")
21 | parser.add_argument("--time", default=10, type=int, help="Requested hour limit per job")
22 | args = parser.parse_args()
23 |
24 | # get username for use in email and squeue below (do crazy things for arjun)
25 | cmluser_str = getpass.getuser()
26 | username = cmluser_str
27 |
28 | # Parse and validate input:
29 | if args.name is None:
30 | dispatch_name = args.file.name
31 | else:
32 | dispatch_name = args.name
33 |
34 | # Usage warnings:
35 | if args.mem > 385:
36 | raise ValueError("Maximal node memory exceeded.")
37 | if args.gpus > 8:
38 | raise ValueError("Maximal node GPU number exceeded.")
39 | if args.qos == "high" and args.gpus > 4:
40 | warnings.warn("QOS only allows for 4 GPUs, GPU request has been reduced to 4.")
41 | args.gpus = 4
42 | if args.qos == "medium" and args.gpus > 2:
43 | warnings.warn("QOS only allows for 2 GPUs, GPU request has been reduced to 2.")
44 | args.gpus = 2
45 | if args.qos == "default" and args.gpus > 1:
46 | warnings.warn("QOS only allows for 1 GPU, GPU request has been reduced to 1.")
47 | args.gpus = 1
48 | if args.mem / args.gpus > 48:
49 | warnings.warn("You are oversubscribing to memory. "
50 | "This might leave some GPUs idle as total node memory is consumed.")
51 |
52 | # 1) Stripping file of comments and blank lines
53 | content = args.file.readlines()
54 | jobs = [c.strip().split("#", 1)[0] for c in content if "python" in c and c[0] != "#"]
55 |
56 | print(f"Detected {len(jobs)} jobs.")
57 |
58 | # Write file list
59 | authkey = random.randint(10**5, 10**6 - 1)
60 | with open(f".cml_job_list_{authkey}.temp.sh", "w") as file:
61 | file.writelines(chr(10).join(job for job in jobs))
62 | file.write("\n")
63 |
64 | # 2) Prepare environment
65 | if not os.path.exists("cmllogs"):
66 | os.makedirs("cmllogs")
67 |
68 | # 3) Construct launch file
69 | SBATCH_PROTOTYPE = \
70 | f"""#!/bin/bash
71 | # Lines that begin with #SBATCH specify commands to be used by SLURM for scheduling
72 | #SBATCH --job-name={"".join(e for e in dispatch_name if e.isalnum())}
73 | #SBATCH --array={f"1-{len(jobs)}"}
74 | #SBATCH --output=cmllogs/%x_%A_%a.log
75 | #SBATCH --error=cmllogs/%x_%A_%a.log
76 | #SBATCH --time={args.time}:00:00
77 | #SBATCH --account={"tomg" if args.qos != "scav" else "scavenger"}
78 | #SBATCH --qos={args.qos if args.qos != "scav" else "scavenger"}
79 | #SBATCH --gres=gpu:{args.gpus}
80 | #SBATCH --cpus-per-task=4
81 | #SBATCH --partition={"dpart" if args.qos != "scav" else "scavenger"}
82 | #SBATCH --mem={args.mem}gb
83 | #SBATCH --mail-user={username}@umd.edu
84 | #SBATCH --mail-type=END,TIME_LIMIT,FAIL,ARRAY_TASKS
85 | #SBATCH --exclude=cmlgrad05,cmlgrad02,cml12,cml17,cml18,cml19,cml20,cml21,cml22,cml23,cml24
86 | srun $(head -n $((${{SLURM_ARRAY_TASK_ID}} + 0)) .cml_job_list_{authkey}.temp.sh | tail -n 1)
87 | """
88 |
89 | # Write launch commands to file
90 | with open(f".cml_launch_{authkey}.temp.sh", "w") as file:
91 | file.write(SBATCH_PROTOTYPE)
92 | print("Launch prototype is ...")
93 | print("---------------")
94 | print(SBATCH_PROTOTYPE)
95 | print("---------------")
96 | print(chr(10).join("srun " + job for job in jobs))
97 | print(f"Preparing {len(jobs)} jobs ")
98 | print("Terminate if necessary ...")
99 | for _ in range(10):
100 | time.sleep(1)
101 |
102 | # Execute file with sbatch
103 | subprocess.run(["/usr/bin/sbatch", f".cml_launch_{authkey}.temp.sh"])
104 | print("Subprocess launched ...")
105 | time.sleep(1)
106 | os.system(f"watch squeue -u {cmluser_str}")
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/dispatch.py:
--------------------------------------------------------------------------------
1 | """dispatch.py
2 | for dispatching jobs on CML
3 | July 2021
4 | """
5 | import argparse
6 | import getpass
7 | import os
8 | import random
9 | import subprocess
10 | import time
11 | import warnings
12 |
13 | parser = argparse.ArgumentParser(description="Dispatch python jobs on the CML cluster")
14 | parser.add_argument("file", type=argparse.FileType())
15 | parser.add_argument("--qos", default="scav", type=str,
16 | help="QOS, choose default, medium, high, scav")
17 | parser.add_argument("--name", default=None, type=str,
18 | help="Name that will be displayed in squeue. Default: file name")
19 | parser.add_argument("--gpus", default="1", type=int, help="Requested GPUs per job")
20 | parser.add_argument("--mem", default="64", type=int, help="Requested memory per job")
21 | parser.add_argument("--time", default=10, type=int, help="Requested hour limit per job")
22 | args = parser.parse_args()
23 |
24 | # get username for use in email and squeue below (do crazy things for arjun)
25 | cmluser_str = getpass.getuser()
26 | username = cmluser_str
27 |
28 | # Parse and validate input:
29 | if args.name is None:
30 | dispatch_name = args.file.name
31 | else:
32 | dispatch_name = args.name
33 |
34 | # Usage warnings:
35 | if args.mem > 385:
36 | raise ValueError("Maximal node memory exceeded.")
37 | if args.gpus > 8:
38 | raise ValueError("Maximal node GPU number exceeded.")
39 | if args.qos == "high" and args.gpus > 4:
40 | warnings.warn("QOS only allows for 4 GPUs, GPU request has been reduced to 4.")
41 | args.gpus = 4
42 | if args.qos == "medium" and args.gpus > 2:
43 | warnings.warn("QOS only allows for 2 GPUs, GPU request has been reduced to 2.")
44 | args.gpus = 2
45 | if args.qos == "default" and args.gpus > 1:
46 | warnings.warn("QOS only allows for 1 GPU, GPU request has been reduced to 1.")
47 | args.gpus = 1
48 | if args.mem / args.gpus > 48:
49 | warnings.warn("You are oversubscribing to memory. "
50 | "This might leave some GPUs idle as total node memory is consumed.")
51 |
52 | # 1) Stripping file of comments and blank lines
53 | content = args.file.readlines()
54 | jobs = [c.strip().split("#", 1)[0] for c in content if "python" in c and c[0] != "#"]
55 |
56 | print(f"Detected {len(jobs)} jobs.")
57 |
58 | # Write file list
59 | authkey = random.randint(10**5, 10**6 - 1)
60 | with open(f".cml_job_list_{authkey}.temp.sh", "w") as file:
61 | file.writelines(chr(10).join(job for job in jobs))
62 | file.write("\n")
63 |
64 | # 2) Prepare environment
65 | if not os.path.exists("cmllogs"):
66 | os.makedirs("cmllogs")
67 |
68 | # 3) Construct launch file
69 | SBATCH_PROTOTYPE = \
70 | f"""#!/bin/bash
71 | # Lines that begin with #SBATCH specify commands to be used by SLURM for scheduling
72 | #SBATCH --job-name={"".join(e for e in dispatch_name if e.isalnum())}
73 | #SBATCH --array={f"1-{len(jobs)}"}
74 | #SBATCH --output=cmllogs/%x_%A_%a.log
75 | #SBATCH --error=cmllogs/%x_%A_%a.log
76 | #SBATCH --time={args.time}:00:00
77 | #SBATCH --account={"tomg" if args.qos != "scav" else "scavenger"}
78 | #SBATCH --qos={args.qos if args.qos != "scav" else "scavenger"}
79 | #SBATCH --gres=gpu:{args.gpus}
80 | #SBATCH --cpus-per-task=4
81 | #SBATCH --partition={"dpart" if args.qos != "scav" else "scavenger"}
82 | #SBATCH --mem={args.mem}gb
83 | #SBATCH --mail-user={username}@umd.edu
84 | #SBATCH --mail-type=END,TIME_LIMIT,FAIL,ARRAY_TASKS
85 | #SBATCH --exclude=cmlgrad05,cmlgrad02,cml12,cml17,cml18,cml19,cml20,cml21,cml22,cml23,cml24
86 | srun $(head -n $((${{SLURM_ARRAY_TASK_ID}} + 0)) .cml_job_list_{authkey}.temp.sh | tail -n 1)
87 | """
88 |
89 | # Write launch commands to file
90 | with open(f".cml_launch_{authkey}.temp.sh", "w") as file:
91 | file.write(SBATCH_PROTOTYPE)
92 | print("Launch prototype is ...")
93 | print("---------------")
94 | print(SBATCH_PROTOTYPE)
95 | print("---------------")
96 | print(chr(10).join("srun " + job for job in jobs))
97 | print(f"Preparing {len(jobs)} jobs ")
98 | print("Terminate if necessary ...")
99 | for _ in range(10):
100 | time.sleep(1)
101 |
102 | # Execute file with sbatch
103 | subprocess.run(["/usr/bin/sbatch", f".cml_launch_{authkey}.temp.sh"])
104 | print("Subprocess launched ...")
105 | time.sleep(1)
106 | os.system(f"watch squeue -u {cmluser_str}")
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/celebA_test.py:
--------------------------------------------------------------------------------
1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
2 | from Fid import calculate_fid_given_samples
3 | import torch
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 |
10 | def create_folder(path):
11 | try:
12 | os.mkdir(path)
13 | except OSError as exc:
14 | if exc.errno != errno.EEXIST:
15 | raise
16 | pass
17 |
18 |
19 | def del_folder(path):
20 | try:
21 | shutil.rmtree(path)
22 | except OSError as exc:
23 | pass
24 |
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--time_steps', default=100, type=int)
28 | parser.add_argument('--sample_steps', default=None, type=int)
29 | parser.add_argument('--kernel_std', default=0.2, type=float)
30 | parser.add_argument('--initial_mask', default=1, type=int)
31 | parser.add_argument('--save_folder', default='inpainting_with_algorithm_1', type=str)
32 | parser.add_argument('--load_path', default='centered_mask_celebA_128_100_step_ker_0.2_data_augmentation_float/model.pt', type=str)
33 | parser.add_argument('--data_path', default='/cmlscratch/bansal01/spring_2022/Cold-Diffusion/deblurring-diffusion-pytorch/root_celebA_128_test_new/', type=str)
34 | parser.add_argument('--test_type', default='test_data', type=str)
35 | parser.add_argument('--fade_routine', default='Incremental', type=str)
36 | parser.add_argument('--sampling_routine', default='default', type=str)
37 | parser.add_argument('--remove_time_embed', action="store_true")
38 | parser.add_argument('--discrete', action="store_true")
39 | parser.add_argument('--residual', action="store_true")
40 | parser.add_argument('--image_size', default=128, type=int)
41 | parser.add_argument('--dataset', default=None, type=str)
42 | args = parser.parse_args()
43 | print(args)
44 |
45 | img_path=None
46 | if 'train' in args.test_type:
47 | img_path = args.data_path
48 | elif 'test' in args.test_type:
49 | img_path = args.data_path
50 |
51 | model = Unet(
52 | dim=64,
53 | dim_mults=(1, 2, 4, 8),
54 | channels=3,
55 | with_time_emb=not args.remove_time_embed,
56 | residual=args.residual
57 | ).cuda()
58 |
59 | # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
60 |
61 | diffusion = GaussianDiffusion(
62 | model,
63 | image_size=args.image_size,
64 | device_of_kernel='cuda',
65 | channels=3,
66 | timesteps=args.time_steps, # number of steps
67 | loss_type='l1',
68 | kernel_std=args.kernel_std, # L1 or L2
69 | fade_routine=args.fade_routine,
70 | sampling_routine=args.sampling_routine,
71 | discrete=args.discrete,
72 | initial_mask=args.initial_mask
73 | ).cuda()
74 |
75 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
76 |
77 | trainer = Trainer(
78 | diffusion,
79 | img_path,
80 | image_size=args.image_size,
81 | train_batch_size=32,
82 | train_lr=2e-5,
83 | train_num_steps=700000, # total training steps
84 | gradient_accumulate_every=2, # gradient accumulation steps
85 | ema_decay=0.995, # exponential moving average decay
86 | fp16=False, # turn on mixed precision training with apex
87 | results_folder=args.save_folder,
88 | load_path=args.load_path,
89 | dataset=args.dataset
90 | )
91 |
92 | if args.test_type == 'train_data':
93 | trainer.test_from_data('train', s_times=args.sample_steps)
94 |
95 | elif args.test_type == 'test_data':
96 | trainer.test_from_data('test', s_times=args.sample_steps)
97 |
98 | elif args.test_type == 'mixup_train_data':
99 | trainer.test_with_mixup('train')
100 |
101 | elif args.test_type == 'mixup_test_data':
102 | trainer.test_with_mixup('test')
103 |
104 | elif args.test_type == 'test_random':
105 | trainer.test_from_random('random')
106 |
107 | elif args.test_type == 'test_fid_distance_decrease_from_manifold':
108 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
109 |
110 | elif args.test_type == 'test_paper_invert_section_images':
111 | trainer.paper_invert_section_images()
112 |
113 | elif args.test_type == 'test_paper_showing_diffusion_images_diff':
114 | trainer.paper_showing_diffusion_images()
115 |
116 | elif args.test_type == 'test_save_images':
117 | trainer.test_from_data_save_results()
118 |
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/cifar10_test.py:
--------------------------------------------------------------------------------
1 | from defading_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model
2 | from Fid import calculate_fid_given_samples
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 |
10 | def create_folder(path):
11 | try:
12 | os.mkdir(path)
13 | except OSError as exc:
14 | if exc.errno != errno.EEXIST:
15 | raise
16 | pass
17 |
18 |
19 | def del_folder(path):
20 | try:
21 | shutil.rmtree(path)
22 | except OSError as exc:
23 | pass
24 |
25 |
26 | create = 0
27 |
28 | if create:
29 | trainset = torchvision.datasets.CIFAR10(
30 | root='./data', train=False, download=True)
31 | root = './root_cifar10_test/'
32 | del_folder(root)
33 | create_folder(root)
34 |
35 | for i in range(10):
36 | lable_root = root + str(i) + '/'
37 | create_folder(lable_root)
38 |
39 | for idx in range(len(trainset)):
40 | img, label = trainset[idx]
41 | print(idx)
42 | img.save(root + str(label) + '/' + str(idx) + '.png')
43 |
44 |
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('--time_steps', default=50, type=int)
47 | parser.add_argument('--sample_steps', default=None, type=int)
48 | parser.add_argument('--kernel_std', default=0.1, type=float)
49 | parser.add_argument('--save_folder', default='progression_cifar', type=str)
50 | parser.add_argument('--load_path', default='/cmlscratch/eborgnia/cold_diffusion/paper_defading_random_1/model.pt', type=str)
51 | parser.add_argument('--data_path', default='./root_cifar10_test/', type=str)
52 | parser.add_argument('--test_type', default='test_paper_showing_diffusion_images_diff', type=str)
53 | parser.add_argument('--fade_routine', default='Random_Incremental', type=str)
54 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
55 | parser.add_argument('--remove_time_embed', action="store_true")
56 | parser.add_argument('--discrete', action="store_true")
57 | parser.add_argument('--residual', action="store_true")
58 |
59 | args = parser.parse_args()
60 | print(args)
61 |
62 | img_path=None
63 | if 'train' in args.test_type:
64 | img_path = args.data_path
65 | elif 'test' in args.test_type:
66 | img_path = args.data_path
67 |
68 | print("Img Path is ", img_path)
69 |
70 | model = Model(resolution=32,
71 | in_channels=3,
72 | out_ch=3,
73 | ch=128,
74 | ch_mult=(1, 2, 2, 2),
75 | num_res_blocks=2,
76 | attn_resolutions=(16,),
77 | dropout=0.1).cuda()
78 |
79 | diffusion = GaussianDiffusion(
80 | model,
81 | image_size = 32,
82 | device_of_kernel = 'cuda',
83 | channels = 3,
84 | timesteps = args.time_steps, # number of steps
85 | loss_type = 'l1', # L1 or L2
86 | kernel_std=args.kernel_std,
87 | fade_routine=args.fade_routine,
88 | sampling_routine = args.sampling_routine,
89 | discrete=args.discrete
90 | ).cuda()
91 |
92 | trainer = Trainer(
93 | diffusion,
94 | img_path,
95 | image_size = 32,
96 | train_batch_size = 32,
97 | train_lr = 2e-5,
98 | train_num_steps = 700000, # total training steps
99 | gradient_accumulate_every = 2, # gradient accumulation steps
100 | ema_decay = 0.995, # exponential moving average decay
101 | fp16 = False, # turn on mixed precision training with apex
102 | results_folder = args.save_folder,
103 | load_path = args.load_path
104 | )
105 |
106 | if args.test_type == 'train_data':
107 | trainer.test_from_data('train', s_times=args.sample_steps)
108 |
109 | elif args.test_type == 'test_data':
110 | trainer.test_from_data('test', s_times=args.sample_steps)
111 |
112 | elif args.test_type == 'mixup_train_data':
113 | trainer.test_with_mixup('train')
114 |
115 | elif args.test_type == 'mixup_test_data':
116 | trainer.test_with_mixup('test')
117 |
118 | elif args.test_type == 'test_random':
119 | trainer.test_from_random('random')
120 |
121 | elif args.test_type == 'test_fid_distance_decrease_from_manifold':
122 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
123 |
124 | elif args.test_type == 'test_paper_invert_section_images':
125 | trainer.paper_invert_section_images()
126 |
127 | elif args.test_type == 'test_paper_showing_diffusion_images_diff':
128 | trainer.paper_showing_diffusion_images()
129 |
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/cifar10_train.py:
--------------------------------------------------------------------------------
1 | from comet_ml import Experiment
2 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 | create = 0
24 |
25 | if create:
26 | trainset = torchvision.datasets.CIFAR10(
27 | root='./data', train=True, download=True)
28 | root = './root_cifar10/'
29 | del_folder(root)
30 | create_folder(root)
31 |
32 | for i in range(10):
33 | lable_root = root + str(i) + '/'
34 | create_folder(lable_root)
35 |
36 | for idx in range(len(trainset)):
37 | img, label = trainset[idx]
38 | print(idx)
39 | img.save(root + str(label) + '/' + str(idx) + '.png')
40 |
41 |
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('--time_steps', default=50, type=int,
44 | help="This is the number of steps in which a clean image looses information.")
45 | parser.add_argument('--train_steps', default=700000, type=int,
46 | help='The number of iterations for training.')
47 | parser.add_argument('--blur_std', default=0.1, type=float,
48 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.')
49 | parser.add_argument('--blur_size', default=3, type=int,
50 | help='It sets the size of gaussian blur used in blur routines for each step t')
51 | parser.add_argument('--image_size', default=32, type=int)
52 | parser.add_argument('--batch_size', default=32, type=int)
53 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
54 | parser.add_argument('--data_path', default='./root_cifar10/', type=str)
55 | parser.add_argument('--load_path', default=None, type=str)
56 | parser.add_argument('--blur_routine', default='Incremental', type=str,
57 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail')
58 | parser.add_argument('--train_routine', default='Final', type=str)
59 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str,
60 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2')
61 | parser.add_argument('--remove_time_embed', action="store_true")
62 | parser.add_argument('--residual', action="store_true")
63 | parser.add_argument('--discrete', action="store_true")
64 | parser.add_argument('--da', default='cifar10', type=str)
65 |
66 |
67 | args = parser.parse_args()
68 | print(args)
69 |
70 |
71 | model = Model(resolution=args.image_size,
72 | in_channels=3,
73 | out_ch=3,
74 | ch=128,
75 | ch_mult=(1,2,2,2),
76 | num_res_blocks=2,
77 | attn_resolutions=(16,),
78 | dropout=0.1).cuda()
79 |
80 | diffusion = GaussianDiffusion(
81 | model,
82 | image_size = args.image_size,
83 | device_of_kernel = 'cuda',
84 | channels = 3,
85 | timesteps = args.time_steps, # number of steps
86 | loss_type = 'l1', # L1 or L2
87 | kernel_std=args.blur_std,
88 | kernel_size=args.blur_size,
89 | blur_routine=args.blur_routine,
90 | train_routine = args.train_routine,
91 | sampling_routine = args.sampling_routine,
92 | discrete=args.discrete
93 | ).cuda()
94 |
95 | import torch
96 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
97 |
98 |
99 | trainer = Trainer(
100 | diffusion,
101 | args.data_path,
102 | image_size = args.image_size,
103 | train_batch_size = args.batch_size,
104 | train_lr = 2e-5,
105 | train_num_steps = args.train_steps, # total training steps
106 | gradient_accumulate_every = 2, # gradient accumulation steps
107 | ema_decay = 0.995, # exponential moving average decay
108 | fp16 = False, # turn on mixed precision training with apex
109 | results_folder = args.save_folder,
110 | load_path = args.load_path,
111 | dataset = args.da
112 | )
113 |
114 | trainer.train()
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/celebA_128.py:
--------------------------------------------------------------------------------
1 | #from comet_ml import Experiment
2 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 |
9 | def create_folder(path):
10 | try:
11 | os.mkdir(path)
12 | except OSError as exc:
13 | if exc.errno != errno.EEXIST:
14 | raise
15 | pass
16 |
17 | def del_folder(path):
18 | try:
19 | shutil.rmtree(path)
20 | except OSError as exc:
21 | pass
22 |
23 | create = 0
24 | from pathlib import Path
25 | from PIL import Image
26 |
27 | if create:
28 | root_train = './root_celebA_128_train_new/'
29 | root_test = './root_celebA_128_test_new/'
30 | del_folder(root_train)
31 | create_folder(root_train)
32 |
33 | del_folder(root_test)
34 | create_folder(root_test)
35 |
36 | exts = ['jpg', 'jpeg', 'png']
37 | folder = '/fs/cml-datasets/CelebA-HQ/images-128/'
38 | paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
39 |
40 | for idx in range(len(paths)):
41 | img = Image.open(paths[idx])
42 | print(idx)
43 | if idx < 0.9*len(paths):
44 | img.save(root_train + str(idx) + '.png')
45 | else:
46 | img.save(root_test + str(idx) + '.png')
47 |
48 |
49 |
50 |
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument('--time_steps', default=50, type=int,
53 | help="This is the number of steps in which a clean image looses information.")
54 | parser.add_argument('--train_steps', default=700000, type=int,
55 | help='The number of iterations for training.')
56 | parser.add_argument('--blur_std', default=0.1, type=float,
57 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.')
58 | parser.add_argument('--blur_size', default=3, type=int,
59 | help='It sets the size of gaussian blur used in blur routines for each step t')
60 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
61 | parser.add_argument('--data_path', default='./root_celebA_128_train_new/', type=str)
62 | parser.add_argument('--load_path', default=None, type=str)
63 | parser.add_argument('--blur_routine', default='Incremental', type=str,
64 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail')
65 | parser.add_argument('--train_routine', default='Final', type=str)
66 | parser.add_argument('--sampling_routine', default='default', type=str,
67 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2')
68 | parser.add_argument('--remove_time_embed', action="store_true")
69 | parser.add_argument('--residual', action="store_true")
70 | parser.add_argument('--loss_type', default='l1', type=str)
71 | parser.add_argument('--discrete', action="store_true")
72 |
73 |
74 | args = parser.parse_args()
75 | print(args)
76 |
77 |
78 | model = Unet(
79 | dim = 64,
80 | dim_mults = (1, 2, 4, 8),
81 | channels=3,
82 | with_time_emb=not(args.remove_time_embed),
83 | residual=args.residual
84 | ).cuda()
85 |
86 | diffusion = GaussianDiffusion(
87 | model,
88 | image_size = 128,
89 | device_of_kernel = 'cuda',
90 | channels = 3,
91 | timesteps = args.time_steps, # number of steps
92 | loss_type = args.loss_type, # L1 or L2
93 | kernel_std=args.blur_std,
94 | kernel_size=args.blur_size,
95 | blur_routine=args.blur_routine,
96 | train_routine = args.train_routine,
97 | sampling_routine = args.sampling_routine,
98 | discrete=args.discrete
99 | ).cuda()
100 |
101 | import torch
102 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
103 |
104 | trainer = Trainer(
105 | diffusion,
106 | args.data_path,
107 | image_size = 128,
108 | train_batch_size = 32,
109 | train_lr = 2e-5,
110 | train_num_steps = args.train_steps, # total training steps
111 | gradient_accumulate_every = 2, # gradient accumulation steps
112 | ema_decay = 0.995, # exponential moving average decay
113 | fp16 = False, # turn on mixed precision training with apex
114 | results_folder = args.save_folder,
115 | load_path = args.load_path,
116 | dataset = 'celebA'
117 | )
118 |
119 | trainer.train()
--------------------------------------------------------------------------------
/snowification/train.py:
--------------------------------------------------------------------------------
1 | from diffusion import GaussianDiffusion, Trainer, get_dataset
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 | from diffusion.model.get_model import get_model
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--time_steps', default=50, type=int)
11 | parser.add_argument('--train_steps', default=700000, type=int)
12 | parser.add_argument('--save_folder', default='/cmlscratch/hmchu/cold_diff/', type=str)
13 | parser.add_argument('--load_path', default=None, type=str)
14 | parser.add_argument('--train_routine', default='Final', type=str)
15 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
16 | parser.add_argument('--remove_time_embed', action="store_true")
17 | parser.add_argument('--dataset_folder', default='./root_cifar10', type=str)
18 | parser.add_argument('--random_aug', action='store_true')
19 | parser.add_argument('--output_mean_scale', action='store_true')
20 | parser.add_argument('--exp_name', default='', type=str)
21 | parser.add_argument('--dataset', default='cifar10')
22 | parser.add_argument('--model', default='UnetConvNext', type=str)
23 |
24 | parser.add_argument('--forward_process_type', default='Snow')
25 |
26 | # Decolor args
27 | parser.add_argument('--decolor_routine', default='Constant')
28 | parser.add_argument('--decolor_ema_factor', default=0.9, type=float)
29 | parser.add_argument('--decolor_total_remove', action='store_true')
30 | parser.add_argument('--to_lab', action='store_true')
31 |
32 | parser.add_argument('--loss_type', type=str, default='l1')
33 | parser.add_argument('--resume_training', action='store_true')
34 | parser.add_argument('--load_model_steps', default=-1, type=int)
35 |
36 | # Snow arg
37 | parser.add_argument('--snow_level', default=1, type=int)
38 | parser.add_argument('--random_snow', action='store_true')
39 | parser.add_argument('--single_snow', action='store_true')
40 | parser.add_argument('--fix_brightness', action='store_true')
41 |
42 | parser.add_argument('--resolution', default=-1, type=int)
43 |
44 | args = parser.parse_args()
45 |
46 | assert len(args.exp_name) > 0
47 | args.save_folder = os.path.join(args.save_folder, args.exp_name)
48 | print(args)
49 |
50 | if args.resume_training:
51 | if args.load_model_steps == -1:
52 | args.load_path = os.path.join(args.save_folder, 'model.pt')
53 | else:
54 | args.load_path = os.path.join(args.save_folder, f'model_{args.load_model_steps}.pt')
55 | print(f'resume from checkpoint stored at {args.load_path}')
56 |
57 | with_time_emb = not args.remove_time_embed
58 |
59 | model = get_model(args, with_time_emb=with_time_emb).cuda()
60 | model_one_shot = get_model(args, with_time_emb=False).cuda()
61 |
62 |
63 |
64 | image_size = get_dataset.get_image_size(args.dataset)
65 | if args.resolution != -1:
66 | image_size = (args.resolution, args.resolution)
67 |
68 | use_torchvison_dataset = False
69 | if 'cifar10' in args.dataset:
70 | use_torchvison_dataset = True
71 | args.dataset = 'cifar10_train'
72 |
73 | if image_size[0] <= 64:
74 | train_batch_size = 32
75 | elif image_size[0] > 64:
76 | train_batch_size = 16
77 |
78 |
79 | diffusion = GaussianDiffusion(
80 | model,
81 | image_size = image_size,
82 | device_of_kernel = 'cuda',
83 | channels = 3,
84 | one_shot_denoise_fn=model_one_shot,
85 | timesteps = args.time_steps, # number of steps
86 | loss_type = args.loss_type, # L1 or L2
87 | train_routine = args.train_routine,
88 | sampling_routine = args.sampling_routine,
89 | forward_process_type=args.forward_process_type,
90 | decolor_routine=args.decolor_routine,
91 | decolor_ema_factor=args.decolor_ema_factor,
92 | decolor_total_remove=args.decolor_total_remove,
93 | snow_level=args.snow_level,
94 | single_snow=args.single_snow,
95 | batch_size=train_batch_size,
96 | random_snow=args.random_snow,
97 | to_lab=args.to_lab,
98 | load_path=args.load_path,
99 | results_folder=args.save_folder,
100 | fix_brightness=args.fix_brightness,
101 | ).cuda()
102 |
103 | trainer = Trainer(
104 | diffusion,
105 | args.dataset_folder,
106 | image_size = image_size,
107 | train_batch_size = train_batch_size,
108 | train_lr = 2e-5,
109 | train_num_steps = args.train_steps, # total training steps
110 | gradient_accumulate_every = 2, # gradient accumulation steps
111 | ema_decay = 0.995, # exponential moving average decay
112 | fp16 = False, # turn on mixed precision training with apex
113 | results_folder = args.save_folder,
114 | load_path = args.load_path,
115 | random_aug=args.random_aug,
116 | torchvision_dataset=use_torchvison_dataset,
117 | dataset = f'{args.dataset}',
118 | to_lab=args.to_lab,
119 | )
120 |
121 | trainer.train()
122 |
--------------------------------------------------------------------------------
/decolor-diffusion/train.py:
--------------------------------------------------------------------------------
1 | from diffusion import GaussianDiffusion, Trainer, get_dataset
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 | from diffusion.model.get_model import get_model
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--time_steps', default=50, type=int)
11 | parser.add_argument('--train_steps', default=700000, type=int)
12 | parser.add_argument('--save_folder', default='/cmlscratch/hmchu/cold_diff/', type=str)
13 | parser.add_argument('--load_path', default=None, type=str)
14 | parser.add_argument('--train_routine', default='Final', type=str)
15 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
16 | parser.add_argument('--remove_time_embed', action="store_true")
17 | parser.add_argument('--dataset_folder', default='./root_cifar10', type=str)
18 | parser.add_argument('--random_aug', action='store_true')
19 | parser.add_argument('--output_mean_scale', action='store_true')
20 | parser.add_argument('--exp_name', default='', type=str)
21 | parser.add_argument('--dataset', default='cifar10')
22 | parser.add_argument('--model', default='UnetConvNext', type=str)
23 |
24 | parser.add_argument('--forward_process_type', default='Snow')
25 |
26 | # Decolor args
27 | parser.add_argument('--decolor_routine', default='Constant')
28 | parser.add_argument('--decolor_ema_factor', default=0.9, type=float)
29 | parser.add_argument('--decolor_total_remove', action='store_true')
30 | parser.add_argument('--to_lab', action='store_true')
31 |
32 | parser.add_argument('--loss_type', type=str, default='l1')
33 | parser.add_argument('--resume_training', action='store_true')
34 | parser.add_argument('--load_model_steps', default=-1, type=int)
35 |
36 | # Snow arg
37 | parser.add_argument('--snow_level', default=1, type=int)
38 | parser.add_argument('--random_snow', action='store_true')
39 | parser.add_argument('--single_snow', action='store_true')
40 | parser.add_argument('--fix_brightness', action='store_true')
41 |
42 | parser.add_argument('--resolution', default=-1, type=int)
43 |
44 | args = parser.parse_args()
45 |
46 | assert len(args.exp_name) > 0
47 | args.save_folder = os.path.join(args.save_folder, args.exp_name)
48 | print(args)
49 |
50 | if args.resume_training:
51 | if args.load_model_steps == -1:
52 | args.load_path = os.path.join(args.save_folder, 'model.pt')
53 | else:
54 | args.load_path = os.path.join(args.save_folder, f'model_{args.load_model_steps}.pt')
55 | print(f'resume from checkpoint stored at {args.load_path}')
56 |
57 | with_time_emb = not args.remove_time_embed
58 |
59 | model = get_model(args, with_time_emb=with_time_emb).cuda()
60 | model_one_shot = get_model(args, with_time_emb=False).cuda()
61 |
62 |
63 |
64 | image_size = get_dataset.get_image_size(args.dataset)
65 | if args.resolution != -1:
66 | image_size = (args.resolution, args.resolution)
67 |
68 | use_torchvison_dataset = False
69 | if 'cifar10' in args.dataset:
70 | use_torchvison_dataset = True
71 | args.dataset = 'cifar10_train'
72 |
73 | if image_size[0] <= 64:
74 | train_batch_size = 32
75 | elif image_size[0] > 64:
76 | train_batch_size = 16
77 |
78 |
79 | diffusion = GaussianDiffusion(
80 | model,
81 | image_size = image_size,
82 | device_of_kernel = 'cuda',
83 | channels = 3,
84 | one_shot_denoise_fn=model_one_shot,
85 | timesteps = args.time_steps, # number of steps
86 | loss_type = args.loss_type, # L1 or L2
87 | train_routine = args.train_routine,
88 | sampling_routine = args.sampling_routine,
89 | forward_process_type=args.forward_process_type,
90 | decolor_routine=args.decolor_routine,
91 | decolor_ema_factor=args.decolor_ema_factor,
92 | decolor_total_remove=args.decolor_total_remove,
93 | snow_level=args.snow_level,
94 | single_snow=args.single_snow,
95 | batch_size=train_batch_size,
96 | random_snow=args.random_snow,
97 | to_lab=args.to_lab,
98 | load_path=args.load_path,
99 | results_folder=args.save_folder,
100 | fix_brightness=args.fix_brightness,
101 | ).cuda()
102 |
103 | trainer = Trainer(
104 | diffusion,
105 | args.dataset_folder,
106 | image_size = image_size,
107 | train_batch_size = train_batch_size,
108 | train_lr = 2e-5,
109 | train_num_steps = args.train_steps, # total training steps
110 | gradient_accumulate_every = 2, # gradient accumulation steps
111 | ema_decay = 0.995, # exponential moving average decay
112 | fp16 = False, # turn on mixed precision training with apex
113 | results_folder = args.save_folder,
114 | load_path = args.load_path,
115 | random_aug=args.random_aug,
116 | torchvision_dataset=use_torchvison_dataset,
117 | dataset = f'{args.dataset}',
118 | to_lab=args.to_lab,
119 | )
120 |
121 | trainer.train()
122 |
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/mnist_test.py:
--------------------------------------------------------------------------------
1 | from pycave.bayes import GaussianMixture
2 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 | from Fid import calculate_fid_given_samples
9 |
10 | def create_folder(path):
11 | try:
12 | os.mkdir(path)
13 | except OSError as exc:
14 | if exc.errno != errno.EEXIST:
15 | raise
16 | pass
17 |
18 | def del_folder(path):
19 | try:
20 | shutil.rmtree(path)
21 | except OSError as exc:
22 | pass
23 |
24 | create = 0
25 |
26 | if create:
27 | trainset = torchvision.datasets.MNIST(
28 | root='./data', train=False, download=True)
29 | root = './root_mnist_test/'
30 | del_folder(root)
31 | create_folder(root)
32 |
33 | for i in range(10):
34 | lable_root = root + str(i) + '/'
35 | create_folder(lable_root)
36 |
37 | for idx in range(len(trainset)):
38 | img, label = trainset[idx]
39 | print(idx)
40 | img.save(root + str(label) + '/' + str(idx) + '.png')
41 |
42 |
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument('--time_steps', default=50, type=int,
45 | help="This is the number of steps in which a clean image looses information.")
46 | parser.add_argument('--sample_steps', default=None, type=int)
47 | parser.add_argument('--blur_std', default=0.1, type=float,
48 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.')
49 | parser.add_argument('--blur_size', default=3, type=int,
50 | help='It sets the size of gaussian blur used in blur routines for each step t')
51 | parser.add_argument('--save_folder', default='./results_mnist', type=str)
52 | parser.add_argument('--load_path', default=None, type=str)
53 | parser.add_argument('--data_path', default='./root_mnist/', type=str)
54 | parser.add_argument('--test_type', default='train_data', type=str)
55 | parser.add_argument('--blur_routine', default='Incremental', type=str,
56 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail')
57 | parser.add_argument('--train_routine', default='Final', type=str)
58 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str,
59 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2')
60 | parser.add_argument('--gmm_size', default=8, type=int)
61 | parser.add_argument('--gmm_cluster', default=10, type=int)
62 |
63 | args = parser.parse_args()
64 | print(args)
65 |
66 | img_path=None
67 | if 'train' in args.test_type:
68 | img_path = args.data_path
69 | elif 'test' in args.test_type:
70 | img_path = args.data_path
71 |
72 | print("Img Path is ", img_path)
73 |
74 |
75 | model = Unet(
76 | dim = 64,
77 | dim_mults = (1, 2, 4, 8),
78 | channels=1
79 | ).cuda()
80 |
81 |
82 | diffusion = GaussianDiffusion(
83 | model,
84 | image_size = 32,
85 | device_of_kernel = 'cuda',
86 | channels = 1,
87 | timesteps = args.time_steps, # number of steps
88 | loss_type = 'l1', # L1 or L2
89 | kernel_std=args.blur_std,
90 | kernel_size=args.blur_size,
91 | blur_routine=args.blur_routine,
92 | train_routine=args.train_routine,
93 | sampling_routine = args.sampling_routine
94 | ).cuda()
95 |
96 | import torch
97 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
98 |
99 | trainer = Trainer(
100 | diffusion,
101 | img_path,
102 | image_size = 32,
103 | train_batch_size = 32,
104 | train_lr = 2e-5,
105 | train_num_steps = 700000, # total training steps
106 | gradient_accumulate_every = 2, # gradient accumulation steps
107 | ema_decay = 0.995, # exponential moving average decay
108 | fp16 = False, # turn on mixed precision training with apex
109 | results_folder = args.save_folder,
110 | load_path = args.load_path
111 | )
112 |
113 | if args.test_type == 'train_data':
114 | trainer.test_from_data('train', s_times=args.sample_steps)
115 |
116 | elif args.test_type == 'test_data':
117 | trainer.test_from_data('test', s_times=args.sample_steps)
118 |
119 | elif args.test_type == 'train_save_orig_data_same_as_trained':
120 | trainer.save_training_data()
121 |
122 | elif args.test_type == 'test_save_orig_data_same_as_tested':
123 | trainer.save_training_data()
124 |
125 | elif args.test_type == 'test_fid_distance_decrease_from_manifold':
126 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
127 |
128 | elif args.test_type == 'train_fid_distance_decrease_from_manifold':
129 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
130 |
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/cifar10_test.py:
--------------------------------------------------------------------------------
1 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 | import torch
8 | import random
9 | from Fid import calculate_fid_given_samples
10 |
11 | seed_value=123457
12 | torch.manual_seed(seed_value) # cpu vars
13 | random.seed(seed_value) # Python
14 | torch.cuda.manual_seed(seed_value)
15 | torch.cuda.manual_seed_all(seed_value)
16 | torch.backends.cudnn.deterministic = True
17 | torch.backends.cudnn.benchmark = False
18 |
19 | def create_folder(path):
20 | try:
21 | os.mkdir(path)
22 | except OSError as exc:
23 | if exc.errno != errno.EEXIST:
24 | raise
25 | pass
26 |
27 | def del_folder(path):
28 | try:
29 | shutil.rmtree(path)
30 | except OSError as exc:
31 | pass
32 |
33 | create = 0
34 |
35 | if create:
36 | trainset = torchvision.datasets.CIFAR10(
37 | root='./data', train=False, download=True)
38 | root = './root_cifar10_test/'
39 | del_folder(root)
40 | create_folder(root)
41 |
42 | for i in range(10):
43 | lable_root = root + str(i) + '/'
44 | create_folder(lable_root)
45 |
46 | for idx in range(len(trainset)):
47 | img, label = trainset[idx]
48 | print(idx)
49 | img.save(root + str(label) + '/' + str(idx) + '.png')
50 |
51 |
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument('--time_steps', default=50, type=int,
54 | help="This is the number of steps in which a clean image looses information.")
55 | parser.add_argument('--sample_steps', default=None, type=int)
56 | parser.add_argument('--blur_std', default=0.1, type=float,
57 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.')
58 | parser.add_argument('--blur_size', default=3, type=int,
59 | help='It sets the size of gaussian blur used in blur routines for each step t')
60 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
61 | parser.add_argument('--data_path', default='./root_cifar10/', type=str)
62 | parser.add_argument('--load_path', default=None, type=str)
63 | parser.add_argument('--test_type', default='train_data', type=str)
64 | parser.add_argument('--blur_routine', default='Incremental', type=str,
65 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail')
66 | parser.add_argument('--train_routine', default='Final', type=str)
67 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str,
68 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2')
69 | parser.add_argument('--remove_time_embed', action="store_true")
70 | parser.add_argument('--residual', action="store_true")
71 |
72 |
73 | args = parser.parse_args()
74 | print(args)
75 |
76 | img_path=None
77 | if args.test_type == 'train_distribution_cov_vector':
78 | img_path = args.data_path
79 | elif 'train' in args.test_type:
80 | img_path = args.data_path
81 | elif 'test' in args.test_type:
82 | img_path = args.data_path
83 |
84 | print("Img Path is ", img_path)
85 |
86 |
87 | model = Model(resolution=32,
88 | in_channels=3,
89 | out_ch=3,
90 | ch=128,
91 | ch_mult=(1,2,2,2),
92 | num_res_blocks=2,
93 | attn_resolutions=(16,),
94 | dropout=0.01).cuda()
95 |
96 |
97 | diffusion = GaussianDiffusion(
98 | model,
99 | image_size = 32,
100 | device_of_kernel = 'cuda',
101 | channels = 3,
102 | timesteps = args.time_steps, # number of steps
103 | loss_type = 'l1', # L1 or L2
104 | kernel_std=args.blur_std,
105 | kernel_size=args.blur_size,
106 | blur_routine=args.blur_routine,
107 | train_routine=args.train_routine,
108 | sampling_routine = args.sampling_routine
109 | ).cuda()
110 |
111 | trainer = Trainer(
112 | diffusion,
113 | img_path,
114 | image_size = 32,
115 | train_batch_size = 32,
116 | train_lr = 2e-5,
117 | train_num_steps = 700000, # total training steps
118 | gradient_accumulate_every = 2, # gradient accumulation steps
119 | ema_decay = 0.995, # exponential moving average decay
120 | fp16 = False, # turn on mixed precision training with apex
121 | results_folder = args.save_folder,
122 | load_path = args.load_path,
123 | shuffle=True
124 | )
125 |
126 | import torch
127 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
128 |
129 |
130 | if args.test_type == 'train_data':
131 | trainer.test_from_data('train', s_times=args.sample_steps)
132 |
133 | elif args.test_type == 'test_data':
134 | trainer.test_from_data('test', s_times=args.sample_steps)
135 |
136 | elif args.test_type == 'test_data_save_results':
137 | trainer.test_from_data_save_results('test')
138 |
139 | elif args.test_type == 'test_fid_distance_decrease_from_manifold':
140 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
141 |
142 | elif args.test_type == 'train_fid_distance_decrease_from_manifold':
143 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
144 |
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/cifar10_test.py:
--------------------------------------------------------------------------------
1 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer, Model
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 | import torch
8 | import random
9 | from Fid import calculate_fid_given_samples
10 |
11 | seed_value=123457
12 | torch.manual_seed(seed_value) # cpu vars
13 | random.seed(seed_value) # Python
14 | torch.cuda.manual_seed(seed_value)
15 | torch.cuda.manual_seed_all(seed_value)
16 | torch.backends.cudnn.deterministic = True
17 | torch.backends.cudnn.benchmark = False
18 |
19 | def create_folder(path):
20 | try:
21 | os.mkdir(path)
22 | except OSError as exc:
23 | if exc.errno != errno.EEXIST:
24 | raise
25 | pass
26 |
27 | def del_folder(path):
28 | try:
29 | shutil.rmtree(path)
30 | except OSError as exc:
31 | pass
32 |
33 | create = 0
34 |
35 | if create:
36 | trainset = torchvision.datasets.CIFAR10(
37 | root='./data', train=False, download=True)
38 | root = './root_cifar10_test/'
39 | del_folder(root)
40 | create_folder(root)
41 |
42 | for i in range(10):
43 | lable_root = root + str(i) + '/'
44 | create_folder(lable_root)
45 |
46 | for idx in range(len(trainset)):
47 | img, label = trainset[idx]
48 | print(idx)
49 | img.save(root + str(label) + '/' + str(idx) + '.png')
50 |
51 |
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument('--time_steps', default=50, type=int)
54 | parser.add_argument('--sample_steps', default=None, type=int)
55 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
56 | parser.add_argument('--load_path', default=None, type=str)
57 | parser.add_argument('--data_path', default='./root_cifar10_test/', type=str)
58 | parser.add_argument('--test_type', default='train_data', type=str)
59 | parser.add_argument('--resolution_routine', default='Incremental', type=str)
60 | parser.add_argument('--train_routine', default='Final', type=str)
61 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
62 | parser.add_argument('--remove_time_embed', action="store_true")
63 | parser.add_argument('--residual', action="store_true")
64 |
65 |
66 | args = parser.parse_args()
67 | print(args)
68 |
69 | img_path=None
70 | if args.test_type == 'train_distribution_cov_vector':
71 | img_path = args.data_path
72 | elif 'train' in args.test_type:
73 | img_path = args.data_path
74 | elif 'test' in args.test_type:
75 | img_path = args.data_path
76 |
77 | print("Img Path is ", img_path)
78 |
79 |
80 | model = Model(resolution=32,
81 | in_channels=3,
82 | out_ch=3,
83 | ch=128,
84 | ch_mult=(1,2,2,2),
85 | num_res_blocks=2,
86 | attn_resolutions=(16,),
87 | dropout=0.01).cuda()
88 |
89 |
90 | diffusion = GaussianDiffusion(
91 | model,
92 | image_size = 32,
93 | device_of_kernel = 'cuda',
94 | channels = 3,
95 | timesteps = args.time_steps, # number of steps
96 | loss_type = 'l1', # L1 or L2
97 | resolution_routine=args.resolution_routine,
98 | train_routine=args.train_routine,
99 | sampling_routine = args.sampling_routine
100 | ).cuda()
101 |
102 | trainer = Trainer(
103 | diffusion,
104 | img_path,
105 | image_size = 32,
106 | train_batch_size = 32,
107 | train_lr = 2e-5,
108 | train_num_steps = 700000, # total training steps
109 | gradient_accumulate_every = 2, # gradient accumulation steps
110 | ema_decay = 0.995, # exponential moving average decay
111 | fp16 = False, # turn on mixed precision training with apex
112 | results_folder = args.save_folder,
113 | load_path = args.load_path,
114 | shuffle=True
115 | )
116 |
117 | if args.test_type == 'train_data':
118 | trainer.test_from_data('train', s_times=args.sample_steps)
119 |
120 | elif args.test_type == 'train_data_dropout':
121 | trainer.test_from_data_dropout('train_drop', s_times=args.sample_steps)
122 |
123 | elif args.test_type == 'test_data_dropout':
124 | trainer.test_from_data_dropout('test_drop', s_times=args.sample_steps)
125 |
126 | elif args.test_type == 'test_data':
127 | trainer.test_from_data('test', s_times=args.sample_steps)
128 |
129 | elif args.test_type == 'mixup_train_data':
130 | trainer.test_with_mixup('train')
131 |
132 | elif args.test_type == 'mixup_test_data':
133 | trainer.test_with_mixup('test')
134 |
135 | elif args.test_type == 'test_random':
136 | trainer.test_from_random('test_random')
137 |
138 | elif args.test_type == 'train_random':
139 | trainer.test_from_random('train_random')
140 |
141 | elif args.test_type == 'train_distribution_cov_vector':
142 | trainer.sample_as_a_vector_cov(start=0, end=None, siz=4, ch=3)
143 |
144 | elif args.test_type == 'test_fid_distance_decrease_from_manifold':
145 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
146 |
147 | elif args.test_type == 'train_fid_distance_decrease_from_manifold':
148 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
149 |
150 | elif args.test_type == 'test_paper_invert_section_images':
151 | trainer.paper_invert_section_images()
152 |
153 | elif args.test_type == 'test_paper_showing_diffusion_images':
154 | trainer.paper_showing_diffusion_images()
155 |
156 | elif args.test_type == 'test_paper_showing_diffusion_imgs_og':
157 | trainer.paper_showing_diffusion_imgs_og()
158 |
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/AFHQ_128_test.py:
--------------------------------------------------------------------------------
1 | from pycave.bayes import GaussianMixture
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 | from Fid import calculate_fid_given_samples
8 | import torch
9 | import random
10 |
11 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
12 | from gmm_pycave import GaussianMixture as GaussianMixture
13 |
14 | seed_value=123457
15 | torch.manual_seed(seed_value) # cpu vars
16 | random.seed(seed_value) # Python
17 | torch.cuda.manual_seed(seed_value)
18 | torch.cuda.manual_seed_all(seed_value)
19 | torch.backends.cudnn.deterministic = True
20 | torch.backends.cudnn.benchmark = False
21 |
22 | def create_folder(path):
23 | try:
24 | os.mkdir(path)
25 | except OSError as exc:
26 | if exc.errno != errno.EEXIST:
27 | raise
28 | pass
29 |
30 | def del_folder(path):
31 | try:
32 | shutil.rmtree(path)
33 | except OSError as exc:
34 | pass
35 |
36 |
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--time_steps', default=50, type=int,
39 | help="This is the number of steps in which a clean image looses information.")
40 | parser.add_argument('--sample_steps', default=None, type=int)
41 | parser.add_argument('--blur_std', default=0.1, type=float,
42 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.')
43 | parser.add_argument('--blur_size', default=3, type=int,
44 | help='It sets the size of gaussian blur used in blur routines for each step t')
45 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
46 | parser.add_argument('--data_path', default='./AFHQ/afhq/train/', type=str)
47 | parser.add_argument('--load_path', default=None, type=str)
48 | parser.add_argument('--test_type', default='train_data', type=str)
49 | parser.add_argument('--blur_routine', default='Incremental', type=str,
50 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail')
51 | parser.add_argument('--train_routine', default='Final', type=str)
52 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str,
53 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2')
54 | parser.add_argument('--remove_time_embed', action="store_true")
55 | parser.add_argument('--residual', action="store_true")
56 | parser.add_argument('--gmm_size', default=8, type=int)
57 | parser.add_argument('--gmm_cluster', default=10, type=int)
58 | parser.add_argument('--gmm_sample_at', default=1, type=int)
59 | parser.add_argument('--bs', default=32, type=int)
60 | parser.add_argument('--discrete', action="store_true")
61 | parser.add_argument('--noise', default=0, type=float)
62 |
63 | args = parser.parse_args()
64 | print(args)
65 |
66 | img_path=None
67 | if 'train' in args.test_type:
68 | img_path = args.data_path
69 | elif 'test' in args.test_type:
70 | img_path = args.data_path
71 |
72 | print("Img Path is ", img_path)
73 |
74 |
75 | model = Unet(
76 | dim = 64,
77 | dim_mults = (1, 2, 4, 8),
78 | channels=3,
79 | with_time_emb=not(args.remove_time_embed),
80 | residual=args.residual
81 | ).cuda()
82 |
83 |
84 | diffusion = GaussianDiffusion(
85 | model,
86 | image_size = 128,
87 | device_of_kernel = 'cuda',
88 | channels = 3,
89 | timesteps = args.time_steps, # number of steps
90 | loss_type = 'l1', # L1 or L2
91 | kernel_std=args.blur_std,
92 | kernel_size=args.blur_size,
93 | blur_routine=args.blur_routine,
94 | train_routine=args.train_routine,
95 | sampling_routine = args.sampling_routine,
96 | discrete=args.discrete
97 | ).cuda()
98 |
99 | import torch
100 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
101 |
102 |
103 | trainer = Trainer(
104 | diffusion,
105 | img_path,
106 | image_size = 128,
107 | train_batch_size = args.bs,
108 | train_lr = 2e-5,
109 | train_num_steps = 700000, # total training steps
110 | gradient_accumulate_every = 2, # gradient accumulation steps
111 | ema_decay = 0.995, # exponential moving average decay
112 | fp16 = False, # turn on mixed precision training with apex
113 | results_folder = args.save_folder,
114 | load_path = args.load_path,
115 | dataset = None,
116 | shuffle=False
117 | )
118 |
119 | if args.test_type == 'train_data':
120 | trainer.test_from_data('train', s_times=args.sample_steps)
121 |
122 | elif args.test_type == 'test_data':
123 | trainer.test_from_data('test', s_times=args.sample_steps)
124 |
125 |
126 |
127 | # this for extreme case and we perform gmm at the meaned vectors and test with what happens with noise
128 | elif args.test_type == 'train_distribution_mean_blur_torch_gmm':
129 | trainer.sample_as_a_mean_blur_torch_gmm(GaussianMixture, start=0, end=26900, ch=3, clusters=args.gmm_cluster)
130 |
131 | # to save
132 | elif args.test_type == 'train_distribution_mean_blur_torch_gmm_ablation':
133 | trainer.sample_as_a_mean_blur_torch_gmm_ablation(GaussianMixture, ch=3, clusters=args.gmm_cluster, noise=args.noise)
134 |
135 |
136 | # this for the non extreme case and you can create gmm at any point in the blur sequence and then blur from there
137 | elif args.test_type == 'train_distribution_blur_torch_gmm':
138 | trainer.sample_as_a_blur_torch_gmm(GaussianMixture, start=0, end=26900, ch=3, clusters=args.gmm_cluster)
139 |
140 |
141 | ######### Quant ##############
142 | elif args.test_type == 'test_fid_distance_decrease_from_manifold':
143 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
144 |
145 | elif args.test_type == 'train_fid_distance_decrease_from_manifold':
146 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
147 |
148 | ###############################
149 |
--------------------------------------------------------------------------------
/defading-diffusion-pytorch/inpainting_paper_train.sh:
--------------------------------------------------------------------------------
1 | #python defading-diffusion-pytorch/cifar10_train.py --time_steps 50 --save_folder ./cifar_inpainting --discrete --sampling_routine x0_step_down --train_steps 700000 --blur_std 0.1 --fade_routine Random_Incremental
2 |
3 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 50 --save_folder ./celebA_inpainting --discrete --sampling_routine x0_step_down --train_steps 700000 --blur_std 0.1 --fade_routine Random_Incremental
4 |
5 | #python defading-diffusion-pytorch/mnist_train.py --time_steps 50 --save_folder ./mnist_inpainting --discrete --sampling_routine x0_step_down --train_steps 700000 --blur_std 0.1 --fade_routine Random_Incremental
6 |
7 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 200 --save_folder ./celebA_128_200_step --discrete --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.1 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_200_step/model.pt
8 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --save_folder ./celebA_128_100_step --discrete --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_100_step/model.pt
9 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 50 --save_folder ./celebA_128_50_step --discrete --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.4 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_50_step/model.pt
10 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --save_folder ./celebA_128_100_step_data_augmentation --discrete --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --fade_routine Random_Incremental --initial_mask 11 --image_size 128 --dataset celebA --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_100_step_data_augmentation/model.pt
11 |
12 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 200 --save_folder ./celebA_128_200_step_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.1 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_200_step_float/model.pt
13 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --save_folder ./celebA_128_100_step_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_100_step_float/model.pt
14 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 50 --save_folder ./celebA_128_50_step_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.4 --fade_routine Random_Incremental --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_50_step_float/model.pt
15 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --save_folder ./celebA_128_100_step_data_augmentation_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --fade_routine Random_Incremental --initial_mask 11 --image_size 128 --dataset celebA --load_path /cmlscratch/eborgnia/cold_diffusion/celebA_128_100_step_data_augmentation_float/model.pt
16 |
17 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 200 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_200_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128
18 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_100_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128
19 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 50 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_50_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128
20 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_100_step_ker_0.2_data_augmentation_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 --dataset celebA
21 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 125 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_125_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128
22 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 150 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_150_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128
23 | #python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_100_step_ker_0.2_data_augmentation --discrete --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 --dataset celebA
24 |
25 | python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_100_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/centered_mask_celebA_128_100_step_ker_0.2_float/model.pt
26 | python defading-diffusion-pytorch/celebA_train.py --time_steps 50 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_50_step_ker_0.2_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 --load_path /cmlscratch/eborgnia/cold_diffusion/centered_mask_celebA_128_50_step_ker_0.2_float/model.pt
27 | python defading-diffusion-pytorch/celebA_train.py --time_steps 100 --fade_routine Incremental --save_folder ./centered_mask_celebA_128_100_step_ker_0.2_data_augmentation_float --sampling_routine x0_step_down --train_steps 350000 --kernel_std 0.2 --initial_mask 1 --image_size 128 --dataset celebA --load_path /cmlscratch/eborgnia/cold_diffusion/centered_mask_celebA_128_100_step_ker_0.2_data_augmentation_float/model.pt
--------------------------------------------------------------------------------
/deblurring-diffusion-pytorch/celebA_128_test.py:
--------------------------------------------------------------------------------
1 | from pycave.bayes import GaussianMixture
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 | from Fid import calculate_fid_given_samples
8 | import torch
9 | import random
10 |
11 | from deblurring_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
12 | from gmm_pycave import GaussianMixture as GaussianMixture
13 |
14 | seed_value=123457
15 | torch.manual_seed(seed_value) # cpu vars
16 | random.seed(seed_value) # Python
17 | torch.cuda.manual_seed(seed_value)
18 | torch.cuda.manual_seed_all(seed_value)
19 | torch.backends.cudnn.deterministic = True
20 | torch.backends.cudnn.benchmark = False
21 |
22 | def create_folder(path):
23 | try:
24 | os.mkdir(path)
25 | except OSError as exc:
26 | if exc.errno != errno.EEXIST:
27 | raise
28 | pass
29 |
30 | def del_folder(path):
31 | try:
32 | shutil.rmtree(path)
33 | except OSError as exc:
34 | pass
35 |
36 |
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--time_steps', default=50, type=int,
39 | help="This is the number of steps in which a clean image looses information.")
40 | parser.add_argument('--sample_steps', default=None, type=int)
41 | parser.add_argument('--blur_std', default=0.1, type=float,
42 | help='It sets the standard deviation for blur routines which have different meaning based on blur routine.')
43 | parser.add_argument('--blur_size', default=3, type=int,
44 | help='It sets the size of gaussian blur used in blur routines for each step t')
45 | parser.add_argument('--save_folder', default='./results_cifar10', type=str)
46 | parser.add_argument('--data_path', default='./root_celebA_128_train_new/', type=str)
47 | parser.add_argument('--load_path', default=None, type=str)
48 | parser.add_argument('--test_type', default='train_data', type=str)
49 | parser.add_argument('--blur_routine', default='Incremental', type=str,
50 | help='This will set the type of blur routine one can use, check the code for what each one of them does in detail')
51 | parser.add_argument('--train_routine', default='Final', type=str)
52 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str,
53 | help='The choice of sampling routine for reversing the diffusion process, when set as default it corresponds to Alg. 1 while when set as x0_step_down it stands for Alg. 2')
54 | parser.add_argument('--remove_time_embed', action="store_true")
55 | parser.add_argument('--residual', action="store_true")
56 | parser.add_argument('--gmm_size', default=8, type=int)
57 | parser.add_argument('--gmm_cluster', default=10, type=int)
58 | parser.add_argument('--gmm_sample_at', default=1, type=int)
59 | parser.add_argument('--bs', default=32, type=int)
60 | parser.add_argument('--discrete', action="store_true")
61 | parser.add_argument('--noise', default=0, type=float)
62 |
63 | args = parser.parse_args()
64 | print(args)
65 |
66 | img_path=None
67 | if 'train' in args.test_type:
68 | img_path = args.data_path
69 | elif 'test' in args.test_type:
70 | img_path = args.data_path
71 |
72 | print("Img Path is ", img_path)
73 |
74 |
75 | model = Unet(
76 | dim = 64,
77 | dim_mults = (1, 2, 4, 8),
78 | channels=3,
79 | with_time_emb=not(args.remove_time_embed),
80 | residual=args.residual
81 | ).cuda()
82 |
83 |
84 | diffusion = GaussianDiffusion(
85 | model,
86 | image_size = 128,
87 | device_of_kernel = 'cuda',
88 | channels = 3,
89 | timesteps = args.time_steps, # number of steps
90 | loss_type = 'l1', # L1 or L2
91 | kernel_std=args.blur_std,
92 | kernel_size=args.blur_size,
93 | blur_routine=args.blur_routine,
94 | train_routine=args.train_routine,
95 | sampling_routine = args.sampling_routine,
96 | discrete=args.discrete
97 | ).cuda()
98 |
99 | import torch
100 | diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count()))
101 |
102 |
103 | trainer = Trainer(
104 | diffusion,
105 | img_path,
106 | image_size = 128,
107 | train_batch_size = args.bs,
108 | train_lr = 2e-5,
109 | train_num_steps = 700000, # total training steps
110 | gradient_accumulate_every = 2, # gradient accumulation steps
111 | ema_decay = 0.995, # exponential moving average decay
112 | fp16 = False, # turn on mixed precision training with apex
113 | results_folder = args.save_folder,
114 | load_path = args.load_path,
115 | dataset = None,
116 | shuffle=False
117 | )
118 |
119 | if args.test_type == 'train_data':
120 | trainer.test_from_data('train', s_times=args.sample_steps)
121 |
122 | elif args.test_type == 'test_data':
123 | trainer.test_from_data('test', s_times=args.sample_steps)
124 |
125 |
126 |
127 | # this for extreme case and we perform gmm at the meaned vectors and test with what happens with noise
128 | elif args.test_type == 'train_distribution_mean_blur_torch_gmm':
129 | trainer.sample_as_a_mean_blur_torch_gmm(GaussianMixture, start=0, end=26900, ch=3, clusters=args.gmm_cluster)
130 |
131 | # to save
132 | elif args.test_type == 'train_distribution_mean_blur_torch_gmm_ablation':
133 | trainer.sample_as_a_mean_blur_torch_gmm_ablation(GaussianMixture, ch=3, clusters=args.gmm_cluster, noise=args.noise)
134 |
135 |
136 | # this for the non extreme case and you can create gmm at any point in the blur sequence and then blur from there
137 | elif args.test_type == 'train_distribution_blur_torch_gmm':
138 | trainer.sample_as_a_blur_torch_gmm(GaussianMixture, siz=args.gmm_size, ch=3, clusters=args.gmm_cluster, sample_at=args.gmm_sample_at)
139 |
140 |
141 | ######### Quant ##############
142 | elif args.test_type == 'test_fid_distance_decrease_from_manifold':
143 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
144 |
145 | elif args.test_type == 'train_fid_distance_decrease_from_manifold':
146 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
147 |
148 |
149 | ########## for paper ##########
150 |
151 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page':
152 | trainer.paper_showing_diffusion_images_cover_page()
153 |
154 | elif args.test_type == 'train_paper_showing_diffusion_images_cover_page_both_sampling':
155 | trainer.paper_showing_diffusion_images_cover_page_both_sampling()
156 |
157 |
158 |
--------------------------------------------------------------------------------
/snowification/test.py:
--------------------------------------------------------------------------------
1 | from diffusion import GaussianDiffusion, Trainer, get_dataset
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 | from diffusion.model.get_model import get_model
8 | from Fid.fid_score import calculate_fid_given_samples
9 | import torch
10 | import random
11 |
12 |
13 |
14 | def create_folder(path):
15 | try:
16 | os.mkdir(path)
17 | except OSError as exc:
18 | if exc.errno != errno.EEXIST:
19 | raise
20 | pass
21 |
22 | def del_folder(path):
23 | try:
24 | shutil.rmtree(path)
25 | except OSError as exc:
26 | pass
27 |
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument('--time_steps', default=50, type=int)
30 | parser.add_argument('--sample_steps', default=None, type=int)
31 | parser.add_argument('--save_folder_train', default='/cmlscratch/hmchu/cold_diff/', type=str)
32 | parser.add_argument('--save_folder_test', default='/cmlscratch/hmchu/cold_diff_paper_test/', type=str)
33 | parser.add_argument('--load_path', default=None, type=str)
34 | parser.add_argument('--test_type', default='train_data', type=str)
35 | parser.add_argument('--train_routine', default='Final', type=str)
36 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
37 | parser.add_argument('--remove_time_embed', action="store_true")
38 |
39 | parser.add_argument('--exp_name', default='', type=str)
40 | parser.add_argument('--test_postfix', default='', type=str)
41 | parser.add_argument('--dataset_folder', default='./root_cifar10', type=str)
42 | parser.add_argument('--output_mean_scale', action='store_true')
43 | parser.add_argument('--random_aug', action='store_true')
44 | parser.add_argument('--model', default='UnetConvNext', type=str)
45 | parser.add_argument('--dataset', default='cifar10')
46 |
47 | parser.add_argument('--forward_process_type', default='GaussianBlur')
48 | # GaussianBlur args
49 |
50 | # Decolor args
51 | parser.add_argument('--decolor_routine', default='Constant')
52 | parser.add_argument('--decolor_ema_factor', default=0.9, type=float)
53 | parser.add_argument('--decolor_total_remove', action='store_true')
54 |
55 | parser.add_argument('--load_model_steps', default=-1, type=int)
56 | parser.add_argument('--resume_training', action='store_true')
57 |
58 | parser.add_argument('--order_seed', default=-1.0, type=float)
59 | parser.add_argument('--resolution', default=-1, type=int)
60 |
61 | parser.add_argument('--to_lab', action='store_true')
62 | parser.add_argument('--snow_level', default=1, type=int)
63 | parser.add_argument('--random_snow', action='store_true')
64 | parser.add_argument('--single_snow', action='store_true')
65 | parser.add_argument('--fix_brightness', action='store_true')
66 |
67 | parser.add_argument('--test_fid', action='store_true')
68 |
69 |
70 |
71 |
72 | args = parser.parse_args()
73 | assert len(args.exp_name) > 0
74 |
75 | #if args.test_type == 'test_paper':
76 | # args.save_folder_test = '/cmlscratch/hmchu/cold_diff_paper_test/'
77 |
78 | if args.load_model_steps != -1:
79 | args.load_path = os.path.join(args.save_folder_train, args.exp_name, f'model_{args.load_model_steps}.pt')
80 | else:
81 | args.load_path = os.path.join(args.save_folder_train, args.exp_name, 'model.pt')
82 |
83 | if args.test_postfix != '':
84 | save_folder_name = f'{args.exp_name}_{args.test_postfix}'
85 | else:
86 | save_folder_name = args.exp_name
87 | args.save_folder_test = os.path.join(args.save_folder_test, save_folder_name, args.test_type)
88 | print(args.save_folder_test)
89 | print(args)
90 |
91 |
92 | img_path = args.dataset_folder
93 |
94 | with_time_emb = not args.remove_time_embed
95 |
96 | model = get_model(args, with_time_emb=with_time_emb).cuda()
97 | model_one_shot = get_model(args, with_time_emb=False).cuda()
98 |
99 | image_size = get_dataset.get_image_size(args.dataset)
100 | if args.resolution != -1:
101 | image_size = (args.resolution, args.resolution)
102 | print(f'image_size: {image_size}')
103 |
104 | use_torchvison_dataset = False
105 | if image_size[0] <= 64:
106 | train_batch_size = 32
107 | elif image_size[0] > 64:
108 | train_batch_size = 16
109 |
110 | print(args.dataset)
111 |
112 | seed_value=args.order_seed
113 | torch.manual_seed(seed_value) # cpu vars
114 | random.seed(seed_value) # Python
115 | torch.cuda.manual_seed(seed_value)
116 | torch.cuda.manual_seed_all(seed_value)
117 | torch.backends.cudnn.deterministic = True
118 | torch.backends.cudnn.benchmark = False
119 |
120 |
121 |
122 |
123 | diffusion = GaussianDiffusion(
124 | model,
125 | image_size = image_size,
126 | device_of_kernel = 'cuda',
127 | channels = 3,
128 | one_shot_denoise_fn=model_one_shot,
129 | timesteps = args.time_steps, # number of steps
130 | loss_type = 'l1', # L1 or L2
131 | train_routine=args.train_routine,
132 | sampling_routine = args.sampling_routine,
133 | forward_process_type=args.forward_process_type,
134 | decolor_routine=args.decolor_routine,
135 | decolor_ema_factor=args.decolor_ema_factor,
136 | decolor_total_remove=args.decolor_total_remove,
137 | snow_level=args.snow_level,
138 | random_snow=args.random_snow,
139 | single_snow=args.single_snow,
140 | batch_size=train_batch_size,
141 | to_lab=args.to_lab,
142 | load_snow_base=False,
143 | fix_brightness=args.fix_brightness,
144 | load_path = args.load_path,
145 | results_folder = args.save_folder_test,
146 | ).cuda()
147 |
148 | trainer = Trainer(
149 | diffusion,
150 | img_path,
151 | image_size = image_size,
152 | train_batch_size = train_batch_size,
153 | train_lr = 2e-5,
154 | train_num_steps = 700000, # total training steps
155 | gradient_accumulate_every = 2, # gradient accumulation steps
156 | ema_decay = 0.995, # exponential moving average decay
157 | fp16 = False, # turn on mixed precision training with apex
158 | results_folder = args.save_folder_test,
159 | load_path = args.load_path,
160 | random_aug=args.random_aug,
161 | torchvision_dataset=use_torchvison_dataset,
162 | dataset=f'{args.dataset}',
163 | order_seed=args.order_seed,
164 | to_lab=args.to_lab,
165 | )
166 |
167 | if args.test_type == 'train_data':
168 | trainer.test_from_data('train', s_times=args.sample_steps)
169 | if args.test_fid:
170 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=1000)
171 |
172 | elif args.test_type == 'test_data':
173 | trainer.test_from_data('test', s_times=args.sample_steps)
174 | if args.test_fid:
175 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=2000)
176 |
177 | elif args.test_type == 'test_paper':
178 | trainer.paper_invert_section_images()
179 | if args.test_fid:
180 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=1000)
181 |
182 | elif args.test_type == 'test_paper_series':
183 | trainer.paper_showing_diffusion_images()
184 |
185 | elif args.test_type == 'test_rebuttal':
186 | trainer.paper_showing_diffusion_images_cover_page()
187 |
188 |
--------------------------------------------------------------------------------
/decolor-diffusion/test.py:
--------------------------------------------------------------------------------
1 | from diffusion import GaussianDiffusion, Trainer, get_dataset
2 | import torchvision
3 | import os
4 | import errno
5 | import shutil
6 | import argparse
7 | from diffusion.model.get_model import get_model
8 | from Fid.fid_score import calculate_fid_given_samples
9 | import torch
10 | import random
11 |
12 |
13 |
14 | def create_folder(path):
15 | try:
16 | os.mkdir(path)
17 | except OSError as exc:
18 | if exc.errno != errno.EEXIST:
19 | raise
20 | pass
21 |
22 | def del_folder(path):
23 | try:
24 | shutil.rmtree(path)
25 | except OSError as exc:
26 | pass
27 |
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument('--time_steps', default=50, type=int)
30 | parser.add_argument('--sample_steps', default=None, type=int)
31 | parser.add_argument('--save_folder_train', default='/cmlscratch/hmchu/cold_diff/', type=str)
32 | parser.add_argument('--save_folder_test', default='/cmlscratch/hmchu/cold_diff_paper_test/', type=str)
33 | parser.add_argument('--load_path', default=None, type=str)
34 | parser.add_argument('--test_type', default='train_data', type=str)
35 | parser.add_argument('--train_routine', default='Final', type=str)
36 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
37 | parser.add_argument('--remove_time_embed', action="store_true")
38 |
39 | parser.add_argument('--exp_name', default='', type=str)
40 | parser.add_argument('--test_postfix', default='', type=str)
41 | parser.add_argument('--dataset_folder', default='./root_cifar10', type=str)
42 | parser.add_argument('--output_mean_scale', action='store_true')
43 | parser.add_argument('--random_aug', action='store_true')
44 | parser.add_argument('--model', default='UnetConvNext', type=str)
45 | parser.add_argument('--dataset', default='cifar10')
46 |
47 | parser.add_argument('--forward_process_type', default='GaussianBlur')
48 | # GaussianBlur args
49 |
50 | # Decolor args
51 | parser.add_argument('--decolor_routine', default='Constant')
52 | parser.add_argument('--decolor_ema_factor', default=0.9, type=float)
53 | parser.add_argument('--decolor_total_remove', action='store_true')
54 |
55 | parser.add_argument('--load_model_steps', default=-1, type=int)
56 | parser.add_argument('--resume_training', action='store_true')
57 |
58 | parser.add_argument('--order_seed', default=-1.0, type=float)
59 | parser.add_argument('--resolution', default=-1, type=int)
60 |
61 | parser.add_argument('--to_lab', action='store_true')
62 | parser.add_argument('--snow_level', default=1, type=int)
63 | parser.add_argument('--random_snow', action='store_true')
64 | parser.add_argument('--single_snow', action='store_true')
65 | parser.add_argument('--fix_brightness', action='store_true')
66 |
67 | parser.add_argument('--test_fid', action='store_true')
68 |
69 |
70 |
71 |
72 | args = parser.parse_args()
73 | assert len(args.exp_name) > 0
74 |
75 | #if args.test_type == 'test_paper':
76 | # args.save_folder_test = '/cmlscratch/hmchu/cold_diff_paper_test/'
77 |
78 | if args.load_model_steps != -1:
79 | args.load_path = os.path.join(args.save_folder_train, args.exp_name, f'model_{args.load_model_steps}.pt')
80 | else:
81 | args.load_path = os.path.join(args.save_folder_train, args.exp_name, 'model.pt')
82 |
83 | if args.test_postfix != '':
84 | save_folder_name = f'{args.exp_name}_{args.test_postfix}'
85 | else:
86 | save_folder_name = args.exp_name
87 | args.save_folder_test = os.path.join(args.save_folder_test, save_folder_name, args.test_type)
88 | print(args.save_folder_test)
89 | print(args)
90 |
91 |
92 | img_path = args.dataset_folder
93 |
94 | with_time_emb = not args.remove_time_embed
95 |
96 | model = get_model(args, with_time_emb=with_time_emb).cuda()
97 | model_one_shot = get_model(args, with_time_emb=False).cuda()
98 |
99 | image_size = get_dataset.get_image_size(args.dataset)
100 | if args.resolution != -1:
101 | image_size = (args.resolution, args.resolution)
102 | print(f'image_size: {image_size}')
103 |
104 | use_torchvison_dataset = False
105 | if image_size[0] <= 64:
106 | train_batch_size = 32
107 | elif image_size[0] > 64:
108 | train_batch_size = 16
109 |
110 | print(args.dataset)
111 |
112 | seed_value=args.order_seed
113 | torch.manual_seed(seed_value) # cpu vars
114 | random.seed(seed_value) # Python
115 | torch.cuda.manual_seed(seed_value)
116 | torch.cuda.manual_seed_all(seed_value)
117 | torch.backends.cudnn.deterministic = True
118 | torch.backends.cudnn.benchmark = False
119 |
120 |
121 |
122 |
123 | diffusion = GaussianDiffusion(
124 | model,
125 | image_size = image_size,
126 | device_of_kernel = 'cuda',
127 | channels = 3,
128 | one_shot_denoise_fn=model_one_shot,
129 | timesteps = args.time_steps, # number of steps
130 | loss_type = 'l1', # L1 or L2
131 | train_routine=args.train_routine,
132 | sampling_routine = args.sampling_routine,
133 | forward_process_type=args.forward_process_type,
134 | decolor_routine=args.decolor_routine,
135 | decolor_ema_factor=args.decolor_ema_factor,
136 | decolor_total_remove=args.decolor_total_remove,
137 | snow_level=args.snow_level,
138 | random_snow=args.random_snow,
139 | single_snow=args.single_snow,
140 | batch_size=train_batch_size,
141 | to_lab=args.to_lab,
142 | load_snow_base=False,
143 | fix_brightness=args.fix_brightness,
144 | load_path = args.load_path,
145 | results_folder = args.save_folder_test,
146 | ).cuda()
147 |
148 | trainer = Trainer(
149 | diffusion,
150 | img_path,
151 | image_size = image_size,
152 | train_batch_size = train_batch_size,
153 | train_lr = 2e-5,
154 | train_num_steps = 700000, # total training steps
155 | gradient_accumulate_every = 2, # gradient accumulation steps
156 | ema_decay = 0.995, # exponential moving average decay
157 | fp16 = False, # turn on mixed precision training with apex
158 | results_folder = args.save_folder_test,
159 | load_path = args.load_path,
160 | random_aug=args.random_aug,
161 | torchvision_dataset=use_torchvison_dataset,
162 | dataset=f'{args.dataset}',
163 | order_seed=args.order_seed,
164 | to_lab=args.to_lab,
165 | )
166 |
167 | if args.test_type == 'train_data':
168 | trainer.test_from_data('train', s_times=args.sample_steps)
169 | if args.test_fid:
170 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=1000)
171 |
172 | elif args.test_type == 'test_data':
173 | trainer.test_from_data('test', s_times=args.sample_steps)
174 | if args.test_fid:
175 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=2000)
176 |
177 | elif args.test_type == 'test_paper':
178 | trainer.paper_invert_section_images()
179 | if args.test_fid:
180 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=1000)
181 |
182 | elif args.test_type == 'test_paper_series':
183 | trainer.paper_showing_diffusion_images()
184 |
185 | elif args.test_type == 'test_rebuttal':
186 | trainer.paper_showing_diffusion_images_cover_page()
187 |
188 |
--------------------------------------------------------------------------------
/resolution-diffusion-pytorch/mnist_test.py:
--------------------------------------------------------------------------------
1 | from pycave.bayes import GaussianMixture
2 | from resolution_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
3 | import torchvision
4 | import os
5 | import errno
6 | import shutil
7 | import argparse
8 | from Fid import calculate_fid_given_samples
9 |
10 | def create_folder(path):
11 | try:
12 | os.mkdir(path)
13 | except OSError as exc:
14 | if exc.errno != errno.EEXIST:
15 | raise
16 | pass
17 |
18 | def del_folder(path):
19 | try:
20 | shutil.rmtree(path)
21 | except OSError as exc:
22 | pass
23 |
24 | create = 0
25 |
26 | if create:
27 | trainset = torchvision.datasets.MNIST(
28 | root='./data', train=False, download=True)
29 | root = './root_mnist_test/'
30 | del_folder(root)
31 | create_folder(root)
32 |
33 | for i in range(10):
34 | lable_root = root + str(i) + '/'
35 | create_folder(lable_root)
36 |
37 | for idx in range(len(trainset)):
38 | img, label = trainset[idx]
39 | print(idx)
40 | img.save(root + str(label) + '/' + str(idx) + '.png')
41 |
42 |
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument('--time_steps', default=50, type=int)
45 | parser.add_argument('--sample_steps', default=None, type=int)
46 | parser.add_argument('--save_folder', default='./results_mnist', type=str)
47 | parser.add_argument('--load_path', default=None, type=str)
48 | parser.add_argument('--data_path', default='./root_mnist_test/', type=str)
49 | parser.add_argument('--test_type', default='train_data', type=str)
50 | parser.add_argument('--resolution_routine', default='Incremental', type=str)
51 | parser.add_argument('--train_routine', default='Final', type=str)
52 | parser.add_argument('--sampling_routine', default='x0_step_down', type=str)
53 | parser.add_argument('--gmm_size', default=8, type=int)
54 | parser.add_argument('--gmm_cluster', default=10, type=int)
55 |
56 | args = parser.parse_args()
57 | print(args)
58 |
59 | img_path=None
60 | if 'train' in args.test_type:
61 | img_path = args.data_path
62 | elif 'test' in args.test_type:
63 | img_path = args.data_path
64 |
65 | print("Img Path is ", img_path)
66 |
67 |
68 | model = Unet(
69 | dim = 64,
70 | dim_mults = (1, 2, 4, 8),
71 | channels=1
72 | ).cuda()
73 |
74 |
75 | diffusion = GaussianDiffusion(
76 | model,
77 | image_size = 32,
78 | device_of_kernel = 'cuda',
79 | channels = 1,
80 | timesteps = args.time_steps, # number of steps
81 | loss_type = 'l1', # L1 or L2
82 | resolution_routine=args.resolution_routine,
83 | train_routine=args.train_routine,
84 | sampling_routine = args.sampling_routine
85 | ).cuda()
86 |
87 | trainer = Trainer(
88 | diffusion,
89 | img_path,
90 | image_size = 32,
91 | train_batch_size = 32,
92 | train_lr = 2e-5,
93 | train_num_steps = 700000, # total training steps
94 | gradient_accumulate_every = 2, # gradient accumulation steps
95 | ema_decay = 0.995, # exponential moving average decay
96 | fp16 = False, # turn on mixed precision training with apex
97 | results_folder = args.save_folder,
98 | load_path = args.load_path
99 | )
100 |
101 | if args.test_type == 'train_data':
102 | trainer.test_from_data('train', s_times=args.sample_steps)
103 |
104 | elif args.test_type == 'test_data':
105 | trainer.test_from_data('test', s_times=args.sample_steps)
106 |
107 | elif args.test_type == 'mixup_train_data':
108 | trainer.test_with_mixup('train')
109 |
110 | elif args.test_type == 'mixup_test_data':
111 | trainer.test_with_mixup('test')
112 |
113 | elif args.test_type == 'test_random':
114 | trainer.test_from_random('random')
115 |
116 | elif args.test_type == 'train_distribution_cov_vector':
117 | trainer.sample_as_a_vector_cov(start=0, end=None, siz=4, ch=1)
118 | trainer.sample_as_a_vector_cov(start=0, end=None, siz=8, ch=1)
119 | trainer.sample_as_a_vector_cov(start=0, end=None, siz=16, ch=1)
120 | # trainer.sample_as_a_vector_cov(start=0, end=None, siz=16, ch=1)
121 | # trainer.sample_as_a_vector_cov(start=0, end=None, siz=8, ch=1)
122 | # trainer.sample_as_a_vector_cov(start=0, end=None, siz=4, ch=1)
123 |
124 | elif args.test_type == 'train_distribution_gmm':
125 | trainer.sample_as_a_vector_gmm(start=0, end=None, siz=args.gmm_size, ch=1, clusters=args.gmm_cluster)
126 | # trainer.sample_as_a_vector_gmm(start=0, end=None, siz=8, ch=1, clusters=10)
127 | # trainer.sample_as_a_vector_gmm(start=0, end=None, siz=4, ch=1, clusters=10)
128 | # trainer.sample_as_a_vector_gmm(start=0, end=None, siz=16, ch=1, clusters=10)
129 |
130 | # elif args.test_type == 'train_distribution_save_gmm':
131 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=8, ch=1, clusters=10)
132 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=4, ch=1, clusters=10)
133 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=16, ch=1, clusters=10)
134 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=32, ch=1, clusters=10)
135 | #
136 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=8, ch=1, clusters=20)
137 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=4, ch=1, clusters=20)
138 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=16, ch=1, clusters=20)
139 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=32, ch=1, clusters=20)
140 | #
141 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=8, ch=1, clusters=50)
142 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=4, ch=1, clusters=50)
143 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=16, ch=1, clusters=50)
144 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=32, ch=1, clusters=50)
145 | #
146 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=8, ch=1, clusters=100)
147 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=4, ch=1, clusters=100)
148 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=16, ch=1, clusters=100)
149 | # trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=32, ch=1, clusters=100)
150 |
151 | elif args.test_type == 'train_distribution_save_gmm':
152 | trainer.sample_as_a_vector_gmm_and_save(start=0, end=None, siz=args.gmm_size, ch=1, clusters=args.gmm_cluster)
153 |
154 | elif args.test_type == 'train_distribution_save_gmm_slowly':
155 | trainer.sample_as_a_vector_gmm_and_save_slowly(start=0, end=None, siz=args.gmm_size, ch=1, clusters=args.gmm_cluster)
156 |
157 | elif args.test_type == 'train_save_orig_data_same_as_trained':
158 | trainer.save_training_data()
159 |
160 | elif args.test_type == 'test_save_orig_data_same_as_tested':
161 | trainer.save_training_data()
162 |
163 | elif args.test_type == 'test_fid_distance_decrease_from_manifold':
164 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
165 |
166 | elif args.test_type == 'train_fid_distance_decrease_from_manifold':
167 | trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None)
168 |
169 | elif args.test_type == 'sample_from_train_data':
170 | trainer.sample_from_data_save(start=0, end=None)
171 |
172 | elif args.test_type == 'sample_from_test_data':
173 | trainer.sample_from_data_save(start=0, end=None)
174 |
175 | elif args.test_type == 'train_distribution_save_pytorch_gmm':
176 | trainer.sample_as_a_vector_pytorch_gmm_and_save(GaussianMixture, start=0, end=59000, siz=args.gmm_size, ch=1, clusters=args.gmm_cluster)
177 |
178 | elif args.test_type == 'test_paper_invert_section_images':
179 | trainer.paper_invert_section_images()
180 |
181 | elif args.test_type == 'test_paper_showing_diffusion_images':
182 | trainer.paper_showing_diffusion_images()
183 |
184 | elif args.test_type == 'test_paper_showing_diffusion_imgs_og':
185 | trainer.paper_showing_diffusion_imgs_og()
186 |
--------------------------------------------------------------------------------
/decolor-diffusion/diffusion/model/unet_convnext.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from inspect import isfunction
5 | from einops import rearrange
6 |
7 | def exists(x):
8 | return x is not None
9 |
10 | def default(val, d):
11 | if exists(val):
12 | return val
13 | return d() if isfunction(d) else d
14 |
15 |
16 |
17 | class Residual(nn.Module):
18 | def __init__(self, fn):
19 | super().__init__()
20 | self.fn = fn
21 |
22 | def forward(self, x, *args, **kwargs):
23 | return self.fn(x, *args, **kwargs) + x
24 |
25 | class SinusoidalPosEmb(nn.Module):
26 | def __init__(self, dim):
27 | super().__init__()
28 | self.dim = dim
29 |
30 | def forward(self, x):
31 | device = x.device
32 | half_dim = self.dim // 2
33 | emb = math.log(10000) / (half_dim - 1)
34 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
35 | emb = x[:, None] * emb[None, :]
36 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
37 | return emb
38 |
39 | def Upsample(dim):
40 | return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
41 |
42 | def Downsample(dim):
43 | return nn.Conv2d(dim, dim, 4, 2, 1)
44 |
45 | class LayerNorm(nn.Module):
46 | def __init__(self, dim, eps = 1e-5):
47 | super().__init__()
48 | self.eps = eps
49 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
50 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
51 |
52 | def forward(self, x):
53 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
54 | mean = torch.mean(x, dim = 1, keepdim = True)
55 | return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
56 |
57 | class PreNorm(nn.Module):
58 | def __init__(self, dim, fn):
59 | super().__init__()
60 | self.fn = fn
61 | self.norm = LayerNorm(dim)
62 |
63 | def forward(self, x):
64 | x = self.norm(x)
65 | return self.fn(x)
66 |
67 | # building block modules
68 |
69 | class ConvNextBlock(nn.Module):
70 | """ https://arxiv.org/abs/2201.03545 """
71 |
72 | def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):
73 | super().__init__()
74 | self.mlp = nn.Sequential(
75 | nn.GELU(),
76 | nn.Linear(time_emb_dim, dim)
77 | ) if exists(time_emb_dim) else None
78 |
79 | self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
80 |
81 | self.net = nn.Sequential(
82 | LayerNorm(dim) if norm else nn.Identity(),
83 | nn.Conv2d(dim, dim_out * mult, 3, padding = 1),
84 | nn.GELU(),
85 | nn.Conv2d(dim_out * mult, dim_out, 3, padding = 1)
86 | )
87 |
88 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
89 |
90 | def forward(self, x, time_emb = None):
91 | h = self.ds_conv(x)
92 |
93 | if exists(self.mlp):
94 | assert exists(time_emb), 'time emb must be passed in'
95 | condition = self.mlp(time_emb)
96 | h = h + rearrange(condition, 'b c -> b c 1 1')
97 |
98 | h = self.net(h)
99 | return h + self.res_conv(x)
100 |
101 | class LinearAttention(nn.Module):
102 | def __init__(self, dim, heads = 4, dim_head = 32):
103 | super().__init__()
104 | self.scale = dim_head ** -0.5
105 | self.heads = heads
106 | hidden_dim = dim_head * heads
107 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
108 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
109 |
110 | def forward(self, x):
111 | b, c, h, w = x.shape
112 | qkv = self.to_qkv(x).chunk(3, dim = 1)
113 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
114 | q = q * self.scale
115 |
116 | k = k.softmax(dim = -1)
117 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
118 |
119 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
120 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
121 | return self.to_out(out)
122 |
123 | # model
124 |
125 | class UnetConvNextBlock(nn.Module):
126 | def __init__(
127 | self,
128 | dim,
129 | out_dim = None,
130 | dim_mults=(1, 2, 4, 8),
131 | channels = 3,
132 | with_time_emb = True,
133 | output_mean_scale = False,
134 | residual = False,
135 | ):
136 | super().__init__()
137 | self.channels = channels
138 | self.residual = residual
139 | print("Is Time embed used ? ", with_time_emb)
140 | self.output_mean_scale = output_mean_scale
141 |
142 | dims = [channels, *map(lambda m: dim * m, dim_mults)]
143 | in_out = list(zip(dims[:-1], dims[1:]))
144 |
145 | if with_time_emb:
146 | time_dim = dim
147 | self.time_mlp = nn.Sequential(
148 | SinusoidalPosEmb(dim),
149 | nn.Linear(dim, dim * 4),
150 | nn.GELU(),
151 | nn.Linear(dim * 4, dim)
152 | )
153 | else:
154 | time_dim = None
155 | self.time_mlp = None
156 |
157 | self.downs = nn.ModuleList([])
158 | self.ups = nn.ModuleList([])
159 | num_resolutions = len(in_out)
160 |
161 | for ind, (dim_in, dim_out) in enumerate(in_out):
162 | is_last = ind >= (num_resolutions - 1)
163 |
164 | self.downs.append(nn.ModuleList([
165 | ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0),
166 | ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim),
167 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
168 | Downsample(dim_out) if not is_last else nn.Identity()
169 | ]))
170 |
171 | mid_dim = dims[-1]
172 | self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
173 | self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim)))
174 | self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
175 |
176 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
177 | is_last = ind >= (num_resolutions - 1)
178 |
179 | self.ups.append(nn.ModuleList([
180 | ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim),
181 | ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim),
182 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
183 | Upsample(dim_in) if not is_last else nn.Identity()
184 | ]))
185 |
186 | out_dim = default(out_dim, channels)
187 | self.final_conv = nn.Sequential(
188 | ConvNextBlock(dim, dim),
189 | nn.Conv2d(dim, out_dim, 1)
190 | )
191 |
192 | def forward(self, x, time=None):
193 | orig_x = x
194 | t = None
195 | if time is not None and exists(self.time_mlp):
196 | t = self.time_mlp(time)
197 |
198 | original_mean = torch.mean(x, [1, 2, 3], keepdim=True)
199 | h = []
200 |
201 | for convnext, convnext2, attn, downsample in self.downs:
202 | x = convnext(x, t)
203 | x = convnext2(x, t)
204 | x = attn(x)
205 | h.append(x)
206 | x = downsample(x)
207 |
208 | x = self.mid_block1(x, t)
209 | x = self.mid_attn(x)
210 | x = self.mid_block2(x, t)
211 |
212 | for convnext, convnext2, attn, upsample in self.ups:
213 | x = torch.cat((x, h.pop()), dim=1)
214 | x = convnext(x, t)
215 | x = convnext2(x, t)
216 | x = attn(x)
217 | x = upsample(x)
218 | if self.residual:
219 | return self.final_conv(x) + orig_x
220 |
221 | out = self.final_conv(x)
222 | if self.output_mean_scale:
223 | out_mean = torch.mean(out, [1,2,3], keepdim=True)
224 | out = out - original_mean + out_mean
225 |
226 | return out
227 |
228 |
229 |
--------------------------------------------------------------------------------
/snowification/diffusion/model/unet_convnext.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from inspect import isfunction
5 | from einops import rearrange
6 |
7 | def exists(x):
8 | return x is not None
9 |
10 | def default(val, d):
11 | if exists(val):
12 | return val
13 | return d() if isfunction(d) else d
14 |
15 |
16 |
17 | class Residual(nn.Module):
18 | def __init__(self, fn):
19 | super().__init__()
20 | self.fn = fn
21 |
22 | def forward(self, x, *args, **kwargs):
23 | return self.fn(x, *args, **kwargs) + x
24 |
25 | class SinusoidalPosEmb(nn.Module):
26 | def __init__(self, dim):
27 | super().__init__()
28 | self.dim = dim
29 |
30 | def forward(self, x):
31 | device = x.device
32 | half_dim = self.dim // 2
33 | emb = math.log(10000) / (half_dim - 1)
34 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
35 | emb = x[:, None] * emb[None, :]
36 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
37 | return emb
38 |
39 | def Upsample(dim):
40 | return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
41 |
42 | def Downsample(dim):
43 | return nn.Conv2d(dim, dim, 4, 2, 1)
44 |
45 | class LayerNorm(nn.Module):
46 | def __init__(self, dim, eps = 1e-5):
47 | super().__init__()
48 | self.eps = eps
49 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
50 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
51 |
52 | def forward(self, x):
53 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
54 | mean = torch.mean(x, dim = 1, keepdim = True)
55 | return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
56 |
57 | class PreNorm(nn.Module):
58 | def __init__(self, dim, fn):
59 | super().__init__()
60 | self.fn = fn
61 | self.norm = LayerNorm(dim)
62 |
63 | def forward(self, x):
64 | x = self.norm(x)
65 | return self.fn(x)
66 |
67 | # building block modules
68 |
69 | class ConvNextBlock(nn.Module):
70 | """ https://arxiv.org/abs/2201.03545 """
71 |
72 | def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):
73 | super().__init__()
74 | self.mlp = nn.Sequential(
75 | nn.GELU(),
76 | nn.Linear(time_emb_dim, dim)
77 | ) if exists(time_emb_dim) else None
78 |
79 | self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
80 |
81 | self.net = nn.Sequential(
82 | LayerNorm(dim) if norm else nn.Identity(),
83 | nn.Conv2d(dim, dim_out * mult, 3, padding = 1),
84 | nn.GELU(),
85 | nn.Conv2d(dim_out * mult, dim_out, 3, padding = 1)
86 | )
87 |
88 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
89 |
90 | def forward(self, x, time_emb = None):
91 | h = self.ds_conv(x)
92 |
93 | if exists(self.mlp):
94 | assert exists(time_emb), 'time emb must be passed in'
95 | condition = self.mlp(time_emb)
96 | h = h + rearrange(condition, 'b c -> b c 1 1')
97 |
98 | h = self.net(h)
99 | return h + self.res_conv(x)
100 |
101 | class LinearAttention(nn.Module):
102 | def __init__(self, dim, heads = 4, dim_head = 32):
103 | super().__init__()
104 | self.scale = dim_head ** -0.5
105 | self.heads = heads
106 | hidden_dim = dim_head * heads
107 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
108 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
109 |
110 | def forward(self, x):
111 | b, c, h, w = x.shape
112 | qkv = self.to_qkv(x).chunk(3, dim = 1)
113 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
114 | q = q * self.scale
115 |
116 | k = k.softmax(dim = -1)
117 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
118 |
119 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
120 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
121 | return self.to_out(out)
122 |
123 | # model
124 |
125 | class UnetConvNextBlock(nn.Module):
126 | def __init__(
127 | self,
128 | dim,
129 | out_dim = None,
130 | dim_mults=(1, 2, 4, 8),
131 | channels = 3,
132 | with_time_emb = True,
133 | output_mean_scale = False,
134 | residual = False,
135 | ):
136 | super().__init__()
137 | self.channels = channels
138 | self.residual = residual
139 | print("Is Time embed used ? ", with_time_emb)
140 | self.output_mean_scale = output_mean_scale
141 |
142 | dims = [channels, *map(lambda m: dim * m, dim_mults)]
143 | in_out = list(zip(dims[:-1], dims[1:]))
144 |
145 | if with_time_emb:
146 | time_dim = dim
147 | self.time_mlp = nn.Sequential(
148 | SinusoidalPosEmb(dim),
149 | nn.Linear(dim, dim * 4),
150 | nn.GELU(),
151 | nn.Linear(dim * 4, dim)
152 | )
153 | else:
154 | time_dim = None
155 | self.time_mlp = None
156 |
157 | self.downs = nn.ModuleList([])
158 | self.ups = nn.ModuleList([])
159 | num_resolutions = len(in_out)
160 |
161 | for ind, (dim_in, dim_out) in enumerate(in_out):
162 | is_last = ind >= (num_resolutions - 1)
163 |
164 | self.downs.append(nn.ModuleList([
165 | ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0),
166 | ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim),
167 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
168 | Downsample(dim_out) if not is_last else nn.Identity()
169 | ]))
170 |
171 | mid_dim = dims[-1]
172 | self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
173 | self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim)))
174 | self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
175 |
176 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
177 | is_last = ind >= (num_resolutions - 1)
178 |
179 | self.ups.append(nn.ModuleList([
180 | ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim),
181 | ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim),
182 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
183 | Upsample(dim_in) if not is_last else nn.Identity()
184 | ]))
185 |
186 | out_dim = default(out_dim, channels)
187 | self.final_conv = nn.Sequential(
188 | ConvNextBlock(dim, dim),
189 | nn.Conv2d(dim, out_dim, 1)
190 | )
191 |
192 | def forward(self, x, time=None):
193 | orig_x = x
194 | t = None
195 | if time is not None and exists(self.time_mlp):
196 | t = self.time_mlp(time)
197 |
198 | original_mean = torch.mean(x, [1, 2, 3], keepdim=True)
199 | h = []
200 |
201 | for convnext, convnext2, attn, downsample in self.downs:
202 | x = convnext(x, t)
203 | x = convnext2(x, t)
204 | x = attn(x)
205 | h.append(x)
206 | x = downsample(x)
207 |
208 | x = self.mid_block1(x, t)
209 | x = self.mid_attn(x)
210 | x = self.mid_block2(x, t)
211 |
212 | for convnext, convnext2, attn, upsample in self.ups:
213 | x = torch.cat((x, h.pop()), dim=1)
214 | x = convnext(x, t)
215 | x = convnext2(x, t)
216 | x = attn(x)
217 | x = upsample(x)
218 | if self.residual:
219 | return self.final_conv(x) + orig_x
220 |
221 | out = self.final_conv(x)
222 | if self.output_mean_scale:
223 | out_mean = torch.mean(out, [1,2,3], keepdim=True)
224 | out = out - original_mean + out_mean
225 |
226 | return out
227 |
228 |
229 |
--------------------------------------------------------------------------------
/snowification/diffusion/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from kornia.color.rgb import linear_rgb_to_rgb, rgb_to_linear_rgb
7 | from kornia.color.xyz import rgb_to_xyz, xyz_to_rgb
8 |
9 | def rgb2hsv(image_old: torch.Tensor, eps: float = 1e-8, rescale=True) -> torch.Tensor:
10 | r"""Convert an image from RGB to HSV.
11 |
12 | .. image:: _static/img/rgb_to_hsv.png
13 |
14 | The image data is assumed to be in the range of (0, 1).
15 |
16 | Args:
17 | image: RGB Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.
18 | eps: scalar to enforce numarical stability.
19 |
20 | Returns:
21 | HSV version of the image with shape of :math:`(*, 3, H, W)`.
22 | The H channel values are in the range 0..2pi. S and V are in the range 0..1.
23 |
24 | .. note::
25 | See a working example `here `__.
27 |
28 | Example:
29 | >>> input = torch.rand(2, 3, 4, 5)
30 | >>> output = rgb_to_hsv(input) # 2x3x4x5
31 | """
32 | if rescale:
33 | image = (image_old + 1) * 0.5
34 | if not isinstance(image, torch.Tensor):
35 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
36 |
37 | if len(image.shape) < 3 or image.shape[-3] != 3:
38 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
39 |
40 | max_rgb, argmax_rgb = image.max(-3)
41 | min_rgb, argmin_rgb = image.min(-3)
42 | deltac = max_rgb - min_rgb
43 |
44 | v = max_rgb
45 | s = deltac / (max_rgb + eps)
46 |
47 | deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)
48 | rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)
49 |
50 | h1 = (bc - gc)
51 | h2 = (rc - bc) + 2.0 * deltac
52 | h3 = (gc - rc) + 4.0 * deltac
53 |
54 | h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)
55 | h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)
56 | h = (h / 6.0) % 1.0
57 | h = 2. * math.pi * h # we return 0/2pi output
58 |
59 | return torch.stack((h, s, v), dim=-3)
60 |
61 |
62 |
63 | def hsv2rgb(image: torch.Tensor, rescale=True) -> torch.Tensor:
64 | r"""Convert an image from HSV to RGB.
65 |
66 | The H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1.
67 |
68 | Args:
69 | image: HSV Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.
70 |
71 | Returns:
72 | RGB version of the image with shape of :math:`(*, 3, H, W)`.
73 |
74 | Example:
75 | >>> input = torch.rand(2, 3, 4, 5)
76 | >>> output = hsv_to_rgb(input) # 2x3x4x5
77 | """
78 | if not isinstance(image, torch.Tensor):
79 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
80 |
81 | if len(image.shape) < 3 or image.shape[-3] != 3:
82 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
83 |
84 | h: torch.Tensor = image[..., 0, :, :] / (2 * math.pi)
85 | s: torch.Tensor = image[..., 1, :, :]
86 | v: torch.Tensor = image[..., 2, :, :]
87 |
88 | hi: torch.Tensor = torch.floor(h * 6) % 6
89 | f: torch.Tensor = ((h * 6) % 6) - hi
90 | one: torch.Tensor = torch.tensor(1.0, device=image.device, dtype=image.dtype)
91 | p: torch.Tensor = v * (one - s)
92 | q: torch.Tensor = v * (one - f * s)
93 | t: torch.Tensor = v * (one - (one - f) * s)
94 |
95 | hi = hi.long()
96 | indices: torch.Tensor = torch.stack([hi, hi + 6, hi + 12], dim=-3)
97 | out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)
98 | out = torch.gather(out, -3, indices)
99 |
100 | if rescale:
101 | out = 2.0 * out - 1
102 |
103 | return out
104 |
105 |
106 | """
107 | The RGB to Lab color transformations were translated from scikit image's rgb2lab and lab2rgb
108 |
109 | https://github.com/scikit-image/scikit-image/blob/a48bf6774718c64dade4548153ae16065b595ca9/skimage/color/colorconv.py
110 |
111 | """
112 |
113 | def rgb2lab(image_old: torch.Tensor) -> torch.Tensor:
114 | r"""Convert a RGB image to Lab.
115 |
116 | .. image:: _static/img/rgb_to_lab.png
117 |
118 | The image data is assumed to be in the range of :math:`[0, 1]`. Lab
119 | color is computed using the D65 illuminant and Observer 2.
120 |
121 | Args:
122 | image: RGB Image to be converted to Lab with shape :math:`(*, 3, H, W)`.
123 |
124 | Returns:
125 | Lab version of the image with shape :math:`(*, 3, H, W)`.
126 | The L channel values are in the range 0..100. a and b are in the range -127..127.
127 |
128 | Example:
129 | >>> input = torch.rand(2, 3, 4, 5)
130 | >>> output = rgb_to_lab(input) # 2x3x4x5
131 | """
132 | image = (image_old + 1) * 0.5
133 | if not isinstance(image, torch.Tensor):
134 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
135 |
136 | if len(image.shape) < 3 or image.shape[-3] != 3:
137 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
138 |
139 | # Convert from sRGB to Linear RGB
140 | lin_rgb = rgb_to_linear_rgb(image)
141 |
142 | xyz_im: torch.Tensor = rgb_to_xyz(lin_rgb)
143 |
144 | # normalize for D65 white point
145 | xyz_ref_white = torch.tensor([0.95047, 1.0, 1.08883], device=xyz_im.device, dtype=xyz_im.dtype)[..., :, None, None]
146 | xyz_normalized = torch.div(xyz_im, xyz_ref_white)
147 |
148 | threshold = 0.008856
149 | power = torch.pow(xyz_normalized.clamp(min=threshold), 1 / 3.0)
150 | scale = 7.787 * xyz_normalized + 4.0 / 29.0
151 | xyz_int = torch.where(xyz_normalized > threshold, power, scale)
152 |
153 | x: torch.Tensor = xyz_int[..., 0, :, :]
154 | y: torch.Tensor = xyz_int[..., 1, :, :]
155 | z: torch.Tensor = xyz_int[..., 2, :, :]
156 |
157 | L: torch.Tensor = (116.0 * y) - 16.0
158 | a: torch.Tensor = 500.0 * (x - y)
159 | _b: torch.Tensor = 200.0 * (y - z)
160 |
161 | out: torch.Tensor = torch.stack([L, a, _b], dim=-3)
162 |
163 | return out
164 |
165 |
166 | def lab2rgb(image: torch.Tensor, clip: bool = True) -> torch.Tensor:
167 | r"""Convert a Lab image to RGB.
168 |
169 | Args:
170 | image: Lab image to be converted to RGB with shape :math:`(*, 3, H, W)`.
171 | clip: Whether to apply clipping to insure output RGB values in range :math:`[0, 1]`.
172 |
173 | Returns:
174 | Lab version of the image with shape :math:`(*, 3, H, W)`.
175 |
176 | Example:
177 | >>> input = torch.rand(2, 3, 4, 5)
178 | >>> output = lab_to_rgb(input) # 2x3x4x5
179 | """
180 | if not isinstance(image, torch.Tensor):
181 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
182 |
183 | if len(image.shape) < 3 or image.shape[-3] != 3:
184 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
185 |
186 | L: torch.Tensor = image[..., 0, :, :]
187 | a: torch.Tensor = image[..., 1, :, :]
188 | _b: torch.Tensor = image[..., 2, :, :]
189 |
190 | fy = (L + 16.0) / 116.0
191 | fx = (a / 500.0) + fy
192 | fz = fy - (_b / 200.0)
193 |
194 | # if color data out of range: Z < 0
195 | fz = fz.clamp(min=0.0)
196 |
197 | fxyz = torch.stack([fx, fy, fz], dim=-3)
198 |
199 | # Convert from Lab to XYZ
200 | power = torch.pow(fxyz, 3.0)
201 | scale = (fxyz - 4.0 / 29.0) / 7.787
202 | xyz = torch.where(fxyz > 0.2068966, power, scale)
203 |
204 | # For D65 white point
205 | xyz_ref_white = torch.tensor([0.95047, 1.0, 1.08883], device=xyz.device, dtype=xyz.dtype)[..., :, None, None]
206 | xyz_im = xyz * xyz_ref_white
207 |
208 | rgbs_im: torch.Tensor = xyz_to_rgb(xyz_im)
209 |
210 | # https://github.com/richzhang/colorization-pytorch/blob/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/util/util.py#L107
211 | # rgbs_im = torch.where(rgbs_im < 0, torch.zeros_like(rgbs_im), rgbs_im)
212 |
213 | # Convert from RGB Linear to sRGB
214 | rgb_im = linear_rgb_to_rgb(rgbs_im)
215 |
216 | # Clip to 0,1 https://www.w3.org/Graphics/Color/srgb
217 | if clip:
218 | rgb_im = torch.clamp(rgb_im, min=0.0, max=1.0)
219 |
220 | rgb_im = 2.0 * rgb_im - 1
221 |
222 | return rgb_im
223 |
224 |
225 | class RgbToLab(nn.Module):
226 | r"""Convert an image from RGB to Lab.
227 |
228 | The image data is assumed to be in the range of :math:`[0, 1]`. Lab
229 | color is computed using the D65 illuminant and Observer 2.
230 |
231 | Returns:
232 | Lab version of the image.
233 |
234 | Shape:
235 | - image: :math:`(*, 3, H, W)`
236 | - output: :math:`(*, 3, H, W)`
237 |
238 | Examples:
239 | >>> input = torch.rand(2, 3, 4, 5)
240 | >>> lab = RgbToLab()
241 | >>> output = lab(input) # 2x3x4x5
242 |
243 | Reference:
244 | [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html
245 |
246 | [2] https://www.easyrgb.com/en/math.php
247 |
248 | [3] https://github.com/torch/image/blob/dc061b98fb7e946e00034a5fc73e883a299edc7f/generic/image.c#L1467
249 | """
250 |
251 | def forward(self, image: torch.Tensor) -> torch.Tensor:
252 | return rgb_to_lab(image)
253 |
254 |
255 |
256 |
257 | class LabToRgb(nn.Module):
258 | r"""Convert an image from Lab to RGB.
259 |
260 | Returns:
261 | RGB version of the image. Range may not be in :math:`[0, 1]`.
262 |
263 | Shape:
264 | - image: :math:`(*, 3, H, W)`
265 | - output: :math:`(*, 3, H, W)`
266 |
267 | Examples:
268 | >>> input = torch.rand(2, 3, 4, 5)
269 | >>> rgb = LabToRgb()
270 | >>> output = rgb(input) # 2x3x4x5
271 |
272 | References:
273 | [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html
274 |
275 | [2] https://www.easyrgb.com/en/math.php
276 |
277 | [3] https://github.com/torch/image/blob/dc061b98fb7e946e00034a5fc73e883a299edc7f/generic/image.c#L1518
278 | """
279 |
280 | def forward(self, image: torch.Tensor, clip: bool = True) -> torch.Tensor:
281 | return lab_to_rgb(image, clip)
282 |
283 |
--------------------------------------------------------------------------------