├── 2d_plot_fm_todo ├── __init__.py ├── __pycache__ │ ├── fm.cpython-39.pyc │ ├── dataset.cpython-39.pyc │ └── network.cpython-39.pyc ├── chamferdist.py ├── dataset.py ├── network.py └── fm.py ├── .gitignore ├── assets ├── fm_loss.png ├── fm_2d_vis.png └── trajectory_visualization.png ├── requirements.txt ├── LICENSE ├── image_fm_todo ├── sampling.py ├── network.py ├── train.py ├── fm.py ├── module.py └── dataset.py └── README.md /2d_plot_fm_todo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/results/ 2 | **/data/ 3 | **/__pycache__/ 4 | **/.ipynb_checkpoints/ 5 | -------------------------------------------------------------------------------- /assets/fm_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/assets/fm_loss.png -------------------------------------------------------------------------------- /assets/fm_2d_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/assets/fm_2d_vis.png -------------------------------------------------------------------------------- /assets/trajectory_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/assets/trajectory_visualization.png -------------------------------------------------------------------------------- /2d_plot_fm_todo/__pycache__/fm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/2d_plot_fm_todo/__pycache__/fm.cpython-39.pyc -------------------------------------------------------------------------------- /2d_plot_fm_todo/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/2d_plot_fm_todo/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /2d_plot_fm_todo/__pycache__/network.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/2d_plot_fm_todo/__pycache__/network.cpython-39.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn==1.1.3 2 | ipython==8.12.0 3 | jupyter==1.0.0 4 | matplotlib==3.6.0 5 | torch==2.0.1 6 | torch-ema==0.3 7 | torchvision==0.15.2 8 | tqdm==4.64.1 9 | jupyterlab==4.0.2 10 | jaxtyping==0.2.20 11 | pytorch_lightning 12 | dotmap 13 | numpy==1.26.0 14 | -------------------------------------------------------------------------------- /2d_plot_fm_todo/chamferdist.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import KDTree 2 | from scipy.spatial.distance import cdist 3 | 4 | def chamfer_distance(S1, S2) -> float: 5 | """ 6 | Computes the Chamfer distance between two point clouds defined as: 7 | d_CD(S1, S2) = \sigma_{x \in S1} min_{y in S2} ||x - y||^2 + \sigma_{y \in S2} min_{x in S1} ||x - y||^2 8 | """ 9 | dist = cdist(S1, S2) 10 | dist1 = dist.min(axis=1) ** 2 11 | dist2 = dist.min(axis=0) ** 2 12 | return dist1.sum() + dist2.sum() 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 KAIST Visual AI Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /image_fm_todo/sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from dataset import tensor_to_pil_image 6 | from fm import FlowMatching 7 | from pathlib import Path 8 | 9 | 10 | def main(args): 11 | save_dir = Path(args.save_dir) 12 | save_dir.mkdir(exist_ok=True, parents=True) 13 | 14 | device = f"cuda:{args.gpu}" 15 | 16 | fm = FlowMatching(None, None) 17 | fm.load(args.ckpt_path) 18 | fm.eval() 19 | fm = fm.to(device) 20 | 21 | total_num_samples = 500 22 | num_batches = int(np.ceil(total_num_samples / args.batch_size)) 23 | 24 | for i in range(num_batches): 25 | sidx = i * args.batch_size 26 | eidx = min(sidx + args.batch_size, total_num_samples) 27 | B = eidx - sidx 28 | 29 | if args.use_cfg: # Enable CFG sampling 30 | assert fm.network.use_cfg, f"The model was not trained to support CFG." 31 | shape = (B, 3, fm.image_resolution, fm.image_resolution) 32 | samples = fm.sample( 33 | shape, 34 | num_inference_timesteps=20, 35 | class_label=torch.randint(1, 4, (B,)).to(device), 36 | guidance_scale=args.cfg_scale, 37 | ) 38 | else: 39 | raise NotImplementedError("In Assignment 7, we sample images with CFG setup only.") 40 | 41 | pil_images = tensor_to_pil_image(samples) 42 | 43 | for j, img in zip(range(sidx, eidx), pil_images): 44 | img.save(save_dir / f"{j}.png") 45 | print(f"Saved the {j}-th image.") 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--batch_size", type=int, default=128) 51 | parser.add_argument("--gpu", type=int, default=0) 52 | parser.add_argument("--ckpt_path", type=str) 53 | parser.add_argument("--save_dir", type=str) 54 | parser.add_argument("--use_cfg", action="store_true") 55 | parser.add_argument("--cfg_scale", type=float, default=7.5) 56 | 57 | args = parser.parse_args() 58 | main(args) 59 | -------------------------------------------------------------------------------- /2d_plot_fm_todo/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn import datasets 4 | from torch.utils.data import DataLoader, Dataset 5 | 6 | 7 | def normalize(ds, scaling_factor=2.0): 8 | return (ds - ds.mean()) / ds.std() * scaling_factor 9 | 10 | 11 | def sample_checkerboard(n): 12 | # https://github.com/ghliu/SB-FBSDE/blob/main/data.py 13 | n_points = 3 * n 14 | n_classes = 2 15 | freq = 5 16 | x = np.random.uniform( 17 | -(freq // 2) * np.pi, (freq // 2) * np.pi, size=(n_points, n_classes) 18 | ) 19 | mask = np.logical_or( 20 | np.logical_and(np.sin(x[:, 0]) > 0.0, np.sin(x[:, 1]) > 0.0), 21 | np.logical_and(np.sin(x[:, 0]) < 0.0, np.sin(x[:, 1]) < 0.0), 22 | ) 23 | y = np.eye(n_classes)[1 * mask] 24 | x0 = x[:, 0] * y[:, 0] 25 | x1 = x[:, 1] * y[:, 0] 26 | sample = np.concatenate([x0[..., None], x1[..., None]], axis=-1) 27 | sqr = np.sum(np.square(sample), axis=-1) 28 | idxs = np.where(sqr == 0) 29 | sample = np.delete(sample, idxs, axis=0) 30 | 31 | return sample 32 | 33 | 34 | def load_twodim(num_samples: int, dataset: str, dimension: int = 2): 35 | 36 | if dataset == "gaussian_centered": 37 | sample = np.random.normal(size=(num_samples, dimension)) 38 | sample = sample 39 | 40 | if dataset == "gaussian_shift": 41 | sample = np.random.normal(size=(num_samples, dimension)) 42 | sample = sample + 1.5 43 | 44 | if dataset == "circle": 45 | X, y = datasets.make_circles( 46 | n_samples=num_samples, noise=0.0, random_state=None, factor=0.5 47 | ) 48 | sample = X * 4 49 | 50 | if dataset == "scurve": 51 | X, y = datasets.make_s_curve( 52 | n_samples=num_samples, noise=0.0, random_state=None 53 | ) 54 | sample = normalize(X[:, [0, 2]]) 55 | 56 | if dataset == "moon": 57 | X, y = datasets.make_moons(n_samples=num_samples, noise=0.0, random_state=None) 58 | sample = normalize(X) 59 | 60 | if dataset == "swiss_roll": 61 | X, y = datasets.make_swiss_roll( 62 | n_samples=num_samples, noise=0.0, random_state=None, hole=True 63 | ) 64 | sample = normalize(X[:, [0, 2]]) 65 | 66 | if dataset == "checkerboard": 67 | sample = normalize(sample_checkerboard(num_samples)) 68 | 69 | return torch.tensor(sample).float() 70 | 71 | 72 | class TwoDimDataClass(Dataset): 73 | def __init__(self, dataset_type: str, N: int, batch_size: int, dimension=2): 74 | 75 | self.X = load_twodim(N, dataset_type, dimension=dimension) 76 | self.name = dataset_type 77 | self.batch_size = batch_size 78 | self.dimension = 2 79 | 80 | def __len__(self): 81 | return self.X.shape[0] 82 | 83 | def __getitem__(self, idx): 84 | return self.X[idx] 85 | 86 | def get_dataloader(self, shuffle=True): 87 | return DataLoader( 88 | self, 89 | batch_size=self.batch_size, 90 | shuffle=shuffle, 91 | pin_memory=True, 92 | ) 93 | 94 | 95 | def get_data_iterator(iterable): 96 | iterator = iterable.__iter__() 97 | while True: 98 | try: 99 | yield iterator.__next__() 100 | except StopIteration: 101 | iterator = iterable.__iter__() 102 | -------------------------------------------------------------------------------- /2d_plot_fm_todo/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class TimeEmbedding(nn.Module): 10 | def __init__(self, hidden_size, frequency_embedding_size=256): 11 | super().__init__() 12 | self.mlp = nn.Sequential( 13 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 14 | nn.SiLU(), 15 | nn.Linear(hidden_size, hidden_size, bias=True), 16 | ) 17 | self.frequency_embedding_size = frequency_embedding_size 18 | 19 | @staticmethod 20 | def timestep_embedding(t, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | :param t: a 1-D Tensor of N indices, one per batch element. 24 | These may be fractional. 25 | :param dim: the dimension of the output. 26 | :param max_period: controls the minimum frequency of the embeddings. 27 | :return: an (N, D) Tensor of positional embeddings. 28 | """ 29 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) 33 | * torch.arange(start=0, end=half, dtype=torch.float32) 34 | / half 35 | ).to(device=t.device) 36 | args = t[:, None].float() * freqs[None] 37 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 38 | if dim % 2: 39 | embedding = torch.cat( 40 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 41 | ) 42 | return embedding 43 | 44 | def forward(self, t: torch.Tensor): 45 | if t.ndim == 0: 46 | t = t.unsqueeze(-1) 47 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 48 | t_emb = self.mlp(t_freq) 49 | return t_emb 50 | 51 | 52 | class TimeLinear(nn.Module): 53 | def __init__(self, dim_in: int, dim_out: int, num_timesteps: int): 54 | super().__init__() 55 | self.dim_in = dim_in 56 | self.dim_out = dim_out 57 | self.num_timesteps = num_timesteps 58 | 59 | self.time_embedding = TimeEmbedding(dim_out) 60 | self.fc = nn.Linear(dim_in, dim_out) 61 | 62 | def forward(self, x: torch.Tensor, t: torch.Tensor): 63 | x = self.fc(x) 64 | alpha = self.time_embedding(t).view(-1, self.dim_out) 65 | 66 | return alpha * x 67 | 68 | 69 | class SimpleNet(nn.Module): 70 | def __init__( 71 | self, dim_in: int, dim_out: int, dim_hids: List[int], num_timesteps: int 72 | ): 73 | super().__init__() 74 | """ 75 | (TODO) Build a noise estimating network. 76 | 77 | Args: 78 | dim_in: dimension of input 79 | dim_out: dimension of output 80 | dim_hids: dimensions of hidden features 81 | num_timesteps: number of timesteps 82 | """ 83 | 84 | ######## TODO ######## 85 | # DO NOT change the code outside this part. 86 | 87 | ###################### 88 | 89 | def forward(self, x: torch.Tensor, t: torch.Tensor): 90 | """ 91 | (TODO) Implement the forward pass. This should output 92 | the noise prediction of the noisy input x at timestep t. 93 | 94 | Args: 95 | x: the noisy data after t period diffusion 96 | t: the time that the forward diffusion has been running 97 | """ 98 | ######## TODO ######## 99 | # DO NOT change the code outside this part. 100 | 101 | ###################### 102 | return x 103 | -------------------------------------------------------------------------------- /image_fm_todo/network.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from module import DownSample, ResBlock, Swish, TimeEmbedding, UpSample 8 | from torch.nn import init 9 | 10 | 11 | class UNet(nn.Module): 12 | def __init__(self, T=1000, image_resolution=64, ch=128, ch_mult=[1,2,2,2], attn=[1], num_res_blocks=4, dropout=0.1, use_cfg=False, cfg_dropout=0.1, num_classes=None): 13 | super().__init__() 14 | self.image_resolution = image_resolution 15 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' 16 | tdim = ch * 4 17 | # self.time_embedding = TimeEmbedding(T, ch, tdim) 18 | self.time_embedding = TimeEmbedding(tdim) 19 | 20 | # classifier-free guidance 21 | self.use_cfg = use_cfg 22 | self.cfg_dropout = cfg_dropout 23 | if use_cfg: 24 | assert num_classes is not None 25 | cdim = tdim 26 | self.class_embedding = nn.Embedding(num_classes+1, cdim) 27 | 28 | self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) 29 | self.downblocks = nn.ModuleList() 30 | chs = [ch] # record output channel when dowmsample for upsample 31 | now_ch = ch 32 | for i, mult in enumerate(ch_mult): 33 | out_ch = ch * mult 34 | for _ in range(num_res_blocks): 35 | self.downblocks.append(ResBlock( 36 | in_ch=now_ch, out_ch=out_ch, tdim=tdim, 37 | dropout=dropout, attn=(i in attn))) 38 | now_ch = out_ch 39 | chs.append(now_ch) 40 | if i != len(ch_mult) - 1: 41 | self.downblocks.append(DownSample(now_ch)) 42 | chs.append(now_ch) 43 | 44 | self.middleblocks = nn.ModuleList([ 45 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True), 46 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False), 47 | ]) 48 | 49 | self.upblocks = nn.ModuleList() 50 | for i, mult in reversed(list(enumerate(ch_mult))): 51 | out_ch = ch * mult 52 | for _ in range(num_res_blocks + 1): 53 | self.upblocks.append(ResBlock( 54 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, 55 | dropout=dropout, attn=(i in attn))) 56 | now_ch = out_ch 57 | if i != 0: 58 | self.upblocks.append(UpSample(now_ch)) 59 | assert len(chs) == 0 60 | 61 | self.tail = nn.Sequential( 62 | nn.GroupNorm(32, now_ch), 63 | Swish(), 64 | nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) 65 | ) 66 | self.initialize() 67 | 68 | def initialize(self): 69 | init.xavier_uniform_(self.head.weight) 70 | init.zeros_(self.head.bias) 71 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) 72 | init.zeros_(self.tail[-1].bias) 73 | 74 | def forward(self, x, timestep, class_label=None): 75 | # Timestep embedding 76 | temb = self.time_embedding(timestep) 77 | 78 | if self.use_cfg and class_label is not None: 79 | if self.training: 80 | assert not torch.any(class_label == 0) # 0 for null. 81 | 82 | ######## TODO ######## 83 | # DO NOT change the code outside this part. 84 | # Assignment 2. Implement random null conditioning in CFG training. 85 | raise NotImplementedError("TODO") 86 | ####################### 87 | 88 | ######## TODO ######## 89 | # DO NOT change the code outside this part. 90 | # Assignment 2. Implement class conditioning 91 | raise NotImplementedError("TODO") 92 | ####################### 93 | 94 | # Downsampling 95 | h = self.head(x) 96 | hs = [h] 97 | for layer in self.downblocks: 98 | h = layer(h, temb) 99 | hs.append(h) 100 | # Middle 101 | for layer in self.middleblocks: 102 | h = layer(h, temb) 103 | # Upsampling 104 | for layer in self.upblocks: 105 | if isinstance(layer, ResBlock): 106 | h = torch.cat([h, hs.pop()], dim=1) 107 | h = layer(h, temb) 108 | h = self.tail(h) 109 | 110 | assert len(hs) == 0 111 | return h 112 | -------------------------------------------------------------------------------- /image_fm_todo/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from datetime import datetime 4 | from pathlib import Path 5 | 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | import torch 9 | from dataset import AFHQDataModule, get_data_iterator, tensor_to_pil_image 10 | from dotmap import DotMap 11 | from fm import FlowMatching, FMScheduler 12 | from network import UNet 13 | from pytorch_lightning import seed_everything 14 | from torchvision.transforms.functional import to_pil_image 15 | from tqdm import tqdm 16 | 17 | matplotlib.use("Agg") 18 | 19 | 20 | def get_current_time(): 21 | now = datetime.now().strftime("%m-%d-%H%M%S") 22 | return now 23 | 24 | 25 | def main(args): 26 | """config""" 27 | config = DotMap() 28 | config.update(vars(args)) 29 | config.device = f"cuda:{args.gpu}" 30 | 31 | now = get_current_time() 32 | assert args.use_cfg, f"In Assignment 7, we sample images with CFG setup only." 33 | 34 | if args.use_cfg: 35 | save_dir = Path(f"results/cfg_fm-{now}") 36 | else: 37 | save_dir = Path(f"results/fm-{now}") 38 | save_dir.mkdir(exist_ok=True, parents=True) 39 | print(f"save_dir: {save_dir}") 40 | 41 | seed_everything(config.seed) 42 | 43 | with open(save_dir / "config.json", "w") as f: 44 | json.dump(config, f, indent=2) 45 | """######""" 46 | 47 | image_resolution = 64 48 | ds_module = AFHQDataModule( 49 | "./data", 50 | batch_size=config.batch_size, 51 | num_workers=4, 52 | max_num_images_per_cat=config.max_num_images_per_cat, 53 | image_resolution=image_resolution 54 | ) 55 | 56 | train_dl = ds_module.train_dataloader() 57 | train_it = get_data_iterator(train_dl) 58 | 59 | # Set up the scheduler 60 | fm_scheduler = FMScheduler(sigma_min=args.sigma_min) 61 | 62 | network = UNet( 63 | image_resolution=image_resolution, 64 | ch=128, 65 | ch_mult=[1, 2, 2, 2], 66 | attn=[1], 67 | num_res_blocks=4, 68 | dropout=0.1, 69 | use_cfg=args.use_cfg, 70 | cfg_dropout=args.cfg_dropout, 71 | num_classes=getattr(ds_module, "num_classes", None), 72 | ) 73 | 74 | fm = FlowMatching(network, fm_scheduler) 75 | fm = fm.to(config.device) 76 | 77 | optimizer = torch.optim.Adam(fm.network.parameters(), lr=2e-4) 78 | scheduler = torch.optim.lr_scheduler.LambdaLR( 79 | optimizer, lr_lambda=lambda t: min((t + 1) / config.warmup_steps, 1.0) 80 | ) 81 | 82 | step = 0 83 | losses = [] 84 | with tqdm(initial=step, total=config.train_num_steps) as pbar: 85 | while step < config.train_num_steps: 86 | if step % config.log_interval == 0: 87 | fm.eval() 88 | plt.plot(losses) 89 | plt.savefig(f"{save_dir}/loss.png") 90 | plt.close() 91 | shape = (4, 3, fm.image_resolution, fm.image_resolution) 92 | if args.use_cfg: 93 | class_label = torch.tensor([1,1,2,3]).to(config.device) 94 | samples = fm.sample(shape, class_label=class_label, guidance_scale=7.5, verbose=False) 95 | else: 96 | samples = fm.sample(shape, return_traj=False, verbose=False) 97 | pil_images = tensor_to_pil_image(samples) 98 | for i, img in enumerate(pil_images): 99 | img.save(save_dir / f"step={step}-{i}.png") 100 | 101 | fm.save(f"{save_dir}/last.ckpt") 102 | fm.train() 103 | 104 | img, label = next(train_it) 105 | img, label = img.to(config.device), label.to(config.device) 106 | if args.use_cfg: # Conditional, CFG training 107 | loss = fm.get_loss(img, class_label=label) 108 | else: # Unconditional training 109 | loss = fm.get_loss(img) 110 | pbar.set_description(f"Loss: {loss.item():.4f}") 111 | 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | scheduler.step() 116 | losses.append(loss.item()) 117 | 118 | step += 1 119 | pbar.update(1) 120 | 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--gpu", type=int, default=0) 125 | parser.add_argument("--batch_size", type=int, default=16) 126 | parser.add_argument( 127 | "--train_num_steps", 128 | type=int, 129 | default=100000, 130 | help="the number of model training steps.", 131 | ) 132 | parser.add_argument("--warmup_steps", type=int, default=200) 133 | parser.add_argument("--log_interval", type=int, default=200) 134 | parser.add_argument( 135 | "--max_num_images_per_cat", 136 | type=int, 137 | default=3000, 138 | help="max number of images per category for AFHQ dataset", 139 | ) 140 | parser.add_argument("--sigma_min", type=float, default=0.001) 141 | parser.add_argument("--seed", type=int, default=63) 142 | parser.add_argument("--image_resolution", type=int, default=64) 143 | parser.add_argument("--use_cfg", action="store_true") 144 | parser.add_argument("--cfg_dropout", type=float, default=0.1) 145 | args = parser.parse_args() 146 | main(args) 147 | -------------------------------------------------------------------------------- /2d_plot_fm_todo/fm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | 10 | def expand_t(t, x): 11 | for _ in range(x.ndim - 1): 12 | t = t.unsqueeze(-1) 13 | return t 14 | 15 | 16 | class FMScheduler(nn.Module): 17 | def __init__(self, num_train_timesteps=1000, sigma_min=0.001): 18 | super().__init__() 19 | self.num_train_timesteps = num_train_timesteps 20 | self.sigma_min = sigma_min 21 | 22 | def uniform_sample_t(self, batch_size) -> torch.LongTensor: 23 | ts = ( 24 | np.random.choice(np.arange(self.num_train_timesteps), batch_size) 25 | / self.num_train_timesteps 26 | ) 27 | return torch.from_numpy(ts) 28 | 29 | def compute_psi_t(self, x1, t, x): 30 | """ 31 | Compute the conditional flow psi_t(x | x_1). 32 | 33 | Note that time flows in the opposite direction compared to DDPM/DDIM. 34 | As t moves from 0 to 1, the probability paths shift from a prior distribution p_0(x) 35 | to a more complex data distribution p_1(x). 36 | 37 | Input: 38 | x1 (`torch.Tensor`): Data sample from the data distribution. 39 | t (`torch.Tensor`): Timestep in [0,1). 40 | x (`torch.Tensor`): The input to the conditional psi_t(x). 41 | Output: 42 | psi_t (`torch.Tensor`): The conditional flow at t. 43 | """ 44 | t = expand_t(t, x1) 45 | 46 | ######## TODO ######## 47 | # DO NOT change the code outside this part. 48 | # compute psi_t(x) 49 | 50 | psi_t = x1 51 | ###################### 52 | 53 | return psi_t 54 | 55 | def step(self, xt, vt, dt): 56 | """ 57 | The simplest ode solver as the first-order Euler method: 58 | x_next = xt + dt * vt 59 | """ 60 | 61 | ######## TODO ######## 62 | # DO NOT change the code outside this part. 63 | # implement each step of the first-order Euler method. 64 | x_next = xt 65 | ###################### 66 | 67 | return x_next 68 | 69 | 70 | class FlowMatching(nn.Module): 71 | def __init__(self, network: nn.Module, fm_scheduler: FMScheduler, **kwargs): 72 | super().__init__() 73 | self.network = network 74 | self.fm_scheduler = fm_scheduler 75 | 76 | @property 77 | def device(self): 78 | return next(self.network.parameters()).device 79 | 80 | @property 81 | def image_resolution(self): 82 | return self.network.image_resolution 83 | 84 | def get_loss(self, x1, class_label=None, x0=None): 85 | """ 86 | The conditional flow matching objective, corresponding Eq. 23 in the FM paper. 87 | """ 88 | batch_size = x1.shape[0] 89 | t = self.fm_scheduler.uniform_sample_t(batch_size).to(x1) 90 | if x0 is None: 91 | x0 = torch.randn_like(x1) 92 | 93 | ######## TODO ######## 94 | # DO NOT change the code outside this part. 95 | # Implement the CFM objective. 96 | if class_label is not None: 97 | model_out = self.network(x1, t, class_label=class_label) 98 | else: 99 | model_out = self.network(x1, t) 100 | 101 | loss = x1.mean() 102 | ###################### 103 | 104 | return loss 105 | 106 | def conditional_psi_sample(self, x1, t, x0=None): 107 | if x0 is None: 108 | x0 = torch.randn_like(x1) 109 | return self.fm_scheduler.compute_psi_t(x1, t, x0) 110 | 111 | @torch.no_grad() 112 | def sample( 113 | self, 114 | shape, 115 | num_inference_timesteps=50, 116 | return_traj=False, 117 | class_label: Optional[torch.Tensor] = None, 118 | guidance_scale: Optional[float] = 1.0, 119 | verbose=False, 120 | ): 121 | batch_size = shape[0] 122 | x_T = torch.randn(shape).to(self.device) 123 | do_classifier_free_guidance = guidance_scale > 1.0 124 | 125 | if do_classifier_free_guidance: 126 | assert class_label is not None 127 | assert ( 128 | len(class_label) == batch_size 129 | ), f"len(class_label) != batch_size. {len(class_label)} != {batch_size}" 130 | 131 | traj = [x_T] 132 | 133 | timesteps = [ 134 | i / num_inference_timesteps for i in range(num_inference_timesteps) 135 | ] 136 | timesteps = [torch.tensor([t] * x_T.shape[0]).to(x_T) for t in timesteps] 137 | pbar = tqdm(timesteps) if verbose else timesteps 138 | xt = x_T 139 | for i, t in enumerate(pbar): 140 | t_next = timesteps[i + 1] if i < len(timesteps) - 1 else torch.ones_like(t) 141 | 142 | 143 | ######## TODO ######## 144 | # Complete the sampling loop 145 | 146 | xt = self.fm_scheduler.step(xt, torch.zeros_like(xt), torch.zeros_like(t)) 147 | 148 | ###################### 149 | 150 | traj[-1] = traj[-1].cpu() 151 | traj.append(xt.clone().detach()) 152 | if return_traj: 153 | return traj 154 | else: 155 | return traj[-1] 156 | 157 | def save(self, file_path): 158 | hparams = { 159 | "network": self.network, 160 | "fm_scheduler": self.fm_scheduler, 161 | } 162 | state_dict = self.state_dict() 163 | 164 | dic = {"hparams": hparams, "state_dict": state_dict} 165 | torch.save(dic, file_path) 166 | 167 | def load(self, file_path): 168 | dic = torch.load(file_path, map_location="cpu") 169 | hparams = dic["hparams"] 170 | state_dict = dic["state_dict"] 171 | 172 | self.network = hparams["network"] 173 | self.fm_scheduler = hparams["fm_scheduler"] 174 | 175 | self.load_state_dict(state_dict) 176 | -------------------------------------------------------------------------------- /image_fm_todo/fm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | 10 | def expand_t(t, x): 11 | for _ in range(x.ndim - 1): 12 | t = t.unsqueeze(-1) 13 | return t 14 | 15 | 16 | class FMScheduler(nn.Module): 17 | def __init__(self, num_train_timesteps=1000, sigma_min=0.001): 18 | super().__init__() 19 | self.num_train_timesteps = num_train_timesteps 20 | self.sigma_min = sigma_min 21 | 22 | def uniform_sample_t(self, batch_size) -> torch.LongTensor: 23 | ts = ( 24 | np.random.choice(np.arange(self.num_train_timesteps), batch_size) 25 | / self.num_train_timesteps 26 | ) 27 | return torch.from_numpy(ts) 28 | 29 | def compute_psi_t(self, x1, t, x): 30 | """ 31 | Compute the conditional flow psi_t(x | x_1). 32 | 33 | Note that time flows in the opposite direction compared to DDPM/DDIM. 34 | As t moves from 0 to 1, the probability paths shift from a prior distribution p_0(x) 35 | to a more complex data distribution p_1(x). 36 | 37 | Input: 38 | x1 (`torch.Tensor`): Data sample from the data distribution. 39 | t (`torch.Tensor`): Timestep in [0,1). 40 | x (`torch.Tensor`): The input to the conditional psi_t(x). 41 | Output: 42 | psi_t (`torch.Tensor`): The conditional flow at t. 43 | """ 44 | t = expand_t(t, x1) 45 | 46 | ######## TODO ######## 47 | # DO NOT change the code outside this part. 48 | # compute psi_t(x) 49 | 50 | psi_t = x1 51 | ###################### 52 | 53 | return psi_t 54 | 55 | def step(self, xt, vt, dt): 56 | """ 57 | The simplest ode solver as the first-order Euler method: 58 | x_next = xt + dt * vt 59 | """ 60 | 61 | ######## TODO ######## 62 | # DO NOT change the code outside this part. 63 | # implement each step of the first-order Euler method. 64 | x_next = xt 65 | ###################### 66 | 67 | return x_next 68 | 69 | 70 | class FlowMatching(nn.Module): 71 | def __init__(self, network: nn.Module, fm_scheduler: FMScheduler, **kwargs): 72 | super().__init__() 73 | self.network = network 74 | self.fm_scheduler = fm_scheduler 75 | 76 | @property 77 | def device(self): 78 | return next(self.network.parameters()).device 79 | 80 | @property 81 | def image_resolution(self): 82 | return self.network.image_resolution 83 | 84 | def get_loss(self, x1, class_label=None, x0=None): 85 | """ 86 | The conditional flow matching objective, corresponding Eq. 23 in the FM paper. 87 | """ 88 | batch_size = x1.shape[0] 89 | t = self.fm_scheduler.uniform_sample_t(batch_size).to(x1) 90 | if x0 is None: 91 | x0 = torch.randn_like(x1) 92 | 93 | ######## TODO ######## 94 | # DO NOT change the code outside this part. 95 | # Implement the CFM objective. 96 | if class_label is not None: 97 | model_out = self.network(x1, t, class_label=class_label) 98 | else: 99 | model_out = self.network(x1, t) 100 | 101 | loss = x1.mean() 102 | ###################### 103 | 104 | return loss 105 | 106 | def conditional_psi_sample(self, x1, t, x0=None): 107 | if x0 is None: 108 | x0 = torch.randn_like(x1) 109 | return self.fm_scheduler.compute_psi_t(x1, t, x0) 110 | 111 | @torch.no_grad() 112 | def sample( 113 | self, 114 | shape, 115 | num_inference_timesteps=50, 116 | return_traj=False, 117 | class_label: Optional[torch.Tensor] = None, 118 | guidance_scale: Optional[float] = 1.0, 119 | verbose=False, 120 | ): 121 | batch_size = shape[0] 122 | x_T = torch.randn(shape).to(self.device) 123 | do_classifier_free_guidance = guidance_scale > 1.0 124 | 125 | if do_classifier_free_guidance: 126 | assert class_label is not None 127 | assert ( 128 | len(class_label) == batch_size 129 | ), f"len(class_label) != batch_size. {len(class_label)} != {batch_size}" 130 | 131 | traj = [x_T] 132 | 133 | timesteps = [ 134 | i / num_inference_timesteps for i in range(num_inference_timesteps) 135 | ] 136 | timesteps = [torch.tensor([t] * x_T.shape[0]).to(x_T) for t in timesteps] 137 | pbar = tqdm(timesteps) if verbose else timesteps 138 | xt = x_T 139 | for i, t in enumerate(pbar): 140 | t_next = timesteps[i + 1] if i < len(timesteps) - 1 else torch.ones_like(t) 141 | 142 | 143 | ######## TODO ######## 144 | # Complete the sampling loop 145 | 146 | xt = self.fm_scheduler.step(xt, torch.zeros_like(xt), torch.zeros_like(t)) 147 | 148 | ###################### 149 | 150 | traj[-1] = traj[-1].cpu() 151 | traj.append(xt.clone().detach()) 152 | if return_traj: 153 | return traj 154 | else: 155 | return traj[-1] 156 | 157 | def save(self, file_path): 158 | hparams = { 159 | "network": self.network, 160 | "fm_scheduler": self.fm_scheduler, 161 | } 162 | state_dict = self.state_dict() 163 | 164 | dic = {"hparams": hparams, "state_dict": state_dict} 165 | torch.save(dic, file_path) 166 | 167 | def load(self, file_path): 168 | dic = torch.load(file_path, map_location="cpu") 169 | hparams = dic["hparams"] 170 | state_dict = dic["state_dict"] 171 | 172 | self.network = hparams["network"] 173 | self.fm_scheduler = hparams["fm_scheduler"] 174 | 175 | self.load_state_dict(state_dict) 176 | -------------------------------------------------------------------------------- /image_fm_todo/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | 8 | 9 | class Swish(nn.Module): 10 | def forward(self, x): 11 | return x * torch.sigmoid(x) 12 | 13 | 14 | class DownSample(nn.Module): 15 | def __init__(self, in_ch): 16 | super().__init__() 17 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) 18 | self.initialize() 19 | 20 | def initialize(self): 21 | init.xavier_uniform_(self.main.weight) 22 | init.zeros_(self.main.bias) 23 | 24 | def forward(self, x, temb): 25 | x = self.main(x) 26 | return x 27 | 28 | 29 | class UpSample(nn.Module): 30 | def __init__(self, in_ch): 31 | super().__init__() 32 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) 33 | self.initialize() 34 | 35 | def initialize(self): 36 | init.xavier_uniform_(self.main.weight) 37 | init.zeros_(self.main.bias) 38 | 39 | def forward(self, x, temb): 40 | _, _, H, W = x.shape 41 | x = F.interpolate(x, scale_factor=2, mode="nearest") 42 | x = self.main(x) 43 | return x 44 | 45 | 46 | class AttnBlock(nn.Module): 47 | def __init__(self, in_ch): 48 | super().__init__() 49 | self.group_norm = nn.GroupNorm(32, in_ch) 50 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 51 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 52 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 53 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 54 | self.initialize() 55 | 56 | def initialize(self): 57 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: 58 | init.xavier_uniform_(module.weight) 59 | init.zeros_(module.bias) 60 | init.xavier_uniform_(self.proj.weight, gain=1e-5) 61 | 62 | def forward(self, x): 63 | B, C, H, W = x.shape 64 | h = self.group_norm(x) 65 | q = self.proj_q(h) 66 | k = self.proj_k(h) 67 | v = self.proj_v(h) 68 | 69 | q = q.permute(0, 2, 3, 1).view(B, H * W, C) 70 | k = k.view(B, C, H * W) 71 | w = torch.bmm(q, k) * (int(C) ** (-0.5)) 72 | assert list(w.shape) == [B, H * W, H * W] 73 | w = F.softmax(w, dim=-1) 74 | 75 | v = v.permute(0, 2, 3, 1).view(B, H * W, C) 76 | h = torch.bmm(w, v) 77 | assert list(h.shape) == [B, H * W, C] 78 | h = h.view(B, H, W, C).permute(0, 3, 1, 2) 79 | h = self.proj(h) 80 | 81 | return x + h 82 | 83 | 84 | class ResBlock(nn.Module): 85 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): 86 | super().__init__() 87 | self.block1 = nn.Sequential( 88 | nn.GroupNorm(32, in_ch), 89 | Swish(), 90 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), 91 | ) 92 | self.temb_proj = nn.Sequential( 93 | Swish(), 94 | nn.Linear(tdim, out_ch), 95 | ) 96 | self.block2 = nn.Sequential( 97 | nn.GroupNorm(32, out_ch), 98 | Swish(), 99 | nn.Dropout(dropout), 100 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), 101 | ) 102 | if in_ch != out_ch: 103 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) 104 | else: 105 | self.shortcut = nn.Identity() 106 | if attn: 107 | self.attn = AttnBlock(out_ch) 108 | else: 109 | self.attn = nn.Identity() 110 | self.initialize() 111 | 112 | def initialize(self): 113 | for module in self.modules(): 114 | if isinstance(module, (nn.Conv2d, nn.Linear)): 115 | init.xavier_uniform_(module.weight) 116 | init.zeros_(module.bias) 117 | init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) 118 | 119 | def forward(self, x, temb): 120 | h = self.block1(x) 121 | h += self.temb_proj(temb)[:, :, None, None] 122 | h = self.block2(h) 123 | 124 | h = h + self.shortcut(x) 125 | h = self.attn(h) 126 | return h 127 | 128 | 129 | class TimeEmbedding(nn.Module): 130 | def __init__(self, hidden_size, frequency_embedding_size=256): 131 | super().__init__() 132 | self.mlp = nn.Sequential( 133 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 134 | nn.SiLU(), 135 | nn.Linear(hidden_size, hidden_size, bias=True), 136 | ) 137 | self.frequency_embedding_size = frequency_embedding_size 138 | 139 | @staticmethod 140 | def timestep_embedding(t, dim, max_period=10000): 141 | """ 142 | Create sinusoidal timestep embeddings. 143 | :param t: a 1-D Tensor of N indices, one per batch element. 144 | These may be fractional. 145 | :param dim: the dimension of the output. 146 | :param max_period: controls the minimum frequency of the embeddings. 147 | :return: an (N, D) Tensor of positional embeddings. 148 | """ 149 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 150 | half = dim // 2 151 | freqs = torch.exp( 152 | -math.log(max_period) 153 | * torch.arange(start=0, end=half, dtype=torch.float32) 154 | / half 155 | ).to(device=t.device) 156 | args = t[:, None].float() * freqs[None] 157 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 158 | if dim % 2: 159 | embedding = torch.cat( 160 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 161 | ) 162 | return embedding 163 | 164 | def forward(self, t): 165 | if t.ndim == 0: 166 | t = t.unsqueeze(-1) 167 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 168 | t_emb = self.mlp(t_freq) 169 | return t_emb 170 | -------------------------------------------------------------------------------- /image_fm_todo/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import chain 3 | from multiprocessing import Pool 4 | from pathlib import Path 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | 10 | 11 | def listdir(dname): 12 | fnames = list( 13 | chain( 14 | *[ 15 | list(Path(dname).rglob("*." + ext)) 16 | for ext in ["png", "jpg", "jpeg", "JPG"] 17 | ] 18 | ) 19 | ) 20 | return fnames 21 | 22 | 23 | def tensor_to_pil_image(x: torch.Tensor, single_image=False): 24 | """ 25 | x: [B,C,H,W] 26 | """ 27 | if x.ndim == 3: 28 | x = x.unsqueeze(0) 29 | single_image = True 30 | 31 | x = (x * 0.5 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy() 32 | images = (x * 255).round().astype("uint8") 33 | images = [Image.fromarray(image) for image in images] 34 | if single_image: 35 | return images[0] 36 | return images 37 | 38 | 39 | def get_data_iterator(iterable): 40 | """Allows training with DataLoaders in a single infinite loop: 41 | for i, data in enumerate(inf_generator(train_loader)): 42 | """ 43 | iterator = iterable.__iter__() 44 | while True: 45 | try: 46 | yield iterator.__next__() 47 | except StopIteration: 48 | iterator = iterable.__iter__() 49 | 50 | 51 | class AFHQDataset(torch.utils.data.Dataset): 52 | def __init__( 53 | self, 54 | root: str, 55 | split: str, 56 | transform=None, 57 | max_num_images_per_cat=-1, 58 | label_offset=1, 59 | ): 60 | super().__init__() 61 | self.root = root 62 | self.split = split 63 | self.transform = transform 64 | self.max_num_images_per_cat = max_num_images_per_cat 65 | self.label_offset = label_offset 66 | 67 | categories = os.listdir(os.path.join(root, split)) 68 | self.num_classes = len(categories) 69 | 70 | fnames, labels = [], [] 71 | for idx, cat in enumerate(sorted(categories)): 72 | category_dir = os.path.join(root, split, cat) 73 | cat_fnames = listdir(category_dir) 74 | cat_fnames = sorted(cat_fnames) 75 | if self.max_num_images_per_cat > 0: 76 | cat_fnames = cat_fnames[: self.max_num_images_per_cat] 77 | fnames += cat_fnames 78 | labels += [idx + label_offset] * len( 79 | cat_fnames 80 | ) # label 0 is for null class. 81 | 82 | self.fnames = fnames 83 | self.labels = labels 84 | 85 | def __getitem__(self, idx): 86 | img = Image.open(self.fnames[idx]).convert("RGB") 87 | label = self.labels[idx] 88 | assert label >= self.label_offset 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | 92 | return img, label 93 | 94 | def __len__(self): 95 | return len(self.labels) 96 | 97 | 98 | class AFHQDataModule(object): 99 | def __init__( 100 | self, 101 | root: str = "data", 102 | batch_size: int = 32, 103 | num_workers: int = 4, 104 | max_num_images_per_cat: int = 1000, 105 | image_resolution: int = 64, 106 | label_offset=1, 107 | transform=None, 108 | ): 109 | self.root = root 110 | self.batch_size = batch_size 111 | self.num_workers = num_workers 112 | self.afhq_root = os.path.join(root, "afhq") 113 | self.max_num_images_per_cat = max_num_images_per_cat 114 | self.image_resolution = image_resolution 115 | self.label_offset = label_offset 116 | self.transform = transform 117 | 118 | if not os.path.exists(self.afhq_root): 119 | print(f"{self.afhq_root} is empty. Downloading AFHQ dataset...") 120 | self._download_dataset() 121 | 122 | self._set_dataset() 123 | 124 | def _set_dataset(self): 125 | if self.transform is None: 126 | self.transform = transforms.Compose( 127 | [ 128 | transforms.Resize((self.image_resolution, self.image_resolution)), 129 | transforms.ToTensor(), 130 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 131 | ] 132 | ) 133 | self.train_ds = AFHQDataset( 134 | self.afhq_root, 135 | "train", 136 | self.transform, 137 | max_num_images_per_cat=self.max_num_images_per_cat, 138 | label_offset=self.label_offset, 139 | ) 140 | self.val_ds = AFHQDataset( 141 | self.afhq_root, 142 | "val", 143 | self.transform, 144 | max_num_images_per_cat=self.max_num_images_per_cat, 145 | label_offset=self.label_offset, 146 | ) 147 | 148 | self.num_classes = self.train_ds.num_classes 149 | 150 | def _download_dataset(self): 151 | URL = "https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0" 152 | ZIP_FILE = f"./{self.root}/afhq.zip" 153 | os.system(f"mkdir -p {self.root}") 154 | os.system(f"wget -N {URL} -O {ZIP_FILE}") 155 | os.system(f"unzip {ZIP_FILE} -d {self.root}") 156 | os.system(f"rm {ZIP_FILE}") 157 | 158 | def train_dataloader(self): 159 | return torch.utils.data.DataLoader( 160 | self.train_ds, 161 | batch_size=self.batch_size, 162 | num_workers=self.num_workers, 163 | shuffle=True, 164 | drop_last=True, 165 | ) 166 | 167 | def val_dataloader(self): 168 | return torch.utils.data.DataLoader( 169 | self.val_ds, 170 | batch_size=self.batch_size, 171 | num_workers=self.num_workers, 172 | shuffle=False, 173 | drop_last=False, 174 | ) 175 | 176 | 177 | if __name__ == "__main__": 178 | data_module = AFHQDataModule("data", 32, 4, -1, 64, 1) 179 | 180 | eval_dir = Path(data_module.afhq_root) / "eval" 181 | eval_dir.mkdir(exist_ok=True) 182 | 183 | def func(path): 184 | fn = path.name 185 | cmd = f"cp {path} {eval_dir / fn}" 186 | os.system(cmd) 187 | img = Image.open(str(eval_dir / fn)) 188 | img = img.resize((64, 64)) 189 | img.save(str(eval_dir / fn)) 190 | print(fn) 191 | 192 | with Pool(8) as pool: 193 | pool.map(func, data_module.val_ds.fnames) 194 | 195 | print(f"Constructed eval dir at {eval_dir}") 196 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Flow Matching 4 |

