├── CIFAR10
└── blank.txt
├── Checkpoints
└── blank.txt
├── CheckpointsCondition
└── blank.txt
├── SampledImgs
├── blank.txt
├── noisy.png
├── 104_sampled_64.png
├── NoisyGuidenceImgs.png
├── NoisyNoGuidenceImgs.png
├── SampledGuidenceImgs.png
└── SampledNoGuidenceImgs.png
├── Diffusion
├── __init__.py
├── Train.py
├── Diffusion.py
└── Model.py
├── mp_imgs
└── cameraman.tif
├── DiffusionFreeGuidence
├── __init__.py
├── DiffusionCondition.py
├── TrainCondition.py
└── ModelCondition.py
├── LICENSE
├── Main.py
├── MainCondition.py
├── README.md
├── Scheduler.py
└── .gitignore
/CIFAR10/blank.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Checkpoints/blank.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/CheckpointsCondition/blank.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SampledImgs/blank.txt:
--------------------------------------------------------------------------------
1 | blank
2 |
--------------------------------------------------------------------------------
/Diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .Diffusion import *
2 | from .Model import *
3 | from .Train import *
--------------------------------------------------------------------------------
/SampledImgs/noisy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minipuding/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/noisy.png
--------------------------------------------------------------------------------
/mp_imgs/cameraman.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minipuding/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/mp_imgs/cameraman.tif
--------------------------------------------------------------------------------
/DiffusionFreeGuidence/__init__.py:
--------------------------------------------------------------------------------
1 | from .DiffusionCondition import *
2 | from .ModelCondition import *
3 | from .TrainCondition import *
--------------------------------------------------------------------------------
/SampledImgs/104_sampled_64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minipuding/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/104_sampled_64.png
--------------------------------------------------------------------------------
/SampledImgs/NoisyGuidenceImgs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minipuding/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/NoisyGuidenceImgs.png
--------------------------------------------------------------------------------
/SampledImgs/NoisyNoGuidenceImgs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minipuding/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/NoisyNoGuidenceImgs.png
--------------------------------------------------------------------------------
/SampledImgs/SampledGuidenceImgs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minipuding/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/SampledGuidenceImgs.png
--------------------------------------------------------------------------------
/SampledImgs/SampledNoGuidenceImgs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minipuding/DenoisingDiffusionProbabilityModel-ddpm-/HEAD/SampledImgs/SampledNoGuidenceImgs.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 ZOUbohao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Main.py:
--------------------------------------------------------------------------------
1 | from Diffusion.Train import train, eval
2 |
3 |
4 | def main(model_config = None):
5 | modelConfig = {
6 | "state": "train", # or eval
7 | "epoch": 200,
8 | "batch_size": 80,
9 | "T": 1000,
10 | "channel": 128,
11 | "channel_mult": [1, 2, 3, 4],
12 | "attn": [2],
13 | "num_res_blocks": 2,
14 | "dropout": 0.15,
15 | "lr": 1e-4,
16 | "multiplier": 2.,
17 | "beta_1": 1e-4,
18 | "beta_T": 0.02,
19 | "img_size": 32,
20 | "grad_clip": 1.,
21 | "device": "cuda:0",
22 | "training_load_weight": None,
23 | "save_weight_dir": "./Checkpoints/",
24 | "test_load_weight": "ckpt_199_.pt",
25 | "sampled_dir": "./SampledImgs/",
26 | "sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
27 | "sampledImgName": "SampledNoGuidenceImgs.png",
28 | "nrow": 8
29 | }
30 | if model_config is not None:
31 | modelConfig = model_config
32 | if modelConfig["state"] == "train":
33 | train(modelConfig)
34 | else:
35 | eval(modelConfig)
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/MainCondition.py:
--------------------------------------------------------------------------------
1 | from DiffusionFreeGuidence.TrainCondition import train, eval
2 |
3 |
4 | def main(model_config=None):
5 | modelConfig = {
6 | "state": "train", # or eval
7 | "epoch": 70,
8 | "batch_size": 80,
9 | "T": 500,
10 | "channel": 128,
11 | "channel_mult": [1, 2, 2, 2],
12 | "num_res_blocks": 2,
13 | "dropout": 0.15,
14 | "lr": 1e-4,
15 | "multiplier": 2.5,
16 | "beta_1": 1e-4,
17 | "beta_T": 0.028,
18 | "img_size": 32,
19 | "grad_clip": 1.,
20 | "device": "cuda:0",
21 | "w": 1.8,
22 | "save_dir": "./CheckpointsCondition/",
23 | "training_load_weight": None,
24 | "test_load_weight": "ckpt_63_.pt",
25 | "sampled_dir": "./SampledImgs/",
26 | "sampledNoisyImgName": "NoisyGuidenceImgs.png",
27 | "sampledImgName": "SampledGuidenceImgs.png",
28 | "nrow": 8
29 | }
30 | if model_config is not None:
31 | modelConfig = model_config
32 | if modelConfig["state"] == "train":
33 | train(modelConfig)
34 | else:
35 | eval(modelConfig)
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DenoisingDiffusionProbabilityModel
2 | This may be the simplest implement of DDPM. I trained with CIFAR-10 dataset. The links of pretrain weight, which trained on CIFAR-10 are in the Issue 2.
3 |
4 | If you really want to know more about the framwork of DDPM, I have listed some papers for reading by order in the closed Issue 1.
5 |
6 |
7 | Lil' Log is also a very nice blog for understanding the details of DDPM, the reference is
8 | "https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#:~:text=Diffusion%20models%20are%20inspired%20by,data%20samples%20from%20the%20noise."
9 |
10 |
11 | **HOW TO RUN**
12 | * 1. You can run Main.py to train the UNet on CIFAR-10 dataset. After training, you can set the parameters in the model config to see the amazing process of DDPM.
13 | * 2. You can run MainCondition.py to train UNet on CIFAR-10. This is for DDPM + Classifier free guidence.
14 |
15 | Some generated images are showed below:
16 |
17 | * 1. DDPM without guidence:
18 |
19 | 
20 |
21 | * 2. DDPM + Classifier free guidence:
22 |
23 | 
24 |
--------------------------------------------------------------------------------
/Scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 |
3 | class GradualWarmupScheduler(_LRScheduler):
4 | def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None):
5 | self.multiplier = multiplier
6 | self.total_epoch = warm_epoch
7 | self.after_scheduler = after_scheduler
8 | self.finished = False
9 | self.last_epoch = None
10 | self.base_lrs = None
11 | super().__init__(optimizer)
12 |
13 | def get_lr(self):
14 | if self.last_epoch > self.total_epoch:
15 | if self.after_scheduler:
16 | if not self.finished:
17 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
18 | self.finished = True
19 | return self.after_scheduler.get_lr()
20 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
21 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
22 |
23 |
24 | def step(self, epoch=None, metrics=None):
25 | if self.finished and self.after_scheduler:
26 | if epoch is None:
27 | self.after_scheduler.step(None)
28 | else:
29 | self.after_scheduler.step(epoch - self.total_epoch)
30 | else:
31 | return super(GradualWarmupScheduler, self).step(epoch)
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/DiffusionFreeGuidence/DiffusionCondition.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import numpy as np
6 |
7 |
8 | def extract(v, t, x_shape):
9 | """
10 | Extract some coefficients at specified timesteps, then reshape to
11 | [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
12 | """
13 | device = t.device
14 | out = torch.gather(v, index=t, dim=0).float().to(device)
15 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
16 |
17 |
18 | class GaussianDiffusionTrainer(nn.Module):
19 | """
20 | 前向加噪过程和``Diffusion.Diffusion.py``中的``GaussianDiffusionTrainer``几乎完全一样
21 | 不同点在于模型输入,除了需要输入``x_t``, ``t``, 还要输入条件``labels``
22 | """
23 | def __init__(self, model, beta_1, beta_T, T):
24 | super().__init__()
25 |
26 | self.model = model
27 | self.T = T
28 |
29 | self.register_buffer(
30 | 'betas', torch.linspace(beta_1, beta_T, T).double())
31 | alphas = 1. - self.betas
32 | alphas_bar = torch.cumprod(alphas, dim=0)
33 |
34 | # calculations for diffusion q(x_t | x_{t-1}) and others
35 | self.register_buffer(
36 | 'sqrt_alphas_bar', torch.sqrt(alphas_bar))
37 | self.register_buffer(
38 | 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
39 |
40 | def forward(self, x_0, labels):
41 | """
42 | Algorithm 1.
43 | """
44 | t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device)
45 | noise = torch.randn_like(x_0)
46 | x_t = extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + \
47 | extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise
48 | loss = F.mse_loss(self.model(x_t, t, labels), noise, reduction='none') # 不同点在于模型的输入多了``labels``
49 | return loss
50 |
51 |
52 | class GaussianDiffusionSampler(nn.Module):
53 | """
54 | 反向扩散过程和``Diffusion.Diffusion.py``中的``GaussianDiffusionSampler``绝大部分一样,
55 | 所以在此只说明不一样的点
56 | """
57 | def __init__(self, model, beta_1, beta_T, T, w=0.):
58 | super().__init__()
59 |
60 | self.model = model
61 | self.T = T
62 | # In the classifier free guidence paper, w is the key to control the gudience.
63 | # w = 0 and with label = 0 means no guidence.
64 | # w > 0 and label > 0 means guidence. Guidence would be stronger if w is bigger.
65 | # 不同点1: 在初始化时需要输入一个权重系数``w``, 用来控制条件的强弱程度
66 | self.w = w
67 |
68 | self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
69 | alphas = 1. - self.betas
70 | alphas_bar = torch.cumprod(alphas, dim=0)
71 | alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
72 | self.register_buffer('coeff1', torch.sqrt(1. / alphas))
73 | self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
74 | self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
75 |
76 | def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
77 | assert x_t.shape == eps.shape
78 | return (
79 | extract(self.coeff1, t, x_t.shape) * x_t -
80 | extract(self.coeff2, t, x_t.shape) * eps
81 | )
82 |
83 | def p_mean_variance(self, x_t, t, labels):
84 | # below: only log_variance is used in the KL computations
85 | var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
86 | var = extract(var, t, x_t.shape)
87 |
88 | # 不同点2: 模型推理时需要计算有条件和无条件(随机噪声)情况下模型的输出,
89 | # 将两次输出的结果用权重``self.w``进行合并得到最终输出
90 | eps = self.model(x_t, t, labels)
91 | nonEps = self.model(x_t, t, torch.zeros_like(labels).to(labels.device))
92 | # 参考原文公式(6)
93 | eps = (1. + self.w) * eps - self.w * nonEps
94 | xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
95 | return xt_prev_mean, var
96 |
97 | def forward(self, x_T, labels):
98 | """
99 | Algorithm 2.
100 | """
101 | x_t = x_T
102 | for time_step in reversed(range(self.T)):
103 | print(time_step)
104 | t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
105 | # 除了输入多一个``labels``其他都和普通Diffusion Model一样
106 | mean, var = self.p_mean_variance(x_t=x_t, t=t, labels=labels)
107 | if time_step > 0:
108 | noise = torch.randn_like(x_t)
109 | else:
110 | noise = 0
111 | x_t = mean + torch.sqrt(var) * noise
112 | assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
113 | x_0 = x_t
114 | return torch.clip(x_0, -1, 1)
115 |
--------------------------------------------------------------------------------
/Diffusion/Train.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from typing import Dict
4 |
5 | import torch
6 | import torch.optim as optim
7 | from tqdm import tqdm
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms
10 | from torchvision.datasets import CIFAR10
11 | from torchvision.utils import save_image
12 |
13 | from Diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer
14 | from Diffusion.Model import UNet
15 | from Scheduler import GradualWarmupScheduler
16 |
17 |
18 | def train(modelConfig: Dict):
19 | device = torch.device(modelConfig["device"])
20 | # dataset
21 | dataset = CIFAR10(
22 | root='./CIFAR10', train=True, download=True,
23 | transform=transforms.Compose([
24 | transforms.RandomHorizontalFlip(),
25 | transforms.ToTensor(),
26 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
27 | ]))
28 | dataloader = DataLoader(
29 | dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
30 |
31 | # model setup
32 | net_model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
33 | num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
34 | if modelConfig["training_load_weight"] is not None:
35 | net_model.load_state_dict(torch.load(os.path.join(
36 | modelConfig["save_weight_dir"], modelConfig["training_load_weight"]), map_location=device))
37 | optimizer = torch.optim.AdamW(
38 | net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
39 | # 设置学习率衰减,按余弦函数的1/2个周期衰减,从``lr``衰减至0
40 | cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
41 | optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
42 | # 设置逐步预热调度器,学习率从0逐渐增加至multiplier * lr,共用1/10总epoch数,后续学习率按``cosineScheduler``设置进行变化
43 | warmUpScheduler = GradualWarmupScheduler(
44 | optimizer=optimizer, multiplier=modelConfig["multiplier"], warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
45 | # 实例化训练模型
46 | trainer = GaussianDiffusionTrainer(
47 | net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
48 |
49 | # start training
50 | for e in range(modelConfig["epoch"]):
51 | with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
52 | for images, labels in tqdmDataLoader:
53 | # train
54 | optimizer.zero_grad() # 清空过往梯度
55 | x_0 = images.to(device) # 将输入图像加载到计算设备上
56 | loss = trainer(x_0).sum() / 1000. # 前向传播并计算损失
57 | loss.backward() # 反向计算梯度
58 | torch.nn.utils.clip_grad_norm_(
59 | net_model.parameters(), modelConfig["grad_clip"]) # 裁剪梯度,防止梯度爆炸
60 | optimizer.step() # 更新参数
61 | tqdmDataLoader.set_postfix(ordered_dict={
62 | "epoch": e,
63 | "loss: ": loss.item(),
64 | "img shape: ": x_0.shape,
65 | "LR": optimizer.state_dict()['param_groups'][0]["lr"]
66 | }) # 设置进度条显示内容
67 | warmUpScheduler.step() # 调度器更新学习率
68 | torch.save(net_model.state_dict(), os.path.join(
69 | modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt")) # 保存模型
70 |
71 |
72 | def eval(modelConfig: Dict):
73 | # load model and evaluate
74 | with torch.no_grad():
75 | # 建立和加载模型
76 | device = torch.device(modelConfig["device"])
77 | model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
78 | num_res_blocks=modelConfig["num_res_blocks"], dropout=0.)
79 | ckpt = torch.load(os.path.join(
80 | modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device)
81 | model.load_state_dict(ckpt)
82 | print("model load weight done.")
83 | # 实例化反向扩散采样器
84 | model.eval()
85 | sampler = GaussianDiffusionSampler(
86 | model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
87 | # Sampled from standard normal distribution
88 | # 随机生成高斯噪声图像并保存
89 | noisyImage = torch.randn(
90 | size=[modelConfig["batch_size"], 3, 32, 32], device=device)
91 | saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
92 | save_image(saveNoisy, os.path.join(
93 | modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
94 | # 反向扩散并保存输出图像
95 | sampledImgs = sampler(noisyImage)
96 | sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1]
97 | save_image(sampledImgs, os.path.join(
98 | modelConfig["sampled_dir"], modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])
--------------------------------------------------------------------------------
/DiffusionFreeGuidence/TrainCondition.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict
3 | import numpy as np
4 |
5 | import torch
6 | import torch.optim as optim
7 | from tqdm import tqdm
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms
10 | from torchvision.datasets import CIFAR10
11 | from torchvision.utils import save_image
12 |
13 | from DiffusionFreeGuidence.DiffusionCondition import GaussianDiffusionSampler, GaussianDiffusionTrainer
14 | from DiffusionFreeGuidence.ModelCondition import UNet
15 | from Scheduler import GradualWarmupScheduler
16 |
17 |
18 | def train(modelConfig: Dict):
19 | device = torch.device(modelConfig["device"])
20 | # dataset
21 | dataset = CIFAR10(
22 | root='./CIFAR10', train=True, download=True,
23 | transform=transforms.Compose([
24 | transforms.ToTensor(),
25 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
26 | ]))
27 | dataloader = DataLoader(
28 | dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
29 |
30 | # model setup
31 | # 这里模型的输入相比于无条件的情况多了一个``num_labels``即分类数据集的类别数,这里是CIFAR10有10个类别,所以num_labels=10
32 | net_model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
33 | num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
34 | if modelConfig["training_load_weight"] is not None:
35 | net_model.load_state_dict(torch.load(os.path.join(
36 | modelConfig["save_dir"], modelConfig["training_load_weight"]), map_location=device), strict=False)
37 | print("Model weight load down.")
38 | optimizer = torch.optim.AdamW(
39 | net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
40 | # 设置学习率衰减,按余弦函数的1/2周期衰减,从``lr``衰减至0
41 | cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
42 | optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
43 | # 设置逐步预热调度器,学习率从0逐渐增加至multiplier * lr, 共用1/10总epoch数,后续学习率按``cosineScheduler``设置进行变化
44 | warmUpScheduler = GradualWarmupScheduler(optimizer=optimizer, multiplier=modelConfig["multiplier"],
45 | warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
46 | # 实例化训练模型
47 | trainer = GaussianDiffusionTrainer(
48 | net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
49 |
50 | # start training
51 | for e in range(modelConfig["epoch"]):
52 | with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
53 | for images, labels in tqdmDataLoader:
54 | # train
55 | b = images.shape[0] # 获取batch大小
56 | optimizer.zero_grad() # 清空过往梯度
57 | x_0 = images.to(device) # 将输入图像加载到计算设备上
58 | labels = labels.to(device) + 1 # 将label也就是condition加载到计算设备上,这里+1的原因
59 | # 和``ModelCondition.py``中的``ConditionalEmbedding``一致
60 | if np.random.rand() < 0.1:
61 | labels = torch.zeros_like(labels).to(device) # 10%的概率使用0替代condition
62 | loss = trainer(x_0, labels).sum() / b ** 2. # 前向传播计算损失
63 | loss.backward() # 反向计算梯度
64 | torch.nn.utils.clip_grad_norm_(
65 | net_model.parameters(), modelConfig["grad_clip"]) # 裁剪梯度,防止梯度爆炸
66 | optimizer.step() # 更新参数
67 | tqdmDataLoader.set_postfix(ordered_dict={
68 | "epoch": e,
69 | "loss: ": loss.item(),
70 | "img shape: ": x_0.shape,
71 | "LR": optimizer.state_dict()['param_groups'][0]["lr"]
72 | }) # 设置进度条显示内容
73 | warmUpScheduler.step() # 调度器更新
74 | torch.save(net_model.state_dict(), os.path.join(
75 | modelConfig["save_dir"], 'ckpt_' + str(e) + "_.pt")) # 保存模型
76 |
77 |
78 | def eval(modelConfig: Dict):
79 | device = torch.device(modelConfig["device"])
80 | # load model and evaluate
81 | with torch.no_grad():
82 | # 这一块代码是用来生成label也就是condition,用来指导图像生成,
83 | # 具体做法是将batch按照10个类别分成10部分,假设batch_size=50, 那么step=5,
84 | # 经过for循环得到的labelList就是[0,0,0,0,0,1,1,1,1,1,2,...,9,9,9,9,9]
85 | # 最后还要对label+1得到最终的label,+1原因和之前一样。
86 | step = int(modelConfig["batch_size"] // 10)
87 | labelList = []
88 | k = 0
89 | for i in range(1, modelConfig["batch_size"] + 1):
90 | labelList.append(torch.ones(size=[1]).long() * k)
91 | if i % step == 0:
92 | if k < 10 - 1:
93 | k += 1
94 | labels = torch.cat(labelList, dim=0).long().to(device) + 1
95 | print("labels: ", labels)
96 | # 建立和加载模型
97 | model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
98 | num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
99 | ckpt = torch.load(os.path.join(
100 | modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device)
101 | model.load_state_dict(ckpt)
102 | print("model load weight done.")
103 | # 实例化反向扩散采样器
104 | model.eval()
105 | sampler = GaussianDiffusionSampler(
106 | model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)
107 | # Sampled from standard normal distribution
108 | # 随机生成高斯噪声图像并保存
109 | noisyImage = torch.randn(
110 | size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)
111 | saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
112 | save_image(saveNoisy, os.path.join(
113 | modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
114 | # 反向扩散并保存输出图像
115 | sampledImgs = sampler(noisyImage, labels)
116 | sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1]
117 | print(sampledImgs)
118 | save_image(sampledImgs, os.path.join(
119 | modelConfig["sampled_dir"], modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])
120 |
--------------------------------------------------------------------------------
/Diffusion/Diffusion.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | # ``extract``函数的作用是从v这一序列中按照索引t取出需要的数,然后reshape到输入数据x的维度
8 | def extract(v, t, x_shape):
9 | """
10 | Extract some coefficients at specified timesteps, then reshape to
11 | [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
12 | """
13 | device = t.device
14 | # ``torch.gather``的用法建议看https://zhuanlan.zhihu.com/p/352877584的第一条评论
15 | # 在此处的所有调用实例中,v都是一维,可以看作是索引取值,即等价v[t], t大小为[batch_size, 1]
16 | out = torch.gather(v, index=t, dim=0).float().to(device)
17 | # 再把索引到的值reshape到[batch_size, 1, 1, ...], 维度和x_shape相同
18 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
19 |
20 |
21 | # ``GaussianDiffusionTrainer``包含了Diffusion Model的前向过程(加噪) & 训练过程
22 | class GaussianDiffusionTrainer(nn.Module):
23 | def __init__(self, model, beta_1, beta_T, T):
24 | """
25 | 初始化前向模型
26 | Args:
27 | model: 骨干模型,主流为U-Net+Attention
28 | beta_1: beta的起始值,本实例中取1e-4
29 | beta_T: bata在t=T时的值,本实例中取0.2
30 | T: 时间步数, 本实例中取1000
31 | """
32 | super().__init__()
33 | # 参数赋值
34 | self.model = model
35 | self.T = T
36 |
37 | # 等间隔得到beta_1到beta_T之间共T个step对应的beta值,组成序列存为类成员(后边可以用``self.betas``访问)
38 | self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
39 | # 根据公式,令alphas = 1 - betas
40 | alphas = 1. - self.betas
41 | # 根据公式,计算alpha连乘结果,存为alphas_bar
42 | # ``torch.cumprod``用于计算一个序列每个数与其前面所有数连乘的结果,得到一个序列,长度等于原序列长度
43 | # 例如:
44 | # a = torch.tensor([2,3,1,4])
45 | # b = torch.cumprod(a, dim=0)其实就等于torch.tensor([2, 2*3, 2*3*1, 2*3*1*4]) = torch.tensor([2, 6, 6, 24])
46 | alphas_bar = torch.cumprod(alphas, dim=0)
47 |
48 | # calculations for diffusion q(x_t | x_{t-1}) and others
49 | # 根据公式计算sqrt(alphas_bar)以及sqrt(1-alphas_bar)分别作为正向扩散的均值和标准差,存入类成员
50 | # 可用``self.sqrt_alphas_bar``和``sqrt_one_minus_alphas_bar``来访问
51 | self.register_buffer(
52 | 'sqrt_alphas_bar', torch.sqrt(alphas_bar))
53 | self.register_buffer(
54 | 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
55 |
56 | def forward(self, x_0):
57 | """
58 | Algorithm 1.
59 | """
60 | # 从0~T中随机选batch_size个时间点
61 | t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
62 | # 参数重整化技巧,先生成均值为0方差为1的高斯分布,再通过乘标准差加均值的方式用于间接采样
63 | noise = torch.randn_like(x_0)
64 | x_t = (
65 | extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
66 | extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
67 | # 做一步反向扩散,希望模型可以预测出加入的噪声,也就是公式中的z_t
68 | loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
69 | return loss
70 |
71 |
72 | # ``GaussianDiffusionSampler``包含了Diffusion Model的后向过程 & 推理过程
73 | class GaussianDiffusionSampler(nn.Module):
74 | def __init__(self, model, beta_1, beta_T, T):
75 | """
76 | 所有参数含义和``GaussianDiffusionTrainer``(前向过程)一样
77 | """
78 | super().__init__()
79 |
80 | self.model = model
81 | self.T = T
82 |
83 | # 这里获取betas, alphas以及alphas_bar和前向过程一模一样
84 | self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
85 | alphas = 1. - self.betas
86 | alphas_bar = torch.cumprod(alphas, dim=0)
87 | # 这一步是方便后面运算,相当于构建alphas_bar{t-1}
88 | alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T] # 把alpha_bar的第一个数字换成1,按序后移
89 |
90 | # 根据公式,后向过程中的计算均值需要用到的系数用coeff1和coeff2表示
91 | self.register_buffer('coeff1', torch.sqrt(1. / alphas))
92 | self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
93 |
94 | # 根据公式,计算后向过程的方差
95 | self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
96 |
97 | def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
98 | """
99 | 该函数用于反向过程中,条件概率分布q(x_{t-1}|x_t)的均值
100 | Args:
101 | x_t: 迭代至当前步骤的图像
102 | t: 当前步数
103 | eps: 模型预测的噪声,也就是z_t
104 | Returns:
105 | x_{t-1}的均值,mean = coeff1 * x_t - coeff2 * eps
106 | """
107 | assert x_t.shape == eps.shape
108 | return (
109 | extract(self.coeff1, t, x_t.shape) * x_t -
110 | extract(self.coeff2, t, x_t.shape) * eps
111 | )
112 |
113 | def p_mean_variance(self, x_t, t):
114 | """
115 | 该函数用于反向过程中,计算条件概率分布q(x_{t-1}|x_t)的均值和方差
116 | Args:
117 | x_t: 迭代至当前步骤的图像
118 | t: 当前步数
119 | Returns:
120 | xt_prev_mean: 均值
121 | var: 方差
122 | """
123 | # below: only log_variance is used in the KL computations
124 | # 这一步我略有不解,为什么要把算好的反向过程的方差大部分替换成betas。
125 | # 我猜测,后向过程方差``posterior_var``的计算过程仅仅是betas乘上一个(1 - alpha_bar_{t-1}) / (1 - alpha_bar_{t}),
126 | # 由于1 - alpha_bar_{t}这个数值非常趋近于0,分母为0会导致nan,
127 | # 而整体(1 - alpha_bar_{t-1}) / (1 - alpha_bar_{t})非常趋近于1,所以直接用betas近似后向过程的方差,
128 | # 但是t = 1 的时候(1 - alpha_bar_{0}) / (1 - alpha_bar_{1})还不是非常趋近于1,所以这个数值要保留,
129 | # 因此就有拼接``torch.cat([self.posterior_var[1:2], self.betas[1:]])``这一步
130 | var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
131 | var = extract(var, t, x_t.shape)
132 |
133 | # 模型前向预测得到eps(也就是z_t)
134 | eps = self.model(x_t, t)
135 | # 计算均值
136 | xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
137 |
138 | return xt_prev_mean, var
139 |
140 | def forward(self, x_T):
141 | """
142 | Algorithm 2.
143 | """
144 | # 反向扩散过程,从x_t迭代至x_0
145 | x_t = x_T
146 | for time_step in reversed(range(self.T)):
147 | print(time_step)
148 | # t = [1, 1, ....] * time_step, 长度为batch_size
149 | t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
150 | # 计算条件概率分布q(x_{t-1}|x_t)的均值和方差
151 | mean, var= self.p_mean_variance(x_t=x_t, t=t)
152 | # no noise when t == 0
153 | # 最后一步的高斯噪声设为0(我认为不设为0问题也不大,就本实例而言,t=0时的方差已经很小了)
154 | if time_step > 0:
155 | noise = torch.randn_like(x_t)
156 | else:
157 | noise = 0
158 | x_t = mean + torch.sqrt(var) * noise
159 | assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
160 | x_0 = x_t
161 | # ``torch.clip(x_0, -1, 1)``,把x_0的值限制在-1到1之间,超出部分截断
162 | return torch.clip(x_0, -1, 1)
163 |
164 |
165 |
--------------------------------------------------------------------------------
/DiffusionFreeGuidence/ModelCondition.py:
--------------------------------------------------------------------------------
1 | import math
2 | from telnetlib import PRAGMA_HEARTBEAT
3 | import torch
4 | from torch import nn
5 | from torch.nn import init
6 | from torch.nn import functional as F
7 |
8 |
9 | def drop_connect(x, drop_ratio):
10 | """
11 | 这个函数在整个Project中都没被用到, 暂时先不考虑它的功能
12 | """
13 | keep_ratio = 1.0 - drop_ratio
14 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
15 | mask.bernoulli_(p=keep_ratio)
16 | x.div_(keep_ratio)
17 | x.mul_(mask)
18 | return x
19 |
20 |
21 | class Swish(nn.Module):
22 | def forward(self, x):
23 | return x * torch.sigmoid(x)
24 |
25 |
26 | class TimeEmbedding(nn.Module):
27 | """
28 | 和``Diffusion.Model``中的``TimeEmbedding``一模一样
29 | """
30 | def __init__(self, T, d_model, dim):
31 | assert d_model % 2 == 0
32 | super().__init__()
33 | emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
34 | emb = torch.exp(-emb)
35 | pos = torch.arange(T).float()
36 | emb = pos[:, None] * emb[None, :]
37 | assert list(emb.shape) == [T, d_model // 2]
38 | emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
39 | assert list(emb.shape) == [T, d_model // 2, 2]
40 | emb = emb.view(T, d_model)
41 |
42 | self.timembedding = nn.Sequential(
43 | nn.Embedding.from_pretrained(emb, freeze=False),
44 | nn.Linear(d_model, dim),
45 | Swish(),
46 | nn.Linear(dim, dim),
47 | )
48 |
49 | def forward(self, t):
50 | emb = self.timembedding(t)
51 | return emb
52 |
53 |
54 | class ConditionalEmbedding(nn.Module):
55 | """
56 | 这是一个条件编码模块,将condition编码为embedding
57 | 除了初始化Embedding不同,其他部分与time-embedding无异。
58 | """
59 | def __init__(self, num_labels, d_model, dim):
60 | assert d_model % 2 == 0
61 | super().__init__()
62 | # 注意,这里在初始化embedding时有一个细节——``num_embeddings=num_labels+1``也就是10+1=11
63 | # 本实例中考虑的condition是CIFAR10的label,共10个类别,对应0~9,按理来说只需要10个embedding即可,
64 | # 但是我们需要给``无条件``情况一个embedding表示,在本实例中就是用``0```来表示,
65 | # 与此同时10个类别对应的标号分别加一,即1~10(会在``TrainCondition.py``中体现), 因此共需要11个embedding
66 | self.condEmbedding = nn.Sequential(
67 | nn.Embedding(num_embeddings=num_labels + 1, embedding_dim=d_model, padding_idx=0),
68 | nn.Linear(d_model, dim),
69 | Swish(),
70 | nn.Linear(dim, dim),
71 | )
72 |
73 | def forward(self, labels):
74 | cemb = self.condEmbedding(labels)
75 | return cemb
76 |
77 |
78 | class DownSample(nn.Module):
79 | """
80 | 相比于``Diffusion.Model.DownSample``, 这里的降采样模块多加了一个5x5、stride=2的conv层
81 | 前向过程由3x3和5x5卷积输出相加得来,不知为什么这么做,可能为了融合更多尺度的信息
82 | 查看原文(4.Experiments 3~4行),原文描述所使用的模型与《Diffusion Models Beat GANs on Image Synthesis》所用模型一致,
83 | 但是该文章源码并没有使用这种降采样方式,只是简单的3x3或者avg_pool
84 | """
85 | def __init__(self, in_ch):
86 | super().__init__()
87 | self.c1 = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
88 | self.c2 = nn.Conv2d(in_ch, in_ch, 5, stride=2, padding=2)
89 |
90 | def forward(self, x, temb, cemb):
91 | x = self.c1(x) + self.c2(x)
92 | return x
93 |
94 |
95 | class UpSample(nn.Module):
96 | """
97 | 相比于``Diffusion.Model.UpSample``, 这里的上采样模块使用反卷积而不是最近邻插值
98 | 同``DownSample``也不明白原因,因该两种方式都可以,看个人喜好。
99 | """
100 | def __init__(self, in_ch):
101 | super().__init__()
102 | self.c = nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1)
103 | self.t = nn.ConvTranspose2d(in_ch, in_ch, kernel_size=5, stride=2, padding=2, output_padding=1)
104 |
105 | def forward(self, x, temb, cemb):
106 | _, _, H, W = x.shape
107 | x = self.t(x)
108 | x = self.c(x)
109 | return x
110 |
111 |
112 | class AttnBlock(nn.Module):
113 | """
114 | 和``Diffusion.Model``中的``AttnBlock``一模一样
115 | """
116 | def __init__(self, in_ch):
117 | super().__init__()
118 | self.group_norm = nn.GroupNorm(32, in_ch)
119 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
120 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
121 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
122 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
123 |
124 | def forward(self, x):
125 | B, C, H, W = x.shape
126 | h = self.group_norm(x)
127 | q = self.proj_q(h)
128 | k = self.proj_k(h)
129 | v = self.proj_v(h)
130 |
131 | q = q.permute(0, 2, 3, 1).view(B, H * W, C)
132 | k = k.view(B, C, H * W)
133 | w = torch.bmm(q, k) * (int(C) ** (-0.5))
134 | assert list(w.shape) == [B, H * W, H * W]
135 | w = F.softmax(w, dim=-1)
136 |
137 | v = v.permute(0, 2, 3, 1).view(B, H * W, C)
138 | h = torch.bmm(w, v)
139 | assert list(h.shape) == [B, H * W, C]
140 | h = h.view(B, H, W, C).permute(0, 3, 1, 2)
141 | h = self.proj(h)
142 |
143 | return x + h
144 |
145 |
146 | class ResBlock(nn.Module):
147 | """
148 | 相比于``Diffusion.Model.ResBlock``, 这里的残差模块多加了一个条件投射层``self.cond_proj``,
149 | 在这里其实可以直接把它看作另一个time-embedding, 它们参与训练的方式一模一样
150 | """
151 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=True):
152 | super().__init__()
153 | self.block1 = nn.Sequential(
154 | nn.GroupNorm(32, in_ch),
155 | Swish(),
156 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
157 | )
158 | self.temb_proj = nn.Sequential(
159 | Swish(),
160 | nn.Linear(tdim, out_ch),
161 | )
162 | self.cond_proj = nn.Sequential(
163 | Swish(),
164 | nn.Linear(tdim, out_ch),
165 | )
166 | self.block2 = nn.Sequential(
167 | nn.GroupNorm(32, out_ch),
168 | Swish(),
169 | nn.Dropout(dropout),
170 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
171 | )
172 | if in_ch != out_ch:
173 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
174 | else:
175 | self.shortcut = nn.Identity()
176 | if attn:
177 | self.attn = AttnBlock(out_ch)
178 | else:
179 | self.attn = nn.Identity()
180 |
181 | def forward(self, x, temb, cemb):
182 | h = self.block1(x)
183 | h += self.temb_proj(temb)[:, :, None, None] # 加上time-embedding
184 | h += self.cond_proj(cemb)[:, :, None, None] # 加上conditional-embedding
185 | h = self.block2(h) # 特征融合
186 |
187 | h = h + self.shortcut(x)
188 | h = self.attn(h)
189 | return h
190 |
191 |
192 | class UNet(nn.Module):
193 | """
194 | 相比于``Diffusion.Model.UNet``, 这里的UNet模块就多加了一个``cond_embedding``,
195 | 还有一个变化是在降采样和上采样阶段没有加自注意力层,只在中间过度的时候加了一次,这我不明白是何用意,
196 | 可能是希望网络不要从自己身上学到太多,多关注condition?(我瞎猜的)
197 | """
198 | def __init__(self, T, num_labels, ch, ch_mult, num_res_blocks, dropout):
199 | super().__init__()
200 | tdim = ch * 4
201 | self.time_embedding = TimeEmbedding(T, ch, tdim)
202 | self.cond_embedding = ConditionalEmbedding(num_labels, ch, tdim)
203 | self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
204 | self.downblocks = nn.ModuleList()
205 | chs = [ch] # record output channel when dowmsample for upsample
206 | now_ch = ch
207 | for i, mult in enumerate(ch_mult):
208 | out_ch = ch * mult
209 | for _ in range(num_res_blocks):
210 | self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout))
211 | now_ch = out_ch
212 | chs.append(now_ch)
213 | if i != len(ch_mult) - 1:
214 | self.downblocks.append(DownSample(now_ch))
215 | chs.append(now_ch)
216 |
217 | self.middleblocks = nn.ModuleList([
218 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
219 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
220 | ])
221 |
222 | self.upblocks = nn.ModuleList()
223 | for i, mult in reversed(list(enumerate(ch_mult))):
224 | out_ch = ch * mult
225 | for _ in range(num_res_blocks + 1):
226 | self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=False))
227 | now_ch = out_ch
228 | if i != 0:
229 | self.upblocks.append(UpSample(now_ch))
230 | assert len(chs) == 0
231 |
232 | self.tail = nn.Sequential(
233 | nn.GroupNorm(32, now_ch),
234 | Swish(),
235 | nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
236 | )
237 |
238 | def forward(self, x, t, labels):
239 | # Timestep embedding
240 | temb = self.time_embedding(t)
241 | cemb = self.cond_embedding(labels)
242 | # Downsampling
243 | h = self.head(x)
244 | hs = [h]
245 | for layer in self.downblocks:
246 | h = layer(h, temb, cemb)
247 | hs.append(h)
248 | # Middle
249 | for layer in self.middleblocks:
250 | h = layer(h, temb, cemb)
251 | # Upsampling
252 | for layer in self.upblocks:
253 | if isinstance(layer, ResBlock):
254 | h = torch.cat([h, hs.pop()], dim=1)
255 | h = layer(h, temb, cemb)
256 | h = self.tail(h)
257 |
258 | assert len(hs) == 0
259 | return h
260 |
261 |
262 | if __name__ == '__main__':
263 | batch_size = 8
264 | model = UNet(
265 | T=1000, num_labels=10, ch=128, ch_mult=[1, 2, 2, 2],
266 | num_res_blocks=2, dropout=0.1)
267 | x = torch.randn(batch_size, 3, 32, 32)
268 | t = torch.randint(1000, size=[batch_size])
269 | labels = torch.randint(10, size=[batch_size])
270 | # resB = ResBlock(128, 256, 64, 0.1)
271 | # x = torch.randn(batch_size, 128, 32, 32)
272 | # t = torch.randn(batch_size, 64)
273 | # labels = torch.randn(batch_size, 64)
274 | # y = resB(x, t, labels)
275 | y = model(x, t, labels)
276 | print(y.shape)
277 |
278 |
--------------------------------------------------------------------------------
/Diffusion/Model.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import math
4 | import torch
5 | from torch import nn
6 | from torch.nn import init
7 | from torch.nn import functional as F
8 |
9 |
10 | class Swish(nn.Module):
11 | """
12 | 定义swish激活函数,可参考https://blog.csdn.net/bblingbbling/article/details/107105648
13 | """
14 | def forward(self, x):
15 | return x * torch.sigmoid(x)
16 |
17 |
18 | class TimeEmbedding(nn.Module):
19 | """
20 | 定义``时间嵌入``模块
21 | """
22 | def __init__(self, T, d_model, dim):
23 | """
24 | 初始的time-embedding是由一系列不同频率的正弦、余弦函数采样值表示,
25 | 即:[[sin(w_0*x), cos(w_0*x)],
26 | [sin(w_1*x), cos(w_1*x)],
27 | ...,
28 | [sin(w_T)*x, cos(w_T*x)]], 维度为 T * d_model
29 | 在本实例中,频率范围是[0:T], x在1e-4~1范围,共d_model // 2个离散点;将sin, cos并在一起组成d_model个离散点
30 | Args:
31 | T: int, 总迭代步数,本实例中T=1000
32 | d_model: 输入维度(通道数/初始embedding长度)
33 | dim: 输出维度(通道数)
34 | """
35 | assert d_model % 2 == 0
36 | super().__init__()
37 | # 前两行计算x向量,共64个点
38 | emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
39 | emb = torch.exp(-emb)
40 | # T个时间位置组成频率部分
41 | pos = torch.arange(T).float()
42 | # 两两相乘构成T*(d_model//2)的矩阵,并assert形状
43 | emb = pos[:, None] * emb[None, :]
44 | assert list(emb.shape) == [T, d_model // 2]
45 | # 计算不同频率sin, cos值,判断形状,并reshape到T*d_model
46 | emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
47 | assert list(emb.shape) == [T, d_model // 2, 2]
48 | emb = emb.view(T, d_model)
49 |
50 | # MLP层,通过初始编码计算提取特征后的embedding
51 | # 包含两个线性层,第一个用swish激活函数,第二个不使用激活函数
52 | self.timembedding = nn.Sequential(
53 | nn.Embedding.from_pretrained(emb),
54 | nn.Linear(d_model, dim),
55 | Swish(),
56 | nn.Linear(dim, dim),
57 | )
58 | self.initialize()
59 |
60 | def initialize(self):
61 | for module in self.modules():
62 | if isinstance(module, nn.Linear):
63 | init.xavier_uniform_(module.weight)
64 | init.zeros_(module.bias)
65 |
66 | def forward(self, t):
67 | emb = self.timembedding(t)
68 | return emb
69 |
70 |
71 | class DownSample(nn.Module):
72 | """
73 | 通过stride=2的卷积层进行降采样
74 | """
75 | def __init__(self, in_ch):
76 | super().__init__()
77 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
78 | self.initialize()
79 |
80 | def initialize(self):
81 | init.xavier_uniform_(self.main.weight)
82 | init.zeros_(self.main.bias)
83 |
84 | def forward(self, x, temb):
85 | x = self.main(x)
86 | return x
87 |
88 |
89 | class UpSample(nn.Module):
90 | """
91 | 通过conv+最近邻插值进行上采样
92 | """
93 | def __init__(self, in_ch):
94 | super().__init__()
95 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
96 | self.initialize()
97 |
98 | def initialize(self):
99 | init.xavier_uniform_(self.main.weight)
100 | init.zeros_(self.main.bias)
101 |
102 | def forward(self, x, temb):
103 | _, _, H, W = x.shape
104 | x = F.interpolate(
105 | x, scale_factor=2, mode='nearest')
106 | x = self.main(x)
107 | return x
108 |
109 |
110 | class AttnBlock(nn.Module):
111 | """
112 | 自注意力模块,其中线性层均用kernel为1的卷积层表示
113 | """
114 | def __init__(self, in_ch):
115 | # ``self.proj_q``, ``self.proj_k``, ``self.proj_v``分别用于学习query, key, value
116 | # ``self.proj``作为自注意力后的线性投射层
117 | super().__init__()
118 | self.group_norm = nn.GroupNorm(32, in_ch)
119 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
120 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
121 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
122 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
123 | self.initialize()
124 |
125 | def initialize(self):
126 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
127 | init.xavier_uniform_(module.weight)
128 | init.zeros_(module.bias)
129 | init.xavier_uniform_(self.proj.weight, gain=1e-5)
130 |
131 | def forward(self, x):
132 | B, C, H, W = x.shape
133 | # 输入经过组归一化以及全连接层后分别得到query, key, value
134 | h = self.group_norm(x)
135 | q = self.proj_q(h)
136 | k = self.proj_k(h)
137 | v = self.proj_v(h)
138 |
139 | # 用矩阵乘法计算query与key的相似性权重w
140 | # 其中的``torch.bmm``的效果是第1维不动,第2,3维的矩阵做矩阵乘法,
141 | # 如a.shape=(_n, _h, _m), b.shape=(_n, _m, _w) --> torch.bmm(a, b).shape=(_n, _h, _w)
142 | # 矩阵运算后得到的权重要除以根号C, 归一化(相当于去除通道数对权重w绝对值的影响)
143 | q = q.permute(0, 2, 3, 1).view(B, H * W, C)
144 | k = k.view(B, C, H * W)
145 | w = torch.bmm(q, k) * (int(C) ** (-0.5))
146 | assert list(w.shape) == [B, H * W, H * W]
147 | w = F.softmax(w, dim=-1)
148 |
149 | # 再用刚得到的权重w对value进行注意力加权,操作也是一次矩阵乘法运算
150 | v = v.permute(0, 2, 3, 1).view(B, H * W, C)
151 | h = torch.bmm(w, v)
152 | assert list(h.shape) == [B, H * W, C]
153 | h = h.view(B, H, W, C).permute(0, 3, 1, 2)
154 |
155 | # 最后经过线性投射层输出,返回值加上输入x构成跳跃连接(残差连接)
156 | h = self.proj(h)
157 |
158 | return x + h
159 |
160 |
161 | class ResBlock(nn.Module):
162 | """
163 | 残差网络模块
164 | """
165 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
166 | """
167 | Args:
168 | in_ch: int, 输入通道数
169 | out_ch: int, 输出通道数
170 | tdim: int, time-embedding的长度/维数
171 | dropout: float, dropout的比例
172 | attn: bool, 是否使用自注意力模块
173 | """
174 | super().__init__()
175 | # 模块1: gn -> swish -> conv
176 | self.block1 = nn.Sequential(
177 | nn.GroupNorm(32, in_ch),
178 | Swish(),
179 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
180 | )
181 | # time_embedding 映射层: swish -> fc
182 | self.temb_proj = nn.Sequential(
183 | Swish(),
184 | nn.Linear(tdim, out_ch),
185 | )
186 | # 模块2: gn -> swish -> dropout -> conv
187 | self.block2 = nn.Sequential(
188 | nn.GroupNorm(32, out_ch),
189 | Swish(),
190 | nn.Dropout(dropout),
191 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
192 | )
193 | # 如果输入输出通道数不一样,则添加一个过渡层``shortcut``, 卷积核为1, 否则什么也不做
194 | if in_ch != out_ch:
195 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
196 | else:
197 | self.shortcut = nn.Identity()
198 | # 如果需要加attention, 则添加一个``AttnBlock``, 否则什么也不做
199 | if attn:
200 | self.attn = AttnBlock(out_ch)
201 | else:
202 | self.attn = nn.Identity()
203 | self.initialize()
204 |
205 | def initialize(self):
206 | for module in self.modules():
207 | if isinstance(module, (nn.Conv2d, nn.Linear)):
208 | init.xavier_uniform_(module.weight)
209 | init.zeros_(module.bias)
210 | init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)
211 |
212 | def forward(self, x, temb):
213 | h = self.block1(x) # 输入特征经过模块1编码
214 | h += self.temb_proj(temb)[:, :, None, None] # 将time-embedding加入到网络
215 | h = self.block2(h) # 将混合后的特征输入到模块2进一步编码
216 |
217 | h = h + self.shortcut(x) # 残差连接
218 | h = self.attn(h) # 经过自注意力模块(如果attn=True的话)
219 | return h
220 |
221 |
222 | class UNet(nn.Module):
223 | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
224 | """
225 |
226 | Args:
227 | T: int, 总迭代步数,本实例中T=1000
228 | ch: int, UNet第一层卷积的通道数,每下采样一次在这基础上翻倍, 本实例中ch=128
229 | ch_mult: list, UNet每次下采样通道数翻倍的乘数,本实例中ch_mult=[1,2,3,4]
230 | attn: list, 表示在第几次降采样中使用attention
231 | num_res_blocks: int, 降采样或者上采样中每一层次的残差模块数目
232 | dropout: float, dropout比率
233 | """
234 | super().__init__()
235 | # assert确保需要加attention的位置小于总降采样次数
236 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
237 | # 将time-embedding从长度为ch初始化编码到tdim = ch * 4
238 | tdim = ch * 4
239 | # 实例化初始的time-embedding层
240 | self.time_embedding = TimeEmbedding(T, ch, tdim)
241 | # 实例化头部卷积层
242 | self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
243 |
244 | # 实例化U-Net的编码器部分,即降采样部分,每一层次由``num_res_blocks``个残差块组成
245 | # 其中chs用于记录降采样过程中的各阶段通道数,now_ch表示当前阶段的通道数
246 | self.downblocks = nn.ModuleList()
247 | chs = [ch] # record output channel when dowmsample for upsample
248 | now_ch = ch
249 | for i, mult in enumerate(ch_mult): # i表示列表ch_mult的索引, mult表示ch_mult[i]
250 | out_ch = ch * mult
251 | for _ in range(num_res_blocks):
252 | self.downblocks.append(ResBlock(
253 | in_ch=now_ch, out_ch=out_ch, tdim=tdim,
254 | dropout=dropout, attn=(i in attn)))
255 | now_ch = out_ch
256 | chs.append(now_ch)
257 | if i != len(ch_mult) - 1:
258 | self.downblocks.append(DownSample(now_ch))
259 | chs.append(now_ch)
260 |
261 | # 实例化U-Net编码器和解码器的过渡层,由两个残差块组成
262 | # 这里我不明白为什么第一个残差块加attention, 第二个不加……问就是``工程科学``
263 | self.middleblocks = nn.ModuleList([
264 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
265 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
266 | ])
267 |
268 | # 实例化U-Net的解码器部分, 与编码器几乎对称
269 | # 唯一不同的是,每一层次的残差块比编码器多一个,
270 | # 原因是第一个残差块要用来融合当前特征图与跳转连接过来的特征图,第二、三个才是和编码器对称用来抽特征
271 | self.upblocks = nn.ModuleList()
272 | for i, mult in reversed(list(enumerate(ch_mult))):
273 | out_ch = ch * mult
274 | for _ in range(num_res_blocks + 1):
275 | self.upblocks.append(ResBlock(
276 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
277 | dropout=dropout, attn=(i in attn)))
278 | now_ch = out_ch
279 | if i != 0:
280 | self.upblocks.append(UpSample(now_ch))
281 | assert len(chs) == 0
282 |
283 | # 尾部模块: gn -> swish -> conv, 目的是回到原图通道数
284 | self.tail = nn.Sequential(
285 | nn.GroupNorm(32, now_ch),
286 | Swish(),
287 | nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
288 | )
289 | # 注意这里只初始化头部和尾部模块,因为其他模块在实例化的时候已经初始化过了
290 | self.initialize()
291 |
292 | def initialize(self):
293 | init.xavier_uniform_(self.head.weight)
294 | init.zeros_(self.head.bias)
295 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
296 | init.zeros_(self.tail[-1].bias)
297 |
298 | def forward(self, x, t):
299 | # Timestep embedding
300 | temb = self.time_embedding(t)
301 | # Downsampling
302 | h = self.head(x)
303 | hs = [h]
304 | for layer in self.downblocks:
305 | h = layer(h, temb)
306 | hs.append(h)
307 | # Middle
308 | for layer in self.middleblocks:
309 | h = layer(h, temb)
310 | # Upsampling
311 | for layer in self.upblocks:
312 | if isinstance(layer, ResBlock):
313 | h = torch.cat([h, hs.pop()], dim=1)
314 | h = layer(h, temb)
315 | h = self.tail(h)
316 |
317 | assert len(hs) == 0
318 | return h
319 |
320 |
321 | if __name__ == '__main__':
322 | batch_size = 8
323 | model = UNet(
324 | T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1],
325 | num_res_blocks=2, dropout=0.1)
326 | x = torch.randn(batch_size, 3, 32, 32)
327 | t = torch.randint(1000, (batch_size, ))
328 | y = model(x, t)
329 | print(y.shape)
330 |
331 |
--------------------------------------------------------------------------------