├── .gitignore ├── README.md ├── config ├── train_config.yaml └── train_reflow_config.yaml ├── datasets ├── __init__.py └── reflow_dataset.py ├── draw_result_fig.py ├── fig ├── loss_curve.png ├── loss_curve_cfg.png ├── loss_curve_cfg_reflow.png ├── results_fig.png ├── results_fig_cfg.png ├── results_fig_cfg_2steps.png ├── results_fig_cfg_3steps.png ├── results_fig_cfg_4steps.png ├── results_fig_cfg_reflow_2steps.png ├── results_fig_cfg_reflow_3steps.png └── results_fig_cfg_reflow_4steps.png ├── infer.py ├── model.py ├── plot_loss_curve.py ├── rectified_flow.py ├── reflow_sample_generate.py ├── train.py └── train_reflow.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pt 3 | *.pth 4 | *.safetensors 5 | *.bin 6 | data/ 7 | **/__pycache__/ 8 | results/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 从零手搓Flow Matching(Rectified Flow) 2 | 3 | 作者:Tong Tong 4 | 5 | B站主页:[Double童发发](https://space.bilibili.com/323109608) 6 | 7 | 本套代码有相关讲解视频,详见B站:[从零手搓Flow Matching(Rectified Flow)](https://www.bilibili.com/video/BV1Sjv4ezEDN/),同时强烈建议先看一下本人B站关于[Flow Matching](https://www.bilibili.com/video/BV1Wv3xeNEds/)和[Recitified Flow](https://www.bilibili.com/video/BV19m421G7W8/)的算法讲解视频,会对理解代码有很大帮助。 8 | 9 | **特别推荐看一下本人的[扩散模型之老司机开车理论视频](https://www.bilibili.com/video/BV1qW42197dv/),对你理解扩散模型有很大帮助~** 10 | 11 | **TODO**: 12 | - [ ] v1.4版本计划增加文本条件输入(计划仅做简单实验,语言模型较大无法满足让大家都能上手的目标) 13 | - [ ] v1.3版本计划增加distillation 14 | - [x] 开放reflow(2-Rectified Flow)模型权重和数据 15 | - [x] v1.2版本增加reflow 16 | - [x] 开放v1.1版本相关模型权重文件(百度网盘形式) 17 | - [x] v1.1版本计划增加MNIST条件生成 18 | - [x] v1.0开放模型预训练权重(百度网盘形式) 19 | 20 | 21 | **一些bug修复说明**: 22 | - 感谢B站粉丝大佬@EchozL提醒,MiniUnet编的草率了,现已更新,最高分辨率的特征也concat啦~ 23 | 24 | **温馨提示(跪求支持):** 25 | 项目更新速度受大家支持程度的影响,最新[reflow视频](https://www.bilibili.com/video/BV14XDkYNEVN/)**点赞+投币**数目大于500,我立即爆肝更新下一期视频。此外,周一到周五晚上直播有概率手搓下一期视频代码内容,大家可以期待一下~ 26 | * 目前视频点赞+投币进度(截止2024年12月7日): 482/500 27 | 28 | ## 项目说明 29 | * 本项目代码基于MNIST数据集实现算法的训练与推理,可实现有条件或无条件生成0-9手写字体,目前有条件生成仅支持使用类别label,也即0-9整型数字,使用文本作为条件计划下个版本支持。 30 | * 本项目完全**从零手搓**,尽可能不参考其他任何代码,从论文原理出发逐步实现,因此算是**极简实现**的一种,并**不能保证最优性能**,各位大佬可以逐步修改完善。 31 | * 为了让大家都能上手,本代码只基于深度学习框架Pytorch和一些其他必要的库。数据集选择MNIST作为案例,该数据集Pytorch本身自带,数据集规模较小,也方便展示效果,最重要的是**即使是使用CPU都能训练**!!! 32 | * 模型结构自己手搓了一个MiniUnet,大家可以根据自己的需求修改,也可以使用其他更复杂的模型,比如Unet、DiT等。 33 | * 代码中有很多注释,希望能帮助大家理解代码,如果有问题欢迎留言交流。 34 | * 代码环境要求很低,甚至不需要GPU都可以 35 | * Python 3.8+ 36 | * Pytorch 2.0+ 37 | * Numpy 38 | * Matplotlib 39 | * 其他的就缺啥装啥 40 | * 代码运行方式 41 | * 如果需要训练代码请务必先查看config文件夹里的配置文件,并根据实际情况修改相关参数,尤其是是否使用classifier-free guidance,是否使用GPU等,设置好了再开始训练 42 | * 训练:`python train.py`,训练参数配置文件为`config/train_config.yaml` 43 | * reflow训练:`python train_reflow.py`,训练参数配置文件为`config/train_reflow_config.yaml` 44 | * 推理:`python infer.py` 45 | * 画loss曲线:`python plot_loss_curve.py` 46 | * 结果图像展示(100张生成图像拼图生成):`python draw_result_fig.py` 47 | * 各版本权重代码和数据[点击下载](https://pan.baidu.com/s/1ZV1z9OSSXRYX5E5Ws8xvow?pwd=9hmi),提取码9hmi,把checkpoints和data文件夹放到根目录下即可,**注意!代码或模型版本更新导致文件同步更新!请下载最新文件,更新日期2024年11月10日** 48 | 49 | ## 版本说明 50 | ### V1.2: Reflow 51 | * V1.2版本在V1.1版本的基础上进一步支持reflow训练 52 | * Reflow模型需要构建新的数据集,根据实验结果**所需数据量极大,算力成本较高,带来的提升确不够明显,对于MNIST这种简单数据集实用性不强**。6万张MNIST数据集需要**100万个**通过原生rectified flow模型(也即1-Rectified Flow模型)的样本对$`(Z_{0}^{1}, Z_{1}^{1})`$训练20个epoch,才有能看出来的效果 53 | * Reflow过程模型初始权重为1-Rectified Flow模型的权重 54 | * 模型收敛较好 55 | 56 | ![loss curve](/fig/loss_curve_cfg_reflow.png) 57 | * 生成效果展示,每一行为一个类别的生成结果,从0-9,上图为2-Rectified Flow模型**2步**生成效果,下图为1-Rectified Flow模型的**2步**生成效果 58 | 59 | ![results](/fig/results_fig_cfg_reflow_2steps.png) 60 | ![results](/fig/results_fig_cfg_2steps.png) 61 | 62 | 63 | 64 | ### V1.1: Flow Matching(Rectified Flow)条件生成 65 | * V1.1版本同时支持无条件生成和条件生成 66 | * 模型收敛较好 67 | 68 | ![loss curve](/fig/loss_curve_cfg.png) 69 | * 生成效果展示,每一行为一个类别的生成结果,从0-9 70 | 71 | ![results](/fig/results_fig_cfg.png) 72 | 73 | ### V1.0:Flow Matching(Rectified Flow)无条件生成 74 | * V1.0版本仅支持无条件生成 75 | * 模型收敛较好 76 | 77 | ![loss curve](/fig/loss_curve.png) 78 | * 生成效果展示 79 | 80 | ![results](/fig/results_fig.png) 81 | 82 | --- 83 | * 代码实现原理参考论文 84 | * Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow 85 | * Flow Matching for Generative Modeling 86 | * Classifier-Free Diffusion Guidance 87 | -------------------------------------------------------------------------------- /config/train_config.yaml: -------------------------------------------------------------------------------- 1 | base_channels: 64 # base_channels大一些有好处 2 | epochs: 50 # 训练多少个epoch 3 | batch_size: 16 # batch_size小一些 4 | lr_adjust_epoch: 25 # 学习率调整的epoch,降为原有的10% 5 | batch_print_interval: 100 # 打印间隔,以batch为单位 6 | checkpoint_save_interval: 10 # 模型保存间隔,以epoch为单位 7 | save_path: './checkpoints/v1.1-cfg' # 模型保存路径 8 | use_cfg: True # 是否使用classifier-free guidance,开启就可以训练条件生成模型了 9 | device: 'cuda' # cuda、cpu、mps(only macbook) 10 | -------------------------------------------------------------------------------- /config/train_reflow_config.yaml: -------------------------------------------------------------------------------- 1 | base_channels: 64 # base_channels大一些有好处 2 | epochs: 20 # 训练多少个epoch 3 | batch_size: 16 # batch_size小一些 4 | lr: 0.00001 # 学习率 5 | lr_adjust_epoch: 10 # 学习率调整的epoch,降为原有的10% 6 | batch_print_interval: 100 # 打印间隔,以batch为单位 7 | checkpoint_save_interval: 10 # 模型保存间隔,以epoch为单位 8 | save_path: './checkpoints/v1.2-reflow-cfg' # 模型保存路径 9 | use_cfg: True # 是否使用classifier-free guidance,开启就可以训练条件生成模型了 10 | device: 'cuda' # cuda、cpu、mps(only macbook) 11 | 12 | # reflow新增参数 13 | img_root_path: './data/reflow_img' # 图像的地址 14 | noise_root_path: './data/reflow_noise' # 噪声的地址 15 | checkpoint_path: './checkpoints/v1.1-cfg/miniunet_49.pth' # 模型的地址 finetune -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .reflow_dataset import ReflowDataset 2 | -------------------------------------------------------------------------------- /datasets/reflow_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | from torch.utils.data import Dataset 4 | from typing import List, Union, Tuple, Optional, Any 5 | from PIL import Image 6 | import numpy as np 7 | from torchvision.transforms import ToTensor 8 | import os 9 | 10 | 11 | class ReflowDataset(Dataset): 12 | """ReflowDataset 13 | 用于训练Reflow模型的数据集 14 | 15 | Args: 16 | img_root_path (str): 图像的根路径 17 | noise_root_path (str): 噪声的根路径 18 | transform (optional): 图像transform. Defaults to None. 19 | """ 20 | 21 | def __init__(self, 22 | img_root_path: str, 23 | noise_root_path: str, 24 | transform: Optional[Any] = None): 25 | # 通过根路径获得所有图片的路径 26 | self.img_path = [] 27 | self.noise_path = [] 28 | self.labels = [] 29 | 30 | for label in os.listdir(img_root_path): 31 | img_path = os.path.join(img_root_path, label) 32 | noise_path = os.path.join(noise_root_path, label) 33 | for img_name in os.listdir(img_path): 34 | self.labels.append(int(label)) 35 | self.img_path.append(os.path.join(img_path, img_name)) 36 | self.noise_path.append( 37 | os.path.join(noise_path, img_name.replace('.png', '.npy'))) 38 | 39 | self.transform = transform 40 | 41 | def __len__(self): 42 | return len(self.img_path) 43 | 44 | def __getitem__(self, idx): 45 | img_path = self.img_path[idx] 46 | noise_path = self.noise_path[idx] 47 | label = self.labels[idx] 48 | 49 | # 读取png图片 50 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 51 | # 读取npy文件 52 | noise = np.load(noise_path) 53 | 54 | if self.transform: 55 | img = self.transform(img) 56 | 57 | # noise已经自动变为tensor 58 | noise = torch.tensor(noise) 59 | # 删除一维 60 | noise = noise.squeeze(0) 61 | 62 | return {'img': img, 'noise': noise, 'label': label} 63 | 64 | 65 | if __name__ == '__main__': 66 | transform = ToTensor() 67 | dataset = ReflowDataset('./data/reflow_img', './data/reflow_noise', 68 | transform) 69 | img, noise, label = dataset[100] 70 | print(len(dataset)) 71 | print(img.shape, noise.shape, label) 72 | print(img.max(), img.min()) 73 | -------------------------------------------------------------------------------- /draw_result_fig.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import cv2 3 | import os 4 | 5 | # 读取results文件夹的100张图片 6 | # img_folder = './results/reflow-cfg' # v1.2 7 | img_folder = 'results/cfg' # v1.1 8 | img_files = [ 9 | os.path.join(img_folder, f) for f in os.listdir(img_folder) 10 | if f.endswith('.png') 11 | ][:100] 12 | # 按照自然数顺序排列 13 | img_files.sort(key=lambda x: int(os.path.basename(x).split('.')[0])) 14 | 15 | fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(10, 10)) 16 | 17 | for ax, img_file in zip(axes.flatten(), img_files): 18 | img = cv2.imread(img_file) 19 | img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 20 | ax.imshow(img_rgb) 21 | ax.axis('off') 22 | 23 | plt.subplots_adjust(wspace=0.1, hspace=0.1) 24 | plt.tight_layout() 25 | plt.show() 26 | -------------------------------------------------------------------------------- /fig/loss_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/rectified-flow/7de2f274911cda0d1298fa1240d4c50903e78ec5/fig/loss_curve.png -------------------------------------------------------------------------------- /fig/loss_curve_cfg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/rectified-flow/7de2f274911cda0d1298fa1240d4c50903e78ec5/fig/loss_curve_cfg.png -------------------------------------------------------------------------------- /fig/loss_curve_cfg_reflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/rectified-flow/7de2f274911cda0d1298fa1240d4c50903e78ec5/fig/loss_curve_cfg_reflow.png -------------------------------------------------------------------------------- /fig/results_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/rectified-flow/7de2f274911cda0d1298fa1240d4c50903e78ec5/fig/results_fig.png -------------------------------------------------------------------------------- /fig/results_fig_cfg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/rectified-flow/7de2f274911cda0d1298fa1240d4c50903e78ec5/fig/results_fig_cfg.png -------------------------------------------------------------------------------- /fig/results_fig_cfg_2steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/rectified-flow/7de2f274911cda0d1298fa1240d4c50903e78ec5/fig/results_fig_cfg_2steps.png -------------------------------------------------------------------------------- /fig/results_fig_cfg_3steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/rectified-flow/7de2f274911cda0d1298fa1240d4c50903e78ec5/fig/results_fig_cfg_3steps.png -------------------------------------------------------------------------------- /fig/results_fig_cfg_4steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/rectified-flow/7de2f274911cda0d1298fa1240d4c50903e78ec5/fig/results_fig_cfg_4steps.png -------------------------------------------------------------------------------- /fig/results_fig_cfg_reflow_2steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/rectified-flow/7de2f274911cda0d1298fa1240d4c50903e78ec5/fig/results_fig_cfg_reflow_2steps.png -------------------------------------------------------------------------------- /fig/results_fig_cfg_reflow_3steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/rectified-flow/7de2f274911cda0d1298fa1240d4c50903e78ec5/fig/results_fig_cfg_reflow_3steps.png -------------------------------------------------------------------------------- /fig/results_fig_cfg_reflow_4steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TongTong313/rectified-flow/7de2f274911cda0d1298fa1240d4c50903e78ec5/fig/results_fig_cfg_reflow_4steps.png -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import MiniUnet 3 | from rectified_flow import RectifiedFlow 4 | import cv2 5 | import os 6 | import numpy as np 7 | 8 | 9 | def infer( 10 | checkpoint_path, 11 | base_channels=16, 12 | step=50, # 采样步数(Euler方法的迭代次数) 10步效果就很好 1步效果不好 13 | num_imgs=5, 14 | y=None, 15 | cfg_scale=7.0, 16 | save_path='./results', 17 | save_noise_path=None, 18 | device='cuda'): 19 | """flow matching模型推理 20 | 21 | Args: 22 | checkpoint_path (str): 模型路径 23 | base_channels (int, optional): MiniUnet的基础通道数,默认值为16。 24 | step (int, optional): 采样步数(Euler方法的迭代次数),默认值为50。 25 | num_imgs (int, optional): 推理一次生成图片数量,默认值为5。 26 | y (torch.Tensor, optional): 条件生成中的条件,可以为数据标签(每一个标签是一个类别int型)或text文本(下一版本支持),维度为[B]或[B, L],其中B要么与num_imgs相等,要么为1(所有图像依照同一个条件生成)。 27 | cfg_scale (float, optional): Classifier-free Guidance的缩放因子,默认值为7.0,y如果是None,无论这个值是几都是无条件生成。这个值越大,多样性下降,但生成图像更符合条件要求。这个值越小,多样性增加,但生成图像可能不符合条件要求。 28 | save_path (str, optional): 保存路径,默认值为'./results'。 29 | save_noise_path (str, optional): 保存噪声路径,默认值为None。 30 | device (str, optional): 推理设备,默认值为'cuda'。 31 | """ 32 | os.makedirs(save_path, exist_ok=True) 33 | if save_noise_path is not None: 34 | os.makedirs(save_noise_path, exist_ok=True) 35 | 36 | if y is not None: 37 | assert len(y.shape) == 1 or len( 38 | y.shape) == 2, 'y must be 1D or 2D tensor' 39 | assert y.shape[0] == num_imgs or y.shape[ 40 | 0] == 1, 'y.shape[0] must be equal to num_imgs or 1' 41 | if y.shape[0] == 1: 42 | y = y.repeat(num_imgs, 1).reshape(num_imgs) 43 | y = y.to(device) 44 | # 生成一些图片 45 | # 加载模型 46 | model = MiniUnet(base_channels=base_channels) 47 | model.to(device) 48 | model.eval() 49 | 50 | # 加载RectifiedFlow 51 | rf = RectifiedFlow() 52 | 53 | checkpoint = torch.load(checkpoint_path) 54 | model.load_state_dict(checkpoint['model']) 55 | 56 | # with torch.no_grad(): # 无需梯度,加速,降显存 57 | with torch.no_grad(): 58 | # 无条件或有条件生成图片 59 | for i in range(num_imgs): 60 | print(f'Generating {i}th image...') 61 | # Euler法间隔 62 | dt = 1.0 / step 63 | 64 | # 初始的x_t就是x_0,标准高斯噪声 65 | x_t = torch.randn(1, 1, 28, 28).to(device) 66 | noise = x_t.detach().cpu().numpy() 67 | 68 | # 提取第i个图像的标签条件y_i 69 | if y is not None: 70 | y_i = y[i].unsqueeze(0) 71 | 72 | for j in range(step): 73 | if j % 10 == 0: 74 | print(f'Generating {i}th image, step {j}...') 75 | t = j * dt 76 | t = torch.tensor([t]).to(device) 77 | 78 | if y is not None: 79 | # classifier-free guidance需要同时预测有条件和无条件的输出 80 | # 利用CFG的公式:x = x_uncond + cfg_scale * (x_cond - x_uncond) 81 | # 为什么用score推导的公式放到预测向量场v的情形可以直接用? SDE ODE 82 | v_pred_uncond = model(x=x_t, t=t) 83 | v_pred_cond = model(x=x_t, t=t, y=y_i) 84 | v_pred = v_pred_uncond + cfg_scale * (v_pred_cond - 85 | v_pred_uncond) 86 | else: 87 | v_pred = model(x=x_t, t=t) 88 | 89 | # 使用Euler法计算下一个时间的x_t 90 | x_t = rf.euler(x_t, v_pred, dt) 91 | 92 | # 最后一步的x_t就是生成的图片 93 | # 先去掉batch维度 94 | x_t = x_t[0] 95 | # 归一化到0到1 96 | # x_t = (x_t / 2 + 0.5).clamp(0, 1) 97 | x_t = x_t.clamp(0, 1) 98 | img = x_t.detach().cpu().numpy() 99 | img = img[0] * 255 100 | img = img.astype('uint8') 101 | cv2.imwrite(os.path.join(save_path, f'{i}.png'), img) 102 | if save_noise_path is not None: 103 | # 保存为一个.npy格式的文件 104 | np.save(os.path.join(save_noise_path, f'{i}.npy'), noise) 105 | 106 | 107 | if __name__ == '__main__': 108 | # 每个条件生成10张图像 109 | # label一个数字出现十次 110 | y = [] 111 | for i in range(10): 112 | y.extend([i] * 10) 113 | # v1.1 1-RF 114 | infer(checkpoint_path='./checkpoints/v1.1-cfg/miniunet_49.pth', 115 | base_channels=64, 116 | step=2, 117 | num_imgs=100, 118 | y=torch.tensor(y), 119 | cfg_scale=5.0, 120 | save_path='./results/cfg', 121 | device='cuda') 122 | 123 | # v1.2 2-RF 124 | infer(checkpoint_path='./checkpoints/v1.2-reflow-cfg/miniunet_19.pth', 125 | base_channels=64, 126 | step=2, 127 | num_imgs=100, 128 | y=torch.tensor(y), 129 | cfg_scale=5.0, 130 | save_path='./results/reflow-cfg', 131 | device='cuda') 132 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # MiniUnet MNIST 28*28 4090 3G左右显存 6 | class DownLayer(nn.Module): 7 | """MiniUnet的下采样层 Resnet 8 | """ 9 | 10 | def __init__(self, 11 | in_channels, 12 | out_channels, 13 | time_emb_dim=16, 14 | downsample=False): 15 | super(DownLayer, self).__init__() 16 | 17 | self.conv1 = nn.Conv2d(in_channels, 18 | out_channels, 19 | kernel_size=3, 20 | padding=1) 21 | self.conv2 = nn.Conv2d(out_channels, 22 | out_channels, 23 | kernel_size=3, 24 | padding=1) 25 | self.bn1 = nn.BatchNorm2d(out_channels) 26 | self.bn2 = nn.BatchNorm2d(out_channels) 27 | 28 | self.act = nn.ReLU() 29 | 30 | # 线性层,用于时间编码换通道 [B, dim] -> [B, in_channels] 31 | self.fc = nn.Linear(time_emb_dim, in_channels) 32 | 33 | if in_channels != out_channels: 34 | self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) 35 | else: 36 | self.shortcut = None 37 | 38 | # 降采样 39 | self.downsample = downsample 40 | if downsample: 41 | self.pool = nn.MaxPool2d(2) 42 | 43 | self.in_channels = in_channels 44 | 45 | def forward(self, x, temb): 46 | # x: [B, C, H, W] 47 | res = x 48 | x += self.fc(temb)[:, :, None, None] # [B, in_channels, 1, 1] 49 | x = self.conv1(x) 50 | x = self.bn1(x) 51 | x = self.act(x) 52 | x = self.conv2(x) 53 | x = self.bn2(x) 54 | x = self.act(x) 55 | 56 | if self.shortcut is not None: 57 | res = self.shortcut(res) 58 | 59 | x = x + res 60 | 61 | if self.downsample: 62 | x = self.pool(x) 63 | 64 | return x 65 | 66 | 67 | class UpLayer(nn.Module): 68 | """MiniUnet的上采样层 69 | """ 70 | 71 | def __init__(self, 72 | in_channels, 73 | out_channels, 74 | time_emb_dim=16, 75 | upsample=False): 76 | super(UpLayer, self).__init__() 77 | 78 | self.conv1 = nn.Conv2d(in_channels, 79 | out_channels, 80 | kernel_size=3, 81 | padding=1) 82 | self.conv2 = nn.Conv2d(out_channels, 83 | out_channels, 84 | kernel_size=3, 85 | padding=1) 86 | self.bn1 = nn.BatchNorm2d(out_channels) 87 | self.bn2 = nn.BatchNorm2d(out_channels) 88 | 89 | self.act = nn.ReLU() 90 | 91 | # 线性层,用于时间编码换通道 92 | self.fc = nn.Linear(time_emb_dim, in_channels) 93 | 94 | if in_channels != out_channels: 95 | self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) 96 | else: 97 | self.shortcut = None 98 | 99 | self.upsample = upsample 100 | if upsample: 101 | self.upsample = nn.Upsample(scale_factor=2) 102 | 103 | def forward(self, x, temb): 104 | # 上采样 105 | if self.upsample: 106 | x = self.upsample(x) 107 | res = x 108 | 109 | x += self.fc(temb)[:, :, None, None] 110 | x = self.conv1(x) 111 | x = self.bn1(x) 112 | x = self.act(x) 113 | x = self.conv2(x) 114 | x = self.bn2(x) 115 | x = self.act(x) 116 | 117 | if self.shortcut is not None: 118 | res = self.shortcut(res) 119 | x = x + res 120 | 121 | return x 122 | 123 | 124 | class MiddleLayer(nn.Module): 125 | """MiniUnet的中间层 126 | """ 127 | 128 | def __init__(self, in_channels, out_channels, time_emb_dim=16): 129 | super(MiddleLayer, self).__init__() 130 | 131 | self.conv1 = nn.Conv2d(in_channels, 132 | out_channels, 133 | kernel_size=3, 134 | padding=1) 135 | self.conv2 = nn.Conv2d(out_channels, 136 | out_channels, 137 | kernel_size=3, 138 | padding=1) 139 | self.bn1 = nn.BatchNorm2d(out_channels) 140 | self.bn2 = nn.BatchNorm2d(out_channels) 141 | 142 | self.act = nn.ReLU() 143 | 144 | # 线性层,用于时间编码换通道 145 | self.fc = nn.Linear(time_emb_dim, in_channels) 146 | 147 | if in_channels != out_channels: 148 | self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) 149 | else: 150 | self.shortcut = None 151 | 152 | def forward(self, x, temb): 153 | res = x 154 | 155 | x += self.fc(temb)[:, :, None, None] 156 | x = self.conv1(x) 157 | x = self.bn1(x) 158 | x = self.act(x) 159 | x = self.conv2(x) 160 | x = self.bn2(x) 161 | x = self.act(x) 162 | 163 | if self.shortcut is not None: 164 | x = self.shortcut(x) 165 | x = x + res 166 | 167 | return x 168 | 169 | 170 | class MiniUnet(nn.Module): 171 | """采用MiniUnet,对MNIST数据做生成 172 | 两个下采样block 一个中间block 两个上采样block 173 | """ 174 | 175 | def __init__(self, base_channels=16, time_emb_dim=None): 176 | super(MiniUnet, self).__init__() 177 | 178 | if time_emb_dim is None: 179 | self.time_emb_dim = base_channels 180 | 181 | self.base_channels = base_channels 182 | 183 | self.conv_in = nn.Conv2d(1, base_channels, kernel_size=3, padding=1) 184 | 185 | # 多个Layer构成block 186 | self.down1 = nn.ModuleList([ 187 | DownLayer(base_channels, 188 | base_channels * 2, 189 | time_emb_dim=self.time_emb_dim, 190 | downsample=False), 191 | DownLayer(base_channels * 2, 192 | base_channels * 2, 193 | time_emb_dim=self.time_emb_dim) 194 | ]) 195 | self.maxpool1 = nn.MaxPool2d(2) 196 | 197 | self.down2 = nn.ModuleList([ 198 | DownLayer(base_channels * 2, 199 | base_channels * 4, 200 | time_emb_dim=self.time_emb_dim, 201 | downsample=False), 202 | DownLayer(base_channels * 4, 203 | base_channels * 4, 204 | time_emb_dim=self.time_emb_dim) 205 | ]) 206 | self.maxpool2 = nn.MaxPool2d(2) 207 | 208 | self.middle = MiddleLayer(base_channels * 4, 209 | base_channels * 4, 210 | time_emb_dim=self.time_emb_dim) 211 | 212 | self.upsample1 = nn.Upsample(scale_factor=2) 213 | self.up1 = nn.ModuleList([ 214 | UpLayer( 215 | base_channels * 8, # concat 216 | base_channels * 2, 217 | time_emb_dim=self.time_emb_dim, 218 | upsample=False), 219 | UpLayer(base_channels * 2, 220 | base_channels * 2, 221 | time_emb_dim=self.time_emb_dim) 222 | ]) 223 | self.upsample2 = nn.Upsample(scale_factor=2) 224 | self.up2 = nn.ModuleList([ 225 | UpLayer(base_channels * 4, 226 | base_channels, 227 | time_emb_dim=self.time_emb_dim, 228 | upsample=False), 229 | UpLayer(base_channels, 230 | base_channels, 231 | time_emb_dim=self.time_emb_dim) 232 | ]) 233 | 234 | self.conv_out = nn.Conv2d(base_channels, 1, kernel_size=1, padding=0) 235 | 236 | def time_emb(self, t, dim): 237 | """对时间进行正弦函数的编码,单一维度 238 | 目标:让模型感知到输入x_t的时刻t 239 | 实现方式:多种多样 240 | 输入x:[B, C, H, W] x += temb 与空间无关的,也即每个空间位置(H, W),都需要加上一个相同的时间编码向量[B, C] 241 | 假设B=1 t=0.1 242 | 1. 简单粗暴法 243 | temb = [0.1] * C -> [0.1, 0.1, 0.1, ……] 244 | x += temb.reshape(1, C, 1, 1) 245 | 2. 类似绝对位置编码方式 246 | 本代码实现方式 247 | 3. 通过学习的方式(保证T是离散的0, 1, 2, 3,……,T) 248 | temb_learn = nn.Parameter(T+1, dim) 249 | x += temb_learn[t, :].reshape(1, C, 1, 1) 250 | 251 | 252 | Args: 253 | t (float): 时间,维度为[B] 254 | dim (int): 编码的维度 255 | 256 | Returns: 257 | torch.Tensor: 编码后的时间,维度为[B, dim] 输入是[B, C, H, W] 258 | """ 259 | # 生成正弦编码 260 | # 把t映射到[0, 1000] 261 | t = t * 1000 262 | # 10000^k k=torch.linspace…… 263 | freqs = torch.pow(10000, torch.linspace(0, 1, dim // 2)).to(t.device) 264 | sin_emb = torch.sin(t[:, None] / freqs) 265 | cos_emb = torch.cos(t[:, None] / freqs) 266 | 267 | return torch.cat([sin_emb, cos_emb], dim=-1) 268 | 269 | def label_emb(self, y, dim): 270 | """对类别标签进行编码,同样采用正弦编码 271 | 272 | Args: 273 | y (torch.Tensor): 图像标签,维度为[B] label:0-9 274 | dim (int): 编码的维度 275 | 276 | Returns: 277 | torch.Tensor: 编码后的标签,维度为[B, dim] 278 | """ 279 | y = y * 1000 280 | 281 | freqs = torch.pow(10000, torch.linspace(0, 1, dim // 2)).to(y.device) 282 | sin_emb = torch.sin(y[:, None] / freqs) 283 | cos_emb = torch.cos(y[:, None] / freqs) 284 | 285 | return torch.cat([sin_emb, cos_emb], dim=-1) 286 | 287 | def forward(self, x, t, y=None): 288 | """前向传播函数 289 | 290 | Args: 291 | x (torch.Tensor): 输入数据,维度为[B, C, H, W] 292 | t (torch.Tensor): 时间,维度为[B] 293 | y (torch.Tensor, optional): 数据标签(每一个标签是一个类别int型)或text文本(下一版本支持),维度为[B]或[B, L]。 Defaults to None. 294 | """ 295 | # x:(B, C, H, W) 296 | # 时间编码加上 297 | x = self.conv_in(x) 298 | # 时间编码 299 | temb = self.time_emb(t, self.base_channels) 300 | # 这里注意,我们把temb和labelemb加起来,作为一个整体的temb输入到MiniUnet中,让模型进行感知!二者编码维度一样,可以直接相加!就把label的条件信息融入进去了! 301 | if y is not None: 302 | # 判断y是label还是token 303 | if len(y.shape) == 1: 304 | # label编码,-1表示无条件生成,仅用于训练区分,推理的时候不需要 305 | # 把y中等于-1的部分找出来不进行任何编码,其余的进行编码 306 | yemb = self.label_emb(y, self.base_channels) 307 | # 把y等于-1的index找出来,然后把对应的y_emb设置为0 308 | yemb[y == -1] = 0.0 309 | temb += yemb 310 | else: # 文字版本 311 | pass 312 | # 下采样 313 | for layer in self.down1: 314 | x = layer(x, temb) 315 | x1 = x 316 | x = self.maxpool1(x) 317 | for layer in self.down2: 318 | x = layer(x, temb) 319 | x2 = x 320 | x = self.maxpool2(x) 321 | 322 | # 中间层 323 | x = self.middle(x, temb) 324 | 325 | # 上采样 326 | x = torch.cat([self.upsample1(x), x2], dim=1) 327 | for layer in self.up1: 328 | x = layer(x, temb) 329 | x = torch.cat([self.upsample2(x), x1], dim=1) 330 | for layer in self.up2: 331 | x = layer(x, temb) 332 | 333 | x = self.conv_out(x) 334 | return x 335 | 336 | 337 | if __name__ == '__main__': 338 | device = 'mps' 339 | model = MiniUnet() 340 | model = model.to(device) 341 | x = torch.randn(2, 1, 28, 28).to(device) 342 | t = torch.randn(2).to(device) 343 | y = torch.tensor([1, 2]).to(device) 344 | 345 | out = model(x, t, y) 346 | print(out.shape) 347 | # torch.Size([2, 16, 28, 28]) 348 | -------------------------------------------------------------------------------- /plot_loss_curve.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | 4 | if __name__ == '__main__': 5 | 6 | # 画Loss曲线看收敛情况 7 | # 读取pth文件,获得loss_list 8 | checkpoint = torch.load('./checkpoints/v1.2-reflow-cfg/miniunet_19.pth') 9 | loss_list = checkpoint['loss_list'] 10 | 11 | # 画图 12 | plt.plot(loss_list) 13 | plt.xlabel('Iteration') 14 | plt.ylabel('Loss') 15 | plt.title('Loss Curve') 16 | plt.tight_layout() 17 | plt.show() 18 | -------------------------------------------------------------------------------- /rectified_flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | # 老司机开车理论->三要素:路线、车、司机 5 | 6 | 7 | class RectifiedFlow: 8 | # 车:图像生成一个迭代公式 ODE f(t+dt) = f(t) + dt*f'(t) 9 | def euler(self, x_t, v, dt): 10 | """ 使用欧拉方法计算下一个时间步长的值 11 | 12 | Args: 13 | x_t: 当前的值,维度为 [B, C, H, W] 14 | v: 当前的速度,维度为 [B, C, H, W] 15 | dt: 时间步长 16 | """ 17 | x_t = x_t + v * dt 18 | 19 | return x_t 20 | 21 | # 路线 22 | # v1.2: reflow增加x_0的输入 23 | def create_flow(self, x_1, t, x_0=None): 24 | """ 使用x_t = t * x_1 + (1 - t) * x_0公式构建x_0到x_1的流 25 | 26 | X_1是原始图像 X_0是噪声图像(服从标准高斯分布) 27 | 28 | Args: 29 | x_1: 原始图像,维度为 [B, C, H, W] 30 | t: 一个标量,表示时间,时间范围为 [0, 1],维度为 [B] 31 | x_0: 噪声图像,维度为 [B, C, H, W],默认值为None 32 | 33 | Returns: 34 | x_t: 在时间t的图像,维度为 [B, C, H, W] 35 | x_0: 噪声图像,维度为 [B, C, H, W] 36 | 37 | """ 38 | 39 | # 需要一个x0,x0服从高斯噪声 40 | if x_0 is None: 41 | x_0 = torch.randn_like(x_1) 42 | 43 | t = t[:, None, None, None] # [B, 1, 1, 1] 44 | 45 | # 获得xt的值 46 | x_t = t * x_1 + (1 - t) * x_0 47 | 48 | return x_t, x_0 49 | 50 | # 司机 51 | def mse_loss(self, v, x_1, x_0): 52 | """ 计算RectifiedFlow的损失函数 53 | L = MSE(x_1 - x_0 - v(t)) 匀速直线运动 54 | 55 | Args: 56 | v: 速度,维度为 [B, C, H, W] 57 | x_1: 原始图像,维度为 [B, C, H, W] 58 | x_0: 噪声图像,维度为 [B, C, H, W] 59 | """ 60 | 61 | # 求loss函数,是一个MSE,最后维度是[B] 62 | 63 | loss = F.mse_loss(x_1 - x_0, v) 64 | # loss = torch.mean((x_1 - x_0 - v)**2) 65 | 66 | return loss 67 | 68 | 69 | if __name__ == '__main__': 70 | # 时间越大,越是接近原始图像 71 | 72 | rf = RectifiedFlow() 73 | 74 | x_t = rf.create_flow(torch.ones(2, 3, 4, 4), 0.999) 75 | 76 | print(x_t) 77 | -------------------------------------------------------------------------------- /reflow_sample_generate.py: -------------------------------------------------------------------------------- 1 | from infer import infer 2 | import torch 3 | import os 4 | 5 | if __name__ == '__main__': 6 | # 每个数字生成100000张图像 7 | # 为了做reflow,生成了10W*10=100W张图像 8 | # reflow可以让加噪过程的交点数目更少,采样速度更快,但会牺牲采样质量 9 | # 1-RF -> 2-RF 10 | 11 | for i in range(10): 12 | save_path = f'./data/reflow_img/{i}' 13 | save_noise_path = f'./data/reflow_noise/{i}' 14 | y = [i] * 100000 15 | 16 | infer(checkpoint_path='./checkpoints/v1.1-cfg/miniunet_49.pth', 17 | base_channels=64, 18 | step=20, 19 | num_imgs=100000, 20 | y=torch.tensor(y), 21 | cfg_scale=7.0, 22 | save_path=save_path, 23 | save_noise_path=save_noise_path, 24 | device='cuda') 25 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import yaml 4 | from torchvision.datasets import MNIST 5 | from torchvision.transforms import ToTensor, Compose, Normalize 6 | from torch.utils.data import DataLoader 7 | from model import MiniUnet 8 | from torch.optim import Adam, AdamW 9 | from torch.optim.lr_scheduler import StepLR 10 | from rectified_flow import RectifiedFlow 11 | 12 | 13 | def train(config: str): 14 | """训练flow matching模型 15 | 16 | Args: 17 | config (str): yaml配置文件路径,包含以下参数: 18 | base_channels (int, optional): MiniUnet的基础通道数,默认值为16。 19 | epochs (int, optional): 训练轮数,默认值为10。 20 | batch_size (int, optional): 批大小,默认值为128。 21 | lr_adjust_epoch (int, optional): 学习率调整轮数,默认值为50。 22 | batch_print_interval (int, optional): batch打印信息间隔,默认值为100。 23 | checkpoint_save_interval (int, optional): checkpopint保存间隔(单位为epoch),默认值为1。 24 | save_path (str, optional): 模型保存路径,默认值为'./checkpoints'。 25 | use_cfg (bool, optional): 是否使用Classifier-free Guidance训练条件生成模型,默认值为False。 26 | device (str, optional): 训练设备,默认值为'cuda'。 27 | 28 | """ 29 | # 读取yaml配置文件 30 | config = yaml.load(open(config, 'rb'), Loader=yaml.FullLoader) 31 | # 解析参数数据,有默认值 32 | base_channels = config.get('base_channels', 16) 33 | epochs = config.get('epochs', 10) 34 | batch_size = config.get('batch_size', 128) 35 | lr_adjust_epoch = config.get('lr_adjust_epoch', 50) 36 | batch_print_interval = config.get('batch_print_interval', 100) 37 | checkpoint_save_interval = config.get('checkpoint_save_interval', 1) 38 | save_path = config.get('save_path', './checkpoints') 39 | use_cfg = config.get('use_cfg', False) 40 | device = config.get('device', 'cuda') 41 | 42 | # 打印训练参数 43 | print('Training config:') 44 | print(f'base_channels: {base_channels}') 45 | print(f'epochs: {epochs}') 46 | print(f'batch_size: {batch_size}') 47 | print(f'lr_adjust_epoch: {lr_adjust_epoch}') 48 | print(f'batch_print_interval: {batch_print_interval}') 49 | print(f'checkpoint_save_interval: {checkpoint_save_interval}') 50 | print(f'save_path: {save_path}') 51 | print(f'use_cfg: {use_cfg}') 52 | print(f'device: {device}') 53 | 54 | # 训练flow matching模型 55 | 56 | # 数据集加载 57 | # 把PIL转为tensor 58 | transform = Compose([ToTensor()]) # 变换成tensor + 变为[0, 1] 59 | 60 | dataset = MNIST( 61 | root='./data', 62 | train=True, # 6w 63 | download=True, 64 | transform=transform) 65 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 66 | 67 | # 模型加载 68 | model = MiniUnet(base_channels) 69 | model.to(device) 70 | 71 | # 优化器加载 Rectified Flow的论文里面有的用的就是AdamW 72 | optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.1) 73 | 74 | # 学习率调整 75 | scheduler = StepLR(optimizer, step_size=lr_adjust_epoch, gamma=0.1) 76 | 77 | # RF加载 78 | rf = RectifiedFlow() 79 | 80 | # 记录训练时候每一轮的loss 81 | loss_list = [] 82 | 83 | # 一些文件夹提前创建 84 | os.makedirs(save_path, exist_ok=True) 85 | 86 | # 训练循环 87 | for epoch in range(epochs): 88 | for batch, data in enumerate(dataloader): 89 | x_1, y = data # x_1原始图像,y是标签,用于CFG 90 | # 均匀采样[0, 1]的时间t randn 标准正态分布 91 | t = torch.rand(x_1.size(0)) 92 | 93 | # 生成flow(实际上是一个点) 94 | x_t, x_0 = rf.create_flow(x_1, t) 95 | 96 | # 4090 大概占用显存3G 97 | x_t = x_t.to(device) 98 | x_0 = x_0.to(device) 99 | x_1 = x_1.to(device) 100 | t = t.to(device) 101 | 102 | optimizer.zero_grad() 103 | 104 | # 这里我们要做一个数据的复制和拼接,复制原始x_1,把一半的y替换成-1表示无条件生成,这里也可以直接有条件、无条件累计两次计算两次loss的梯度 105 | # 一定的概率,把有条件生成换为无条件的 50%的概率 [x_t, x_t] [t, t] 106 | if use_cfg: 107 | x_t = torch.cat([x_t, x_t.clone()], dim=0) 108 | t = torch.cat([t, t.clone()], dim=0) 109 | y = torch.cat([y, -torch.ones_like(y)], dim=0) 110 | x_1 = torch.cat([x_1, x_1.clone()], dim=0) 111 | x_0 = torch.cat([x_0, x_0.clone()], dim=0) 112 | y = y.to(device) 113 | else: 114 | y = None 115 | 116 | v_pred = model(x=x_t, t=t, y=y) 117 | 118 | loss = rf.mse_loss(v_pred, x_1, x_0) 119 | 120 | loss.backward() 121 | optimizer.step() 122 | 123 | if batch % batch_print_interval == 0: 124 | print(f'[Epoch {epoch}] [batch {batch}] loss: {loss.item()}') 125 | 126 | loss_list.append(loss.item()) 127 | 128 | scheduler.step() 129 | 130 | if epoch % checkpoint_save_interval == 0 or epoch == epochs - 1 or epoch == 0: 131 | # 第一轮也保存一下,快速测试用,大家可以删除 132 | # 保存模型 133 | print(f'Saving model {epoch} to {save_path}...') 134 | save_dict = dict(model=model.state_dict(), 135 | optimizer=optimizer.state_dict(), 136 | epoch=epoch, 137 | loss_list=loss_list) 138 | torch.save(save_dict, 139 | os.path.join(save_path, f'miniunet_{epoch}.pth')) 140 | 141 | 142 | if __name__ == '__main__': 143 | train(config='./config/train_config.yaml') 144 | -------------------------------------------------------------------------------- /train_reflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import yaml 4 | from datasets import ReflowDataset 5 | from torchvision.transforms import ToTensor, Compose, Normalize 6 | from torch.utils.data import DataLoader 7 | from model import MiniUnet 8 | from torch.optim import Adam, AdamW 9 | from torch.optim.lr_scheduler import StepLR 10 | from rectified_flow import RectifiedFlow 11 | 12 | # 1. reflow的训练要从上一个1-rectified flow(v1.1)模型的权重作为预训练权重 13 | 14 | 15 | def train(config: str): 16 | """训练reflow模型 17 | 18 | Args: 19 | config (str): yaml配置文件路径,包含以下参数: 20 | base_channels (int, optional): MiniUnet的基础通道数,默认值为16。 21 | epochs (int, optional): 训练轮数,默认值为10。 22 | batch_size (int, optional): 批大小,默认值为128。 23 | lr (float, optional): 学习率,默认值为1e-5。 24 | lr_adjust_epoch (int, optional): 学习率调整轮数,默认值为50。 25 | batch_print_interval (int, optional): batch打印信息间隔,默认值为100。 26 | checkpoint_save_interval (int, optional): checkpopint保存间隔(单位为epoch),默认值为1。 27 | save_path (str, optional): 模型保存路径,默认值为'./checkpoints'。 28 | use_cfg (bool, optional): 是否使用Classifier-free Guidance训练条件生成模型,默认值为False。 29 | img_root_path (str, optional): reflow图像根路径,默认值为None。 30 | noise_root_path (str, optional): reflow噪声根路径,默认值为None。 31 | checkpoint_path (str, optional): 预训练模型路径,默认值为None。 32 | device (str, optional): 训练设备,默认值为'cuda'。 33 | 34 | """ 35 | # 读取yaml配置文件 36 | config = yaml.load(open(config, 'rb'), Loader=yaml.FullLoader) 37 | # 解析参数数据,有默认值 38 | base_channels = config.get('base_channels', 16) 39 | epochs = config.get('epochs', 10) 40 | batch_size = config.get('batch_size', 128) 41 | lr_adjust_epoch = config.get('lr_adjust_epoch', 50) 42 | batch_print_interval = config.get('batch_print_interval', 100) 43 | checkpoint_save_interval = config.get('checkpoint_save_interval', 1) 44 | save_path = config.get('save_path', './checkpoints') 45 | use_cfg = config.get('use_cfg', False) 46 | device = config.get('device', 'cuda') 47 | 48 | # v1.2 reflow增加参数 49 | lr = config.get('lr', 1e-5) 50 | img_root_path = config.get('img_root_path', None) 51 | noise_root_path = config.get('noise_root_path', None) 52 | checkpoint_path = config.get('checkpoint_path', None) 53 | 54 | # 打印训练参数 55 | print('Training config:') 56 | print(f'base_channels: {base_channels}') 57 | print(f'epochs: {epochs}') 58 | print(f'batch_size: {batch_size}') 59 | print(f'learning rate: {lr}') 60 | print(f'lr_adjust_epoch: {lr_adjust_epoch}') 61 | print(f'batch_print_interval: {batch_print_interval}') 62 | print(f'checkpoint_save_interval: {checkpoint_save_interval}') 63 | print(f'save_path: {save_path}') 64 | print(f'use_cfg: {use_cfg}') 65 | print(f'img_root_path: {img_root_path}') 66 | print(f'noise_root_path: {noise_root_path}') 67 | print(f'checkpoint_path: {checkpoint_path}') 68 | print(f'device: {device}') 69 | 70 | # 训练flow matching模型 71 | 72 | # 数据集加载 73 | transform = Compose([ToTensor()]) # 变换成tensor + 变为[0, 1] 74 | 75 | print(f'Loading dataset from {img_root_path} and {noise_root_path}...') 76 | dataset = ReflowDataset(img_root_path=img_root_path, 77 | noise_root_path=noise_root_path, 78 | transform=transform) 79 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 80 | 81 | print(f'Dataset Loaded: {len(dataset)} samples') 82 | 83 | # 模型加载 84 | model = MiniUnet(base_channels) 85 | model.to(device) 86 | 87 | # v1.2 reflow增加预训练权重加载 1-RF 88 | print(f'Loading checkpoint from {checkpoint_path}...') 89 | model.load_state_dict(torch.load(checkpoint_path)['model']) 90 | print('Checkpoint loaded.') 91 | 92 | # 优化器加载 Rectified Flow的论文里面有的用的就是AdamW 93 | optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.1) 94 | 95 | # 学习率调整 96 | scheduler = StepLR(optimizer, step_size=lr_adjust_epoch, gamma=0.1) 97 | 98 | # RF加载 99 | rf = RectifiedFlow() 100 | 101 | # 记录训练时候每一轮的loss 102 | loss_list = [] 103 | 104 | # 模型权重保存路径文件夹提前创建 105 | os.makedirs(save_path, exist_ok=True) 106 | 107 | # 训练循环 108 | for epoch in range(epochs): 109 | for batch, data in enumerate(dataloader): 110 | 111 | # reflow的数据集有三个输出,x_1是原始图像,y是标签 112 | x_1 = data['img'] 113 | x_0 = data['noise'] 114 | y = data['label'] 115 | 116 | # 均匀采样[0, 1]的时间t randn 标准正态分布 117 | t = torch.rand(x_1.size(0)) 118 | 119 | # 生成flow(实际上是一个点) 120 | x_t, _ = rf.create_flow(x_1, t, x_0) 121 | 122 | # 4090 大概占用显存3G 123 | x_t = x_t.to(device) 124 | x_0 = x_0.to(device) 125 | x_1 = x_1.to(device) 126 | t = t.to(device) 127 | 128 | optimizer.zero_grad() 129 | 130 | # 这里我们要做一个数据的复制和拼接,复制原始x_1,把一半的y替换成-1表示无条件生成,这里也可以直接有条件、无条件累计两次计算两次loss的梯度 131 | # 一定的概率,把有条件生成换为无条件的 50%的概率 [x_t, x_t] [t, t] 132 | if use_cfg: 133 | x_t = torch.cat([x_t, x_t.clone()], dim=0) 134 | t = torch.cat([t, t.clone()], dim=0) 135 | y = torch.cat([y, -torch.ones_like(y)], dim=0) 136 | x_1 = torch.cat([x_1, x_1.clone()], dim=0) 137 | x_0 = torch.cat([x_0, x_0.clone()], dim=0) 138 | y = y.to(device) 139 | else: 140 | y = None 141 | 142 | v_pred = model(x=x_t, t=t, y=y) 143 | 144 | loss = rf.mse_loss(v_pred, x_1, x_0) 145 | 146 | loss.backward() 147 | optimizer.step() 148 | 149 | if batch % batch_print_interval == 0: 150 | print(f'[Epoch {epoch}] [batch {batch}] loss: {loss.item()}') 151 | 152 | loss_list.append(loss.item()) 153 | 154 | scheduler.step() 155 | 156 | if epoch % checkpoint_save_interval == 0 or epoch == epochs - 1 or epoch == 0: 157 | # 第一轮也保存一下,快速测试用,大家可以删除 158 | # 保存模型 159 | print(f'Saving model {epoch} to {save_path}...') 160 | save_dict = dict(model=model.state_dict(), 161 | optimizer=optimizer.state_dict(), 162 | epoch=epoch, 163 | loss_list=loss_list) 164 | torch.save(save_dict, 165 | os.path.join(save_path, f'miniunet_{epoch}.pth')) 166 | 167 | 168 | if __name__ == '__main__': 169 | train(config='./config/train_reflow_config.yaml') 170 | --------------------------------------------------------------------------------