├── DiffusionModels ├── diffusionModels │ └── simpleDiffusion │ │ ├── __pycache__ │ │ ├── simpleDiffusion.cpython-36.pyc │ │ └── varianceSchedule.cpython-36.pyc │ │ ├── simpleDiffusion.py │ │ └── varianceSchedule.py ├── image_test.py ├── noisePredictModels │ └── Unet │ │ ├── UNet.py │ │ └── __pycache__ │ │ └── UNet.cpython-36.pyc └── utils │ ├── __pycache__ │ ├── networkHelper.cpython-36.pyc │ └── trainNetworkHelper.cpython-36.pyc │ ├── networkHelper.py │ └── trainNetworkHelper.py └── README.md /DiffusionModels/diffusionModels/simpleDiffusion/__pycache__/simpleDiffusion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHAINNEVERLIU/Pytorch-DDPM/e0b970c054f2b3140306ff86f805fca7f2369cb8/DiffusionModels/diffusionModels/simpleDiffusion/__pycache__/simpleDiffusion.cpython-36.pyc -------------------------------------------------------------------------------- /DiffusionModels/diffusionModels/simpleDiffusion/__pycache__/varianceSchedule.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHAINNEVERLIU/Pytorch-DDPM/e0b970c054f2b3140306ff86f805fca7f2369cb8/DiffusionModels/diffusionModels/simpleDiffusion/__pycache__/varianceSchedule.cpython-36.pyc -------------------------------------------------------------------------------- /DiffusionModels/diffusionModels/simpleDiffusion/simpleDiffusion.py: -------------------------------------------------------------------------------- 1 | from utils.networkHelper import * 2 | from diffusionModels.simpleDiffusion.varianceSchedule import VarianceSchedule 3 | 4 | 5 | class DiffusionModel(nn.Module): 6 | def __init__(self, 7 | schedule_name="linear_beta_schedule", 8 | timesteps=1000, 9 | beta_start=0.0001, 10 | beta_end=0.02, 11 | denoise_model=None): 12 | super(DiffusionModel, self).__init__() 13 | 14 | self.denoise_model = denoise_model 15 | 16 | # 方差生成 17 | variance_schedule_func = VarianceSchedule(schedule_name=schedule_name, beta_start=beta_start, beta_end=beta_end) 18 | self.timesteps = timesteps 19 | self.betas = variance_schedule_func(timesteps) 20 | # define alphas 21 | self.alphas = 1. - self.betas 22 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 23 | self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) 24 | self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) 25 | 26 | # calculations for diffusion q(x_t | x_{t-1}) and others 27 | # x_t = sqrt(alphas_cumprod)*x_0 + sqrt(1 - alphas_cumprod)*z_t 28 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 29 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) 30 | 31 | # calculations for posterior q(x_{t-1} | x_t, x_0) 32 | # 这里用的不是简化后的方差而是算出来的 33 | self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) 34 | 35 | def q_sample(self, x_start, t, noise=None): 36 | # forward diffusion (using the nice property) 37 | # x_t = sqrt(alphas_cumprod)*x_0 + sqrt(1 - alphas_cumprod)*z_t 38 | if noise is None: 39 | noise = torch.randn_like(x_start) 40 | 41 | sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape) 42 | sqrt_one_minus_alphas_cumprod_t = extract( 43 | self.sqrt_one_minus_alphas_cumprod, t, x_start.shape 44 | ) 45 | 46 | return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise 47 | 48 | def compute_loss(self, x_start, t, noise=None, loss_type="l1"): 49 | if noise is None: 50 | noise = torch.randn_like(x_start) 51 | 52 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 53 | predicted_noise = self.denoise_model(x_noisy, t) 54 | 55 | if loss_type == 'l1': 56 | loss = F.l1_loss(noise, predicted_noise) 57 | elif loss_type == 'l2': 58 | loss = F.mse_loss(noise, predicted_noise) 59 | elif loss_type == "huber": 60 | loss = F.smooth_l1_loss(noise, predicted_noise) 61 | else: 62 | raise NotImplementedError() 63 | 64 | return loss 65 | 66 | @torch.no_grad() 67 | def p_sample(self, x, t, t_index): 68 | betas_t = extract(self.betas, t, x.shape) 69 | sqrt_one_minus_alphas_cumprod_t = extract( 70 | self.sqrt_one_minus_alphas_cumprod, t, x.shape 71 | ) 72 | sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape) 73 | 74 | # Equation 11 in the paper 75 | # Use our model (noise predictor) to predict the mean 76 | model_mean = sqrt_recip_alphas_t * ( 77 | x - betas_t * self.denoise_model(x, t) / sqrt_one_minus_alphas_cumprod_t 78 | ) 79 | 80 | if t_index == 0: 81 | return model_mean 82 | else: 83 | posterior_variance_t = extract(self.posterior_variance, t, x.shape) 84 | noise = torch.randn_like(x) 85 | # Algorithm 2 line 4: 86 | return model_mean + torch.sqrt(posterior_variance_t) * noise 87 | 88 | @torch.no_grad() 89 | def p_sample_loop(self, shape): 90 | device = next(self.denoise_model.parameters()).device 91 | 92 | b = shape[0] 93 | # start from pure noise (for each example in the batch) 94 | img = torch.randn(shape, device=device) 95 | imgs = [] 96 | 97 | for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps): 98 | img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), i) 99 | imgs.append(img.cpu().numpy()) 100 | return imgs 101 | 102 | @torch.no_grad() 103 | def sample(self, image_size, batch_size=16, channels=3): 104 | return self.p_sample_loop(shape=(batch_size, channels, image_size, image_size)) 105 | 106 | def forward(self, mode, **kwargs): 107 | if mode == "train": 108 | # 先判断必须参数 109 | if "x_start" and "t" in kwargs.keys(): 110 | # 接下来判断一些非必选参数 111 | if "loss_type" and "noise" in kwargs.keys(): 112 | return self.compute_loss(x_start=kwargs["x_start"], t=kwargs["t"], 113 | noise=kwargs["noise"], loss_type=kwargs["loss_type"]) 114 | elif "loss_type" in kwargs.keys(): 115 | return self.compute_loss(x_start=kwargs["x_start"], t=kwargs["t"], loss_type=kwargs["loss_type"]) 116 | elif "noise" in kwargs.keys(): 117 | return self.compute_loss(x_start=kwargs["x_start"], t=kwargs["t"], noise=kwargs["noise"]) 118 | else: 119 | return self.compute_loss(x_start=kwargs["x_start"], t=kwargs["t"]) 120 | 121 | else: 122 | raise ValueError("扩散模型在训练时必须传入参数x_start和t!") 123 | 124 | elif mode == "generate": 125 | if "image_size" and "batch_size" and "channels" in kwargs.keys(): 126 | return self.sample(image_size=kwargs["image_size"], 127 | batch_size=kwargs["batch_size"], 128 | channels=kwargs["channels"]) 129 | else: 130 | raise ValueError("扩散模型在生成图片时必须传入image_size, batch_size, channels等三个参数") 131 | else: 132 | raise ValueError("mode参数必须从{train}和{generate}两种模式中选择") 133 | 134 | 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /DiffusionModels/diffusionModels/simpleDiffusion/varianceSchedule.py: -------------------------------------------------------------------------------- 1 | from utils.networkHelper import * 2 | 3 | 4 | def cosine_beta_schedule(timesteps, s=0.008, **kwargs): 5 | """ 6 | cosine schedule as proposed in https://arxiv.org/abs/2102.09672 7 | """ 8 | 9 | steps = timesteps + 1 10 | x = torch.linspace(0, timesteps, steps) 11 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 12 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 13 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 14 | return torch.clip(betas, 0.0001, 0.9999) 15 | 16 | 17 | def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02): 18 | return torch.linspace(beta_start, beta_end, timesteps) 19 | 20 | 21 | def quadratic_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02): 22 | return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 23 | 24 | 25 | def sigmoid_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02): 26 | betas = torch.linspace(-6, 6, timesteps) 27 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 28 | 29 | 30 | class VarianceSchedule(nn.Module): 31 | def __init__(self, schedule_name="linear_beta_schedule", beta_start=None, beta_end=None): 32 | super(VarianceSchedule, self).__init__() 33 | 34 | self.schedule_name = schedule_name 35 | 36 | # 定义函数字典 37 | beta_schedule_dict = {'linear_beta_schedule': linear_beta_schedule, 38 | 'cosine_beta_schedule': cosine_beta_schedule, 39 | 'quadratic_beta_schedule': quadratic_beta_schedule, 40 | 'sigmoid_beta_schedule': sigmoid_beta_schedule} 41 | 42 | # 根据输入字符串选择相应的函数 43 | if schedule_name in beta_schedule_dict: 44 | self.selected_schedule = beta_schedule_dict[schedule_name] 45 | else: 46 | raise ValueError('Function not found in dictionary') 47 | 48 | if beta_end and beta_start is None and schedule_name != "cosine_beta_schedule": 49 | self.beta_start = 0.0001 50 | self.beta_end = 0.02 51 | else: 52 | self.beta_start = beta_start 53 | self.beta_end = beta_end 54 | 55 | def forward(self, timesteps): 56 | # print(self.beta_start) 57 | return self.selected_schedule(timesteps=timesteps) if self.schedule_name == "cosine_beta_schedule" \ 58 | else self.selected_schedule(timesteps=timesteps, beta_start=self.beta_start, beta_end=self.beta_end) 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /DiffusionModels/image_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torchvision 4 | import torch 5 | import torchvision.transforms as transforms 6 | from torch.optim import Adam 7 | from utils.networkHelper import * 8 | 9 | from noisePredictModels.Unet.UNet import Unet 10 | from utils.trainNetworkHelper import SimpleDiffusionTrainer 11 | from diffusionModels.simpleDiffusion.simpleDiffusion import DiffusionModel 12 | 13 | 14 | # 数据集加载 15 | data_root_path = "./dataset/" 16 | if not os.path.exists(data_root_path): 17 | os.makedirs(data_root_path) 18 | 19 | imagenet_data = torchvision.datasets.FashionMNIST(data_root_path, train=True, download=True, transform=transforms.ToTensor()) 20 | 21 | image_size = 28 22 | channels = 1 23 | batch_size = 128 24 | 25 | data_loader = torch.utils.data.DataLoader(imagenet_data, 26 | batch_size=batch_size, 27 | shuffle=True, 28 | num_workers=0) 29 | 30 | 31 | device = "cuda" if torch.cuda.is_available() else "cpu" 32 | dim_mults = (1, 2, 4,) 33 | 34 | denoise_model = Unet( 35 | dim=image_size, 36 | channels=channels, 37 | dim_mults=dim_mults 38 | ) 39 | 40 | timesteps = 1000 41 | schedule_name = "linear_beta_schedule" 42 | DDPM = DiffusionModel(schedule_name=schedule_name, 43 | timesteps=timesteps, 44 | beta_start=0.0001, 45 | beta_end=0.02, 46 | denoise_model=denoise_model).to(device) 47 | 48 | optimizer = Adam(DDPM.parameters(), lr=1e-3) 49 | epoches = 20 50 | 51 | Trainer = SimpleDiffusionTrainer(epoches=epoches, 52 | train_loader=data_loader, 53 | optimizer=optimizer, 54 | device=device, 55 | timesteps=timesteps) 56 | 57 | # 训练参数设置 58 | root_path = "./saved_train_models" 59 | setting = "imageSize{}_channels{}_dimMults{}_timeSteps{}_scheduleName{}".format(image_size, channels, dim_mults, timesteps, schedule_name) 60 | 61 | saved_path = os.path.join(root_path, setting) 62 | if not os.path.exists(saved_path): 63 | os.makedirs(saved_path) 64 | 65 | 66 | # 训练好的模型加载,如果模型是已经训练好的,则可以将下面两行代码取消注释 67 | # best_model_path = saved_path + '/' + 'BestModel.pth' 68 | # DDPM.load_state_dict(torch.load(best_model_path)) 69 | 70 | # 如果模型已经训练好则注释下面这行代码,反之则注释上面两行代码 71 | DDPM = Trainer(DDPM, model_save_path=saved_path) 72 | 73 | # 采样:sample 64 images 74 | samples = DDPM(mode="generate", image_size=image_size, batch_size=64, channels=channels) 75 | 76 | # 随机挑一张显示 77 | random_index = 1 78 | generate_image = samples[-1][random_index].reshape(channels, image_size, image_size) 79 | figtest = reverse_transform(torch.from_numpy(generate_image)) 80 | figtest.show() 81 | 82 | 83 | -------------------------------------------------------------------------------- /DiffusionModels/noisePredictModels/Unet/UNet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from einops import rearrange, reduce 3 | 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | from utils.networkHelper import * 7 | 8 | 9 | class WeightStandardizedConv2d(nn.Conv2d): 10 | """ 11 | 权重标准化后的卷积模块 12 | https://arxiv.org/abs/1903.10520 13 | weight standardization purportedly works synergistically with group normalization 14 | """ 15 | 16 | def forward(self, x): 17 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 18 | 19 | weight = self.weight 20 | mean = reduce(weight, "o ... -> o 1 1 1", "mean") 21 | var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False)) 22 | normalized_weight = (weight - mean) * (var + eps).rsqrt() 23 | 24 | return F.conv2d( 25 | x, 26 | normalized_weight, 27 | self.bias, 28 | self.stride, 29 | self.padding, 30 | self.dilation, 31 | self.groups, 32 | ) 33 | 34 | 35 | class Block(nn.Module): 36 | def __init__(self, dim, dim_out, groups=8): 37 | super().__init__() 38 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1) 39 | self.norm = nn.GroupNorm(groups, dim_out) 40 | self.act = nn.SiLU() 41 | 42 | def forward(self, x, scale_shift=None): 43 | x = self.proj(x) 44 | x = self.norm(x) 45 | 46 | if exists(scale_shift): 47 | scale, shift = scale_shift 48 | x = x * (scale + 1) + shift 49 | 50 | x = self.act(x) 51 | return x 52 | 53 | 54 | class ResnetBlock(nn.Module): 55 | """https://arxiv.org/abs/1512.03385""" 56 | 57 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): 58 | super().__init__() 59 | self.mlp = ( 60 | nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) 61 | if exists(time_emb_dim) 62 | else None 63 | ) 64 | 65 | self.block1 = Block(dim, dim_out, groups=groups) 66 | self.block2 = Block(dim_out, dim_out, groups=groups) 67 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 68 | 69 | def forward(self, x, time_emb=None): 70 | scale_shift = None 71 | if exists(self.mlp) and exists(time_emb): 72 | time_emb = self.mlp(time_emb) 73 | time_emb = rearrange(time_emb, "b c -> b c 1 1") 74 | scale_shift = time_emb.chunk(2, dim=1) 75 | 76 | h = self.block1(x, scale_shift=scale_shift) 77 | h = self.block2(h) 78 | return h + self.res_conv(x) 79 | 80 | 81 | class Attention(nn.Module): 82 | def __init__(self, dim, heads=4, dim_head=32): 83 | super().__init__() 84 | self.scale = dim_head**-0.5 85 | self.heads = heads 86 | hidden_dim = dim_head * heads 87 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 88 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 89 | 90 | def forward(self, x): 91 | b, c, h, w = x.shape 92 | qkv = self.to_qkv(x).chunk(3, dim=1) 93 | q, k, v = map( 94 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 95 | ) 96 | q = q * self.scale 97 | 98 | sim = einsum("b h d i, b h d j -> b h i j", q, k) 99 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 100 | attn = sim.softmax(dim=-1) 101 | 102 | out = einsum("b h i j, b h d j -> b h i d", attn, v) 103 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 104 | return self.to_out(out) 105 | 106 | 107 | class LinearAttention(nn.Module): 108 | def __init__(self, dim, heads=4, dim_head=32): 109 | super().__init__() 110 | self.scale = dim_head**-0.5 111 | self.heads = heads 112 | hidden_dim = dim_head * heads 113 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 114 | 115 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 116 | nn.GroupNorm(1, dim)) 117 | 118 | def forward(self, x): 119 | b, c, h, w = x.shape 120 | qkv = self.to_qkv(x).chunk(3, dim=1) 121 | q, k, v = map( 122 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 123 | ) 124 | 125 | q = q.softmax(dim=-2) 126 | k = k.softmax(dim=-1) 127 | 128 | q = q * self.scale 129 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 130 | 131 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q) 132 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) 133 | return self.to_out(out) 134 | 135 | 136 | class PreNorm(nn.Module): 137 | def __init__(self, dim, fn): 138 | super().__init__() 139 | self.fn = fn 140 | self.norm = nn.GroupNorm(1, dim) 141 | 142 | def forward(self, x): 143 | x = self.norm(x) 144 | return self.fn(x) 145 | 146 | 147 | class Unet(nn.Module): 148 | def __init__( 149 | self, 150 | dim, 151 | init_dim=None, 152 | out_dim=None, 153 | dim_mults=(1, 2, 4, 8), 154 | channels=3, 155 | self_condition=False, 156 | resnet_block_groups=4, 157 | ): 158 | super().__init__() 159 | 160 | # determine dimensions 161 | self.channels = channels 162 | self.self_condition = self_condition 163 | input_channels = channels * (2 if self_condition else 1) 164 | 165 | init_dim = default(init_dim, dim) 166 | self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # changed to 1 and 0 from 7,3 167 | 168 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 169 | in_out = list(zip(dims[:-1], dims[1:])) 170 | 171 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 172 | 173 | # time embeddings 174 | time_dim = dim * 4 175 | 176 | self.time_mlp = nn.Sequential( 177 | SinusoidalPositionEmbeddings(dim), 178 | nn.Linear(dim, time_dim), 179 | nn.GELU(), 180 | nn.Linear(time_dim, time_dim), 181 | ) 182 | 183 | # layers 184 | self.downs = nn.ModuleList([]) 185 | self.ups = nn.ModuleList([]) 186 | num_resolutions = len(in_out) 187 | 188 | for ind, (dim_in, dim_out) in enumerate(in_out): 189 | is_last = ind >= (num_resolutions - 1) 190 | 191 | self.downs.append( 192 | nn.ModuleList( 193 | [ 194 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 195 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 196 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 197 | Downsample(dim_in, dim_out) 198 | if not is_last 199 | else nn.Conv2d(dim_in, dim_out, 3, padding=1), 200 | ] 201 | ) 202 | ) 203 | 204 | mid_dim = dims[-1] 205 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 206 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 207 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 208 | 209 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 210 | is_last = ind == (len(in_out) - 1) 211 | 212 | self.ups.append( 213 | nn.ModuleList( 214 | [ 215 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 216 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 217 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 218 | Upsample(dim_out, dim_in) 219 | if not is_last 220 | else nn.Conv2d(dim_out, dim_in, 3, padding=1), 221 | ] 222 | ) 223 | ) 224 | 225 | self.out_dim = default(out_dim, channels) 226 | 227 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) 228 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 229 | 230 | def forward(self, x, time, x_self_cond=None): 231 | if self.self_condition: 232 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 233 | x = torch.cat((x_self_cond, x), dim=1) 234 | 235 | x = self.init_conv(x) 236 | r = x.clone() 237 | 238 | t = self.time_mlp(time) 239 | 240 | h = [] 241 | 242 | for block1, block2, attn, downsample in self.downs: 243 | x = block1(x, t) 244 | h.append(x) 245 | 246 | x = block2(x, t) 247 | x = attn(x) 248 | h.append(x) 249 | 250 | x = downsample(x) 251 | 252 | x = self.mid_block1(x, t) 253 | x = self.mid_attn(x) 254 | x = self.mid_block2(x, t) 255 | 256 | for block1, block2, attn, upsample in self.ups: 257 | x = torch.cat((x, h.pop()), dim=1) 258 | x = block1(x, t) 259 | 260 | x = torch.cat((x, h.pop()), dim=1) 261 | x = block2(x, t) 262 | x = attn(x) 263 | 264 | x = upsample(x) 265 | 266 | x = torch.cat((x, r), dim=1) 267 | 268 | x = self.final_res_block(x, t) 269 | return self.final_conv(x) 270 | 271 | 272 | 273 | -------------------------------------------------------------------------------- /DiffusionModels/noisePredictModels/Unet/__pycache__/UNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHAINNEVERLIU/Pytorch-DDPM/e0b970c054f2b3140306ff86f805fca7f2369cb8/DiffusionModels/noisePredictModels/Unet/__pycache__/UNet.cpython-36.pyc -------------------------------------------------------------------------------- /DiffusionModels/utils/__pycache__/networkHelper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHAINNEVERLIU/Pytorch-DDPM/e0b970c054f2b3140306ff86f805fca7f2369cb8/DiffusionModels/utils/__pycache__/networkHelper.cpython-36.pyc -------------------------------------------------------------------------------- /DiffusionModels/utils/__pycache__/trainNetworkHelper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHAINNEVERLIU/Pytorch-DDPM/e0b970c054f2b3140306ff86f805fca7f2369cb8/DiffusionModels/utils/__pycache__/trainNetworkHelper.cpython-36.pyc -------------------------------------------------------------------------------- /DiffusionModels/utils/networkHelper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import math 5 | from inspect import isfunction 6 | from einops.layers.torch import Rearrange 7 | from torchvision.transforms import Compose, Lambda, ToPILImage 8 | import torch.nn.functional as F 9 | from tqdm.auto import tqdm 10 | 11 | 12 | def exists(x): 13 | """ 14 | 判断数值是否为空 15 | :param x: 输入数据 16 | :return: 如果不为空则True 反之则返回False 17 | """ 18 | return x is not None 19 | 20 | 21 | def default(val, d): 22 | """ 23 | 该函数的目的是提供一个简单的机制来获取给定变量的默认值。 24 | 如果 val 存在,则返回该值。如果不存在,则使用 d 函数提供的默认值, 25 | 或者如果 d 不是一个函数,则返回 d。 26 | :param val:需要判断的变量 27 | :param d:提供默认值的变量或函数 28 | :return: 29 | """ 30 | if exists(val): 31 | return val 32 | return d() if isfunction(d) else d 33 | 34 | 35 | def num_to_groups(num, divisor): 36 | """ 37 | 该函数的目的是将一个数字分成若干组,每组的大小都为 divisor,并返回一个列表, 38 | 其中包含所有这些组的大小。如果 num 不能完全被 divisor 整除,则最后一组的大小将小于 divisor。 39 | :param num: 40 | :param divisor: 41 | :return: 42 | """ 43 | groups = num // divisor 44 | remainder = num % divisor 45 | arr = [divisor] * groups 46 | if remainder > 0: 47 | arr.append(remainder) 48 | return arr 49 | 50 | 51 | class Residual(nn.Module): 52 | def __init__(self, fn): 53 | """ 54 | 残差连接模块 55 | :param fn: 激活函数类型 56 | """ 57 | super().__init__() 58 | self.fn = fn 59 | 60 | def forward(self, x, *args, **kwargs): 61 | """ 62 | 残差连接前馈 63 | :param x: 输入数据 64 | :param args: 65 | :param kwargs: 66 | :return: f(x) + x 67 | """ 68 | return self.fn(x, *args, **kwargs) + x 69 | 70 | 71 | def Upsample(dim, dim_out=None): 72 | """ 73 | 这个上采样模块的作用是将输入张量的尺寸在宽和高上放大 2 倍 74 | :param dim: 75 | :param dim_out: 76 | :return: 77 | """ 78 | return nn.Sequential( 79 | nn.Upsample(scale_factor=2, mode="nearest"), # 先使用最近邻填充将数据在长宽上翻倍 80 | nn.Conv2d(dim, default(dim_out, dim), 3, padding=1), # 再使用卷积对翻倍后的数据提取局部相关关系填充 81 | ) 82 | 83 | 84 | def Downsample(dim, dim_out=None): 85 | # No More Strided Convolutions or Pooling 86 | """ 87 | 下采样模块的作用是将输入张量的分辨率降低,通常用于在深度学习模型中对特征图进行降采样。 88 | 在这个实现中,下采样操作的方式是使用一个 $2 \times 2$ 的最大池化操作, 89 | 将输入张量的宽和高都缩小一半,然后再使用上述的变换和卷积操作得到输出张量。 90 | 由于这个实现使用了形状变换操作,因此没有使用传统的卷积或池化操作进行下采样, 91 | 从而避免了在下采样过程中丢失信息的问题。 92 | :param dim: 93 | :param dim_out: 94 | :return: 95 | """ 96 | return nn.Sequential( 97 | # 将输入张量的形状由 (batch_size, channel, height, width) 变换为 (batch_size, channel * 4, height / 2, width / 2) 98 | Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), 99 | # 对变换后的张量进行一个 $1 \times 1$ 的卷积操作,将通道数从 dim * 4(即变换后的通道数)降到 dim(即指定的输出通道数),得到输出张量。 100 | nn.Conv2d(dim * 4, default(dim_out, dim), 1), 101 | ) 102 | 103 | 104 | class SinusoidalPositionEmbeddings(nn.Module): 105 | def __init__(self, dim): 106 | super().__init__() 107 | self.dim = dim 108 | 109 | def forward(self, time): 110 | device = time.device 111 | half_dim = self.dim // 2 112 | embeddings = math.log(10000) / (half_dim - 1) 113 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 114 | embeddings = time[:, None] * embeddings[None, :] 115 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 116 | return embeddings 117 | 118 | 119 | def extract(a, t, x_shape): 120 | """ 121 | 从给定的张量a中检索特定的元素。t是一个包含要检索的索引的张量, 122 | 这些索引对应于a张量中的元素。这个函数的输出是一个张量, 123 | 包含了t张量中每个索引对应的a张量中的元素 124 | :param a: 125 | :param t: 126 | :param x_shape: 127 | :return: 128 | """ 129 | batch_size = t.shape[0] 130 | out = a.gather(-1, t.cpu()) 131 | return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) 132 | 133 | 134 | # show a random one 135 | reverse_transform = Compose([ 136 | Lambda(lambda t: (t + 1) / 2), 137 | Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC 138 | Lambda(lambda t: t * 255.), 139 | Lambda(lambda t: t.numpy().astype(np.uint8)), 140 | ToPILImage(), 141 | ]) 142 | 143 | -------------------------------------------------------------------------------- /DiffusionModels/utils/trainNetworkHelper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from tqdm.auto import tqdm 5 | 6 | 7 | class EarlyStopping: 8 | def __init__(self, patience=7, verbose=False, delta=0): 9 | self.patience = patience 10 | self.verbose = verbose 11 | self.counter = 0 12 | self.best_score = None 13 | self.early_stop = False 14 | self.val_loss_min = np.Inf 15 | self.delta = delta 16 | 17 | def __call__(self, val_loss, model, path): 18 | score = -val_loss 19 | if self.best_score is None: 20 | self.best_score = score 21 | self.save_checkpoint(val_loss, model, path) 22 | elif score < self.best_score + self.delta: 23 | self.counter += 1 24 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 25 | if self.counter >= self.patience: 26 | self.early_stop = True 27 | else: 28 | self.best_score = score 29 | self.save_checkpoint(val_loss, model, path) 30 | self.counter = 0 31 | 32 | def save_checkpoint(self, val_loss, model, path): 33 | if self.verbose: 34 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 35 | torch.save(model.state_dict(), path+'/'+'checkpoints.pth') 36 | self.val_loss_min = val_loss 37 | 38 | 39 | class dotdict(dict): 40 | """dot.notation access to dictionary attributes""" 41 | __getattr__ = dict.get 42 | __setattr__ = dict.__setitem__ 43 | __delattr__ = dict.__delitem__ 44 | 45 | 46 | class TrainerBase(nn.Module): 47 | def __init__(self, 48 | epoches, 49 | train_loader, 50 | optimizer, 51 | device, 52 | IFEarlyStopping, 53 | IFadjust_learning_rate, 54 | **kwargs): 55 | super(TrainerBase, self).__init__() 56 | 57 | self.epoches = epoches 58 | if self.epoches is None: 59 | raise ValueError("请传入训练总迭代次数") 60 | 61 | self.train_loader = train_loader 62 | if self.train_loader is None: 63 | raise ValueError("请传入train_loader") 64 | 65 | self.optimizer = optimizer 66 | if self.optimizer is None: 67 | raise ValueError("请传入优化器类") 68 | 69 | self.device = device 70 | if self.device is None: 71 | raise ValueError("请传入运行设备类型") 72 | 73 | # 如果启用了提前停止策略则必须进行下面一系列判断 74 | self.IFEarlyStopping = IFEarlyStopping 75 | if IFEarlyStopping: 76 | if "patience" in kwargs.keys(): 77 | self.early_stopping = EarlyStopping(patience=kwargs["patience"], verbose=True) 78 | else: 79 | raise ValueError("启用提前停止策略必须输入{patience=int X}参数") 80 | 81 | if "val_loader" in kwargs.keys(): 82 | self.val_loader = kwargs["val_loader"] 83 | else: 84 | raise ValueError("启用提前停止策略必须输入验证集val_loader") 85 | 86 | # 如果启用了学习率调整策略则必须进行下面一系列判断 87 | self.IFadjust_learning_rate = IFadjust_learning_rate 88 | if IFadjust_learning_rate: 89 | if "types" in kwargs.keys(): 90 | self.types = kwargs["types"] 91 | if "lr_adjust" in kwargs.keys(): 92 | self.lr_adjust = kwargs["lr_adjust"] 93 | else: 94 | self.lr_adjust = None 95 | else: 96 | raise ValueError("启用学习率调整策略必须从{type1 or type2}中选择学习率调整策略参数types") 97 | 98 | def adjust_learning_rate(self, epoch, learning_rate): 99 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 100 | if self.types == 'type1': 101 | lr_adjust = {epoch: learning_rate * (0.1 ** ((epoch - 1) // 10))} # 每10个epoch,学习率缩小10倍 102 | elif self.types == 'type2': 103 | if self.lr_adjust is not None: 104 | lr_adjust = self.lr_adjust 105 | else: 106 | lr_adjust = { 107 | 5: 1e-4, 10: 5e-5, 20: 1e-5, 25: 5e-6, 108 | 30: 1e-6, 35: 5e-7, 40: 1e-8 109 | } 110 | else: 111 | raise ValueError("请从{{0}or{1}}中选择学习率调整策略参数types".format("type1", "type2")) 112 | 113 | if epoch in lr_adjust.keys(): 114 | lr = lr_adjust[epoch] 115 | for param_group in self.optimizer.param_groups: 116 | param_group['lr'] = lr 117 | print('Updating learning rate to {}'.format(lr)) 118 | 119 | @staticmethod 120 | def save_best_model(model, path): 121 | torch.save(model.state_dict(), path+'/'+'BestModel.pth') 122 | print("成功将此次训练模型存储(储存格式为.pth)至:" + str(path)) 123 | 124 | def forward(self, model, *args, **kwargs): 125 | 126 | pass 127 | 128 | 129 | class SimpleDiffusionTrainer(TrainerBase): 130 | def __init__(self, 131 | epoches=None, 132 | train_loader=None, 133 | optimizer=None, 134 | device=None, 135 | IFEarlyStopping=False, 136 | IFadjust_learning_rate=False, 137 | **kwargs): 138 | super(SimpleDiffusionTrainer, self).__init__(epoches, train_loader, optimizer, device, 139 | IFEarlyStopping, IFadjust_learning_rate, 140 | **kwargs) 141 | 142 | if "timesteps" in kwargs.keys(): 143 | self.timesteps = kwargs["timesteps"] 144 | else: 145 | raise ValueError("扩散模型训练必须提供扩散步数参数") 146 | 147 | def forward(self, model, *args, **kwargs): 148 | 149 | for i in range(self.epoches): 150 | losses = [] 151 | loop = tqdm(enumerate(self.train_loader), total=len(self.train_loader)) 152 | for step, (features, labels) in loop: 153 | features = features.to(self.device) 154 | batch_size = features.shape[0] 155 | 156 | # Algorithm 1 line 3: sample t uniformally for every example in the batch 157 | t = torch.randint(0, self.timesteps, (batch_size,), device=self.device).long() 158 | 159 | loss = model(mode="train", x_start=features, t=t, loss_type="huber") 160 | losses.append(loss) 161 | 162 | self.optimizer.zero_grad() 163 | loss.backward() 164 | self.optimizer.step() 165 | 166 | # 更新信息 167 | loop.set_description(f'Epoch [{i}/{self.epoches}]') 168 | loop.set_postfix(loss=loss.item()) 169 | 170 | if "model_save_path" in kwargs.keys(): 171 | self.save_best_model(model=model, path=kwargs["model_save_path"]) 172 | 173 | return model 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-DDPM 2 | 基于pytorch框架从零实现DDPM算法 3 | 4 | 5 | 详细的介绍可参考本人博客:https://zhuanlan.zhihu.com/p/617895786 6 | 7 | --------------------------------------------------------------------------------