├── README.md ├── adpm.py ├── ddcm.py ├── ddim.py ├── ddpm-gau.py ├── ddpm.py └── ddpm2.py /README.md: -------------------------------------------------------------------------------- 1 | # Keras-DDPM 2 | 生成扩散模型的Keras实现 3 | 4 | ## 介绍 5 | - 博客:https://kexue.fm/archives/9152 6 | - 博客:https://kexue.fm/archives/9119 7 | 8 | ## 说明 9 | - 模型主体依然式U-Net格式,但是经过自己的简化(如特征拼接改为相加、去掉了Attention等),加快了收敛速度 10 | - 在单张3090下,训练半天可以初见效果,训练3天的效果如下: 11 | 12 | 13 | ## 环境 14 | - tensorflow 1.15 15 | - keras 2.3.1 16 | - bert4keras (当前Github最新版本,不能用pip安装的版本) 17 | 18 | ## 要点 19 | - loss不能用mse 20 | - 归一化不能用BN 21 | - 步数t的编码可以直接用Embedding层 22 | 23 | ## 交流 24 | QQ交流群:808623966,微信群请加机器人微信号spaces_ac_cn 25 | -------------------------------------------------------------------------------- /adpm.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 生成扩散模型Analytic-DPM参考代码 3 | # 在DDIM上修改,不用改变训练,只修改采样过程的方差 4 | # 博客:https://kexue.fm/archives/9245 5 | 6 | # from ddpm import * # 加载训练好的模型 7 | from ddpm2 import * # 加载训练好的模型 8 | 9 | 10 | def data_generator(t=0): 11 | """图片读取 12 | """ 13 | batch_imgs = [] 14 | while True: 15 | for i in np.random.permutation(len(imgs)): 16 | batch_imgs.append(imread(imgs[i])) 17 | if len(batch_imgs) == batch_size: 18 | batch_imgs = np.array(batch_imgs) 19 | batch_steps = np.array([t] * batch_size) 20 | batch_bar_alpha = bar_alpha[batch_steps][:, None, None, None] 21 | batch_bar_beta = bar_beta[batch_steps][:, None, None, None] 22 | batch_noise = np.random.randn(*batch_imgs.shape) 23 | batch_noisy_imgs = batch_imgs * batch_bar_alpha + batch_noise * batch_bar_beta 24 | yield [batch_noisy_imgs, batch_steps[:, None]] 25 | batch_imgs = [] 26 | 27 | 28 | factors = [(model.predict(data_generator(t), steps=5)**2).mean() 29 | for t in tqdm(range(T), ncols=0)] # 用(batch_size * steps)个样本去估计方差修正项 30 | factors = np.clip(1 - np.array(factors), 0, 1) 31 | 32 | 33 | def sample(path=None, n=4, z_samples=None, stride=1, eta=1): 34 | """随机采样函数 35 | 注:eta控制方差的相对大小;stride空间跳跃 36 | """ 37 | # 采样参数 38 | bar_alpha_ = bar_alpha[::stride] 39 | bar_alpha_pre_ = np.pad(bar_alpha_[:-1], [1, 0], constant_values=1) 40 | bar_beta_ = np.sqrt(1 - bar_alpha_**2) 41 | bar_beta_pre_ = np.sqrt(1 - bar_alpha_pre_**2) 42 | alpha_ = bar_alpha_ / bar_alpha_pre_ 43 | sigma_ = bar_beta_pre_ / bar_beta_ * np.sqrt(1 - alpha_**2) * eta 44 | epsilon_ = bar_beta_ - alpha_ * np.sqrt(bar_beta_pre_**2 - sigma_**2) 45 | gamma_ = epsilon_ * bar_alpha_pre_ / bar_alpha_ # 增加代码 46 | sigma_ = np.sqrt(sigma_**2 + gamma_**2 * factors[::stride]) # 增加代码 47 | T_ = len(bar_alpha_) 48 | # 采样过程 49 | if z_samples is None: 50 | z_samples = np.random.randn(n**2, img_size, img_size, 3) 51 | else: 52 | z_samples = z_samples.copy() 53 | for t in tqdm(range(T_), ncols=0): 54 | t = T_ - t - 1 55 | bt = np.array([[t * stride]] * z_samples.shape[0]) 56 | z_samples -= epsilon_[t] * model.predict([z_samples, bt]) 57 | z_samples /= alpha_[t] 58 | z_samples += np.random.randn(*z_samples.shape) * sigma_[t] 59 | x_samples = np.clip(z_samples, -1, 1) 60 | if path is None: 61 | return x_samples 62 | figure = np.zeros((img_size * n, img_size * n, 3)) 63 | for i in range(n): 64 | for j in range(n): 65 | digit = x_samples[i * n + j] 66 | figure[i * img_size:(i + 1) * img_size, 67 | j * img_size:(j + 1) * img_size] = digit 68 | imwrite(path, figure) 69 | 70 | 71 | sample('test.png', n=8, stride=100, eta=1) 72 | -------------------------------------------------------------------------------- /ddcm.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # DDCM(Denoising Diffusion Codebook Models)参考代码 3 | # 在DDPM上修改,不用改变训练,只修改采样过程 4 | # 博客:https://kexue.fm/archives/9245 5 | 6 | from ddpm2 import * # 加噪训练好的模型 7 | 8 | K = 64 # 每步的Codebook大小 9 | codebook = np.random.randn(T + 1, K, img_size, img_size, 3) 10 | 11 | 12 | def sample(path, n=4): 13 | """随机采样函数 14 | """ 15 | z_samples = codebook[T][np.random.choice(K, size=n**2)] 16 | for t in tqdm(range(T), ncols=0): 17 | t = T - t - 1 18 | bt = np.array([[t]] * z_samples.shape[0]) 19 | z_samples -= beta[t]**2 / bar_beta[t] * model.predict([z_samples, bt]) 20 | z_samples /= alpha[t] 21 | z_samples += codebook[t][np.random.choice(K, size=n**2)] * sigma[t] 22 | x_samples = np.clip(z_samples, -1, 1) 23 | figure = np.zeros((img_size * n, img_size * n, 3)) 24 | for i in range(n): 25 | for j in range(n): 26 | digit = x_samples[i * n + j] 27 | figure[i * img_size:(i + 1) * img_size, 28 | j * img_size:(j + 1) * img_size] = digit 29 | imwrite(path, figure) 30 | 31 | 32 | def encode(path, n=4): 33 | """随机选一些图片,进行编码和重构 34 | """ 35 | x_samples = [imread(f) for f in np.random.choice(imgs, n**2)] 36 | z_samples = np.repeat(codebook[T][:1], n**2, axis=0) 37 | for t in tqdm(range(T), ncols=0): 38 | t = T - t - 1 39 | bt = np.array([[t]] * z_samples.shape[0]) 40 | mp = model.predict([z_samples, bt]) 41 | x0 = (z_samples - bar_beta[t] * mp) / bar_alpha[t] 42 | sims = np.einsum('kuwv,buwv->kb', codebook[t], x_samples - x0) 43 | idxs = sims.argmax(0) 44 | z_samples -= beta[t]**2 / bar_beta[t] * mp 45 | z_samples /= alpha[t] 46 | z_samples += codebook[t][idxs] * sigma[t] 47 | z_samples = np.clip(z_samples, -1, 1) 48 | figure = np.zeros((img_size * n, img_size * n * 2, 3)) 49 | for i in range(n): 50 | for j in range(n): 51 | digit = x_samples[i * n + j] 52 | figure[i * img_size:(i + 1) * img_size, 53 | 2 * j * img_size:(2 * j + 1) * img_size] = digit 54 | digit = z_samples[i * n + j] 55 | figure[i * img_size:(i + 1) * img_size, 56 | (2 * j + 1) * img_size:(2 * j + 2) * img_size] = digit 57 | imwrite(path, figure) 58 | 59 | 60 | sample(f'test1.png') 61 | encode(f'test2.png') 62 | -------------------------------------------------------------------------------- /ddim.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 生成扩散模型DDIM参考代码 3 | # DDIM不用改变训练,只修改采样过程 4 | # 博客:https://kexue.fm/archives/9181 5 | 6 | # from ddpm import * # 加载训练好的模型 7 | from ddpm2 import * # 加载训练好的模型 8 | 9 | 10 | def sample(path=None, n=4, z_samples=None, stride=1, eta=1): 11 | """随机采样函数 12 | 注:eta控制方差的相对大小;stride空间跳跃 13 | """ 14 | # 采样参数 15 | bar_alpha_ = bar_alpha[::stride] 16 | bar_alpha_pre_ = np.pad(bar_alpha_[:-1], [1, 0], constant_values=1) 17 | bar_beta_ = np.sqrt(1 - bar_alpha_**2) 18 | bar_beta_pre_ = np.sqrt(1 - bar_alpha_pre_**2) 19 | alpha_ = bar_alpha_ / bar_alpha_pre_ 20 | sigma_ = bar_beta_pre_ / bar_beta_ * np.sqrt(1 - alpha_**2) * eta 21 | epsilon_ = bar_beta_ - alpha_ * np.sqrt(bar_beta_pre_**2 - sigma_**2) 22 | T_ = len(bar_alpha_) 23 | # 采样过程 24 | if z_samples is None: 25 | z_samples = np.random.randn(n**2, img_size, img_size, 3) 26 | else: 27 | z_samples = z_samples.copy() 28 | for t in tqdm(range(T_), ncols=0): 29 | t = T_ - t - 1 30 | bt = np.array([[t * stride]] * z_samples.shape[0]) 31 | z_samples -= epsilon_[t] * model.predict([z_samples, bt]) 32 | z_samples /= alpha_[t] 33 | z_samples += np.random.randn(*z_samples.shape) * sigma_[t] 34 | x_samples = np.clip(z_samples, -1, 1) 35 | if path is None: 36 | return x_samples 37 | figure = np.zeros((img_size * n, img_size * n, 3)) 38 | for i in range(n): 39 | for j in range(n): 40 | digit = x_samples[i * n + j] 41 | figure[i * img_size:(i + 1) * img_size, 42 | j * img_size:(j + 1) * img_size] = digit 43 | imwrite(path, figure) 44 | 45 | 46 | def sample_inter(path, n=4, k=8, stride=1): 47 | """随机采样插值函数 48 | 说明:随机选择两个随机向量进行球面均匀插值,然后生成对应的结果。 49 | """ 50 | figure = np.ones((img_size * n, img_size * k, 3)) 51 | Z = np.random.randn(n * 2, img_size, img_size, 3) 52 | z_samples = [] 53 | for i in range(n): 54 | for j in range(k): 55 | theta = np.pi / 2 * j / (k - 1) 56 | z = Z[2 * i] * np.sin(theta) + Z[2 * i + 1] * np.cos(theta) 57 | z_samples.append(z) 58 | x_samples = sample(z_samples=np.array(z_samples), stride=stride, eta=0) 59 | for i in range(n): 60 | for j in range(k): 61 | ij = i * k + j 62 | figure[i * img_size:(i + 1) * img_size, 63 | img_size * j:img_size * (j + 1)] = x_samples[ij] 64 | imwrite(path, figure) 65 | 66 | 67 | sample('test.png', n=4, stride=100, eta=0) 68 | sample_inter('test_inter.png', n=8, k=15, stride=20) 69 | -------------------------------------------------------------------------------- /ddpm-gau.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 生成扩散模型DDPM参考代码 3 | # 用了Pre Norm GAU架构 4 | # 实验环境:tf 1.15 + keras 2.3.1 + bert4keras 5 | # 参考:https://kexue.fm/archives/9984 6 | 7 | import os 8 | import cv2 9 | import numpy as np 10 | import tensorflow as tf 11 | from keras import backend as K 12 | from keras.models import Model 13 | from keras.layers import * 14 | from keras.callbacks import Callback 15 | from keras_preprocessing.image import list_pictures 16 | from bert4keras.layers import * 17 | from bert4keras.optimizers import Adam 18 | from bert4keras.optimizers import extend_with_layer_adaptation 19 | from bert4keras.optimizers import extend_with_piecewise_linear_lr 20 | from bert4keras.optimizers import extend_with_exponential_moving_average 21 | from tqdm import tqdm 22 | import warnings 23 | 24 | warnings.filterwarnings("ignore") # 忽略keras带来的满屏警告 25 | 26 | if not os.path.exists('samples'): 27 | os.mkdir('samples') 28 | 29 | # 基本配置 30 | imgs = list_pictures('/mnt/vepfs/sujianlin/CelebA-HQ/train/', 'png') 31 | imgs += list_pictures('/mnt/vepfs/sujianlin/CelebA-HQ/valid/', 'png') 32 | np.random.shuffle(imgs) 33 | img_size = 128 # 如果只想快速实验,可以改为64 34 | batch_size = 64 # 如果显存不够,可以降低为32、16,但不建议低于16 35 | hidden_size = 768 36 | num_layers = 24 37 | 38 | # 超参数选择 39 | T = 1000 40 | alpha = np.sqrt(1 - 0.02 * np.arange(1, T + 1) / T) 41 | beta = np.sqrt(1 - alpha**2) 42 | bar_alpha = np.cumprod(alpha) 43 | bar_beta = np.sqrt(1 - bar_alpha**2) 44 | sigma = beta.copy() 45 | # sigma *= np.pad(bar_beta[:-1], [1, 0]) / bar_beta 46 | 47 | 48 | def imread(f, crop_size=None): 49 | """读取图片 50 | """ 51 | x = cv2.imread(f) 52 | height, width = x.shape[:2] 53 | if crop_size is None: 54 | crop_size = min([height, width]) 55 | else: 56 | crop_size = min([crop_size, height, width]) 57 | height_x = (height - crop_size + 1) // 2 58 | width_x = (width - crop_size + 1) // 2 59 | x = x[height_x:height_x + crop_size, width_x:width_x + crop_size] 60 | if x.shape[:2] != (img_size, img_size): 61 | x = cv2.resize(x, (img_size, img_size)) 62 | x = x.astype('float32') 63 | x = x / 255 * 2 - 1 64 | return x 65 | 66 | 67 | def imwrite(path, figure): 68 | """归一化到了[-1, 1]的图片矩阵保存为图片 69 | """ 70 | figure = (figure + 1) / 2 * 255 71 | figure = np.round(figure, 0).astype('uint8') 72 | cv2.imwrite(path, figure) 73 | 74 | 75 | def data_generator(): 76 | """图片读取 77 | """ 78 | batch_imgs = [] 79 | while True: 80 | for i in np.random.permutation(len(imgs)): 81 | batch_imgs.append(imread(imgs[i])) 82 | if len(batch_imgs) == batch_size: 83 | batch_imgs = np.array(batch_imgs) 84 | batch_steps = np.random.choice(T, batch_size) 85 | batch_bar_alpha = bar_alpha[batch_steps][:, None, None, None] 86 | batch_bar_beta = bar_beta[batch_steps][:, None, None, None] 87 | batch_noise = np.random.randn(*batch_imgs.shape) 88 | batch_noisy_imgs = batch_imgs * batch_bar_alpha + batch_noise * batch_bar_beta 89 | yield [batch_noisy_imgs, batch_steps[:, None]], batch_noise 90 | batch_imgs = [] 91 | 92 | 93 | def rope_2d(x): 94 | """2D-RoPE 95 | """ 96 | w = img_size // 8 97 | pos = K.arange(0, w**2, dtype='float32') 98 | pos1, pos2 = pos // w, pos % w 99 | pos1 = sinusoidal_embeddings(pos1, 64, 1000) 100 | pos2 = sinusoidal_embeddings(pos2, 64, 1000) 101 | return K.concatenate([pos1, pos2], 1)[None] 102 | 103 | 104 | def l2_loss(y_true, y_pred): 105 | """用l2距离为损失,不能用mse代替 106 | """ 107 | return K.sum((y_true - y_pred)**2, axis=[1, 2, 3]) 108 | 109 | 110 | # 搭建去噪模型 111 | x_in = x = Input(shape=(img_size, img_size, 3)) 112 | x = Reshape((img_size // 8, 8, img_size // 8, 8, 3))(x) 113 | x = Permute((1, 3, 2, 4, 5))(x) 114 | x = Reshape((img_size**2 // 64, 192))(x) 115 | x = Dense(hidden_size, use_bias=False)(x) 116 | 117 | t_in = Input(shape=(1,)) 118 | t = Embedding(input_dim=T, output_dim=hidden_size)(t_in) 119 | 120 | x = Add()([x, t]) 121 | p = Lambda(rope_2d)(x) 122 | 123 | for i in range(num_layers): 124 | xi = x 125 | x = LayerNormalization(zero_mean=False, offset=False)(x) 126 | x = GatedAttentionUnit(hidden_size * 2, 128, normalization='softmax')([x, p], p_bias='rotary') 127 | x = Add()([xi, x]) 128 | 129 | x = LayerNormalization(zero_mean=False, offset=False)(x) 130 | x = Dense(192, use_bias=False)(x) 131 | x = Reshape((img_size // 8, img_size // 8, 8, 8, 3))(x) 132 | x = Permute((1, 3, 2, 4, 5))(x) 133 | x = Reshape((img_size, img_size, 3))(x) 134 | 135 | model = Model(inputs=[x_in, t_in], outputs=x) 136 | model.summary() 137 | 138 | OPT = extend_with_layer_adaptation(Adam) 139 | OPT = extend_with_piecewise_linear_lr(OPT) # 此时就是LAMB优化器 140 | OPT = extend_with_exponential_moving_average(OPT) # 加上滑动平均 141 | optimizer = OPT( 142 | learning_rate=1e-3, 143 | ema_momentum=0.9999, 144 | exclude_from_layer_adaptation=['Norm', 'bias'], 145 | lr_schedule={ 146 | 4000: 1, # Warmup步数 147 | 20000: 0.5, 148 | 40000: 0.1, 149 | } 150 | ) 151 | model.compile(loss=l2_loss, optimizer=optimizer) 152 | 153 | 154 | def sample(path=None, n=4, z_samples=None, t0=0): 155 | """随机采样函数 156 | """ 157 | if z_samples is None: 158 | z_samples = np.random.randn(n**2, img_size, img_size, 3) 159 | else: 160 | z_samples = z_samples.copy() 161 | for t in tqdm(range(t0, T), ncols=0): 162 | t = T - t - 1 163 | bt = np.array([[t]] * z_samples.shape[0]) 164 | z_samples -= beta[t]**2 / bar_beta[t] * model.predict([z_samples, bt]) 165 | z_samples /= alpha[t] 166 | z_samples += np.random.randn(*z_samples.shape) * sigma[t] 167 | x_samples = np.clip(z_samples, -1, 1) 168 | if path is None: 169 | return x_samples 170 | figure = np.zeros((img_size * n, img_size * n, 3)) 171 | for i in range(n): 172 | for j in range(n): 173 | digit = x_samples[i * n + j] 174 | figure[i * img_size:(i + 1) * img_size, 175 | j * img_size:(j + 1) * img_size] = digit 176 | imwrite(path, figure) 177 | 178 | 179 | def sample_inter(path, n=4, k=8, sep=10, t0=500): 180 | """随机采样插值函数 181 | """ 182 | figure = np.ones((img_size * n, img_size * (k + 2) + sep * 2, 3)) 183 | x_samples = [imread(f) for f in np.random.choice(imgs, n * 2)] 184 | X = [] 185 | for i in range(n): 186 | figure[i * img_size:(i + 1) * img_size, :img_size] = x_samples[2 * i] 187 | figure[i * img_size:(i + 1) * img_size, 188 | -img_size:] = x_samples[2 * i + 1] 189 | for j in range(k): 190 | lamb = 1. * j / (k - 1) 191 | x = x_samples[2 * i] * (1 - lamb) + x_samples[2 * i + 1] * lamb 192 | X.append(x) 193 | x_samples = np.array(X) * bar_alpha[t0] 194 | x_samples += np.random.randn(*x_samples.shape) * bar_beta[t0] 195 | x_rec_samples = sample(z_samples=x_samples, t0=t0) 196 | for i in range(n): 197 | for j in range(k): 198 | ij = i * k + j 199 | figure[i * img_size:(i + 1) * img_size, img_size * (j + 1) + 200 | sep:img_size * (j + 2) + sep] = x_rec_samples[ij] 201 | imwrite(path, figure) 202 | 203 | 204 | class Trainer(Callback): 205 | """训练回调器 206 | """ 207 | def on_epoch_end(self, epoch, logs=None): 208 | model.save_weights('model.weights') 209 | sample('samples/%05d.png' % (epoch + 1)) 210 | optimizer.apply_ema_weights() 211 | model.save_weights('model.ema.weights') 212 | sample('samples/%05d_ema.png' % (epoch + 1)) 213 | optimizer.reset_old_weights() 214 | 215 | 216 | if __name__ == '__main__': 217 | 218 | trainer = Trainer() 219 | model.fit( 220 | data_generator(), 221 | steps_per_epoch=2000, 222 | epochs=10000, # 只是预先设置足够多的epoch数,可以自行Ctrl+C中断 223 | callbacks=[trainer] 224 | ) 225 | 226 | else: 227 | 228 | model.load_weights('model.ema.weights') 229 | -------------------------------------------------------------------------------- /ddpm.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 生成扩散模型DDPM参考代码 3 | # U-Net结构经过笔者修改,降低了计算量 4 | # 实验环境:tf 1.15 + keras 2.3.1 + bert4keras(当前Github最新版本,不能用pip安装的版本) 5 | # 博客:https://kexue.fm/archives/9152 6 | 7 | import os 8 | import cv2 9 | import numpy as np 10 | import tensorflow as tf 11 | from keras import backend as K 12 | from keras.models import Model 13 | from keras.layers import * 14 | from keras.callbacks import Callback 15 | from keras.initializers import VarianceScaling 16 | from keras_preprocessing.image import list_pictures 17 | from bert4keras.layers import ScaleOffset 18 | from bert4keras.optimizers import Adam 19 | from bert4keras.optimizers import extend_with_layer_adaptation 20 | from bert4keras.optimizers import extend_with_piecewise_linear_lr 21 | from bert4keras.optimizers import extend_with_exponential_moving_average 22 | from tqdm import tqdm 23 | import warnings 24 | 25 | warnings.filterwarnings("ignore") # 忽略keras带来的满屏警告 26 | 27 | if not os.path.exists('samples'): 28 | os.mkdir('samples') 29 | 30 | # 基本配置 31 | imgs = list_pictures('/root/CelebA-HQ/train/', 'png') 32 | imgs += list_pictures('/root/CelebA-HQ/valid/', 'png') 33 | np.random.shuffle(imgs) 34 | img_size = 128 # 如果只想快速实验,可以改为64 35 | batch_size = 64 # 如果显存不够,可以降低为32、16,但不建议低于16 36 | embedding_size = 128 37 | channels = [1, 1, 2, 2, 4, 4] 38 | num_layers = len(channels) * 2 + 1 39 | blocks = 2 # 如果显存不够,可以降低为1 40 | min_pixel = 4 # 不建议降低,显存足够可以增加到8 41 | 42 | # 超参数选择 43 | T = 1000 44 | alpha = np.sqrt(1 - 0.02 * np.arange(1, T + 1) / T) 45 | beta = np.sqrt(1 - alpha**2) 46 | bar_alpha = np.cumprod(alpha) 47 | bar_beta = np.sqrt(1 - bar_alpha**2) 48 | sigma = beta.copy() 49 | # sigma *= np.pad(bar_beta[:-1], [1, 0]) / bar_beta 50 | 51 | 52 | def imread(f, crop_size=None): 53 | """读取图片 54 | """ 55 | x = cv2.imread(f) 56 | height, width = x.shape[:2] 57 | if crop_size is None: 58 | crop_size = min([height, width]) 59 | else: 60 | crop_size = min([crop_size, height, width]) 61 | height_x = (height - crop_size + 1) // 2 62 | width_x = (width - crop_size + 1) // 2 63 | x = x[height_x:height_x + crop_size, width_x:width_x + crop_size] 64 | if x.shape[:2] != (img_size, img_size): 65 | x = cv2.resize(x, (img_size, img_size)) 66 | x = x.astype('float32') 67 | x = x / 255 * 2 - 1 68 | return x 69 | 70 | 71 | def imwrite(path, figure): 72 | """归一化到了[-1, 1]的图片矩阵保存为图片 73 | """ 74 | figure = (figure + 1) / 2 * 255 75 | figure = np.round(figure, 0).astype('uint8') 76 | cv2.imwrite(path, figure) 77 | 78 | 79 | def data_generator(): 80 | """图片读取 81 | """ 82 | batch_imgs = [] 83 | while True: 84 | for i in np.random.permutation(len(imgs)): 85 | batch_imgs.append(imread(imgs[i])) 86 | if len(batch_imgs) == batch_size: 87 | batch_imgs = np.array(batch_imgs) 88 | batch_steps = np.random.choice(T, batch_size) 89 | batch_bar_alpha = bar_alpha[batch_steps][:, None, None, None] 90 | batch_bar_beta = bar_beta[batch_steps][:, None, None, None] 91 | batch_noise = np.random.randn(*batch_imgs.shape) 92 | batch_noisy_imgs = batch_imgs * batch_bar_alpha + batch_noise * batch_bar_beta 93 | yield [batch_noisy_imgs, batch_steps[:, None]], batch_noise 94 | batch_imgs = [] 95 | 96 | 97 | class GroupNorm(ScaleOffset): 98 | """定义GroupNorm,默认groups=32 99 | """ 100 | def call(self, inputs): 101 | inputs = K.reshape(inputs, (-1, 32), -1) 102 | mean, variance = tf.nn.moments(inputs, axes=[1, 2, 3], keepdims=True) 103 | inputs = (inputs - mean) * tf.rsqrt(variance + 1e-6) 104 | inputs = K.flatten(inputs, -2) 105 | return super(GroupNorm, self).call(inputs) 106 | 107 | 108 | def dense(x, out_dim, activation=None, init_scale=1): 109 | """Dense包装 110 | """ 111 | init_scale = max(init_scale, 1e-10) 112 | initializer = VarianceScaling(init_scale, 'fan_avg', 'uniform') 113 | return Dense( 114 | out_dim, 115 | activation=activation, 116 | use_bias=False, 117 | kernel_initializer=initializer 118 | )(x) 119 | 120 | 121 | def conv2d(x, out_dim, activation=None, init_scale=1): 122 | """Conv2D包装 123 | """ 124 | init_scale = max(init_scale, 1e-10) 125 | initializer = VarianceScaling(init_scale, 'fan_avg', 'uniform') 126 | return Conv2D( 127 | out_dim, (3, 3), 128 | padding='same', 129 | activation=activation, 130 | use_bias=False, 131 | kernel_initializer=initializer 132 | )(x) 133 | 134 | 135 | def residual_block(x, ch, t): 136 | """残差block 137 | """ 138 | in_dim = K.int_shape(x)[-1] 139 | out_dim = ch * embedding_size 140 | if in_dim == out_dim: 141 | xi = x 142 | else: 143 | xi = dense(x, out_dim) 144 | x = Add()([x, dense(t, K.int_shape(x)[-1])]) 145 | x = conv2d(x, out_dim, 'swish', 1 / num_layers**0.5) 146 | x = conv2d(x, out_dim, 'swish', 1 / num_layers**0.5) 147 | x = Add()([x, xi]) 148 | x = GroupNorm()(x) 149 | return x 150 | 151 | 152 | def l2_loss(y_true, y_pred): 153 | """用l2距离为损失,不能用mse代替 154 | """ 155 | return K.sum((y_true - y_pred)**2, axis=[1, 2, 3]) 156 | 157 | 158 | # 搭建去噪模型 159 | x_in = x = Input(shape=(img_size, img_size, 3)) 160 | t_in = Input(shape=(1,)) 161 | t = Embedding(input_dim=T, output_dim=embedding_size)(t_in) 162 | t = Lambda(lambda t: t[:, None])(t) 163 | 164 | x = conv2d(x, embedding_size) 165 | inputs, skip_pooling = [x], 0 166 | 167 | for i, ch in enumerate(channels): 168 | for j in range(blocks): 169 | x = residual_block(x, ch, t) 170 | inputs.append(x) 171 | if min(K.int_shape(x)[1:3]) > min_pixel: 172 | x = AveragePooling2D((2, 2))(x) 173 | inputs.append(x) 174 | else: 175 | skip_pooling += 1 176 | 177 | x = residual_block(x, ch, t) 178 | inputs.pop() 179 | 180 | for i, ch in enumerate(channels[::-1]): 181 | if i >= skip_pooling: 182 | x = UpSampling2D((2, 2))(x) 183 | x = Add()([x, inputs.pop()]) 184 | for j in range(blocks): 185 | xi = inputs.pop() 186 | x = residual_block(x, K.int_shape(xi)[-1] // embedding_size, t) 187 | x = Add()([x, xi]) 188 | 189 | x = GroupNorm()(x) 190 | x = conv2d(x, 3) 191 | 192 | model = Model(inputs=[x_in, t_in], outputs=x) 193 | model.summary() 194 | 195 | OPT = extend_with_layer_adaptation(Adam) 196 | OPT = extend_with_piecewise_linear_lr(OPT) # 此时就是LAMB优化器 197 | OPT = extend_with_exponential_moving_average(OPT) # 加上滑动平均 198 | optimizer = OPT( 199 | learning_rate=1e-3, 200 | ema_momentum=0.9999, 201 | exclude_from_layer_adaptation=['Norm', 'bias'], 202 | lr_schedule={ 203 | 4000: 1, # Warmup步数 204 | 20000: 0.5, 205 | 40000: 0.1, 206 | } 207 | ) 208 | model.compile(loss=l2_loss, optimizer=optimizer) 209 | 210 | 211 | def sample(path=None, n=4, z_samples=None, t0=0): 212 | """随机采样函数 213 | """ 214 | if z_samples is None: 215 | z_samples = np.random.randn(n**2, img_size, img_size, 3) 216 | else: 217 | z_samples = z_samples.copy() 218 | for t in tqdm(range(t0, T), ncols=0): 219 | t = T - t - 1 220 | bt = np.array([[t]] * z_samples.shape[0]) 221 | z_samples -= beta[t]**2 / bar_beta[t] * model.predict([z_samples, bt]) 222 | z_samples /= alpha[t] 223 | z_samples += np.random.randn(*z_samples.shape) * sigma[t] 224 | x_samples = np.clip(z_samples, -1, 1) 225 | if path is None: 226 | return x_samples 227 | figure = np.zeros((img_size * n, img_size * n, 3)) 228 | for i in range(n): 229 | for j in range(n): 230 | digit = x_samples[i * n + j] 231 | figure[i * img_size:(i + 1) * img_size, 232 | j * img_size:(j + 1) * img_size] = digit 233 | imwrite(path, figure) 234 | 235 | 236 | def sample_inter(path, n=4, k=8, sep=10, t0=500): 237 | """随机采样插值函数 238 | """ 239 | figure = np.ones((img_size * n, img_size * (k + 2) + sep * 2, 3)) 240 | x_samples = [imread(f) for f in np.random.choice(imgs, n * 2)] 241 | X = [] 242 | for i in range(n): 243 | figure[i * img_size:(i + 1) * img_size, :img_size] = x_samples[2 * i] 244 | figure[i * img_size:(i + 1) * img_size, 245 | -img_size:] = x_samples[2 * i + 1] 246 | for j in range(k): 247 | lamb = 1. * j / (k - 1) 248 | x = x_samples[2 * i] * (1 - lamb) + x_samples[2 * i + 1] * lamb 249 | X.append(x) 250 | x_samples = np.array(X) * bar_alpha[t0] 251 | x_samples += np.random.randn(*x_samples.shape) * bar_beta[t0] 252 | x_rec_samples = sample(z_samples=x_samples, t0=t0) 253 | for i in range(n): 254 | for j in range(k): 255 | ij = i * k + j 256 | figure[i * img_size:(i + 1) * img_size, img_size * (j + 1) + 257 | sep:img_size * (j + 2) + sep] = x_rec_samples[ij] 258 | imwrite(path, figure) 259 | 260 | 261 | class Trainer(Callback): 262 | """训练回调器 263 | """ 264 | def on_epoch_end(self, epoch, logs=None): 265 | model.save_weights('model.weights') 266 | sample('samples/%05d.png' % (epoch + 1)) 267 | optimizer.apply_ema_weights() 268 | model.save_weights('model.ema.weights') 269 | sample('samples/%05d_ema.png' % (epoch + 1)) 270 | optimizer.reset_old_weights() 271 | 272 | 273 | if __name__ == '__main__': 274 | 275 | trainer = Trainer() 276 | model.fit( 277 | data_generator(), 278 | steps_per_epoch=2000, 279 | epochs=10000, # 只是预先设置足够多的epoch数,可以自行Ctrl+C中断 280 | callbacks=[trainer] 281 | ) 282 | 283 | else: 284 | 285 | model.load_weights('model.ema.weights') 286 | -------------------------------------------------------------------------------- /ddpm2.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 生成扩散模型DDPM参考代码2 3 | # 这版U-Net结构尽量保持跟原版一致(除了没加Attention),效果相对更好,计算量也更大 4 | # 实验环境:tf 1.15 + keras 2.3.1 + bert4keras(当前Github最新版本,不能用pip安装的版本) 5 | # 博客:https://kexue.fm/archives/9152 6 | 7 | import os 8 | import cv2 9 | import numpy as np 10 | import tensorflow as tf 11 | from keras import backend as K 12 | from keras.models import Model 13 | from keras.layers import * 14 | from keras.callbacks import Callback 15 | from keras.initializers import VarianceScaling 16 | from keras_preprocessing.image import list_pictures 17 | from bert4keras.layers import ScaleOffset 18 | from bert4keras.optimizers import Adam 19 | from bert4keras.optimizers import extend_with_layer_adaptation 20 | from bert4keras.optimizers import extend_with_piecewise_linear_lr 21 | from bert4keras.optimizers import extend_with_exponential_moving_average 22 | from tqdm import tqdm 23 | import warnings 24 | 25 | warnings.filterwarnings("ignore") # 忽略keras带来的满屏警告 26 | 27 | if not os.path.exists('samples'): 28 | os.mkdir('samples') 29 | 30 | # 基本配置 31 | imgs = list_pictures('/root/CelebA-HQ/train/', 'png') 32 | imgs += list_pictures('/root/CelebA-HQ/valid/', 'png') 33 | np.random.shuffle(imgs) 34 | img_size = 128 # 如果只想快速实验,可以改为64 35 | batch_size = 32 # 如果显存不够,可以降低为16,但不建议低于16 36 | embedding_size = 128 37 | channels = [1, 1, 2, 2, 4, 4] 38 | blocks = 2 # 如果显存不够,可以降低为1 39 | 40 | # 超参数选择 41 | T = 1000 42 | alpha = np.sqrt(1 - 0.02 * np.arange(1, T + 1) / T) 43 | beta = np.sqrt(1 - alpha**2) 44 | bar_alpha = np.cumprod(alpha) 45 | bar_beta = np.sqrt(1 - bar_alpha**2) 46 | sigma = beta.copy() 47 | # sigma *= np.pad(bar_beta[:-1], [1, 0]) / bar_beta 48 | 49 | 50 | def imread(f, crop_size=None): 51 | """读取图片 52 | """ 53 | x = cv2.imread(f) 54 | height, width = x.shape[:2] 55 | if crop_size is None: 56 | crop_size = min([height, width]) 57 | else: 58 | crop_size = min([crop_size, height, width]) 59 | height_x = (height - crop_size + 1) // 2 60 | width_x = (width - crop_size + 1) // 2 61 | x = x[height_x:height_x + crop_size, width_x:width_x + crop_size] 62 | if x.shape[:2] != (img_size, img_size): 63 | x = cv2.resize(x, (img_size, img_size)) 64 | x = x.astype('float32') 65 | x = x / 255 * 2 - 1 66 | return x 67 | 68 | 69 | def imwrite(path, figure): 70 | """归一化到了[-1, 1]的图片矩阵保存为图片 71 | """ 72 | figure = (figure + 1) / 2 * 255 73 | figure = np.round(figure, 0).astype('uint8') 74 | cv2.imwrite(path, figure) 75 | 76 | 77 | def data_generator(): 78 | """图片读取 79 | """ 80 | batch_imgs = [] 81 | while True: 82 | for i in np.random.permutation(len(imgs)): 83 | batch_imgs.append(imread(imgs[i])) 84 | if len(batch_imgs) == batch_size: 85 | batch_imgs = np.array(batch_imgs) 86 | batch_steps = np.random.choice(T, batch_size) 87 | batch_bar_alpha = bar_alpha[batch_steps][:, None, None, None] 88 | batch_bar_beta = bar_beta[batch_steps][:, None, None, None] 89 | batch_noise = np.random.randn(*batch_imgs.shape) 90 | batch_noisy_imgs = batch_imgs * batch_bar_alpha + batch_noise * batch_bar_beta 91 | yield [batch_noisy_imgs, batch_steps[:, None]], batch_noise 92 | batch_imgs = [] 93 | 94 | 95 | class GroupNorm(ScaleOffset): 96 | """定义GroupNorm,默认groups=32 97 | """ 98 | def call(self, inputs): 99 | inputs = K.reshape(inputs, (-1, 32), -1) 100 | mean, variance = tf.nn.moments(inputs, axes=[1, 2, 3], keepdims=True) 101 | inputs = (inputs - mean) * tf.rsqrt(variance + 1e-6) 102 | inputs = K.flatten(inputs, -2) 103 | return super(GroupNorm, self).call(inputs) 104 | 105 | 106 | def dense(x, out_dim, activation=None, init_scale=1): 107 | """Dense包装 108 | """ 109 | init_scale = max(init_scale, 1e-10) 110 | initializer = VarianceScaling(init_scale, 'fan_avg', 'uniform') 111 | return Dense( 112 | out_dim, 113 | activation=activation, 114 | use_bias=False, 115 | kernel_initializer=initializer 116 | )(x) 117 | 118 | 119 | def conv2d(x, out_dim, activation=None, init_scale=1): 120 | """Conv2D包装 121 | """ 122 | init_scale = max(init_scale, 1e-10) 123 | initializer = VarianceScaling(init_scale, 'fan_avg', 'uniform') 124 | return Conv2D( 125 | out_dim, (3, 3), 126 | padding='same', 127 | activation=activation, 128 | use_bias=False, 129 | kernel_initializer=initializer 130 | )(x) 131 | 132 | 133 | def residual_block(x, ch, t): 134 | """残差block 135 | """ 136 | in_dim = K.int_shape(x)[-1] 137 | out_dim = ch * embedding_size 138 | if in_dim == out_dim: 139 | xi = x 140 | else: 141 | xi = dense(x, out_dim) 142 | x = GroupNorm()(x) 143 | x = Activation('swish')(x) 144 | x = conv2d(x, out_dim) 145 | x = Add()([x, dense(t, K.int_shape(x)[-1])]) 146 | x = GroupNorm()(x) 147 | x = Activation('swish')(x) 148 | x = conv2d(x, out_dim, None, 0) 149 | x = Add()([x, xi]) 150 | return x 151 | 152 | 153 | def l2_loss(y_true, y_pred): 154 | """用l2距离为损失,不能用mse代替 155 | """ 156 | return K.sum((y_true - y_pred)**2, axis=[1, 2, 3]) 157 | 158 | 159 | # 搭建去噪模型 160 | x_in = x = Input(shape=(img_size, img_size, 3)) 161 | t_in = Input(shape=(1,)) 162 | t = Embedding( 163 | input_dim=T, 164 | output_dim=embedding_size, 165 | embeddings_initializer='Sinusoidal', 166 | trainable=False 167 | )(t_in) 168 | t = dense(t, embedding_size * 4, 'swish') 169 | t = dense(t, embedding_size * 4, 'swish') 170 | t = Lambda(lambda t: t[:, None])(t) 171 | 172 | x = conv2d(x, embedding_size) 173 | inputs = [x] 174 | 175 | for i, ch in enumerate(channels): 176 | for j in range(blocks): 177 | x = residual_block(x, ch, t) 178 | inputs.append(x) 179 | if i != len(channels) - 1: 180 | x = AveragePooling2D((2, 2))(x) 181 | inputs.append(x) 182 | 183 | x = residual_block(x, ch, t) 184 | 185 | for i, ch in enumerate(channels[::-1]): 186 | for j in range(blocks + 1): 187 | x = Concatenate()([x, inputs.pop()]) 188 | x = residual_block(x, ch, t) 189 | if i != len(channels) - 1: 190 | x = UpSampling2D((2, 2))(x) 191 | 192 | x = GroupNorm()(x) 193 | x = Activation('swish')(x) 194 | x = conv2d(x, 3) 195 | 196 | model = Model(inputs=[x_in, t_in], outputs=x) 197 | model.summary() 198 | 199 | OPT = extend_with_layer_adaptation(Adam) 200 | OPT = extend_with_piecewise_linear_lr(OPT) # 此时就是LAMB优化器 201 | OPT = extend_with_exponential_moving_average(OPT) # 加上滑动平均 202 | optimizer = OPT( 203 | learning_rate=1e-3, 204 | ema_momentum=0.9999, 205 | exclude_from_layer_adaptation=['Norm', 'bias'], 206 | lr_schedule={ 207 | 4000: 1, # Warmup步数 208 | 20000: 0.5, 209 | 40000: 0.1, 210 | } 211 | ) 212 | model.compile(loss=l2_loss, optimizer=optimizer) 213 | 214 | 215 | def sample(path=None, n=4, z_samples=None, t0=0): 216 | """随机采样函数 217 | """ 218 | if z_samples is None: 219 | z_samples = np.random.randn(n**2, img_size, img_size, 3) 220 | else: 221 | z_samples = z_samples.copy() 222 | for t in tqdm(range(t0, T), ncols=0): 223 | t = T - t - 1 224 | bt = np.array([[t]] * z_samples.shape[0]) 225 | z_samples -= beta[t]**2 / bar_beta[t] * model.predict([z_samples, bt]) 226 | z_samples /= alpha[t] 227 | z_samples += np.random.randn(*z_samples.shape) * sigma[t] 228 | x_samples = np.clip(z_samples, -1, 1) 229 | if path is None: 230 | return x_samples 231 | figure = np.zeros((img_size * n, img_size * n, 3)) 232 | for i in range(n): 233 | for j in range(n): 234 | digit = x_samples[i * n + j] 235 | figure[i * img_size:(i + 1) * img_size, 236 | j * img_size:(j + 1) * img_size] = digit 237 | imwrite(path, figure) 238 | 239 | 240 | def sample_inter(path, n=4, k=8, sep=10, t0=500): 241 | """随机采样插值函数 242 | """ 243 | figure = np.ones((img_size * n, img_size * (k + 2) + sep * 2, 3)) 244 | x_samples = [imread(f) for f in np.random.choice(imgs, n * 2)] 245 | X = [] 246 | for i in range(n): 247 | figure[i * img_size:(i + 1) * img_size, :img_size] = x_samples[2 * i] 248 | figure[i * img_size:(i + 1) * img_size, 249 | -img_size:] = x_samples[2 * i + 1] 250 | for j in range(k): 251 | lamb = 1. * j / (k - 1) 252 | x = x_samples[2 * i] * (1 - lamb) + x_samples[2 * i + 1] * lamb 253 | X.append(x) 254 | x_samples = np.array(X) * bar_alpha[t0] 255 | x_samples += np.random.randn(*x_samples.shape) * bar_beta[t0] 256 | x_rec_samples = sample(z_samples=x_samples, t0=t0) 257 | for i in range(n): 258 | for j in range(k): 259 | ij = i * k + j 260 | figure[i * img_size:(i + 1) * img_size, img_size * (j + 1) + 261 | sep:img_size * (j + 2) + sep] = x_rec_samples[ij] 262 | imwrite(path, figure) 263 | 264 | 265 | class Trainer(Callback): 266 | """训练回调器 267 | """ 268 | def on_epoch_end(self, epoch, logs=None): 269 | model.save_weights('model.weights') 270 | sample('samples/%05d.png' % (epoch + 1)) 271 | optimizer.apply_ema_weights() 272 | model.save_weights('model.ema.weights') 273 | sample('samples/%05d_ema.png' % (epoch + 1)) 274 | optimizer.reset_old_weights() 275 | 276 | 277 | if __name__ == '__main__': 278 | 279 | trainer = Trainer() 280 | model.fit( 281 | data_generator(), 282 | steps_per_epoch=2000, 283 | epochs=10000, # 只是预先设置足够多的epoch数,可以自行Ctrl+C中断 284 | callbacks=[trainer] 285 | ) 286 | 287 | else: 288 | 289 | model.load_weights('model.ema.weights') 290 | --------------------------------------------------------------------------------