5 |

6 | KAIST CS492(D): Diffusion Models and Their Applications (Fall 2024)
7 | Programming Assignment 7 8 |

9 |
10 | 11 |
12 |

13 | Instructor: Minhyuk Sung (mhsung [at] kaist.ac.kr)
14 | TA: Juil Koo (63days [at] kaist.ac.kr)
15 |

16 |
17 | 18 |
19 | 20 |
21 | 22 | 23 | ## Abstract 24 | Flow Matching (FM) is a novel generative framework that shares similarities with diffusion models, particularly in how both tackle the Optimal Transport problem through an iterative process. Similar to diffusion models, FM also splits the sampling process into several time-dependent steps. At first glance, FM and diffusion models may seem almost identical due to their shared iterative sampling approach. However, the key differences lie in the objectve function and the choice of trajectories in FM. 25 | 26 | Regading the objective function, diffusion models predict the injected noise during training. In contrast, Flow Matching predicts the displacement between the data distribution and the prior distribution. 27 | 28 | Moreover, Flow Matching is developed from the perspective of _flow_, a time-dependent transformation function that corresponds to the forward pass in diffusion models. Unlike diffuison models, where the forward pass is fixed to ensure that every intermediate distribution also follows a Gaussian distribution, FM offers much greater flexibility in the choice of _flow_. This flexibility allows for the use of simpler trajectories, such as linear interpolation over time, between the data distribution and the prior distribution. 29 | 30 | Experimental results have sohwn that the FM objective and its simpler trajectory are highly effective in modeling the data distribution, Making FM a compelling alternative to diffusion models. 31 | 32 | ## Setup 33 | 34 | Install the required package within the `requirements.txt` 35 | ``` 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | **Please note that this assignment is heavily dependent on Assignment 2. To begin, you should copy the functions you implemented in [Assignment 2](https://github.com/KAIST-Visual-AI-Group/Diffusion-Assignment2-DDIM-CFG).** 40 | 41 | ## Code Structure 42 | ``` 43 | . 44 | ├── 2d_plot_fm_todo (Task 1) 45 | │ ├── fm_tutorial.ipynb <--- Main code 46 | │ ├── dataset.py <--- Define dataset (Swiss-roll, moon, gaussians, etc.) 47 | │ ├── network.py <--- A vector field prediction network 48 | │ └── fm.py <--- (TODO) Implement Flow Matching 49 | │ 50 | └── image_fm_todo (Task 2) 51 |    ├── dataset.py <--- Ready-to-use AFHQ dataset code 52 |    ├── fm.py <--- (TODO) Implement Flow Matching 53 |    ├── module.py <--- Basic modules of a vector field prediction network 54 |    ├── network.py <--- U-Net 55 |    ├── sampling.py <--- Image sampling code 56 |    └── train.py <--- Flow Matching training code 57 | ``` 58 | 59 | 60 | ## Task 0: Introduction 61 | ### Assignment Tips 62 | 63 | In this assignment, we explore Flow Matching (FM), a method that, like diffusion models, deals with time-dependent trajectories but does so from the perspective of vector fields that construct probability density paths. At a high level, FM shares similarities with diffusion models but offers more straight trajectories. Unlike the diffusion models, whose trajectories involve nonlinear terms that add unnecessary complexity, the trajectories in FM can be much simpler. They can be represented as a simple linear interpolation between the data point $x_0$ and random noise $x_1$​ over time $t$. As before, we recommend thoroughly understanding the equations presented in the papers before starting the assignment. The following is a list of recommended resources: 64 | 65 | 1. [[Paper](https://arxiv.org/abs/2210.02747)] Flow Matching for Generative Modeling (FM) 66 | 2. [[Paper](https://arxiv.org/abs/2209.03003)] Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow (RF) 67 | 3. [[Blog](https://mlg.eng.cam.ac.uk/blog/2024/01/20/flow-matching.html)] An Introduction to Flow Matching 68 | 69 | Further material is listed at the bottom of the document. 70 | 71 | ### Modeling the data distribution with flow 72 | In FM, we have a probability density path $p: [0,1] \times \mathbb{R}^d \rightarrow \mathbb{R}\_{>0}$, which is a time dependent probabilty density function, i.e., $\int p\_t(x) dx = 1$. Starting from the simple tractable distribution $p\_0(x) = p\_{\text{prior}}$, we aim to transform it into a more complex distribution $p\_1(x) = p\_{\text{data}}$ using a time-dependent diffeomorphic map, called _flow_: $\psi: [0,1] \times \mathbb{R}^d \rightarrow \mathbb{R}^d$. Rather than directly modeling the flow itself, as we've done in diffusion models, we instead model the derivative of the flow with respect to $t$, which is called _vector field_ $\frac{d}{dt}\psi\_t(x) = v\_t(\psi\_t(x))$. 73 | Given a target probability density path $p\_t(x)$ and a corresponding vector field $u\_t(x)$, we define the Flow Matching objective as: 74 | 75 | $$ 76 | \begin{align*} 77 | \mathcal{L}\_{\text{FM}} (\theta) = \mathbb{E}\_{t, p\_t(x)} \Vert v\_t(x;\theta) - u\_t(x) \Vert^2 78 | \end{align*}. 79 | $$ 80 | 81 | However, since $p\_t$ and $u\_t$ are intractable, we make the training simpler by modeling the conditional probability paths and vector fields: 82 | 83 | $$ 84 | \begin{align*} 85 | p\_t(x) &= \int p\_t(x | x_1) q(x_1) dx_1 \\ 86 | u\_t(x) &= \int u\_t(x | x_1) \frac{p\_t(x | x\_1) q(x\_1)} {p_t(x)} dx_1 87 | \end{align*} 88 | $$ 89 | 90 | The equations above can be derived from the _continuity equation_: $\partial p_t(x) / \partial t = -\nabla \cdot (p_t(x) v_t(x)).$ 91 | See the appendix A of the FM paper for more details. 92 | 93 | Given the conditional probabiliy paths $p_t(x|x_1)$ and vector fields $u_t(x | x_1)$, we define conditional flow matching objective as: 94 | 95 | $$ 96 | \begin{align*} 97 | \mathcal{L}\_{\text{CFM}} (\theta) = \mathbb{E}\_{t, p\_1(x\_1), p\_t(x | x\_1)} \Vert v\_t(x;\theta) - u\_t(x | x_1) \Vert^2 98 | \end{align*} 99 | $$ 100 | 101 | ### Special instances of Flow Matching 102 | While the conditional probability paths and vector fields can be designed in various ways, we opt for the simplest vector field, which takes the form of a Gaussian kernel: 103 | 104 | $$ 105 | \begin{align*} 106 | p\_t(x | x_1) &= \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I) \\ 107 | \psi_t(x) &= \sigma_t(x_1)x + \mu_t(x_1) 108 | \end{align*}. 109 | $$ 110 | 111 | Specifically, we set $\mu_t(x) = tx_1$, and $\sigma_t(x) = 1 - (1 - \sigma\_{\text{min}})t$. Given the $\mu_t(x)$ and $\sigma_t(x)$, the conditional flow is defined as follows: $\psi_t(x) = (1 - (1 - \sigma\_{\text{min}} )t)x + tx_1$. 112 | 113 | In this case, the CFM loss takes the following form: 114 | 115 | $$ 116 | \begin{align*} 117 | \mathcal{L}\_{\text{CFM}} (\theta) = \mathbb{E}\_{t, p\_1(x\_1), p\_t(x | x\_1)} \Vert v\_t(\psi_t(x\_0);\theta) - (x_1 - (1 - \sigma\_{\text{min}}) x_0) \Vert^2 118 | \end{align*} 119 | $$ 120 | 121 | ## Task 1: FM with Swiss-Roll 122 | As similar to previous Assignments 1 and 2, we will first implement Flow Matching (FM) and test it in a simple 2D plot toy experiment setup. 123 | 124 | ❗️❗️❗️ **You are only allowed to edit the part marked by TODO.** ❗️❗️❗️ 125 | 126 | ### TODO 127 | In this assignment, you will implement all key functions of flow matching for training and sampling. 128 | 129 | #### 1-0: Copy the previous completed implementations 130 | You can copy the `2d_plot_ddpm_todo/network.py` and `image_ddpm_todo/network.py` that you've already implemented in Assignments 1 and 2. 131 | 132 | #### 1-1: Implement Flow Matching Scheduler 133 | complete the functions `compute_psi_t()` and `step()` of `FMScheduler` class in `fm.py`. 134 | The `step()` function is a one step of ODESolver from $t=0$ to $t=1$. Although more sophisticated numerical method can be used, you just need to implement the simplest one, first-order Euler method. 135 | 136 | $$ 137 | \begin{align*} 138 | x_{t+\Delta t} = x_t + \Delta t \frac{\partial x_t}{\partial t}, 139 | \end{align*} 140 | $$ 141 | 142 | where $\frac{\partial x_t}{\partial t}$ would be modeled by a neural network. 143 | 144 | 145 | #### 1-2: Implement the conditional flow matching objective. 146 | Complete `get_loss()` function of `FlowMatching` in `fm.py` that corresponds to the conditional flow matching objective written in Eq. 23 of the FM paper. 147 | 148 | 149 | #### 1-3: Implement the sampling code. 150 | Complete `sample()` function of `FlowMatching` in `fm.py`. 151 | 152 | #### 1-4: Training and Evaluation 153 | Once you finish the implementation above, open and run `fm_tutorial.ipynb` via jupyter notebook. It will automatically train a FM and measure chamfer distance between 2D particles sampled by the FM and 2D particles sampled from the target distribution. 154 | 155 | Take screenshots of: 156 | 157 | 1. the training loss curve 158 | 2. the Chamfer Distance reported after executing the Jupyter Notebook 159 | 3. the visualization of the sampled particles. 160 | 161 | Below are the xamples of (1) and (3). 162 | 163 |

