├── 2d_plot_fm_todo ├── __init__.py ├── __pycache__ │ ├── fm.cpython-39.pyc │ ├── dataset.cpython-39.pyc │ └── network.cpython-39.pyc ├── chamferdist.py ├── dataset.py ├── network.py └── fm.py ├── .gitignore ├── assets ├── fm_loss.png ├── fm_2d_vis.png └── trajectory_visualization.png ├── requirements.txt ├── LICENSE ├── image_fm_todo ├── sampling.py ├── network.py ├── train.py ├── fm.py ├── module.py └── dataset.py └── README.md /2d_plot_fm_todo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/results/ 2 | **/data/ 3 | **/__pycache__/ 4 | **/.ipynb_checkpoints/ 5 | -------------------------------------------------------------------------------- /assets/fm_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/assets/fm_loss.png -------------------------------------------------------------------------------- /assets/fm_2d_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/assets/fm_2d_vis.png -------------------------------------------------------------------------------- /assets/trajectory_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/assets/trajectory_visualization.png -------------------------------------------------------------------------------- /2d_plot_fm_todo/__pycache__/fm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/2d_plot_fm_todo/__pycache__/fm.cpython-39.pyc -------------------------------------------------------------------------------- /2d_plot_fm_todo/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/2d_plot_fm_todo/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /2d_plot_fm_todo/__pycache__/network.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment7-Flow/HEAD/2d_plot_fm_todo/__pycache__/network.cpython-39.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn==1.1.3 2 | ipython==8.12.0 3 | jupyter==1.0.0 4 | matplotlib==3.6.0 5 | torch==2.0.1 6 | torch-ema==0.3 7 | torchvision==0.15.2 8 | tqdm==4.64.1 9 | jupyterlab==4.0.2 10 | jaxtyping==0.2.20 11 | pytorch_lightning 12 | dotmap 13 | numpy==1.26.0 14 | -------------------------------------------------------------------------------- /2d_plot_fm_todo/chamferdist.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import KDTree 2 | from scipy.spatial.distance import cdist 3 | 4 | def chamfer_distance(S1, S2) -> float: 5 | """ 6 | Computes the Chamfer distance between two point clouds defined as: 7 | d_CD(S1, S2) = \sigma_{x \in S1} min_{y in S2} ||x - y||^2 + \sigma_{y \in S2} min_{x in S1} ||x - y||^2 8 | """ 9 | dist = cdist(S1, S2) 10 | dist1 = dist.min(axis=1) ** 2 11 | dist2 = dist.min(axis=0) ** 2 12 | return dist1.sum() + dist2.sum() 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 KAIST Visual AI Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /image_fm_todo/sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from dataset import tensor_to_pil_image 6 | from fm import FlowMatching 7 | from pathlib import Path 8 | 9 | 10 | def main(args): 11 | save_dir = Path(args.save_dir) 12 | save_dir.mkdir(exist_ok=True, parents=True) 13 | 14 | device = f"cuda:{args.gpu}" 15 | 16 | fm = FlowMatching(None, None) 17 | fm.load(args.ckpt_path) 18 | fm.eval() 19 | fm = fm.to(device) 20 | 21 | total_num_samples = 500 22 | num_batches = int(np.ceil(total_num_samples / args.batch_size)) 23 | 24 | for i in range(num_batches): 25 | sidx = i * args.batch_size 26 | eidx = min(sidx + args.batch_size, total_num_samples) 27 | B = eidx - sidx 28 | 29 | if args.use_cfg: # Enable CFG sampling 30 | assert fm.network.use_cfg, f"The model was not trained to support CFG." 31 | shape = (B, 3, fm.image_resolution, fm.image_resolution) 32 | samples = fm.sample( 33 | shape, 34 | num_inference_timesteps=20, 35 | class_label=torch.randint(1, 4, (B,)).to(device), 36 | guidance_scale=args.cfg_scale, 37 | ) 38 | else: 39 | raise NotImplementedError("In Assignment 7, we sample images with CFG setup only.") 40 | 41 | pil_images = tensor_to_pil_image(samples) 42 | 43 | for j, img in zip(range(sidx, eidx), pil_images): 44 | img.save(save_dir / f"{j}.png") 45 | print(f"Saved the {j}-th image.") 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--batch_size", type=int, default=128) 51 | parser.add_argument("--gpu", type=int, default=0) 52 | parser.add_argument("--ckpt_path", type=str) 53 | parser.add_argument("--save_dir", type=str) 54 | parser.add_argument("--use_cfg", action="store_true") 55 | parser.add_argument("--cfg_scale", type=float, default=7.5) 56 | 57 | args = parser.parse_args() 58 | main(args) 59 | -------------------------------------------------------------------------------- /2d_plot_fm_todo/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn import datasets 4 | from torch.utils.data import DataLoader, Dataset 5 | 6 | 7 | def normalize(ds, scaling_factor=2.0): 8 | return (ds - ds.mean()) / ds.std() * scaling_factor 9 | 10 | 11 | def sample_checkerboard(n): 12 | # https://github.com/ghliu/SB-FBSDE/blob/main/data.py 13 | n_points = 3 * n 14 | n_classes = 2 15 | freq = 5 16 | x = np.random.uniform( 17 | -(freq // 2) * np.pi, (freq // 2) * np.pi, size=(n_points, n_classes) 18 | ) 19 | mask = np.logical_or( 20 | np.logical_and(np.sin(x[:, 0]) > 0.0, np.sin(x[:, 1]) > 0.0), 21 | np.logical_and(np.sin(x[:, 0]) < 0.0, np.sin(x[:, 1]) < 0.0), 22 | ) 23 | y = np.eye(n_classes)[1 * mask] 24 | x0 = x[:, 0] * y[:, 0] 25 | x1 = x[:, 1] * y[:, 0] 26 | sample = np.concatenate([x0[..., None], x1[..., None]], axis=-1) 27 | sqr = np.sum(np.square(sample), axis=-1) 28 | idxs = np.where(sqr == 0) 29 | sample = np.delete(sample, idxs, axis=0) 30 | 31 | return sample 32 | 33 | 34 | def load_twodim(num_samples: int, dataset: str, dimension: int = 2): 35 | 36 | if dataset == "gaussian_centered": 37 | sample = np.random.normal(size=(num_samples, dimension)) 38 | sample = sample 39 | 40 | if dataset == "gaussian_shift": 41 | sample = np.random.normal(size=(num_samples, dimension)) 42 | sample = sample + 1.5 43 | 44 | if dataset == "circle": 45 | X, y = datasets.make_circles( 46 | n_samples=num_samples, noise=0.0, random_state=None, factor=0.5 47 | ) 48 | sample = X * 4 49 | 50 | if dataset == "scurve": 51 | X, y = datasets.make_s_curve( 52 | n_samples=num_samples, noise=0.0, random_state=None 53 | ) 54 | sample = normalize(X[:, [0, 2]]) 55 | 56 | if dataset == "moon": 57 | X, y = datasets.make_moons(n_samples=num_samples, noise=0.0, random_state=None) 58 | sample = normalize(X) 59 | 60 | if dataset == "swiss_roll": 61 | X, y = datasets.make_swiss_roll( 62 | n_samples=num_samples, noise=0.0, random_state=None, hole=True 63 | ) 64 | sample = normalize(X[:, [0, 2]]) 65 | 66 | if dataset == "checkerboard": 67 | sample = normalize(sample_checkerboard(num_samples)) 68 | 69 | return torch.tensor(sample).float() 70 | 71 | 72 | class TwoDimDataClass(Dataset): 73 | def __init__(self, dataset_type: str, N: int, batch_size: int, dimension=2): 74 | 75 | self.X = load_twodim(N, dataset_type, dimension=dimension) 76 | self.name = dataset_type 77 | self.batch_size = batch_size 78 | self.dimension = 2 79 | 80 | def __len__(self): 81 | return self.X.shape[0] 82 | 83 | def __getitem__(self, idx): 84 | return self.X[idx] 85 | 86 | def get_dataloader(self, shuffle=True): 87 | return DataLoader( 88 | self, 89 | batch_size=self.batch_size, 90 | shuffle=shuffle, 91 | pin_memory=True, 92 | ) 93 | 94 | 95 | def get_data_iterator(iterable): 96 | iterator = iterable.__iter__() 97 | while True: 98 | try: 99 | yield iterator.__next__() 100 | except StopIteration: 101 | iterator = iterable.__iter__() 102 | -------------------------------------------------------------------------------- /2d_plot_fm_todo/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class TimeEmbedding(nn.Module): 10 | def __init__(self, hidden_size, frequency_embedding_size=256): 11 | super().__init__() 12 | self.mlp = nn.Sequential( 13 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 14 | nn.SiLU(), 15 | nn.Linear(hidden_size, hidden_size, bias=True), 16 | ) 17 | self.frequency_embedding_size = frequency_embedding_size 18 | 19 | @staticmethod 20 | def timestep_embedding(t, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | :param t: a 1-D Tensor of N indices, one per batch element. 24 | These may be fractional. 25 | :param dim: the dimension of the output. 26 | :param max_period: controls the minimum frequency of the embeddings. 27 | :return: an (N, D) Tensor of positional embeddings. 28 | """ 29 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) 33 | * torch.arange(start=0, end=half, dtype=torch.float32) 34 | / half 35 | ).to(device=t.device) 36 | args = t[:, None].float() * freqs[None] 37 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 38 | if dim % 2: 39 | embedding = torch.cat( 40 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 41 | ) 42 | return embedding 43 | 44 | def forward(self, t: torch.Tensor): 45 | if t.ndim == 0: 46 | t = t.unsqueeze(-1) 47 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 48 | t_emb = self.mlp(t_freq) 49 | return t_emb 50 | 51 | 52 | class TimeLinear(nn.Module): 53 | def __init__(self, dim_in: int, dim_out: int, num_timesteps: int): 54 | super().__init__() 55 | self.dim_in = dim_in 56 | self.dim_out = dim_out 57 | self.num_timesteps = num_timesteps 58 | 59 | self.time_embedding = TimeEmbedding(dim_out) 60 | self.fc = nn.Linear(dim_in, dim_out) 61 | 62 | def forward(self, x: torch.Tensor, t: torch.Tensor): 63 | x = self.fc(x) 64 | alpha = self.time_embedding(t).view(-1, self.dim_out) 65 | 66 | return alpha * x 67 | 68 | 69 | class SimpleNet(nn.Module): 70 | def __init__( 71 | self, dim_in: int, dim_out: int, dim_hids: List[int], num_timesteps: int 72 | ): 73 | super().__init__() 74 | """ 75 | (TODO) Build a noise estimating network. 76 | 77 | Args: 78 | dim_in: dimension of input 79 | dim_out: dimension of output 80 | dim_hids: dimensions of hidden features 81 | num_timesteps: number of timesteps 82 | """ 83 | 84 | ######## TODO ######## 85 | # DO NOT change the code outside this part. 86 | 87 | ###################### 88 | 89 | def forward(self, x: torch.Tensor, t: torch.Tensor): 90 | """ 91 | (TODO) Implement the forward pass. This should output 92 | the noise prediction of the noisy input x at timestep t. 93 | 94 | Args: 95 | x: the noisy data after t period diffusion 96 | t: the time that the forward diffusion has been running 97 | """ 98 | ######## TODO ######## 99 | # DO NOT change the code outside this part. 100 | 101 | ###################### 102 | return x 103 | -------------------------------------------------------------------------------- /image_fm_todo/network.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from module import DownSample, ResBlock, Swish, TimeEmbedding, UpSample 8 | from torch.nn import init 9 | 10 | 11 | class UNet(nn.Module): 12 | def __init__(self, T=1000, image_resolution=64, ch=128, ch_mult=[1,2,2,2], attn=[1], num_res_blocks=4, dropout=0.1, use_cfg=False, cfg_dropout=0.1, num_classes=None): 13 | super().__init__() 14 | self.image_resolution = image_resolution 15 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' 16 | tdim = ch * 4 17 | # self.time_embedding = TimeEmbedding(T, ch, tdim) 18 | self.time_embedding = TimeEmbedding(tdim) 19 | 20 | # classifier-free guidance 21 | self.use_cfg = use_cfg 22 | self.cfg_dropout = cfg_dropout 23 | if use_cfg: 24 | assert num_classes is not None 25 | cdim = tdim 26 | self.class_embedding = nn.Embedding(num_classes+1, cdim) 27 | 28 | self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) 29 | self.downblocks = nn.ModuleList() 30 | chs = [ch] # record output channel when dowmsample for upsample 31 | now_ch = ch 32 | for i, mult in enumerate(ch_mult): 33 | out_ch = ch * mult 34 | for _ in range(num_res_blocks): 35 | self.downblocks.append(ResBlock( 36 | in_ch=now_ch, out_ch=out_ch, tdim=tdim, 37 | dropout=dropout, attn=(i in attn))) 38 | now_ch = out_ch 39 | chs.append(now_ch) 40 | if i != len(ch_mult) - 1: 41 | self.downblocks.append(DownSample(now_ch)) 42 | chs.append(now_ch) 43 | 44 | self.middleblocks = nn.ModuleList([ 45 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True), 46 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False), 47 | ]) 48 | 49 | self.upblocks = nn.ModuleList() 50 | for i, mult in reversed(list(enumerate(ch_mult))): 51 | out_ch = ch * mult 52 | for _ in range(num_res_blocks + 1): 53 | self.upblocks.append(ResBlock( 54 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, 55 | dropout=dropout, attn=(i in attn))) 56 | now_ch = out_ch 57 | if i != 0: 58 | self.upblocks.append(UpSample(now_ch)) 59 | assert len(chs) == 0 60 | 61 | self.tail = nn.Sequential( 62 | nn.GroupNorm(32, now_ch), 63 | Swish(), 64 | nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) 65 | ) 66 | self.initialize() 67 | 68 | def initialize(self): 69 | init.xavier_uniform_(self.head.weight) 70 | init.zeros_(self.head.bias) 71 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) 72 | init.zeros_(self.tail[-1].bias) 73 | 74 | def forward(self, x, timestep, class_label=None): 75 | # Timestep embedding 76 | temb = self.time_embedding(timestep) 77 | 78 | if self.use_cfg and class_label is not None: 79 | if self.training: 80 | assert not torch.any(class_label == 0) # 0 for null. 81 | 82 | ######## TODO ######## 83 | # DO NOT change the code outside this part. 84 | # Assignment 2. Implement random null conditioning in CFG training. 85 | raise NotImplementedError("TODO") 86 | ####################### 87 | 88 | ######## TODO ######## 89 | # DO NOT change the code outside this part. 90 | # Assignment 2. Implement class conditioning 91 | raise NotImplementedError("TODO") 92 | ####################### 93 | 94 | # Downsampling 95 | h = self.head(x) 96 | hs = [h] 97 | for layer in self.downblocks: 98 | h = layer(h, temb) 99 | hs.append(h) 100 | # Middle 101 | for layer in self.middleblocks: 102 | h = layer(h, temb) 103 | # Upsampling 104 | for layer in self.upblocks: 105 | if isinstance(layer, ResBlock): 106 | h = torch.cat([h, hs.pop()], dim=1) 107 | h = layer(h, temb) 108 | h = self.tail(h) 109 | 110 | assert len(hs) == 0 111 | return h 112 | -------------------------------------------------------------------------------- /image_fm_todo/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from datetime import datetime 4 | from pathlib import Path 5 | 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | import torch 9 | from dataset import AFHQDataModule, get_data_iterator, tensor_to_pil_image 10 | from dotmap import DotMap 11 | from fm import FlowMatching, FMScheduler 12 | from network import UNet 13 | from pytorch_lightning import seed_everything 14 | from torchvision.transforms.functional import to_pil_image 15 | from tqdm import tqdm 16 | 17 | matplotlib.use("Agg") 18 | 19 | 20 | def get_current_time(): 21 | now = datetime.now().strftime("%m-%d-%H%M%S") 22 | return now 23 | 24 | 25 | def main(args): 26 | """config""" 27 | config = DotMap() 28 | config.update(vars(args)) 29 | config.device = f"cuda:{args.gpu}" 30 | 31 | now = get_current_time() 32 | assert args.use_cfg, f"In Assignment 7, we sample images with CFG setup only." 33 | 34 | if args.use_cfg: 35 | save_dir = Path(f"results/cfg_fm-{now}") 36 | else: 37 | save_dir = Path(f"results/fm-{now}") 38 | save_dir.mkdir(exist_ok=True, parents=True) 39 | print(f"save_dir: {save_dir}") 40 | 41 | seed_everything(config.seed) 42 | 43 | with open(save_dir / "config.json", "w") as f: 44 | json.dump(config, f, indent=2) 45 | """######""" 46 | 47 | image_resolution = 64 48 | ds_module = AFHQDataModule( 49 | "./data", 50 | batch_size=config.batch_size, 51 | num_workers=4, 52 | max_num_images_per_cat=config.max_num_images_per_cat, 53 | image_resolution=image_resolution 54 | ) 55 | 56 | train_dl = ds_module.train_dataloader() 57 | train_it = get_data_iterator(train_dl) 58 | 59 | # Set up the scheduler 60 | fm_scheduler = FMScheduler(sigma_min=args.sigma_min) 61 | 62 | network = UNet( 63 | image_resolution=image_resolution, 64 | ch=128, 65 | ch_mult=[1, 2, 2, 2], 66 | attn=[1], 67 | num_res_blocks=4, 68 | dropout=0.1, 69 | use_cfg=args.use_cfg, 70 | cfg_dropout=args.cfg_dropout, 71 | num_classes=getattr(ds_module, "num_classes", None), 72 | ) 73 | 74 | fm = FlowMatching(network, fm_scheduler) 75 | fm = fm.to(config.device) 76 | 77 | optimizer = torch.optim.Adam(fm.network.parameters(), lr=2e-4) 78 | scheduler = torch.optim.lr_scheduler.LambdaLR( 79 | optimizer, lr_lambda=lambda t: min((t + 1) / config.warmup_steps, 1.0) 80 | ) 81 | 82 | step = 0 83 | losses = [] 84 | with tqdm(initial=step, total=config.train_num_steps) as pbar: 85 | while step < config.train_num_steps: 86 | if step % config.log_interval == 0: 87 | fm.eval() 88 | plt.plot(losses) 89 | plt.savefig(f"{save_dir}/loss.png") 90 | plt.close() 91 | shape = (4, 3, fm.image_resolution, fm.image_resolution) 92 | if args.use_cfg: 93 | class_label = torch.tensor([1,1,2,3]).to(config.device) 94 | samples = fm.sample(shape, class_label=class_label, guidance_scale=7.5, verbose=False) 95 | else: 96 | samples = fm.sample(shape, return_traj=False, verbose=False) 97 | pil_images = tensor_to_pil_image(samples) 98 | for i, img in enumerate(pil_images): 99 | img.save(save_dir / f"step={step}-{i}.png") 100 | 101 | fm.save(f"{save_dir}/last.ckpt") 102 | fm.train() 103 | 104 | img, label = next(train_it) 105 | img, label = img.to(config.device), label.to(config.device) 106 | if args.use_cfg: # Conditional, CFG training 107 | loss = fm.get_loss(img, class_label=label) 108 | else: # Unconditional training 109 | loss = fm.get_loss(img) 110 | pbar.set_description(f"Loss: {loss.item():.4f}") 111 | 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | scheduler.step() 116 | losses.append(loss.item()) 117 | 118 | step += 1 119 | pbar.update(1) 120 | 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--gpu", type=int, default=0) 125 | parser.add_argument("--batch_size", type=int, default=16) 126 | parser.add_argument( 127 | "--train_num_steps", 128 | type=int, 129 | default=100000, 130 | help="the number of model training steps.", 131 | ) 132 | parser.add_argument("--warmup_steps", type=int, default=200) 133 | parser.add_argument("--log_interval", type=int, default=200) 134 | parser.add_argument( 135 | "--max_num_images_per_cat", 136 | type=int, 137 | default=3000, 138 | help="max number of images per category for AFHQ dataset", 139 | ) 140 | parser.add_argument("--sigma_min", type=float, default=0.001) 141 | parser.add_argument("--seed", type=int, default=63) 142 | parser.add_argument("--image_resolution", type=int, default=64) 143 | parser.add_argument("--use_cfg", action="store_true") 144 | parser.add_argument("--cfg_dropout", type=float, default=0.1) 145 | args = parser.parse_args() 146 | main(args) 147 | -------------------------------------------------------------------------------- /2d_plot_fm_todo/fm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | 10 | def expand_t(t, x): 11 | for _ in range(x.ndim - 1): 12 | t = t.unsqueeze(-1) 13 | return t 14 | 15 | 16 | class FMScheduler(nn.Module): 17 | def __init__(self, num_train_timesteps=1000, sigma_min=0.001): 18 | super().__init__() 19 | self.num_train_timesteps = num_train_timesteps 20 | self.sigma_min = sigma_min 21 | 22 | def uniform_sample_t(self, batch_size) -> torch.LongTensor: 23 | ts = ( 24 | np.random.choice(np.arange(self.num_train_timesteps), batch_size) 25 | / self.num_train_timesteps 26 | ) 27 | return torch.from_numpy(ts) 28 | 29 | def compute_psi_t(self, x1, t, x): 30 | """ 31 | Compute the conditional flow psi_t(x | x_1). 32 | 33 | Note that time flows in the opposite direction compared to DDPM/DDIM. 34 | As t moves from 0 to 1, the probability paths shift from a prior distribution p_0(x) 35 | to a more complex data distribution p_1(x). 36 | 37 | Input: 38 | x1 (`torch.Tensor`): Data sample from the data distribution. 39 | t (`torch.Tensor`): Timestep in [0,1). 40 | x (`torch.Tensor`): The input to the conditional psi_t(x). 41 | Output: 42 | psi_t (`torch.Tensor`): The conditional flow at t. 43 | """ 44 | t = expand_t(t, x1) 45 | 46 | ######## TODO ######## 47 | # DO NOT change the code outside this part. 48 | # compute psi_t(x) 49 | 50 | psi_t = x1 51 | ###################### 52 | 53 | return psi_t 54 | 55 | def step(self, xt, vt, dt): 56 | """ 57 | The simplest ode solver as the first-order Euler method: 58 | x_next = xt + dt * vt 59 | """ 60 | 61 | ######## TODO ######## 62 | # DO NOT change the code outside this part. 63 | # implement each step of the first-order Euler method. 64 | x_next = xt 65 | ###################### 66 | 67 | return x_next 68 | 69 | 70 | class FlowMatching(nn.Module): 71 | def __init__(self, network: nn.Module, fm_scheduler: FMScheduler, **kwargs): 72 | super().__init__() 73 | self.network = network 74 | self.fm_scheduler = fm_scheduler 75 | 76 | @property 77 | def device(self): 78 | return next(self.network.parameters()).device 79 | 80 | @property 81 | def image_resolution(self): 82 | return self.network.image_resolution 83 | 84 | def get_loss(self, x1, class_label=None, x0=None): 85 | """ 86 | The conditional flow matching objective, corresponding Eq. 23 in the FM paper. 87 | """ 88 | batch_size = x1.shape[0] 89 | t = self.fm_scheduler.uniform_sample_t(batch_size).to(x1) 90 | if x0 is None: 91 | x0 = torch.randn_like(x1) 92 | 93 | ######## TODO ######## 94 | # DO NOT change the code outside this part. 95 | # Implement the CFM objective. 96 | if class_label is not None: 97 | model_out = self.network(x1, t, class_label=class_label) 98 | else: 99 | model_out = self.network(x1, t) 100 | 101 | loss = x1.mean() 102 | ###################### 103 | 104 | return loss 105 | 106 | def conditional_psi_sample(self, x1, t, x0=None): 107 | if x0 is None: 108 | x0 = torch.randn_like(x1) 109 | return self.fm_scheduler.compute_psi_t(x1, t, x0) 110 | 111 | @torch.no_grad() 112 | def sample( 113 | self, 114 | shape, 115 | num_inference_timesteps=50, 116 | return_traj=False, 117 | class_label: Optional[torch.Tensor] = None, 118 | guidance_scale: Optional[float] = 1.0, 119 | verbose=False, 120 | ): 121 | batch_size = shape[0] 122 | x_T = torch.randn(shape).to(self.device) 123 | do_classifier_free_guidance = guidance_scale > 1.0 124 | 125 | if do_classifier_free_guidance: 126 | assert class_label is not None 127 | assert ( 128 | len(class_label) == batch_size 129 | ), f"len(class_label) != batch_size. {len(class_label)} != {batch_size}" 130 | 131 | traj = [x_T] 132 | 133 | timesteps = [ 134 | i / num_inference_timesteps for i in range(num_inference_timesteps) 135 | ] 136 | timesteps = [torch.tensor([t] * x_T.shape[0]).to(x_T) for t in timesteps] 137 | pbar = tqdm(timesteps) if verbose else timesteps 138 | xt = x_T 139 | for i, t in enumerate(pbar): 140 | t_next = timesteps[i + 1] if i < len(timesteps) - 1 else torch.ones_like(t) 141 | 142 | 143 | ######## TODO ######## 144 | # Complete the sampling loop 145 | 146 | xt = self.fm_scheduler.step(xt, torch.zeros_like(xt), torch.zeros_like(t)) 147 | 148 | ###################### 149 | 150 | traj[-1] = traj[-1].cpu() 151 | traj.append(xt.clone().detach()) 152 | if return_traj: 153 | return traj 154 | else: 155 | return traj[-1] 156 | 157 | def save(self, file_path): 158 | hparams = { 159 | "network": self.network, 160 | "fm_scheduler": self.fm_scheduler, 161 | } 162 | state_dict = self.state_dict() 163 | 164 | dic = {"hparams": hparams, "state_dict": state_dict} 165 | torch.save(dic, file_path) 166 | 167 | def load(self, file_path): 168 | dic = torch.load(file_path, map_location="cpu") 169 | hparams = dic["hparams"] 170 | state_dict = dic["state_dict"] 171 | 172 | self.network = hparams["network"] 173 | self.fm_scheduler = hparams["fm_scheduler"] 174 | 175 | self.load_state_dict(state_dict) 176 | -------------------------------------------------------------------------------- /image_fm_todo/fm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | 10 | def expand_t(t, x): 11 | for _ in range(x.ndim - 1): 12 | t = t.unsqueeze(-1) 13 | return t 14 | 15 | 16 | class FMScheduler(nn.Module): 17 | def __init__(self, num_train_timesteps=1000, sigma_min=0.001): 18 | super().__init__() 19 | self.num_train_timesteps = num_train_timesteps 20 | self.sigma_min = sigma_min 21 | 22 | def uniform_sample_t(self, batch_size) -> torch.LongTensor: 23 | ts = ( 24 | np.random.choice(np.arange(self.num_train_timesteps), batch_size) 25 | / self.num_train_timesteps 26 | ) 27 | return torch.from_numpy(ts) 28 | 29 | def compute_psi_t(self, x1, t, x): 30 | """ 31 | Compute the conditional flow psi_t(x | x_1). 32 | 33 | Note that time flows in the opposite direction compared to DDPM/DDIM. 34 | As t moves from 0 to 1, the probability paths shift from a prior distribution p_0(x) 35 | to a more complex data distribution p_1(x). 36 | 37 | Input: 38 | x1 (`torch.Tensor`): Data sample from the data distribution. 39 | t (`torch.Tensor`): Timestep in [0,1). 40 | x (`torch.Tensor`): The input to the conditional psi_t(x). 41 | Output: 42 | psi_t (`torch.Tensor`): The conditional flow at t. 43 | """ 44 | t = expand_t(t, x1) 45 | 46 | ######## TODO ######## 47 | # DO NOT change the code outside this part. 48 | # compute psi_t(x) 49 | 50 | psi_t = x1 51 | ###################### 52 | 53 | return psi_t 54 | 55 | def step(self, xt, vt, dt): 56 | """ 57 | The simplest ode solver as the first-order Euler method: 58 | x_next = xt + dt * vt 59 | """ 60 | 61 | ######## TODO ######## 62 | # DO NOT change the code outside this part. 63 | # implement each step of the first-order Euler method. 64 | x_next = xt 65 | ###################### 66 | 67 | return x_next 68 | 69 | 70 | class FlowMatching(nn.Module): 71 | def __init__(self, network: nn.Module, fm_scheduler: FMScheduler, **kwargs): 72 | super().__init__() 73 | self.network = network 74 | self.fm_scheduler = fm_scheduler 75 | 76 | @property 77 | def device(self): 78 | return next(self.network.parameters()).device 79 | 80 | @property 81 | def image_resolution(self): 82 | return self.network.image_resolution 83 | 84 | def get_loss(self, x1, class_label=None, x0=None): 85 | """ 86 | The conditional flow matching objective, corresponding Eq. 23 in the FM paper. 87 | """ 88 | batch_size = x1.shape[0] 89 | t = self.fm_scheduler.uniform_sample_t(batch_size).to(x1) 90 | if x0 is None: 91 | x0 = torch.randn_like(x1) 92 | 93 | ######## TODO ######## 94 | # DO NOT change the code outside this part. 95 | # Implement the CFM objective. 96 | if class_label is not None: 97 | model_out = self.network(x1, t, class_label=class_label) 98 | else: 99 | model_out = self.network(x1, t) 100 | 101 | loss = x1.mean() 102 | ###################### 103 | 104 | return loss 105 | 106 | def conditional_psi_sample(self, x1, t, x0=None): 107 | if x0 is None: 108 | x0 = torch.randn_like(x1) 109 | return self.fm_scheduler.compute_psi_t(x1, t, x0) 110 | 111 | @torch.no_grad() 112 | def sample( 113 | self, 114 | shape, 115 | num_inference_timesteps=50, 116 | return_traj=False, 117 | class_label: Optional[torch.Tensor] = None, 118 | guidance_scale: Optional[float] = 1.0, 119 | verbose=False, 120 | ): 121 | batch_size = shape[0] 122 | x_T = torch.randn(shape).to(self.device) 123 | do_classifier_free_guidance = guidance_scale > 1.0 124 | 125 | if do_classifier_free_guidance: 126 | assert class_label is not None 127 | assert ( 128 | len(class_label) == batch_size 129 | ), f"len(class_label) != batch_size. {len(class_label)} != {batch_size}" 130 | 131 | traj = [x_T] 132 | 133 | timesteps = [ 134 | i / num_inference_timesteps for i in range(num_inference_timesteps) 135 | ] 136 | timesteps = [torch.tensor([t] * x_T.shape[0]).to(x_T) for t in timesteps] 137 | pbar = tqdm(timesteps) if verbose else timesteps 138 | xt = x_T 139 | for i, t in enumerate(pbar): 140 | t_next = timesteps[i + 1] if i < len(timesteps) - 1 else torch.ones_like(t) 141 | 142 | 143 | ######## TODO ######## 144 | # Complete the sampling loop 145 | 146 | xt = self.fm_scheduler.step(xt, torch.zeros_like(xt), torch.zeros_like(t)) 147 | 148 | ###################### 149 | 150 | traj[-1] = traj[-1].cpu() 151 | traj.append(xt.clone().detach()) 152 | if return_traj: 153 | return traj 154 | else: 155 | return traj[-1] 156 | 157 | def save(self, file_path): 158 | hparams = { 159 | "network": self.network, 160 | "fm_scheduler": self.fm_scheduler, 161 | } 162 | state_dict = self.state_dict() 163 | 164 | dic = {"hparams": hparams, "state_dict": state_dict} 165 | torch.save(dic, file_path) 166 | 167 | def load(self, file_path): 168 | dic = torch.load(file_path, map_location="cpu") 169 | hparams = dic["hparams"] 170 | state_dict = dic["state_dict"] 171 | 172 | self.network = hparams["network"] 173 | self.fm_scheduler = hparams["fm_scheduler"] 174 | 175 | self.load_state_dict(state_dict) 176 | -------------------------------------------------------------------------------- /image_fm_todo/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | 8 | 9 | class Swish(nn.Module): 10 | def forward(self, x): 11 | return x * torch.sigmoid(x) 12 | 13 | 14 | class DownSample(nn.Module): 15 | def __init__(self, in_ch): 16 | super().__init__() 17 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) 18 | self.initialize() 19 | 20 | def initialize(self): 21 | init.xavier_uniform_(self.main.weight) 22 | init.zeros_(self.main.bias) 23 | 24 | def forward(self, x, temb): 25 | x = self.main(x) 26 | return x 27 | 28 | 29 | class UpSample(nn.Module): 30 | def __init__(self, in_ch): 31 | super().__init__() 32 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) 33 | self.initialize() 34 | 35 | def initialize(self): 36 | init.xavier_uniform_(self.main.weight) 37 | init.zeros_(self.main.bias) 38 | 39 | def forward(self, x, temb): 40 | _, _, H, W = x.shape 41 | x = F.interpolate(x, scale_factor=2, mode="nearest") 42 | x = self.main(x) 43 | return x 44 | 45 | 46 | class AttnBlock(nn.Module): 47 | def __init__(self, in_ch): 48 | super().__init__() 49 | self.group_norm = nn.GroupNorm(32, in_ch) 50 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 51 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 52 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 53 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 54 | self.initialize() 55 | 56 | def initialize(self): 57 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: 58 | init.xavier_uniform_(module.weight) 59 | init.zeros_(module.bias) 60 | init.xavier_uniform_(self.proj.weight, gain=1e-5) 61 | 62 | def forward(self, x): 63 | B, C, H, W = x.shape 64 | h = self.group_norm(x) 65 | q = self.proj_q(h) 66 | k = self.proj_k(h) 67 | v = self.proj_v(h) 68 | 69 | q = q.permute(0, 2, 3, 1).view(B, H * W, C) 70 | k = k.view(B, C, H * W) 71 | w = torch.bmm(q, k) * (int(C) ** (-0.5)) 72 | assert list(w.shape) == [B, H * W, H * W] 73 | w = F.softmax(w, dim=-1) 74 | 75 | v = v.permute(0, 2, 3, 1).view(B, H * W, C) 76 | h = torch.bmm(w, v) 77 | assert list(h.shape) == [B, H * W, C] 78 | h = h.view(B, H, W, C).permute(0, 3, 1, 2) 79 | h = self.proj(h) 80 | 81 | return x + h 82 | 83 | 84 | class ResBlock(nn.Module): 85 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): 86 | super().__init__() 87 | self.block1 = nn.Sequential( 88 | nn.GroupNorm(32, in_ch), 89 | Swish(), 90 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), 91 | ) 92 | self.temb_proj = nn.Sequential( 93 | Swish(), 94 | nn.Linear(tdim, out_ch), 95 | ) 96 | self.block2 = nn.Sequential( 97 | nn.GroupNorm(32, out_ch), 98 | Swish(), 99 | nn.Dropout(dropout), 100 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), 101 | ) 102 | if in_ch != out_ch: 103 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) 104 | else: 105 | self.shortcut = nn.Identity() 106 | if attn: 107 | self.attn = AttnBlock(out_ch) 108 | else: 109 | self.attn = nn.Identity() 110 | self.initialize() 111 | 112 | def initialize(self): 113 | for module in self.modules(): 114 | if isinstance(module, (nn.Conv2d, nn.Linear)): 115 | init.xavier_uniform_(module.weight) 116 | init.zeros_(module.bias) 117 | init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) 118 | 119 | def forward(self, x, temb): 120 | h = self.block1(x) 121 | h += self.temb_proj(temb)[:, :, None, None] 122 | h = self.block2(h) 123 | 124 | h = h + self.shortcut(x) 125 | h = self.attn(h) 126 | return h 127 | 128 | 129 | class TimeEmbedding(nn.Module): 130 | def __init__(self, hidden_size, frequency_embedding_size=256): 131 | super().__init__() 132 | self.mlp = nn.Sequential( 133 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 134 | nn.SiLU(), 135 | nn.Linear(hidden_size, hidden_size, bias=True), 136 | ) 137 | self.frequency_embedding_size = frequency_embedding_size 138 | 139 | @staticmethod 140 | def timestep_embedding(t, dim, max_period=10000): 141 | """ 142 | Create sinusoidal timestep embeddings. 143 | :param t: a 1-D Tensor of N indices, one per batch element. 144 | These may be fractional. 145 | :param dim: the dimension of the output. 146 | :param max_period: controls the minimum frequency of the embeddings. 147 | :return: an (N, D) Tensor of positional embeddings. 148 | """ 149 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 150 | half = dim // 2 151 | freqs = torch.exp( 152 | -math.log(max_period) 153 | * torch.arange(start=0, end=half, dtype=torch.float32) 154 | / half 155 | ).to(device=t.device) 156 | args = t[:, None].float() * freqs[None] 157 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 158 | if dim % 2: 159 | embedding = torch.cat( 160 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 161 | ) 162 | return embedding 163 | 164 | def forward(self, t): 165 | if t.ndim == 0: 166 | t = t.unsqueeze(-1) 167 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 168 | t_emb = self.mlp(t_freq) 169 | return t_emb 170 | -------------------------------------------------------------------------------- /image_fm_todo/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import chain 3 | from multiprocessing import Pool 4 | from pathlib import Path 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | 10 | 11 | def listdir(dname): 12 | fnames = list( 13 | chain( 14 | *[ 15 | list(Path(dname).rglob("*." + ext)) 16 | for ext in ["png", "jpg", "jpeg", "JPG"] 17 | ] 18 | ) 19 | ) 20 | return fnames 21 | 22 | 23 | def tensor_to_pil_image(x: torch.Tensor, single_image=False): 24 | """ 25 | x: [B,C,H,W] 26 | """ 27 | if x.ndim == 3: 28 | x = x.unsqueeze(0) 29 | single_image = True 30 | 31 | x = (x * 0.5 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy() 32 | images = (x * 255).round().astype("uint8") 33 | images = [Image.fromarray(image) for image in images] 34 | if single_image: 35 | return images[0] 36 | return images 37 | 38 | 39 | def get_data_iterator(iterable): 40 | """Allows training with DataLoaders in a single infinite loop: 41 | for i, data in enumerate(inf_generator(train_loader)): 42 | """ 43 | iterator = iterable.__iter__() 44 | while True: 45 | try: 46 | yield iterator.__next__() 47 | except StopIteration: 48 | iterator = iterable.__iter__() 49 | 50 | 51 | class AFHQDataset(torch.utils.data.Dataset): 52 | def __init__( 53 | self, 54 | root: str, 55 | split: str, 56 | transform=None, 57 | max_num_images_per_cat=-1, 58 | label_offset=1, 59 | ): 60 | super().__init__() 61 | self.root = root 62 | self.split = split 63 | self.transform = transform 64 | self.max_num_images_per_cat = max_num_images_per_cat 65 | self.label_offset = label_offset 66 | 67 | categories = os.listdir(os.path.join(root, split)) 68 | self.num_classes = len(categories) 69 | 70 | fnames, labels = [], [] 71 | for idx, cat in enumerate(sorted(categories)): 72 | category_dir = os.path.join(root, split, cat) 73 | cat_fnames = listdir(category_dir) 74 | cat_fnames = sorted(cat_fnames) 75 | if self.max_num_images_per_cat > 0: 76 | cat_fnames = cat_fnames[: self.max_num_images_per_cat] 77 | fnames += cat_fnames 78 | labels += [idx + label_offset] * len( 79 | cat_fnames 80 | ) # label 0 is for null class. 81 | 82 | self.fnames = fnames 83 | self.labels = labels 84 | 85 | def __getitem__(self, idx): 86 | img = Image.open(self.fnames[idx]).convert("RGB") 87 | label = self.labels[idx] 88 | assert label >= self.label_offset 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | 92 | return img, label 93 | 94 | def __len__(self): 95 | return len(self.labels) 96 | 97 | 98 | class AFHQDataModule(object): 99 | def __init__( 100 | self, 101 | root: str = "data", 102 | batch_size: int = 32, 103 | num_workers: int = 4, 104 | max_num_images_per_cat: int = 1000, 105 | image_resolution: int = 64, 106 | label_offset=1, 107 | transform=None, 108 | ): 109 | self.root = root 110 | self.batch_size = batch_size 111 | self.num_workers = num_workers 112 | self.afhq_root = os.path.join(root, "afhq") 113 | self.max_num_images_per_cat = max_num_images_per_cat 114 | self.image_resolution = image_resolution 115 | self.label_offset = label_offset 116 | self.transform = transform 117 | 118 | if not os.path.exists(self.afhq_root): 119 | print(f"{self.afhq_root} is empty. Downloading AFHQ dataset...") 120 | self._download_dataset() 121 | 122 | self._set_dataset() 123 | 124 | def _set_dataset(self): 125 | if self.transform is None: 126 | self.transform = transforms.Compose( 127 | [ 128 | transforms.Resize((self.image_resolution, self.image_resolution)), 129 | transforms.ToTensor(), 130 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 131 | ] 132 | ) 133 | self.train_ds = AFHQDataset( 134 | self.afhq_root, 135 | "train", 136 | self.transform, 137 | max_num_images_per_cat=self.max_num_images_per_cat, 138 | label_offset=self.label_offset, 139 | ) 140 | self.val_ds = AFHQDataset( 141 | self.afhq_root, 142 | "val", 143 | self.transform, 144 | max_num_images_per_cat=self.max_num_images_per_cat, 145 | label_offset=self.label_offset, 146 | ) 147 | 148 | self.num_classes = self.train_ds.num_classes 149 | 150 | def _download_dataset(self): 151 | URL = "https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0" 152 | ZIP_FILE = f"./{self.root}/afhq.zip" 153 | os.system(f"mkdir -p {self.root}") 154 | os.system(f"wget -N {URL} -O {ZIP_FILE}") 155 | os.system(f"unzip {ZIP_FILE} -d {self.root}") 156 | os.system(f"rm {ZIP_FILE}") 157 | 158 | def train_dataloader(self): 159 | return torch.utils.data.DataLoader( 160 | self.train_ds, 161 | batch_size=self.batch_size, 162 | num_workers=self.num_workers, 163 | shuffle=True, 164 | drop_last=True, 165 | ) 166 | 167 | def val_dataloader(self): 168 | return torch.utils.data.DataLoader( 169 | self.val_ds, 170 | batch_size=self.batch_size, 171 | num_workers=self.num_workers, 172 | shuffle=False, 173 | drop_last=False, 174 | ) 175 | 176 | 177 | if __name__ == "__main__": 178 | data_module = AFHQDataModule("data", 32, 4, -1, 64, 1) 179 | 180 | eval_dir = Path(data_module.afhq_root) / "eval" 181 | eval_dir.mkdir(exist_ok=True) 182 | 183 | def func(path): 184 | fn = path.name 185 | cmd = f"cp {path} {eval_dir / fn}" 186 | os.system(cmd) 187 | img = Image.open(str(eval_dir / fn)) 188 | img = img.resize((64, 64)) 189 | img.save(str(eval_dir / fn)) 190 | print(fn) 191 | 192 | with Pool(8) as pool: 193 | pool.map(func, data_module.val_ds.fnames) 194 | 195 | print(f"Constructed eval dir at {eval_dir}") 196 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
6 | KAIST CS492(D): Diffusion Models and Their Applications (Fall 2024)
7 | Programming Assignment 7
8 |
13 | Instructor: Minhyuk Sung (mhsung [at] kaist.ac.kr)
14 | TA: Juil Koo (63days [at] kaist.ac.kr)
15 |
20 |
164 |
165 |
168 |
169 |