├── CIFAR10
└── blank.txt
├── Checkpoints
└── blank.txt
├── CheckpointsCondition
└── blank.txt
├── SampledImgs
├── blank.txt
├── noisy.png
├── 104_sampled_64.png
├── NoisyGuidenceImgs.png
├── NoisyNoGuidenceImgs.png
├── SampledGuidenceImgs.png
└── SampledNoGuidenceImgs.png
├── Diffusion
├── __init__.py
├── Diffusion.py
├── Train.py
└── Model.py
├── DiffusionFreeGuidence
├── __init__.py
├── DiffusionCondition.py
├── TrainCondition.py
└── ModelCondition.py
├── LICENSE
├── MainCondition.py
├── Main.py
├── README.md
├── Scheduler.py
└── .gitignore
/CIFAR10/blank.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Checkpoints/blank.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/CheckpointsCondition/blank.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SampledImgs/blank.txt:
--------------------------------------------------------------------------------
1 | blank
2 |
--------------------------------------------------------------------------------
/Diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .Diffusion import *
2 | from .Model import *
3 | from .Train import *
--------------------------------------------------------------------------------
/SampledImgs/noisy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/noisy.png
--------------------------------------------------------------------------------
/DiffusionFreeGuidence/__init__.py:
--------------------------------------------------------------------------------
1 | from .DiffusionCondition import *
2 | from .ModelCondition import *
3 | from .TrainCondition import *
--------------------------------------------------------------------------------
/SampledImgs/104_sampled_64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/104_sampled_64.png
--------------------------------------------------------------------------------
/SampledImgs/NoisyGuidenceImgs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/NoisyGuidenceImgs.png
--------------------------------------------------------------------------------
/SampledImgs/NoisyNoGuidenceImgs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/NoisyNoGuidenceImgs.png
--------------------------------------------------------------------------------
/SampledImgs/SampledGuidenceImgs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/SampledGuidenceImgs.png
--------------------------------------------------------------------------------
/SampledImgs/SampledNoGuidenceImgs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/SampledNoGuidenceImgs.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 ZOUbohao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MainCondition.py:
--------------------------------------------------------------------------------
1 | from DiffusionFreeGuidence.TrainCondition import train, eval
2 |
3 |
4 | def main(model_config=None):
5 | modelConfig = {
6 | "state": "train", # or eval
7 | "epoch": 70,
8 | "batch_size": 80,
9 | "T": 500,
10 | "channel": 128,
11 | "channel_mult": [1, 2, 2, 2],
12 | "num_res_blocks": 2,
13 | "dropout": 0.15,
14 | "lr": 1e-4,
15 | "multiplier": 2.5,
16 | "beta_1": 1e-4,
17 | "beta_T": 0.028,
18 | "img_size": 32,
19 | "grad_clip": 1.,
20 | "device": "cuda:0",
21 | "w": 1.8,
22 | "save_dir": "./CheckpointsCondition/",
23 | "training_load_weight": None,
24 | "test_load_weight": "ckpt_63_.pt",
25 | "sampled_dir": "./SampledImgs/",
26 | "sampledNoisyImgName": "NoisyGuidenceImgs.png",
27 | "sampledImgName": "SampledGuidenceImgs.png",
28 | "nrow": 8
29 | }
30 | if model_config is not None:
31 | modelConfig = model_config
32 | if modelConfig["state"] == "train":
33 | train(modelConfig)
34 | else:
35 | eval(modelConfig)
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/Main.py:
--------------------------------------------------------------------------------
1 | from Diffusion.Train import train, eval
2 |
3 |
4 | def main(model_config = None):
5 | modelConfig = {
6 | "state": "train", # or eval
7 | "epoch": 200,
8 | "batch_size": 80,
9 | "T": 1000,
10 | "channel": 128,
11 | "channel_mult": [1, 2, 3, 4],
12 | "attn": [2],
13 | "num_res_blocks": 2,
14 | "dropout": 0.15,
15 | "lr": 1e-4,
16 | "multiplier": 2.,
17 | "beta_1": 1e-4,
18 | "beta_T": 0.02,
19 | "img_size": 32,
20 | "grad_clip": 1.,
21 | "device": "cuda:0", ### MAKE SURE YOU HAVE A GPU !!!
22 | "training_load_weight": None,
23 | "save_weight_dir": "./Checkpoints/",
24 | "test_load_weight": "ckpt_199_.pt",
25 | "sampled_dir": "./SampledImgs/",
26 | "sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
27 | "sampledImgName": "SampledNoGuidenceImgs.png",
28 | "nrow": 8
29 | }
30 | if model_config is not None:
31 | modelConfig = model_config
32 | if modelConfig["state"] == "train":
33 | train(modelConfig)
34 | else:
35 | eval(modelConfig)
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DenoisingDiffusionProbabilityModel
2 | This may be the simplest implement of DDPM. I trained with CIFAR-10 dataset. The links of pretrain weight, which trained on CIFAR-10 are in the Issue 2.
3 |
4 | If you really want to know more about the framwork of DDPM, I have listed some papers for reading by order in the closed Issue 1.
5 |
6 |
7 | Lil' Log is also a very nice blog for understanding the details of DDPM, the reference is
8 | "https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#:~:text=Diffusion%20models%20are%20inspired%20by,data%20samples%20from%20the%20noise."
9 |
10 |
11 | **HOW TO RUN**
12 | * 1. You can run Main.py to train the UNet on CIFAR-10 dataset. After training, you can set the parameters in the model config to see the amazing process of DDPM.
13 | * 2. You can run MainCondition.py to train UNet on CIFAR-10. This is for DDPM + Classifier free guidence.
14 |
15 | Some generated images are showed below:
16 |
17 | * 1. DDPM without guidence:
18 |
19 | 
20 |
21 | * 2. DDPM + Classifier free guidence:
22 |
23 | 
24 |
--------------------------------------------------------------------------------
/Scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 |
3 | class GradualWarmupScheduler(_LRScheduler):
4 | def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None):
5 | self.multiplier = multiplier
6 | self.total_epoch = warm_epoch
7 | self.after_scheduler = after_scheduler
8 | self.finished = False
9 | self.last_epoch = None
10 | self.base_lrs = None
11 | super().__init__(optimizer)
12 |
13 | def get_lr(self):
14 | if self.last_epoch > self.total_epoch:
15 | if self.after_scheduler:
16 | if not self.finished:
17 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
18 | self.finished = True
19 | return self.after_scheduler.get_lr()
20 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
21 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
22 |
23 |
24 | def step(self, epoch=None, metrics=None):
25 | if self.finished and self.after_scheduler:
26 | if epoch is None:
27 | self.after_scheduler.step(None)
28 | else:
29 | self.after_scheduler.step(epoch - self.total_epoch)
30 | else:
31 | return super(GradualWarmupScheduler, self).step(epoch)
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/Diffusion/Diffusion.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import numpy as np
7 |
8 |
9 | def extract(v, t, x_shape):
10 | """
11 | Extract some coefficients at specified timesteps, then reshape to
12 | [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
13 | """
14 | device = t.device
15 | out = torch.gather(v, index=t, dim=0).float().to(device)
16 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
17 |
18 |
19 | class GaussianDiffusionTrainer(nn.Module):
20 | def __init__(self, model, beta_1, beta_T, T):
21 | super().__init__()
22 |
23 | self.model = model
24 | self.T = T
25 |
26 | self.register_buffer(
27 | 'betas', torch.linspace(beta_1, beta_T, T).double())
28 | alphas = 1. - self.betas
29 | alphas_bar = torch.cumprod(alphas, dim=0)
30 |
31 | # calculations for diffusion q(x_t | x_{t-1}) and others
32 | self.register_buffer(
33 | 'sqrt_alphas_bar', torch.sqrt(alphas_bar))
34 | self.register_buffer(
35 | 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
36 |
37 | def forward(self, x_0):
38 | """
39 | Algorithm 1.
40 | """
41 | t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
42 | noise = torch.randn_like(x_0)
43 | x_t = (
44 | extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
45 | extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
46 | loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
47 | return loss
48 |
49 |
50 | class GaussianDiffusionSampler(nn.Module):
51 | def __init__(self, model, beta_1, beta_T, T):
52 | super().__init__()
53 |
54 | self.model = model
55 | self.T = T
56 |
57 | self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
58 | alphas = 1. - self.betas
59 | alphas_bar = torch.cumprod(alphas, dim=0)
60 | alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
61 |
62 | self.register_buffer('coeff1', torch.sqrt(1. / alphas))
63 | self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
64 |
65 | self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
66 |
67 | def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
68 | assert x_t.shape == eps.shape
69 | return (
70 | extract(self.coeff1, t, x_t.shape) * x_t -
71 | extract(self.coeff2, t, x_t.shape) * eps
72 | )
73 |
74 | def p_mean_variance(self, x_t, t):
75 | # below: only log_variance is used in the KL computations
76 | var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
77 | var = extract(var, t, x_t.shape)
78 |
79 | eps = self.model(x_t, t)
80 | xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
81 |
82 | return xt_prev_mean, var
83 |
84 | def forward(self, x_T):
85 | """
86 | Algorithm 2.
87 | """
88 | x_t = x_T
89 | for time_step in reversed(range(self.T)):
90 | print(time_step)
91 | t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
92 | mean, var= self.p_mean_variance(x_t=x_t, t=t)
93 | # no noise when t == 0
94 | if time_step > 0:
95 | noise = torch.randn_like(x_t)
96 | else:
97 | noise = 0
98 | x_t = mean + torch.sqrt(var) * noise
99 | assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
100 | x_0 = x_t
101 | return torch.clip(x_0, -1, 1)
102 |
103 |
104 |
--------------------------------------------------------------------------------
/DiffusionFreeGuidence/DiffusionCondition.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import numpy as np
7 |
8 |
9 | def extract(v, t, x_shape):
10 | """
11 | Extract some coefficients at specified timesteps, then reshape to
12 | [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
13 | """
14 | device = t.device
15 | out = torch.gather(v, index=t, dim=0).float().to(device)
16 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
17 |
18 |
19 | class GaussianDiffusionTrainer(nn.Module):
20 | def __init__(self, model, beta_1, beta_T, T):
21 | super().__init__()
22 |
23 | self.model = model
24 | self.T = T
25 |
26 | self.register_buffer(
27 | 'betas', torch.linspace(beta_1, beta_T, T).double())
28 | alphas = 1. - self.betas
29 | alphas_bar = torch.cumprod(alphas, dim=0)
30 |
31 | # calculations for diffusion q(x_t | x_{t-1}) and others
32 | self.register_buffer(
33 | 'sqrt_alphas_bar', torch.sqrt(alphas_bar))
34 | self.register_buffer(
35 | 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
36 |
37 | def forward(self, x_0, labels):
38 | """
39 | Algorithm 1.
40 | """
41 | t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
42 | noise = torch.randn_like(x_0)
43 | x_t = extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + \
44 | extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise
45 | loss = F.mse_loss(self.model(x_t, t, labels), noise, reduction='none')
46 | return loss
47 |
48 |
49 | class GaussianDiffusionSampler(nn.Module):
50 | def __init__(self, model, beta_1, beta_T, T, w = 0.):
51 | super().__init__()
52 |
53 | self.model = model
54 | self.T = T
55 | ### In the classifier free guidence paper, w is the key to control the gudience.
56 | ### w = 0 and with label = 0 means no guidence.
57 | ### w > 0 and label > 0 means guidence. Guidence would be stronger if w is bigger.
58 | self.w = w
59 |
60 | self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
61 | alphas = 1. - self.betas
62 | alphas_bar = torch.cumprod(alphas, dim=0)
63 | alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
64 | self.register_buffer('coeff1', torch.sqrt(1. / alphas))
65 | self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
66 | self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
67 |
68 | def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
69 | assert x_t.shape == eps.shape
70 | return extract(self.coeff1, t, x_t.shape) * x_t - extract(self.coeff2, t, x_t.shape) * eps
71 |
72 | def p_mean_variance(self, x_t, t, labels):
73 | # below: only log_variance is used in the KL computations
74 | var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
75 | var = extract(var, t, x_t.shape)
76 | eps = self.model(x_t, t, labels)
77 | nonEps = self.model(x_t, t, torch.zeros_like(labels).to(labels.device))
78 | eps = (1. + self.w) * eps - self.w * nonEps
79 | xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
80 | return xt_prev_mean, var
81 |
82 | def forward(self, x_T, labels):
83 | """
84 | Algorithm 2.
85 | """
86 | x_t = x_T
87 | for time_step in reversed(range(self.T)):
88 | print(time_step)
89 | t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
90 | mean, var= self.p_mean_variance(x_t=x_t, t=t, labels=labels)
91 | if time_step > 0:
92 | noise = torch.randn_like(x_t)
93 | else:
94 | noise = 0
95 | x_t = mean + torch.sqrt(var) * noise
96 | assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
97 | x_0 = x_t
98 | return torch.clip(x_0, -1, 1)
99 |
100 |
101 |
--------------------------------------------------------------------------------
/Diffusion/Train.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from typing import Dict
4 |
5 | import torch
6 | import torch.optim as optim
7 | from tqdm import tqdm
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms
10 | from torchvision.datasets import CIFAR10
11 | from torchvision.utils import save_image
12 |
13 | from Diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer
14 | from Diffusion.Model import UNet
15 | from Scheduler import GradualWarmupScheduler
16 |
17 |
18 | def train(modelConfig: Dict):
19 | device = torch.device(modelConfig["device"])
20 | # dataset
21 | dataset = CIFAR10(
22 | root='./CIFAR10', train=True, download=True,
23 | transform=transforms.Compose([
24 | transforms.RandomHorizontalFlip(),
25 | transforms.ToTensor(),
26 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
27 | ]))
28 | dataloader = DataLoader(
29 | dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
30 |
31 | # model setup
32 | net_model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
33 | num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
34 | if modelConfig["training_load_weight"] is not None:
35 | net_model.load_state_dict(torch.load(os.path.join(
36 | modelConfig["save_weight_dir"], modelConfig["training_load_weight"]), map_location=device))
37 | optimizer = torch.optim.AdamW(
38 | net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
39 | cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
40 | optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
41 | warmUpScheduler = GradualWarmupScheduler(
42 | optimizer=optimizer, multiplier=modelConfig["multiplier"], warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
43 | trainer = GaussianDiffusionTrainer(
44 | net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
45 |
46 | # start training
47 | for e in range(modelConfig["epoch"]):
48 | with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
49 | for images, labels in tqdmDataLoader:
50 | # train
51 | optimizer.zero_grad()
52 | x_0 = images.to(device)
53 | loss = trainer(x_0).sum() / 1000.
54 | loss.backward()
55 | torch.nn.utils.clip_grad_norm_(
56 | net_model.parameters(), modelConfig["grad_clip"])
57 | optimizer.step()
58 | tqdmDataLoader.set_postfix(ordered_dict={
59 | "epoch": e,
60 | "loss: ": loss.item(),
61 | "img shape: ": x_0.shape,
62 | "LR": optimizer.state_dict()['param_groups'][0]["lr"]
63 | })
64 | warmUpScheduler.step()
65 | torch.save(net_model.state_dict(), os.path.join(
66 | modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt"))
67 |
68 |
69 | def eval(modelConfig: Dict):
70 | # load model and evaluate
71 | with torch.no_grad():
72 | device = torch.device(modelConfig["device"])
73 | model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
74 | num_res_blocks=modelConfig["num_res_blocks"], dropout=0.)
75 | ckpt = torch.load(os.path.join(
76 | modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device)
77 | model.load_state_dict(ckpt)
78 | print("model load weight done.")
79 | model.eval()
80 | sampler = GaussianDiffusionSampler(
81 | model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
82 | # Sampled from standard normal distribution
83 | noisyImage = torch.randn(
84 | size=[modelConfig["batch_size"], 3, 32, 32], device=device)
85 | saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
86 | save_image(saveNoisy, os.path.join(
87 | modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
88 | sampledImgs = sampler(noisyImage)
89 | sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1]
90 | save_image(sampledImgs, os.path.join(
91 | modelConfig["sampled_dir"], modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])
--------------------------------------------------------------------------------
/DiffusionFreeGuidence/TrainCondition.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | from typing import Dict
5 | import numpy as np
6 |
7 | import torch
8 | import torch.optim as optim
9 | from tqdm import tqdm
10 | from torch.utils.data import DataLoader
11 | from torchvision import transforms
12 | from torchvision.datasets import CIFAR10
13 | from torchvision.utils import save_image
14 |
15 | from DiffusionFreeGuidence.DiffusionCondition import GaussianDiffusionSampler, GaussianDiffusionTrainer
16 | from DiffusionFreeGuidence.ModelCondition import UNet
17 | from Scheduler import GradualWarmupScheduler
18 |
19 |
20 | def train(modelConfig: Dict):
21 | device = torch.device(modelConfig["device"])
22 | # dataset
23 | dataset = CIFAR10(
24 | root='./CIFAR10', train=True, download=True,
25 | transform=transforms.Compose([
26 | transforms.ToTensor(),
27 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
28 | ]))
29 | dataloader = DataLoader(
30 | dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
31 |
32 | # model setup
33 | net_model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
34 | num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
35 | if modelConfig["training_load_weight"] is not None:
36 | net_model.load_state_dict(torch.load(os.path.join(
37 | modelConfig["save_dir"], modelConfig["training_load_weight"]), map_location=device), strict=False)
38 | print("Model weight load down.")
39 | optimizer = torch.optim.AdamW(
40 | net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
41 | cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
42 | optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
43 | warmUpScheduler = GradualWarmupScheduler(optimizer=optimizer, multiplier=modelConfig["multiplier"],
44 | warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
45 | trainer = GaussianDiffusionTrainer(
46 | net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
47 |
48 | # start training
49 | for e in range(modelConfig["epoch"]):
50 | with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
51 | for images, labels in tqdmDataLoader:
52 | # train
53 | b = images.shape[0]
54 | optimizer.zero_grad()
55 | x_0 = images.to(device)
56 | labels = labels.to(device) + 1
57 | if np.random.rand() < 0.1:
58 | labels = torch.zeros_like(labels).to(device)
59 | loss = trainer(x_0, labels).sum() / b ** 2.
60 | loss.backward()
61 | torch.nn.utils.clip_grad_norm_(
62 | net_model.parameters(), modelConfig["grad_clip"])
63 | optimizer.step()
64 | tqdmDataLoader.set_postfix(ordered_dict={
65 | "epoch": e,
66 | "loss: ": loss.item(),
67 | "img shape: ": x_0.shape,
68 | "LR": optimizer.state_dict()['param_groups'][0]["lr"]
69 | })
70 | warmUpScheduler.step()
71 | torch.save(net_model.state_dict(), os.path.join(
72 | modelConfig["save_dir"], 'ckpt_' + str(e) + "_.pt"))
73 |
74 |
75 | def eval(modelConfig: Dict):
76 | device = torch.device(modelConfig["device"])
77 | # load model and evaluate
78 | with torch.no_grad():
79 | step = int(modelConfig["batch_size"] // 10)
80 | labelList = []
81 | k = 0
82 | for i in range(1, modelConfig["batch_size"] + 1):
83 | labelList.append(torch.ones(size=[1]).long() * k)
84 | if i % step == 0:
85 | if k < 10 - 1:
86 | k += 1
87 | labels = torch.cat(labelList, dim=0).long().to(device) + 1
88 | print("labels: ", labels)
89 | model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
90 | num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
91 | ckpt = torch.load(os.path.join(
92 | modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device)
93 | model.load_state_dict(ckpt)
94 | print("model load weight done.")
95 | model.eval()
96 | sampler = GaussianDiffusionSampler(
97 | model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)
98 | # Sampled from standard normal distribution
99 | noisyImage = torch.randn(
100 | size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)
101 | saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
102 | save_image(saveNoisy, os.path.join(
103 | modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
104 | sampledImgs = sampler(noisyImage, labels)
105 | sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1]
106 | print(sampledImgs)
107 | save_image(sampledImgs, os.path.join(
108 | modelConfig["sampled_dir"], modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])
--------------------------------------------------------------------------------
/DiffusionFreeGuidence/ModelCondition.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import math
4 | from telnetlib import PRAGMA_HEARTBEAT
5 | import torch
6 | from torch import nn
7 | from torch.nn import init
8 | from torch.nn import functional as F
9 |
10 |
11 | def drop_connect(x, drop_ratio):
12 | keep_ratio = 1.0 - drop_ratio
13 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
14 | mask.bernoulli_(p=keep_ratio)
15 | x.div_(keep_ratio)
16 | x.mul_(mask)
17 | return x
18 |
19 | class Swish(nn.Module):
20 | def forward(self, x):
21 | return x * torch.sigmoid(x)
22 |
23 |
24 | class TimeEmbedding(nn.Module):
25 | def __init__(self, T, d_model, dim):
26 | assert d_model % 2 == 0
27 | super().__init__()
28 | emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
29 | emb = torch.exp(-emb)
30 | pos = torch.arange(T).float()
31 | emb = pos[:, None] * emb[None, :]
32 | assert list(emb.shape) == [T, d_model // 2]
33 | emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
34 | assert list(emb.shape) == [T, d_model // 2, 2]
35 | emb = emb.view(T, d_model)
36 |
37 | self.timembedding = nn.Sequential(
38 | nn.Embedding.from_pretrained(emb, freeze=False),
39 | nn.Linear(d_model, dim),
40 | Swish(),
41 | nn.Linear(dim, dim),
42 | )
43 |
44 | def forward(self, t):
45 | emb = self.timembedding(t)
46 | return emb
47 |
48 |
49 | class ConditionalEmbedding(nn.Module):
50 | def __init__(self, num_labels, d_model, dim):
51 | assert d_model % 2 == 0
52 | super().__init__()
53 | self.condEmbedding = nn.Sequential(
54 | nn.Embedding(num_embeddings=num_labels + 1, embedding_dim=d_model, padding_idx=0),
55 | nn.Linear(d_model, dim),
56 | Swish(),
57 | nn.Linear(dim, dim),
58 | )
59 |
60 | def forward(self, t):
61 | emb = self.condEmbedding(t)
62 | return emb
63 |
64 |
65 | class DownSample(nn.Module):
66 | def __init__(self, in_ch):
67 | super().__init__()
68 | self.c1 = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
69 | self.c2 = nn.Conv2d(in_ch, in_ch, 5, stride=2, padding=2)
70 |
71 | def forward(self, x, temb, cemb):
72 | x = self.c1(x) + self.c2(x)
73 | return x
74 |
75 |
76 | class UpSample(nn.Module):
77 | def __init__(self, in_ch):
78 | super().__init__()
79 | self.c = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
80 | self.t = nn.ConvTranspose2d(in_ch, in_ch, 5, 2, 2, 1)
81 |
82 | def forward(self, x, temb, cemb):
83 | _, _, H, W = x.shape
84 | x = self.t(x)
85 | x = self.c(x)
86 | return x
87 |
88 |
89 | class AttnBlock(nn.Module):
90 | def __init__(self, in_ch):
91 | super().__init__()
92 | self.group_norm = nn.GroupNorm(32, in_ch)
93 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
94 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
95 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
96 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
97 |
98 | def forward(self, x):
99 | B, C, H, W = x.shape
100 | h = self.group_norm(x)
101 | q = self.proj_q(h)
102 | k = self.proj_k(h)
103 | v = self.proj_v(h)
104 |
105 | q = q.permute(0, 2, 3, 1).view(B, H * W, C)
106 | k = k.view(B, C, H * W)
107 | w = torch.bmm(q, k) * (int(C) ** (-0.5))
108 | assert list(w.shape) == [B, H * W, H * W]
109 | w = F.softmax(w, dim=-1)
110 |
111 | v = v.permute(0, 2, 3, 1).view(B, H * W, C)
112 | h = torch.bmm(w, v)
113 | assert list(h.shape) == [B, H * W, C]
114 | h = h.view(B, H, W, C).permute(0, 3, 1, 2)
115 | h = self.proj(h)
116 |
117 | return x + h
118 |
119 |
120 |
121 | class ResBlock(nn.Module):
122 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=True):
123 | super().__init__()
124 | self.block1 = nn.Sequential(
125 | nn.GroupNorm(32, in_ch),
126 | Swish(),
127 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
128 | )
129 | self.temb_proj = nn.Sequential(
130 | Swish(),
131 | nn.Linear(tdim, out_ch),
132 | )
133 | self.cond_proj = nn.Sequential(
134 | Swish(),
135 | nn.Linear(tdim, out_ch),
136 | )
137 | self.block2 = nn.Sequential(
138 | nn.GroupNorm(32, out_ch),
139 | Swish(),
140 | nn.Dropout(dropout),
141 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
142 | )
143 | if in_ch != out_ch:
144 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
145 | else:
146 | self.shortcut = nn.Identity()
147 | if attn:
148 | self.attn = AttnBlock(out_ch)
149 | else:
150 | self.attn = nn.Identity()
151 |
152 |
153 | def forward(self, x, temb, labels):
154 | h = self.block1(x)
155 | h += self.temb_proj(temb)[:, :, None, None]
156 | h += self.cond_proj(labels)[:, :, None, None]
157 | h = self.block2(h)
158 |
159 | h = h + self.shortcut(x)
160 | h = self.attn(h)
161 | return h
162 |
163 |
164 | class UNet(nn.Module):
165 | def __init__(self, T, num_labels, ch, ch_mult, num_res_blocks, dropout):
166 | super().__init__()
167 | tdim = ch * 4
168 | self.time_embedding = TimeEmbedding(T, ch, tdim)
169 | self.cond_embedding = ConditionalEmbedding(num_labels, ch, tdim)
170 | self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
171 | self.downblocks = nn.ModuleList()
172 | chs = [ch] # record output channel when dowmsample for upsample
173 | now_ch = ch
174 | for i, mult in enumerate(ch_mult):
175 | out_ch = ch * mult
176 | for _ in range(num_res_blocks):
177 | self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout))
178 | now_ch = out_ch
179 | chs.append(now_ch)
180 | if i != len(ch_mult) - 1:
181 | self.downblocks.append(DownSample(now_ch))
182 | chs.append(now_ch)
183 |
184 | self.middleblocks = nn.ModuleList([
185 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
186 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
187 | ])
188 |
189 | self.upblocks = nn.ModuleList()
190 | for i, mult in reversed(list(enumerate(ch_mult))):
191 | out_ch = ch * mult
192 | for _ in range(num_res_blocks + 1):
193 | self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=False))
194 | now_ch = out_ch
195 | if i != 0:
196 | self.upblocks.append(UpSample(now_ch))
197 | assert len(chs) == 0
198 |
199 | self.tail = nn.Sequential(
200 | nn.GroupNorm(32, now_ch),
201 | Swish(),
202 | nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
203 | )
204 |
205 |
206 | def forward(self, x, t, labels):
207 | # Timestep embedding
208 | temb = self.time_embedding(t)
209 | cemb = self.cond_embedding(labels)
210 | # Downsampling
211 | h = self.head(x)
212 | hs = [h]
213 | for layer in self.downblocks:
214 | h = layer(h, temb, cemb)
215 | hs.append(h)
216 | # Middle
217 | for layer in self.middleblocks:
218 | h = layer(h, temb, cemb)
219 | # Upsampling
220 | for layer in self.upblocks:
221 | if isinstance(layer, ResBlock):
222 | h = torch.cat([h, hs.pop()], dim=1)
223 | h = layer(h, temb, cemb)
224 | h = self.tail(h)
225 |
226 | assert len(hs) == 0
227 | return h
228 |
229 |
230 | if __name__ == '__main__':
231 | batch_size = 8
232 | model = UNet(
233 | T=1000, num_labels=10, ch=128, ch_mult=[1, 2, 2, 2],
234 | num_res_blocks=2, dropout=0.1)
235 | x = torch.randn(batch_size, 3, 32, 32)
236 | t = torch.randint(1000, size=[batch_size])
237 | labels = torch.randint(10, size=[batch_size])
238 | # resB = ResBlock(128, 256, 64, 0.1)
239 | # x = torch.randn(batch_size, 128, 32, 32)
240 | # t = torch.randn(batch_size, 64)
241 | # labels = torch.randn(batch_size, 64)
242 | # y = resB(x, t, labels)
243 | y = model(x, t, labels)
244 | print(y.shape)
245 |
246 |
--------------------------------------------------------------------------------
/Diffusion/Model.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import math
4 | import torch
5 | from torch import nn
6 | from torch.nn import init
7 | from torch.nn import functional as F
8 |
9 |
10 | class Swish(nn.Module):
11 | def forward(self, x):
12 | return x * torch.sigmoid(x)
13 |
14 |
15 | class TimeEmbedding(nn.Module):
16 | def __init__(self, T, d_model, dim):
17 | assert d_model % 2 == 0
18 | super().__init__()
19 | emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
20 | emb = torch.exp(-emb)
21 | pos = torch.arange(T).float()
22 | emb = pos[:, None] * emb[None, :]
23 | assert list(emb.shape) == [T, d_model // 2]
24 | emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
25 | assert list(emb.shape) == [T, d_model // 2, 2]
26 | emb = emb.view(T, d_model)
27 |
28 | self.timembedding = nn.Sequential(
29 | nn.Embedding.from_pretrained(emb),
30 | nn.Linear(d_model, dim),
31 | Swish(),
32 | nn.Linear(dim, dim),
33 | )
34 | self.initialize()
35 |
36 | def initialize(self):
37 | for module in self.modules():
38 | if isinstance(module, nn.Linear):
39 | init.xavier_uniform_(module.weight)
40 | init.zeros_(module.bias)
41 |
42 | def forward(self, t):
43 | emb = self.timembedding(t)
44 | return emb
45 |
46 |
47 | class DownSample(nn.Module):
48 | def __init__(self, in_ch):
49 | super().__init__()
50 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
51 | self.initialize()
52 |
53 | def initialize(self):
54 | init.xavier_uniform_(self.main.weight)
55 | init.zeros_(self.main.bias)
56 |
57 | def forward(self, x, temb):
58 | x = self.main(x)
59 | return x
60 |
61 |
62 | class UpSample(nn.Module):
63 | def __init__(self, in_ch):
64 | super().__init__()
65 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
66 | self.initialize()
67 |
68 | def initialize(self):
69 | init.xavier_uniform_(self.main.weight)
70 | init.zeros_(self.main.bias)
71 |
72 | def forward(self, x, temb):
73 | _, _, H, W = x.shape
74 | x = F.interpolate(
75 | x, scale_factor=2, mode='nearest')
76 | x = self.main(x)
77 | return x
78 |
79 |
80 | class AttnBlock(nn.Module):
81 | def __init__(self, in_ch):
82 | super().__init__()
83 | self.group_norm = nn.GroupNorm(32, in_ch)
84 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
85 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
86 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
87 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
88 | self.initialize()
89 |
90 | def initialize(self):
91 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
92 | init.xavier_uniform_(module.weight)
93 | init.zeros_(module.bias)
94 | init.xavier_uniform_(self.proj.weight, gain=1e-5)
95 |
96 | def forward(self, x):
97 | B, C, H, W = x.shape
98 | h = self.group_norm(x)
99 | q = self.proj_q(h)
100 | k = self.proj_k(h)
101 | v = self.proj_v(h)
102 |
103 | q = q.permute(0, 2, 3, 1).view(B, H * W, C)
104 | k = k.view(B, C, H * W)
105 | w = torch.bmm(q, k) * (int(C) ** (-0.5))
106 | assert list(w.shape) == [B, H * W, H * W]
107 | w = F.softmax(w, dim=-1)
108 |
109 | v = v.permute(0, 2, 3, 1).view(B, H * W, C)
110 | h = torch.bmm(w, v)
111 | assert list(h.shape) == [B, H * W, C]
112 | h = h.view(B, H, W, C).permute(0, 3, 1, 2)
113 | h = self.proj(h)
114 |
115 | return x + h
116 |
117 |
118 | class ResBlock(nn.Module):
119 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
120 | super().__init__()
121 | self.block1 = nn.Sequential(
122 | nn.GroupNorm(32, in_ch),
123 | Swish(),
124 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
125 | )
126 | self.temb_proj = nn.Sequential(
127 | Swish(),
128 | nn.Linear(tdim, out_ch),
129 | )
130 | self.block2 = nn.Sequential(
131 | nn.GroupNorm(32, out_ch),
132 | Swish(),
133 | nn.Dropout(dropout),
134 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
135 | )
136 | if in_ch != out_ch:
137 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
138 | else:
139 | self.shortcut = nn.Identity()
140 | if attn:
141 | self.attn = AttnBlock(out_ch)
142 | else:
143 | self.attn = nn.Identity()
144 | self.initialize()
145 |
146 | def initialize(self):
147 | for module in self.modules():
148 | if isinstance(module, (nn.Conv2d, nn.Linear)):
149 | init.xavier_uniform_(module.weight)
150 | init.zeros_(module.bias)
151 | init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)
152 |
153 | def forward(self, x, temb):
154 | h = self.block1(x)
155 | h += self.temb_proj(temb)[:, :, None, None]
156 | h = self.block2(h)
157 |
158 | h = h + self.shortcut(x)
159 | h = self.attn(h)
160 | return h
161 |
162 |
163 | class UNet(nn.Module):
164 | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
165 | super().__init__()
166 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
167 | tdim = ch * 4
168 | self.time_embedding = TimeEmbedding(T, ch, tdim)
169 |
170 | self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
171 | self.downblocks = nn.ModuleList()
172 | chs = [ch] # record output channel when dowmsample for upsample
173 | now_ch = ch
174 | for i, mult in enumerate(ch_mult):
175 | out_ch = ch * mult
176 | for _ in range(num_res_blocks):
177 | self.downblocks.append(ResBlock(
178 | in_ch=now_ch, out_ch=out_ch, tdim=tdim,
179 | dropout=dropout, attn=(i in attn)))
180 | now_ch = out_ch
181 | chs.append(now_ch)
182 | if i != len(ch_mult) - 1:
183 | self.downblocks.append(DownSample(now_ch))
184 | chs.append(now_ch)
185 |
186 | self.middleblocks = nn.ModuleList([
187 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
188 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
189 | ])
190 |
191 | self.upblocks = nn.ModuleList()
192 | for i, mult in reversed(list(enumerate(ch_mult))):
193 | out_ch = ch * mult
194 | for _ in range(num_res_blocks + 1):
195 | self.upblocks.append(ResBlock(
196 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
197 | dropout=dropout, attn=(i in attn)))
198 | now_ch = out_ch
199 | if i != 0:
200 | self.upblocks.append(UpSample(now_ch))
201 | assert len(chs) == 0
202 |
203 | self.tail = nn.Sequential(
204 | nn.GroupNorm(32, now_ch),
205 | Swish(),
206 | nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
207 | )
208 | self.initialize()
209 |
210 | def initialize(self):
211 | init.xavier_uniform_(self.head.weight)
212 | init.zeros_(self.head.bias)
213 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
214 | init.zeros_(self.tail[-1].bias)
215 |
216 | def forward(self, x, t):
217 | # Timestep embedding
218 | temb = self.time_embedding(t)
219 | # Downsampling
220 | h = self.head(x)
221 | hs = [h]
222 | for layer in self.downblocks:
223 | h = layer(h, temb)
224 | hs.append(h)
225 | # Middle
226 | for layer in self.middleblocks:
227 | h = layer(h, temb)
228 | # Upsampling
229 | for layer in self.upblocks:
230 | if isinstance(layer, ResBlock):
231 | h = torch.cat([h, hs.pop()], dim=1)
232 | h = layer(h, temb)
233 | h = self.tail(h)
234 |
235 | assert len(hs) == 0
236 | return h
237 |
238 |
239 | if __name__ == '__main__':
240 | batch_size = 8
241 | model = UNet(
242 | T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1],
243 | num_res_blocks=2, dropout=0.1)
244 | x = torch.randn(batch_size, 3, 32, 32)
245 | t = torch.randint(1000, (batch_size, ))
246 | y = model(x, t)
247 | print(y.shape)
248 |
249 |
--------------------------------------------------------------------------------