├── fig ├── backward.gif ├── backward.mp4 └── forward.gif ├── .gitignore ├── ema.py ├── README.md ├── infer.py ├── train.py └── dsb.py /fig/backward.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/Diffusion-Schrodinger-Bridge/HEAD/fig/backward.gif -------------------------------------------------------------------------------- /fig/backward.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/Diffusion-Schrodinger-Bridge/HEAD/fig/backward.mp4 -------------------------------------------------------------------------------- /fig/forward.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/Diffusion-Schrodinger-Bridge/HEAD/fig/forward.gif -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pt 3 | *.pth 4 | *.safetensors 5 | *.bin 6 | data/ 7 | **/__pycache__/ 8 | results/ -------------------------------------------------------------------------------- /ema.py: -------------------------------------------------------------------------------- 1 | class EMA(): 2 | 3 | def __init__(self, model, decay): 4 | self.model = model 5 | self.decay = decay 6 | self.shadow = {} 7 | self.backup = {} 8 | 9 | def register(self): 10 | for name, param in self.model.named_parameters(): 11 | if param.requires_grad: 12 | self.shadow[name] = param.data.clone() 13 | 14 | def update(self): 15 | for name, param in self.model.named_parameters(): 16 | if param.requires_grad: 17 | assert name in self.shadow 18 | new_average = (1.0 - self.decay 19 | ) * param.data + self.decay * self.shadow[name] 20 | self.shadow[name] = new_average.clone() 21 | 22 | def apply_shadow(self): 23 | for name, param in self.model.named_parameters(): 24 | if param.requires_grad: 25 | assert name in self.shadow 26 | self.backup[name] = param.data 27 | param.data = self.shadow[name] 28 | 29 | def restore(self): 30 | for name, param in self.model.named_parameters(): 31 | if param.requires_grad: 32 | assert name in self.backup 33 | param.data = self.backup[name] 34 | self.backup = {} 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 从零手搓扩散薛定谔桥(Diffusion Schrödinger Bridge) 2 | 3 | 作者:Tong Tong 4 | 5 | B站主页:[Double童发发](https://space.bilibili.com/323109608) 6 | 7 | 本套代码**未来会出讲解视频**,为了让大家更好的理解,强烈建议观看本人下面的几个B站视频: 8 | - [白话薛定谔桥](https://www.bilibili.com/video/BV1dsYieMEvj/) 9 | - [扩散模型随机微分方程(SDE)公式保姆级手推](https://www.bilibili.com/video/BV1y1YpejEB4/) 10 | - [你一定能听懂的扩散模型Flow Matching基本原理深度解析](https://www.bilibili.com/video/BV1Wv3xeNEds/) 11 | - [你一定能听懂的Recitified Flow基本原理深度解析](https://www.bilibili.com/video/BV19m421G7W8/) 12 | - [零门槛掌握DDPM](https://www.bilibili.com/video/BV1zz421i7UM/) 13 | 14 | **特别推荐看一下本人的[扩散模型之老司机开车理论视频](https://www.bilibili.com/video/BV1qW42197dv/),对你理解扩散模型有很大帮助~** 15 | 16 | **TODO**: 17 | - [ ] 加班加点准备代码讲解视频…… 18 | - [ ] 模型权重文件上传 19 | - [ ] 计划实现一些DSB的变种 20 | 21 | **一些bug修复说明**: 22 | - 暂无 23 | 24 | 25 | ## V1.0:Diffusion Schrödinger Bridge 26 | 27 | ### 说明 28 | 29 | * 代码基于人为生成的二维分布数据,一个为棋盘分布一个为爱心曲线分布。 30 | * 本项目完全**从零手搓**,尽可能不参考其他任何代码,从论文原理出发逐步实现,因此算是**极简实现**的一种,并**不能保证最优性能**,各位大佬可以逐步修改完善,欢迎交流。 31 | * 为了让大家都能上手,本代码只基于深度学习框架Pytorch和一些其他必要的库。该数据集随着训练代码生成,数据集维度与规模较小,也方便展示效果,最重要的是**即使是使用CPU都能训练**!!! 32 | * 模型结构自己手搓了一个MLP模型,大家可以根据自己的需求修改,也可以使用其他更复杂的模型。 33 | * 代码中有很多注释,希望能帮助大家理解代码,如果有问题欢迎留言交流。 34 | * 代码环境要求很低,甚至不需要GPU都可以 35 | * Python 3.8+ 36 | * Pytorch 2.0+ 37 | * Numpy 38 | * Matplotlib 39 | * 其他的就缺啥装啥 40 | * 代码运行方式 41 | * 训练:`python train.py` 42 | * 推理:`python infer.py` 43 | * 代码实现原理参考论文 44 | * Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling 45 | * 代码结果展示 46 | * 前向过程:爱心分布 -> 棋盘分布 47 | ![result_forward](/fig/forward.gif) 48 | 49 | 50 | * 逆向过程:棋盘分布 -> 爱心分布 51 | ![result_backwrad](/fig/backward.gif) 52 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | # 推理代码,随机生成对应分布的1000个数据,然后进行推理,看看最后会到哪里 2 | 3 | from train import generate_2d_data, dsb, create_chessboard, sample_from_chessboard, sample_heart_shape 4 | from dsb import DSBModel, DSB 5 | from matplotlib import pyplot as plt 6 | from matplotlib.animation import FuncAnimation 7 | import torch 8 | import torch.nn as nn 9 | import os 10 | import numpy as np 11 | from typing import List, Tuple, Dict, Union, Optional 12 | from functools import partial 13 | from torch.distributions import Normal 14 | from moviepy.editor import VideoFileClip 15 | 16 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 17 | 18 | 19 | def infer( 20 | x_0, 21 | x_1, 22 | dsb: DSB, 23 | checkpoint_dir: Optional[str] = './checkpoints/dsb_model_final.pth', 24 | device: Optional[str] = 'cuda' 25 | ) -> Union[List[torch.Tensor], List[torch.Tensor]]: 26 | # 加载模型 27 | save_dict = torch.load(checkpoint_dir) 28 | 29 | dsb.to(device) 30 | dsb.eval() 31 | # 加载前向模型权重 32 | dsb.model_dict['f'].load_state_dict(save_dict['forward_model']) 33 | # 加载后向模型权重 34 | dsb.model_dict['b'].load_state_dict(save_dict['backward_model']) 35 | 36 | # 从x_1开始推理 37 | b_path = dsb.sde_sample(x_1, mode='b') 38 | # 从x_0开始推理 39 | f_path = dsb.sde_sample(x_0, mode='f') 40 | 41 | return f_path, b_path 42 | 43 | 44 | # 画图函数,用于画图 45 | def draw_plot(x_0, x_1, path, ax, step): 46 | """画图函数 47 | 48 | Args: 49 | x_0: 起始点,标准的 50 | x_1: 终止点,标准的 51 | path: 使用模型预测的结果 52 | """ 53 | ax.clear() 54 | ax.set_xlim(-5, 5) 55 | ax.set_ylim(-5, 5) 56 | ax.scatter(x_0[:, 0].cpu().numpy(), 57 | x_0[:, 1].cpu().numpy(), 58 | label=r'$\pi_0$', 59 | color='blue', 60 | alpha=0.15) 61 | ax.scatter(x_1[:, 0].cpu().numpy(), 62 | x_1[:, 1].cpu().numpy(), 63 | label=r'$\pi_1$', 64 | color='orange', 65 | alpha=0.15) 66 | ax.scatter(path[step][:, 0].cpu().numpy(), 67 | path[step][:, 1].cpu().numpy(), 68 | label='Generated', 69 | color='red', 70 | alpha=0.15) 71 | # legend固定左上角 72 | ax.legend(loc='upper left') 73 | ax.set_title(f'Distribution t={step}') 74 | 75 | 76 | if __name__ == '__main__': 77 | n_samples = 10000 78 | checkpoint_dir = './checkpoints/dsb_model_final.pth' 79 | 80 | init_dist = Normal(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 81 | target_dist = create_chessboard() 82 | 83 | # 生成爱心样本点 84 | x_0 = sample_heart_shape(num_samples=n_samples, scale=0.2) 85 | x_0 = torch.tensor(x_0).float() 86 | 87 | # x_0 = init_dist.sample((n_samples, )) 88 | x_1 = sample_from_chessboard(target_dist, num_samples=n_samples) 89 | x_1 = torch.tensor(x_1).float() 90 | 91 | f_path, b_path = infer(x_0, x_1, dsb, checkpoint_dir=checkpoint_dir) 92 | 93 | figure, ax = plt.subplots(figsize=(8, 8)) 94 | 95 | animation_fun_b = partial(draw_plot, x_0, x_1, b_path, ax) 96 | animation_fun_f = partial(draw_plot, x_0, x_1, f_path, ax) 97 | 98 | animation = FuncAnimation(figure, 99 | func=animation_fun_b, 100 | frames=np.arange(0, len(b_path)), 101 | interval=200) 102 | # 保存这个动画 103 | animation.save('./fig/backward.gif', writer='imagemagick') 104 | 105 | animation = FuncAnimation(figure, 106 | func=animation_fun_f, 107 | frames=np.arange(0, len(f_path)), 108 | interval=200) 109 | animation.save('./fig/forward.gif', writer='imagemagick') 110 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """训练数据采用人造的两个高斯分布,看是否能实现两个分布之间的前向和后向变换 2 | """ 3 | import numpy as np 4 | import torch 5 | import os 6 | import torch.nn as nn 7 | from torch.distributions import Normal 8 | from torch.utils.data import DataLoader, Dataset, TensorDataset 9 | from ema import EMA 10 | from typing import Tuple, List, Dict, Union, Optional 11 | from dsb import DSBModel, DSB 12 | 13 | 14 | # 创建一个8x8的棋盘分布 15 | def create_chessboard(size=8): 16 | chessboard = np.zeros((size, size)) 17 | chessboard[1::2, ::2] = 1 18 | chessboard[::2, 1::2] = 1 19 | return chessboard 20 | 21 | 22 | # 从棋盘分布中采样样本 23 | def sample_from_chessboard(chessboard, num_samples=10): 24 | indices = np.argwhere(chessboard == 1) 25 | sampled_points = [] 26 | offset = chessboard.shape[0] // 2 27 | for _ in range(num_samples): 28 | idx = indices[np.random.choice(indices.shape[0])] 29 | x = np.random.uniform(low=idx[1], high=idx[1] + 1) - offset 30 | y = np.random.uniform(low=idx[0], high=idx[0] + 1) - offset 31 | sampled_points.append((x, y)) 32 | return sampled_points 33 | 34 | 35 | # 高斯分布生成2D数据 36 | def generate_2d_data(n_samples: int, 37 | mu1: Optional[List[float]] = [-3.0, -3.0], 38 | mu2: Optional[List[float]] = [3.0, 3.0], 39 | sigma1: Optional[List[float]] = [1.0, 1.0], 40 | sigma2: Optional[List[float]] = [1.0, 1.0]): 41 | """生成2D数据""" 42 | init_dist = Normal(torch.tensor(mu1), torch.tensor(sigma1)) 43 | target_dist = Normal(torch.tensor(mu2), torch.tensor(sigma2)) 44 | x_0 = init_dist.sample((n_samples, )) 45 | x_1 = target_dist.sample((n_samples, )) 46 | return x_0, x_1 47 | 48 | 49 | def sample_heart_shape(num_samples=1000, noise=0.1, scale=1.0): 50 | """ 51 | 生成爱心形状的样本点。 52 | 53 | 参数: 54 | num_samples (int): 样本数量。 55 | noise (float): 噪声强度。 56 | scale (float): 缩放因子。 57 | 58 | 返回: 59 | np.ndarray: 生成的样本点。 60 | """ 61 | t = np.linspace(0, 2 * np.pi, num_samples) 62 | x = 16 * np.sin(t)**3 63 | y = 13 * np.cos(t) - 5 * np.cos(2 * t) - 2 * np.cos(3 * t) - np.cos(4 * t) 64 | 65 | # 添加噪声 66 | x += np.random.normal(scale=noise, size=num_samples) 67 | y += np.random.normal(scale=noise, size=num_samples) 68 | 69 | # 缩放 70 | x *= scale 71 | y *= scale 72 | 73 | return np.vstack((x, y)).T 74 | 75 | 76 | def train(x_0, 77 | x_1, 78 | dsb: DSB, 79 | n_epochs: Optional[int] = 50, 80 | steps_per_epoch: Optional[int] = 1000, 81 | batch_size: Optional[int] = 32, 82 | lr: Optional[float] = 1e-4, 83 | checkpoint_dir: Optional[str] = './checkpoints', 84 | checkpoint_save_interval: Optional[int] = 10, 85 | use_ema: Optional[bool] = True, 86 | device: Optional[str] = "cuda"): 87 | """训练模型 88 | 89 | Args: 90 | x_0: 训练数据x_0,shape=(n_samples, dim) 91 | x_1: 训练数据x_1,shape=(n_samples, dim) 92 | dsb: DSB类,用于训练 93 | n_epochs: 训练轮数,这个是大轮数,默认50 94 | steps_per_epoch: 每轮每个模型训练的步数,默认1000 95 | batch_size: 批大小,默认32 96 | lr: 学习率,默认1e-4 97 | checkpoint_dir: 检查点保存路径,默认./checkpoints 98 | checkpoint_save_interval: 检查点保存间隔,默认10 99 | device: 训练设备,默认cuda 100 | """ 101 | # 优化器 102 | # optimizer = torch.optim.Adam(dsb.parameters(), lr=lr) 103 | # 记录损失函数,分为前向和逆向模型两个,每个Step打印一个 104 | losses = {'f': [], 'b': []} 105 | # 模型放到cuda上 106 | dsb.to(device) 107 | 108 | # 模型使用EMA 109 | if use_ema: 110 | ema_model_forward = EMA(dsb.model_dict['f'], decay=0.999) 111 | ema_model_backward = EMA(dsb.model_dict['b'], decay=0.999) 112 | ema_model_forward.register() 113 | ema_model_backward.register() 114 | 115 | # 训练首先就得弄出来一个数据集,根据x_0和x_1 116 | for epoch in range(n_epochs): 117 | if epoch == 0: # 第一轮,我还没有任何模型 118 | first_it = True 119 | 120 | # 每一轮训练是先b再f 121 | for m in ['b', 'f']: 122 | optimizer = torch.optim.Adam(dsb.model_dict[m].parameters(), lr=lr) 123 | if use_ema and epoch > 0: 124 | ema_model = ema_model_forward if m == 'b' else ema_model_backward 125 | else: 126 | ema_model = None 127 | 128 | x_t, target, t_list = dsb.generate_path_and_target( 129 | x_0, x_1, m, first_it, ema_model=ema_model) 130 | # 把这个构成一个pytorch的dataset 131 | dataset = TensorDataset(x_t, target, t_list) 132 | # dataloader 133 | dataloader = DataLoader(dataset, 134 | batch_size=batch_size, 135 | shuffle=True, 136 | drop_last=True) 137 | dl = iter(dataloader) 138 | 139 | for step in range(steps_per_epoch): 140 | # 注意,这里的step可能会超过数据集的长度,所以一旦报错,就得重新读取一次dataloader 141 | # 取一个batch的数据 142 | try: 143 | batch_data = next(dl) 144 | except StopIteration: 145 | x_t, target, t_list = dsb.generate_path_and_target( 146 | x_0, x_1, m, first_it, ema_model=ema_model) 147 | dl = iter( 148 | DataLoader(TensorDataset(x_t, target, t_list), 149 | batch_size=batch_size, 150 | shuffle=True, 151 | drop_last=True)) 152 | batch_data = next(dl) 153 | 154 | # 取一个batch的数据 155 | b_x_t, b_target, b_t_list = batch_data 156 | # 训练 转GPU 157 | b_x_t = b_x_t.to(device) 158 | b_target = b_target.to(device) 159 | b_t_list = b_t_list.to(device) 160 | 161 | optimizer.zero_grad() 162 | # 对应模型预测 163 | pred = dsb.model_dict[m](b_x_t, b_t_list) 164 | # 计算损失函数 165 | loss = dsb.mse_loss(pred, b_target) 166 | # 反向传播 167 | loss.backward() 168 | # 更新参数 169 | optimizer.step() 170 | # 更新EMA 171 | if use_ema and epoch > 0: 172 | if m == 'f': 173 | ema_model_forward.update() 174 | elif m == 'b': 175 | ema_model_backward.update() 176 | 177 | # 记录损失 178 | losses[m].append(loss.item()) 179 | # # 每1000步打印一次 180 | if step % 1000 == 0 or step == steps_per_epoch - 1: 181 | print( 182 | f"Epoch {epoch}, Step {step}, Mode {m}, Loss {loss.item()}" 183 | ) 184 | # 第一轮模型训练完了,后面就有模型了 185 | first_it = False 186 | 187 | # 每一轮打印一次 188 | print( 189 | f"Epoch {epoch}, Loss f {np.mean(losses['f'][-steps_per_epoch:])}, Loss b {np.mean(losses['b'][-steps_per_epoch:])}" 190 | ) 191 | 192 | # 每若干轮轮保存一次模型 193 | if epoch % checkpoint_save_interval == 0: 194 | print(f"Save model at epoch {epoch}") 195 | if use_ema: 196 | ema_model_backward.apply_shadow() 197 | ema_model_forward.apply_shadow() 198 | save_dict = { 199 | 'forward_model': ema_model_forward.model.state_dict(), 200 | 'backward_model': ema_model_backward.model.state_dict(), 201 | 'optimizer': optimizer.state_dict(), 202 | 'losses': losses 203 | } 204 | ema_model_backward.restore() 205 | ema_model_forward.restore() 206 | else: 207 | save_dict = { 208 | 'forward_model': dsb.model_dict['f'].state_dict(), 209 | 'backward_model': dsb.model_dict['b'].state_dict(), 210 | 'optimizer': optimizer.state_dict(), 211 | 'losses': losses 212 | } 213 | 214 | torch.save(save_dict, 215 | os.path.join(checkpoint_dir, f'dsb_model_{epoch}.pth')) 216 | 217 | # 模型较小,就保存一个torch模型 218 | if use_ema: 219 | ema_model_backward.apply_shadow() 220 | ema_model_forward.apply_shadow() 221 | save_dict = { 222 | 'forward_model': ema_model_forward.model.state_dict(), 223 | 'backward_model': ema_model_backward.model.state_dict(), 224 | 'optimizer': optimizer.state_dict(), 225 | 'losses': losses 226 | } 227 | else: 228 | save_dict = { 229 | 'forward_model': dsb.model_dict['f'].state_dict(), 230 | 'backward_model': dsb.model_dict['b'].state_dict(), 231 | 'optimizer': optimizer.state_dict(), 232 | 'losses': losses 233 | } 234 | 235 | torch.save(save_dict, os.path.join(checkpoint_dir, 'dsb_model_final.pth')) 236 | 237 | 238 | dim = 2 239 | hidden_dim = 256 240 | num_layers = 6 241 | activation = nn.ReLU(True) 242 | 243 | forward_model = DSBModel(input_dim=dim + 1, 244 | hidden_dim=hidden_dim, 245 | output_dim=dim, 246 | num_layers=num_layers, 247 | activation=activation) 248 | backward_model = DSBModel(input_dim=dim + 1, 249 | hidden_dim=hidden_dim, 250 | output_dim=dim, 251 | num_layers=num_layers, 252 | activation=activation) 253 | 254 | dsb = DSB(forward_model, 255 | backward_model, 256 | gamma_max=1.0, 257 | gamma_min=0.02, 258 | device='cuda', 259 | num_steps=20) 260 | 261 | if __name__ == "__main__": 262 | n_samples = 100000 263 | 264 | # init_dist = Normal(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 265 | init_dist = sample_heart_shape(num_samples=n_samples, scale=0.2) 266 | target_dist = create_chessboard() 267 | 268 | x_0 = torch.tensor(init_dist).float() 269 | x_1 = sample_from_chessboard(target_dist, num_samples=n_samples) 270 | x_1 = torch.tensor(x_1).float() 271 | 272 | # steps_per_epoch一定要大,每一轮训练不充分直接影响下一轮! 273 | train(x_0, 274 | x_1, 275 | dsb, 276 | n_epochs=50, 277 | steps_per_epoch=10000, 278 | batch_size=256, 279 | lr=1e-4, 280 | checkpoint_dir='./checkpoints', 281 | checkpoint_save_interval=1, 282 | use_ema=False) 283 | -------------------------------------------------------------------------------- /dsb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from typing import List, Tuple, Dict, Union, Optional 6 | 7 | 8 | class MLP(nn.Module): 9 | """MLP模型 10 | 11 | Args: 12 | input_dim: 输入维度 13 | hidden_dim: 隐藏层维度 14 | output_dim: 输出维度 15 | num_layers: MLP层数 16 | activation: 激活函数 17 | """ 18 | 19 | def __init__(self, 20 | input_dim: int, 21 | hidden_dim: int, 22 | output_dim: int, 23 | num_layers: int, 24 | activation=nn.ReLU()): 25 | super(MLP, self).__init__() 26 | self.input_dim = input_dim 27 | self.hidden_dim = hidden_dim 28 | self.output_dim = output_dim 29 | self.num_layers = num_layers 30 | 31 | self.layers = nn.ModuleList() 32 | self.layers.append(nn.Linear(input_dim, hidden_dim)) 33 | 34 | if num_layers > 2: 35 | for _ in range(num_layers - 2): 36 | self.layers.append(nn.Linear(hidden_dim, hidden_dim)) 37 | 38 | self.layers.append(nn.Linear(hidden_dim, output_dim)) 39 | 40 | self.act = activation 41 | 42 | def forward(self, x): 43 | for i in range(self.num_layers): 44 | x = self.layers[i](x) 45 | if i != self.num_layers - 1: 46 | # 最后一层不要激活函数 47 | x = self.act(x) 48 | return x 49 | 50 | 51 | class DSBModel(nn.Module): 52 | """DSB模型用到的基础模型,包含前向和后向,需要把时间作为条件concat 53 | """ 54 | 55 | def __init__(self, 56 | input_dim, 57 | hidden_dim, 58 | output_dim=None, 59 | num_layers=3, 60 | activation=nn.ReLU()): 61 | super(DSBModel, self).__init__() 62 | if output_dim is None: 63 | output_dim = input_dim 64 | self.model = MLP(input_dim, hidden_dim, output_dim, num_layers, 65 | activation) 66 | 67 | def forward(self, x, t): 68 | return self.model(torch.cat([x, t], dim=-1)) 69 | 70 | 71 | class DSB(nn.Module): 72 | """DSB过程,类似于Rectified Flow类,但是更加复杂,同样采用老司机理论 73 | 路线:依据不同的轮次,由一个前向或后向模型逐步逼近,在逼近过程当中,我们需要得到一些量供司机训练使用 74 | 车:就是SDE 75 | 司机:仍然是MSE loss 76 | 77 | 综上,我们需要一个构建数据的东西 78 | """ 79 | 80 | def __init__(self, 81 | forward_model: nn.Module, 82 | backward_model: nn.Module, 83 | num_steps: Optional[int] = 20, 84 | gamma_min: Optional[float] = 0.02, 85 | gamma_max: Optional[float] = 1.0, 86 | device: Optional[torch.device] = torch.device('cpu')): 87 | super(DSB, self).__init__() 88 | self.forward_model = forward_model 89 | self.backward_model = backward_model 90 | 91 | self.forward_model.to(device) 92 | self.backward_model.to(device) 93 | 94 | self.model_dict = {'f': self.forward_model, 'b': self.backward_model} 95 | 96 | gamma_half = np.linspace(gamma_min, gamma_max, num_steps // 2) 97 | self.gamma = np.concatenate([gamma_half, gamma_half[::-1]]) 98 | 99 | self.num_steps = num_steps 100 | self.device = device 101 | 102 | @torch.no_grad() 103 | def generate_path_and_target(self, 104 | x_0, 105 | x_1, 106 | mode: str = 'b', 107 | first_it: bool = False, 108 | ema_model=None): 109 | """使用前向或逆向模型生成路径和target,每个训练大轮开始都要用上一个时刻的模型去推理出路径和target(本轮要训练前向模型,就用上一轮训好的逆向模型生成) 110 | 111 | Args: 112 | x_0: 起始点样本,shape: (batch_size, input_dim) 113 | x_1: 终止点,shape: (batch_size, input_dim) 114 | mode: 模式,f表示前向,b表示逆向 115 | first_it: 是否是第一个大轮次,第一个大轮次首先训练的b模型是没有前置f模型的,需要认为提供(布朗桥) 116 | ema_model: 是否存在EMA模型,有EMA模型用当前EMA权重替代现有权重 117 | 118 | Returns: 119 | """ 120 | # 推理过程,所有模型变为eval模式 121 | self.model_dict[mode].eval() 122 | 123 | if ema_model is not None: 124 | ema_model.apply_shadow() 125 | self.model_dict[mode].load_state_dict(ema_model.model.state_dict()) 126 | 127 | x_0 = x_0.to(self.device) 128 | x_1 = x_1.to(self.device) 129 | 130 | if mode == 'f': 131 | # 上一轮 132 | prev_mode = 'b' 133 | x_start = x_1 134 | elif mode == 'b': 135 | prev_mode = 'f' 136 | x_start = x_0 137 | else: 138 | raise ValueError('mode must be f or b') 139 | 140 | # 生成一系列时间 141 | # 1. 用SDE生成路径需要时间差分值 142 | dt = 1.0 / self.num_steps 143 | # 2. 上一轮模型所用的归一化时间,如果是forward模型,时间就是从0开始到1时刻前的最后一个值,如果是backward模型,时间就是从1开始到0时刻前的最后一个值 144 | if prev_mode == 'f': 145 | t_prev = np.arange(self.num_steps) / self.num_steps 146 | elif prev_mode == 'b': 147 | t_prev = 1 - np.arange(self.num_steps) / self.num_steps 148 | # 3.本轮模型所对应的时间,比如前向模型预测的x_2,在后向模型应该对应的时刻是t=3,反过来,后向模型预测的x_2,在前向模型应该对应的时刻是t=1,所以时间要做一个修正 149 | if prev_mode == 'f': 150 | t_cur = np.arange(1, self.num_steps + 1) / self.num_steps 151 | elif prev_mode == 'b': 152 | t_cur = 1 - np.arange(1, self.num_steps + 1) / self.num_steps 153 | 154 | # 给定几个列表,用来存储数据 155 | path = [] 156 | target = [] 157 | t_list = [] # 时间列表,模型输入需要时间! 158 | 159 | # 准备进入迭代,给定初始值,前向模型的初始值是x_0,后向模型的初始值是x_1 160 | # 一定不要乱,这里的0和1指的就是正常的时间 161 | x = x_start 162 | 163 | # 生成路径,路径要考虑是否为第一轮 164 | if first_it: 165 | # 难点来了,第一轮的时候,我们是不知道前一个时刻的F是什么的,只能靠自己定义,比如布朗桥 166 | assert mode == 'b' # 必须是backward先开始 167 | for k in range(self.num_steps): 168 | t = t_prev[k] # shape: (1) 169 | t = torch.ones((x.shape[0], 1), device=self.device) * t 170 | # 按照SDE的形式生成路径,前向和逆向现在都统一了 171 | # 1. 维纳过程,就认为F是一个恒等变换,也即F(x) = x 172 | dw = torch.sqrt(torch.tensor(dt)) * torch.randn_like(x) 173 | dw = dw.to(self.device) 174 | 175 | # 线性插值 176 | vec = (x_1 - x) / (1 - t) 177 | pred_x = x + vec * dt 178 | 179 | # pred = x # F(x)预测的就是x 180 | pred_x = pred_x.to(self.device) 181 | # 维纳过程计算x_{k+1} 182 | x = pred_x + torch.sqrt(torch.tensor(2 * self.gamma[k])) * dw 183 | # x = x + pred + torch.sqrt(torch.tensor(2 * self.gamma)) * dw 184 | # 2. 当前的x已经是下一个时刻的值,可以来算损失函数里面F(X[k+1])了 185 | # 同样计算pred_x_next,只要vec是上一个时刻的就好了 186 | pred_x_next = x + vec * dt 187 | 188 | # pred_x_next = x 189 | # 3. 计算这一轮的B的回归目标是什么 190 | target_cur = x + pred_x - pred_x_next # 原论文目标 191 | # target_cur = pred - pred_next # 作者代码实现 192 | # target_cur = -torch.sqrt(torch.tensor( 193 | # 2 * self.gamma)) * dw # 其他实现 194 | # 现在都有都有了,我们来整理整理 195 | path.append(x.detach().clone()) 196 | target.append(target_cur) 197 | t_list.append(torch.ones((x.shape[0], 1)) * t_cur[k]) 198 | 199 | else: 200 | # 每一个时刻,都计算一次,一个batch的每一个样本获得一个路径的位置、时间和目标 201 | for k in range(self.num_steps): 202 | # 这个t是上一轮模型对应的时间 203 | t = t_prev[k] # shape: (1) 204 | # 为了计算,要把t的维度变为和x的batch_size一样,也即是(batch_size, 1) 205 | t = torch.ones((x.shape[0], 1), device=self.device) * t 206 | # 按照SDE的形式生成路径,前向和逆向现在都统一了 207 | # 1. 用上一轮的B或者F生成上一个时刻或下一个时刻的x,也即x[k]或者x[k+1]},同时要保留B(x[k+1])或者F(x[k])的值(bf_x),这两个值也是损失函数的一部分 208 | # 维纳过程dz 209 | dw = torch.sqrt(torch.tensor(dt)) * torch.randn_like(x) 210 | 211 | x = x.to(self.device) 212 | dw = dw.to(self.device) 213 | 214 | pred_x = self.model_dict[prev_mode](x, t) 215 | # x -> x_next 也即 x[k] -> x[k+1] 或 x[k+1] -> x[k] 216 | x = pred_x + torch.sqrt(torch.tensor(2 * self.gamma[k])) * dw 217 | # x = x + pred + torch.sqrt(torch.tensor(2 * self.gamma)) * dw 218 | 219 | # 2. 当前的x已经是上一个或下一个时刻的值,可以来算损失函数里面B(X[k])和F(x[k+1])这两项了 220 | pred_next_x = self.model_dict[prev_mode](x, t) 221 | # 3. 计算这一轮的F或者B的回归目标是什么 222 | target_cur = x + pred_x - pred_next_x # 原论文目标 223 | # target_cur = pred - pred_next # 作者代码实现,在训练的时候的pred再减去x 224 | # target_cur = -pred_next - torch.sqrt( 225 | # torch.tensor(2 * self.gamma)) * dw # 其他实现 226 | # 现在都有都有了,我们来整理整理 227 | path.append(x.detach().clone()) 228 | target.append(target_cur) 229 | t_list.append(torch.ones((x.shape[0], 1)) * t_cur[k]) 230 | 231 | # 最后就是把我们得到的所有信息stack一下,构建新的数据 232 | path = torch.stack(path).to( 233 | self.device) # shape: (num_steps, batch_size, input_dim) 234 | target = torch.stack(target).to( 235 | self.device) # shape: (num_steps, batch_size, input_dim) 236 | t_list = torch.stack(t_list).to( 237 | self.device) # shape: (num_steps, batch_size, 1) 238 | 239 | # 每个batch的每一个样本都采样一个时刻 240 | t_sample = torch.randint( 241 | 0, self.num_steps, (1, x_0.shape[0], 1), 242 | device=self.device) # shape: (1, batch_size, 1) 243 | # 从路径中取出对应时刻的路径和目标,用gather函数 244 | t_list = torch.gather(t_list, 0, t_sample).squeeze(0) 245 | x_t = torch.gather(path, 0, 246 | t_sample.expand(1, x_0.shape[0], 247 | x_0.shape[1])).squeeze(0) 248 | 249 | target = torch.gather(target, 0, 250 | t_sample.expand(1, x_0.shape[0], 251 | x_0.shape[1])).squeeze(0) 252 | 253 | # 恢复训练模式 254 | self.model_dict[mode].train() 255 | # 恢复模型权重 256 | if ema_model is not None: 257 | ema_model.restore() 258 | self.model_dict[mode].load_state_dict(ema_model.model.state_dict()) 259 | 260 | return x_t, target, t_list 261 | 262 | @torch.no_grad() 263 | def sde_sample(self, 264 | x_start, 265 | mode: str = 'b', 266 | num_steps: Optional[int] = None, 267 | return_path: Optional[bool] = True): 268 | # 欧拉丸山法根据模型预测的值,生成结果,用于推理 269 | # 这个函数可以选择输出每一个时刻的值,便于画图 270 | if num_steps is None: 271 | num_steps = self.num_steps 272 | 273 | dt = 1.0 / num_steps 274 | path = [] 275 | x = x_start 276 | x = x.to(self.device) 277 | path.append(x.detach().clone()) 278 | 279 | # 输入模型的时间序列 280 | t = np.arange(num_steps) / num_steps 281 | if mode == 'b': 282 | t = 1 - t 283 | 284 | for k in range(num_steps): 285 | t_cur = torch.ones((x.shape[0], 1), device=self.device) * t[k] 286 | pred = self.model_dict[mode](x, t_cur) 287 | dw = torch.sqrt(torch.tensor(dt)) * torch.randn_like(x) 288 | dw = dw.to(self.device) 289 | if k == num_steps - 1: 290 | # 采样过程最后一步就不加噪声了 291 | x = pred 292 | else: 293 | x = pred + torch.sqrt(torch.tensor(2 * self.gamma[k])) * dw 294 | # x = x + pred + torch.sqrt(torch.tensor(2 * self.gamma)) * dw 295 | 296 | path.append(x.detach().clone()) 297 | 298 | if return_path: 299 | return path 300 | else: # 仅返回最终结果 301 | return x 302 | 303 | # 最基本的MSE loss 304 | def mse_loss(self, pred, target): 305 | return F.mse_loss(pred, target) 306 | 307 | 308 | if __name__ == '__main__': 309 | input_dim = 11 310 | hidden_dim = 10 311 | output_dim = 10 312 | num_layers = 3 313 | 314 | forward_model = DSBModel(input_dim, hidden_dim, output_dim, num_layers) 315 | backward_model = DSBModel(input_dim, hidden_dim, output_dim, num_layers) 316 | 317 | dsb = DSB(forward_model, backward_model) 318 | x_0 = torch.randn(10, 10) 319 | x_1 = torch.randn(10, 10) 320 | path, target, t_list = dsb.generate_path_and_target(x_0, 321 | x_1, 322 | mode='b', 323 | first_it=False) 324 | 325 | print(path.shape, target.shape, t_list.shape) 326 | 327 | path = dsb.sde_sample(x_1, mode='b') 328 | 329 | print(path) 330 | --------------------------------------------------------------------------------