164 | image 165 |

166 | 167 |

168 | image 169 |

170 | 171 | 172 | Note that you need to run sampling with 50 inference steps, which is set as the default. 173 | 174 | ## Task 2: Image Generation with FM 175 | 176 | ### TODO 177 | If you've completed Task 1, finish implementing the `sample()` function and `get_loss()` function to work with a classifier-free guidance setup. 178 | After finishing the implementation for the CFG setup, train FM with the CFG setup: `python train.py --use_cfg`. 179 | 180 | ❗️❗️❗️ **You are only allowed to edit the part marked by TODO.** ❗️❗️❗️ 181 | 182 | It will sample images and save a checkpoint every `args.log_interval`. After training a model, sample and save images by 183 | 184 | ``` 185 | python sampling.py --use_cfg --ckpt_path ${CKPT_PATH} --save_dir ${SAVE_DIR_PATH} 186 | ``` 187 | 188 | We recommend starting the training as soon as possible as the training would take 14 hours. 189 | 190 | As done in Assignments 1 and 2, measure FID score using the pre-trained classifier network provided previously. 191 | 192 | ``` 193 | python dataset.py 194 | python fid/measure_fid.py $GT_IMG_DIR $GEN_IMG_DIR 195 | ``` 196 | Use the evaluation set of the AFHQ dataset, `data/afhq/eval`, not `data/afhq/val` as @GT_IMG_DIR. 197 | 198 | Take a screenshot of a FID and include at least 8 sampled images. 199 | 200 | 201 | ## What to Submit 202 | 203 |
204 | Submission Item List 205 |
206 | 207 | - [ ] Code without model checkpoints 208 | 209 | **Task 1** 210 | - [ ] Loss curve screenshot 211 | - [ ] Chamfer distance results of FM sampling with 50 inference steps. 212 | - [ ] Visualization of FM sampling. 213 | 214 | 215 | **Task 2** 216 | - [ ] FID score result obtained with the CFG scale of 7.5 217 | - [ ] At least 8 images generated by Flow Matching 218 |
219 | 220 | In a single PDF file, write your name and student ID, and include submission items listed above. Refer to more detailed instructions written in each task section about what to submit. 221 | Name the document `{NAME}_{STUDENT_ID}.pdf` and submit **both your code and the document** as a **ZIP** file named `{NAME}_{STUDENT_ID}.zip`. 222 | **For this programming assignment**, exclude any model checkpoints and the provided pre-trained classifier checkpoint when compressing the files. 223 | Submit the zip file on GradeScope. 224 | 225 | ## Grading 226 | **You will receive a zero score if:** 227 | - **you do not submit,** 228 | - **your code is not executable in the Python environment we provided, or** 229 | - **you modify anycode outside of the section marked with `TODO` or use different hyperparameters that are supposed to be fixed as given.** 230 | 231 | **Plagiarism in any form will also result in a zero score and will be reported to the university.** 232 | 233 | **Your score will incur a 10% deduction for each missing item in the submission item list.** 234 | 235 | Otherwise, you will receive up to 20 points from this assignment that count toward your final grade. 236 | 237 | - Task 1 238 | - 10 points: Achieve CD lower than 40. 239 | - 5 points: Achieve greater, or equal to 40 and less than 60. 240 | - 0 point: otherwise. 241 | - Task 2 242 | - 10 points: Achieve FID lower than **30** with CFG=7.5. 243 | - 5 points: Achieve FID between **30** and **50** with CFG=7.5. 244 | - 0 point: otherwise. 245 | 246 | ## Further Readings 247 | 248 | If you are interested in this topic, we encourage you to check ou the materials below. 249 | 250 | - [Flow Matching for Generative Modeling](https://arxiv.org/abs/2210.02747) 251 | - [Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow](https://arxiv.org/abs/2209.03003) 252 | - [An Introduction to Flow Matching](https://mlg.eng.cam.ac.uk/blog/2024/01/20/flow-matching.html) 253 | - [Neural Ordinary Differential Equations](https://arxiv.org/abs/1806.07366) 254 | --------------------------------------------------------------------------------