├── README.md ├── classifier_load_dataset.py ├── config_category.json ├── linear_classifier_3classes.py ├── load_dataset.py ├── plotting.py ├── preprocess_celeba-3classes.py ├── properties_single-body_2d_3classes.json └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Compositional Abilities Emerge Multiplicatively: Exploring Diffusion Models on a Synthetic Task 2 | 3 | ## Introduction 4 | This project is the codebase for the paper titled "Compositional Abilities Emerge Multiplicatively: Exploring Diffusion Models on a Synthetic Task," presented at NeurIPS 2023. The code implements Denoising Diffusion Probabilistic Models (DDPM) and is built using PyTorch. 5 | 6 | ## Paper 7 | - **Title:** Compositional Abilities Emerge Multiplicatively: Exploring Diffusion Models on a Synthetic Task 8 | - **Authors:** Maya Okawa, Ekdeep Singh Lubana, Robert P. Dick, Hidenori Tanaka 9 | - **Venue:** Advances in Neural Information Processing Systems (NeurIPS) 10 | - **Published:** 2023 11 | - **Link:** [arXiv link](https://arxiv.org/abs/2310.09336) 12 | 13 | ## Installation 14 | - PyTorch 15 | - torchvision 16 | - numpy 17 | - matplotlib 18 | - tqdm 19 | - einops 20 | 21 | ## Dataset 22 | 23 | You can download the datasets required for this project: 24 | 25 | - **Synthetic Data:** 26 | The synthetic data for this project can be downloaded from the following link: 27 | [Download Synthetic Data](https://www.dropbox.com/scl/fi/6zzb5h4bly2gbignwn4yz/single-body_2d_3classes.zip?rlkey=0uizen48trsl6cm4oaui2ze41&dl=0) 28 | 29 | - **Real Data:** 30 | The preprocessed CelebA data for this project can be downloaded from the following link: 31 | [Download Preprocessed Data](https://www.dropbox.com/scl/fi/j8vwioyhmecibqeglxlok/celeba-3classes-smiling-10000_100.zip?rlkey=grz1de7psug5tlm89h5l7w7ia&dl=0) 32 | 33 | 34 | ## Usage 35 | First, create an `input` directory at the root level of this project. Then, place the data files under the `input` directory. 36 | 37 | To train the model with the synthetic data, run the `train.py` script with the desired parameters. For example: 38 | 39 | `python3 train.py --dataset single-body_2d_3classes` 40 | 41 | To train the model using the CelebA dataset: 42 | 43 | `python3 train.py --dataset celeba-3classes-10000` 44 | 45 | 46 | ## Structure 47 | - `train.py`: Main script for training the DDPM model. 48 | - `load_dataset.py`: Script for loading and processing datasets. 49 | 50 | 51 | ## References 52 | - The `DDPM` class in `train.py` is based on the implementation found at [TeaPearce/Conditional_Diffusion_MNIST](https://github.com/TeaPearce/Conditional_Diffusion_MNIST/blob/main/script.py). 53 | - The `CrossAttention` class in `train.py` is inspired by the code in [Animadversio/DiffusionFromScratch/StableDiff_UNet_model.py](https://github.com/Animadversio/DiffusionFromScratch/blob/master/StableDiff_UNet_model.py). 54 | 55 | 56 | -------------------------------------------------------------------------------- /classifier_load_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset 4 | from PIL import Image 5 | from torchvision import transforms 6 | import os 7 | import json 8 | import matplotlib.pyplot as plt 9 | import glob 10 | import numpy as np 11 | import random 12 | import json 13 | 14 | 15 | class my_dataset(Dataset): 16 | def __init__(self, transform=None, num_samples=5000, dataset="", configs="", training=True, n_class_color=None, n_class_size=None, test_size=1.3, alpha=1.0, flag_four=False): 17 | self.training = training 18 | self.test_size = test_size 19 | self.flag_four = flag_four 20 | 21 | self.dataset = dataset 22 | self.n_class_size = n_class_size 23 | self.n_class_color = n_class_color 24 | 25 | if training: 26 | self.train_image_paths = [] 27 | for config in configs: 28 | new_paths = glob.glob(dataset+"/*/CLEVR_"+config+"_*.png") 29 | self.train_image_paths += new_paths 30 | else: 31 | self.test_image_paths = glob.glob(dataset+"/test/CLEVR_"+configs+"_*.png") 32 | 33 | if self.training: 34 | self.len_data = len(self.train_image_paths) - 1 35 | else: 36 | self.len_data = len(self.test_image_paths) - 1 37 | 38 | self.num_samples = num_samples 39 | self.transform = transform 40 | 41 | 42 | def __getitem__(self, index): 43 | if self.training: 44 | ipath = random.randint(0, len(self.train_image_paths)-1) 45 | img_path = self.train_image_paths[ipath] 46 | else: 47 | ipath = random.randint(0, len(self.test_image_paths)-1) 48 | img_path = self.test_image_paths[ipath] 49 | 50 | img = Image.open(img_path) #.convert('RGB') 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | 54 | json_path = img_path.replace(".png", ".json") 55 | with open(json_path, 'r') as f: 56 | features = json.load(f) 57 | if self.dataset == "single-body_2d_3classes": 58 | size = features[0] 59 | color = features[1][0] 60 | if self.dataset == "single-body_3d_3classes": 61 | size = features[0] 62 | color = features[1][0] 63 | if "single-body_2d_4classes" in self.dataset: 64 | size = features[0] 65 | color = features[1][0] 66 | position = features[-1] 67 | 68 | name_labels = img_path.split("_")[-2] 69 | if "single-body_2d_3classes" in self.dataset or "single-body_3d_3classes" in self.dataset: 70 | label = {0: int(name_labels[0]), 1: int(name_labels[1]), 2: int(name_labels[2])} 71 | elif "single-body_2d_4classes" in self.dataset: 72 | label = {0: int(name_labels[0]), 1: int(name_labels[1]), 2: int(name_labels[2]), 3: int(name_labels[3])} 73 | 74 | return img, label 75 | 76 | def __len__(self): 77 | return self.num_samples 78 | 79 | 80 | if __name__ == '__main__': 81 | transform = transforms.Compose([transforms.ToTensor()]) 82 | dataset = my_dataset(transform, dataset="single-body_2d_color_0.05", n_class_size=1, n_class_color=1, configs=["000","010","100","001"]) 83 | dataloader = DataLoader(dataset, batch_size=4) 84 | 85 | -------------------------------------------------------------------------------- /config_category.json: -------------------------------------------------------------------------------- 1 | { 2 | "H22-train1": {"train": ["00","01","10"], "test": ["11"]}, 3 | "H23-train1": {"train": ["011","211","201","210"], "test": ["200"]}, 4 | "H23-train2": {"train": ["011","211","201","210","010"], "test": ["200"]}, 5 | "H42-train1": {"train": ["0000","1000","0100","0001","0010"], "test": ["0011","0101","1001","0110","1010","1100","0111","1011","1101","1110","1111"]}, 6 | "H32-train1-finetune200": {"train": ["200"], "test": ["011","110","101","111","210","211","201"]}, 7 | "H32-train1": {"train": ["000","001","100","010"], "test": ["011","110","101","111"]}, 8 | "H32-train2": {"train": ["000","001","011","101"], "test": ["010","100","111","110"]}, 9 | "H32-train3": {"train": ["000","010","011","110"], "test": ["001","100","111","101"]}, 10 | "H32-train4": {"train": ["000","100","110","101"], "test": ["001","010","111","011"]}, 11 | "H32-train5": {"train": ["000","100","010","111"], "test": ["001","011","110","101"]}, 12 | "H32-train6": {"train": ["001","011","101","110"], "test": ["000","010","100","111"]}, 13 | "H32-train7": {"train": ["000","010","110","101"], "test": ["100","001","011","111"]}, 14 | "H32-train8": {"train": ["000","100","110","011"], "test": ["010","001","101","111"]}, 15 | "H32-fail1": {"train": ["000","100","010","111"], "test": ["001","011","110","101"]}, 16 | "H32-fail2": {"train": ["000","100","010","111","200","201","300","301"], "test": ["001","011","110","101"]}, 17 | "H32-fail1-finetune001": {"train": ["001"], "test": ["001","011","110","101"]}, 18 | "H32-fail1-finetune110": {"train": ["110"], "test": ["001","011","101","001"]}, 19 | "H32-train5-finetune110": {"train": ["110"], "test": ["001","011","110","101"]}, 20 | "H32-train9": {"train": ["000","011","110","101"], "test": ["111","010","001","100"]}, 21 | "H32-train10": {"train": ["000","001","100","010","011","110","101"], "test": ["011","110","101","111"]} 22 | } 23 | -------------------------------------------------------------------------------- /linear_classifier_3classes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import torch.utils.data as data 6 | 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | 10 | from sklearn import metrics 11 | from sklearn import decomposition 12 | from sklearn import manifold 13 | from tqdm.notebook import trange, tqdm 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | 17 | import copy 18 | import random 19 | import time 20 | import classifier_load_dataset 21 | import itertools 22 | import json 23 | 24 | 25 | 26 | class MLP(nn.Module): 27 | def __init__(self, input_dim, output_dims): 28 | super().__init__() 29 | 30 | self.output_fc0 = nn.Linear(input_dim, output_dims[0]) 31 | self.output_fc1 = nn.Linear(input_dim, output_dims[1]) 32 | self.output_fc2 = nn.Linear(input_dim, output_dims[2]) 33 | 34 | def forward(self, x): 35 | batch_size = x.shape[0] 36 | x = x[:,:3,:,:].reshape(batch_size, -1) 37 | 38 | y_pred = {} 39 | y_pred[0] = self.output_fc0(x) 40 | y_pred[1] = self.output_fc1(x) 41 | y_pred[2] = self.output_fc2(x) 42 | 43 | return y_pred 44 | 45 | 46 | 47 | def train(model, iterator, optimizer, criterion, device): 48 | 49 | epoch_loss = 0 50 | epoch_acc = {0: 0, 1: 0, 2: 0} 51 | model.train() 52 | 53 | for (x, y) in tqdm(iterator, desc="Training", leave=False): 54 | x = x.to(device) 55 | optimizer.zero_grad() 56 | y_pred = model(x) 57 | loss = criterion(y_pred[0], y[0]) + criterion(y_pred[1], y[1]) + criterion(y_pred[2], y[2]) 58 | acc = {} 59 | acc[0] = calculate_accuracy(y_pred[0], y[0]) 60 | acc[1] = calculate_accuracy(y_pred[1], y[1]) 61 | acc[2] = calculate_accuracy(y_pred[2], y[2]) 62 | loss.backward() 63 | optimizer.step() 64 | epoch_loss += loss.item() 65 | epoch_acc[0] += acc[0].item() 66 | epoch_acc[1] += acc[1].item() 67 | epoch_acc[2] += acc[2].item() 68 | 69 | epoch_acc[0] /= len(iterator) 70 | epoch_acc[1] /= len(iterator) 71 | epoch_acc[2] /= len(iterator) 72 | 73 | return epoch_loss / len(iterator), epoch_acc 74 | 75 | def calc_mse(pred, gt): 76 | return torch.sqrt(torch.mean((pred-gt)**2)) 77 | 78 | 79 | def evaluate(model, iterator, criterion, device): 80 | 81 | epoch_loss = 0 82 | #epoch_acc = 0 83 | epoch_acc = {0: 0, 1: 0, 2: 0} 84 | 85 | model.eval() 86 | with torch.no_grad(): 87 | for (x, y) in tqdm(iterator, desc="Evaluating", leave=False): 88 | x = x.to(device) 89 | #y = _y[key].to(device) 90 | 91 | y_pred = model(x) 92 | loss = criterion(y_pred[0], y[0]) + criterion(y_pred[1], y[1]) + criterion(y_pred[2], y[2]) 93 | acc = {} 94 | acc[0] = calculate_accuracy(y_pred[0], y[0]) 95 | acc[1] = calculate_accuracy(y_pred[1], y[1]) 96 | acc[2] = calculate_accuracy(y_pred[2], y[2]) 97 | epoch_loss += loss.item() 98 | #epoch_acc += acc.item() 99 | epoch_acc[0] += acc[0].item() 100 | epoch_acc[1] += acc[1].item() 101 | epoch_acc[2] += acc[2].item() 102 | 103 | epoch_acc[0] /= len(iterator) 104 | epoch_acc[1] /= len(iterator) 105 | epoch_acc[2] /= len(iterator) 106 | 107 | return epoch_loss / len(iterator), epoch_acc #/ len(iterator) 108 | 109 | 110 | def calculate_accuracy(y_pred, y): 111 | top_pred = y_pred.argmax(1, keepdim=True) 112 | correct = top_pred.eq(y.view_as(top_pred)).sum() 113 | acc = correct.float() / y.shape[0] 114 | return acc 115 | 116 | 117 | def epoch_time(start_time, end_time): 118 | elapsed_time = end_time - start_time 119 | elapsed_mins = int(elapsed_time / 60) 120 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) 121 | return elapsed_mins, elapsed_secs 122 | 123 | 124 | if __name__ == "__main__": 125 | 126 | dataset = "single-body_2d_3classes" 127 | properties_json = "properties_"+dataset+".json" 128 | with open(properties_json, 'r') as f: 129 | properties = json.load(f) 130 | 131 | keys, values = zip(*properties.items()) 132 | permutations = [dict(zip(keys, v)) for v in itertools.product(*values)] 133 | 134 | configs = [] 135 | for permutation in permutations: 136 | configs.append("".join(permutation.values())) 137 | 138 | 139 | pixel_size = 28 140 | n_class_color = 2 #3 141 | 142 | tf = transforms.Compose([transforms.Resize((pixel_size,pixel_size)), transforms.ToTensor()]) 143 | train_dataset = classifier_load_dataset.my_dataset(tf, 5000, dataset, configs=configs, training=True, n_class_color=n_class_color) 144 | test_data = classifier_load_dataset.my_dataset(tf, 500, dataset, configs=configs, training=True, n_class_color=n_class_color) 145 | 146 | train_data, valid_data = data.random_split(train_dataset, [4500, 500]) 147 | BATCH_SIZE = 128 148 | valid_data = copy.deepcopy(valid_data) 149 | train_iterator = data.DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE) 150 | valid_iterator = data.DataLoader(valid_data, batch_size=BATCH_SIZE) 151 | test_iterator = data.DataLoader(test_data, batch_size=BATCH_SIZE) 152 | 153 | 154 | INPUT_DIM = pixel_size * pixel_size * 3 #4 155 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 156 | 157 | 158 | #for key in ["shapes","colors","sizes"]: 159 | OUTPUT_DIMS = [len(properties[key]) for key in ["shapes","colors","sizes"]] 160 | OUTPUT_DIMS[1] = n_class_color 161 | model = MLP(INPUT_DIM, OUTPUT_DIMS) 162 | optimizer = optim.Adam(model.parameters()) 163 | criterion = nn.CrossEntropyLoss() 164 | #model = model.to(device) 165 | criterion = criterion.to(device) 166 | 167 | EPOCHS = 10 #200 168 | best_valid_loss = float('inf') 169 | 170 | for epoch in trange(EPOCHS): 171 | 172 | start_time = time.monotonic() 173 | 174 | train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device) 175 | valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device) 176 | 177 | if valid_loss < best_valid_loss: 178 | best_valid_loss = valid_loss 179 | torch.save(model.state_dict(), 'working/linear-classifier_'+dataset+'_multi-class.pt') 180 | 181 | end_time = time.monotonic() 182 | 183 | epoch_mins, epoch_secs = epoch_time(start_time, end_time) 184 | 185 | print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s') 186 | print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc[0]*100:.2f}% {train_acc[1]*100:.2f}% {train_acc[2]*100:.2f}%') 187 | print(f'\tValid Loss: {valid_loss:.3f} | valid Acc: {valid_acc[0]*100:.2f}% {valid_acc[1]*100:.2f}% {valid_acc[2]*100:.2f}%') 188 | 189 | 190 | test_loss, test_acc = evaluate(model, test_iterator, criterion, device) 191 | #print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc[0]*100:.2f}%') 192 | print(f'\tTest Loss: {test_loss:.3f} | test Acc: {test_acc[0]*100:.2f}% {test_acc[1]*100:.2f}% {test_acc[2]*100:.2f}%') 193 | 194 | 195 | -------------------------------------------------------------------------------- /load_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset 4 | from PIL import Image 5 | from torchvision import transforms 6 | import os 7 | import json 8 | import matplotlib.pyplot as plt 9 | import glob 10 | import numpy as np 11 | import random 12 | import json 13 | 14 | 15 | 16 | class my_dataset(Dataset): 17 | def __init__(self, transform=None, num_samples=5000, dataset="", configs="", training=True, test_size=None, alpha=1.0, beta=2.0, remove_node=None, flag_double=1): 18 | self.training = training 19 | self.test_size = test_size 20 | self.dataset = dataset 21 | 22 | prefix = "celeba" if "celeba" in dataset else "CLEVR" 23 | ext = ".jpg" if prefix == "celeba" else ".png" 24 | 25 | if training: 26 | self.train_image_paths = [] 27 | for config in configs: 28 | if config == "000" and alpha != 1500 and remove_node != "100": 29 | path_pattern = f"input/{dataset}/train_{remove_node}/{prefix}_000_*{ext}" 30 | else: 31 | path_pattern = f"working/{dataset}/train/{prefix}_{config}_*{ext}" 32 | new_paths = glob.glob(path_pattern) 33 | 34 | if remove_node == config: 35 | new_paths = new_paths[:alpha] 36 | 37 | self.train_image_paths.extend(new_paths) 38 | self.len_data = len(self.train_image_paths) 39 | else: 40 | self.test_image_paths = glob.glob(f"input/{dataset}/test/{prefix}_{configs}_*{ext}") 41 | self.len_data = len(self.test_image_paths) 42 | 43 | 44 | self.num_samples = num_samples 45 | self.transform = transform 46 | 47 | 48 | def __getitem__(self, index): 49 | if self.training: 50 | ipath = random.randint(0, len(self.train_image_paths)-1) 51 | img_path = self.train_image_paths[ipath] 52 | else: 53 | ipath = random.randint(0, len(self.test_image_paths)-1) 54 | img_path = self.test_image_paths[ipath] 55 | 56 | img = Image.open(img_path) #.convert('RGB') 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | 60 | name_labels = img_path.split("_")[-2] 61 | 62 | if self.dataset == "single-body_2d_3classes": 63 | with open(img_path.replace(".png", ".json"), 'r') as f: 64 | my_dict = json.loads(f.read()) 65 | _size = my_dict[0] 66 | _color = my_dict[1][:3] 67 | 68 | if self.training: 69 | size, color = _size, _color 70 | else: 71 | # Define colors mapping 72 | colors_map = { 73 | '0': [0.9, 0.1, 0.1], 74 | '1': [0.1, 0.1, 0.9], 75 | '2': [0.1, 0.9, 0.1] 76 | } 77 | # Assign size and color based on label values 78 | size = 2.6 if int(name_labels[2]) == 0 else self.test_size 79 | color = colors_map[name_labels[1]] 80 | 81 | # Convert size and color to numpy arrays 82 | size = np.array(size, dtype=np.float32) 83 | color = np.array(color, dtype=np.float32) 84 | 85 | # Create the label dictionary 86 | label = {0: int(name_labels[0]), 1: color, 2: size} 87 | 88 | elif "celeba" in self.dataset: 89 | label = {i: int(name_labels[i]) for i in range(3)} 90 | 91 | 92 | return img, label 93 | 94 | def __len__(self): 95 | return self.num_samples 96 | 97 | 98 | if __name__ == '__main__': 99 | #transform = transforms.Compose([transforms.Resize((54,54)), transforms.ToTensor()]) 100 | transform = transforms.Compose([transforms.ToTensor()]) 101 | dataset = my_dataset(transform, dataset="single-body_2d_3classes", n_class_size=1, n_class_color=1, configs=["000","010","100","001"]) 102 | dataloader = DataLoader(dataset, batch_size=4) 103 | 104 | for img, label in dataloader: 105 | print('label=',label) 106 | print(img.shape) 107 | plt.imshow(np.transpose(img[0].numpy(), (2,1,0))) 108 | plt.show() 109 | print('img.shape=',img.shape) 110 | exit() 111 | 112 | -------------------------------------------------------------------------------- /plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import json 3 | import numpy as np 4 | from PIL import Image 5 | from matplotlib.animation import FuncAnimation, PillowWriter 6 | from torchvision import models, transforms 7 | from matplotlib import cm 8 | import glob 9 | import matplotlib 10 | import linear_classifier_3classes 11 | import mlp_classifier_4classes 12 | import torch 13 | import seaborn as sns 14 | from matplotlib.offsetbox import AnchoredText 15 | import matplotlib.image as mpimg 16 | from torchvision import models, transforms 17 | import torch.nn as nn 18 | import pandas as pd 19 | from sklearn.metrics import log_loss 20 | font = {'size': 15} 21 | matplotlib.rc('font', **font) 22 | cp = sns.color_palette("colorblind") 23 | criterion = nn.CrossEntropyLoss() 24 | 25 | 26 | pixel_size = 28 27 | INPUT_DIM = pixel_size * pixel_size * 3 28 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 29 | n_class_color = 2 30 | 31 | scale_factor = 78 32 | tf = transforms.Compose([transforms.Resize((pixel_size,pixel_size)), transforms.ToTensor()]) 33 | 34 | 35 | 36 | def hamming_distance(target, train_configs): 37 | distance = 999 38 | for train_config in train_configs: 39 | tmpdist = 0 40 | for ic in range(len(target)): 41 | tmpdist += np.abs(int(train_config[ic])-int(target[ic])) 42 | if tmpdist None: 50 | super().__init__() 51 | ''' 52 | standard ResNet style convolutional block 53 | ''' 54 | self.same_channels = in_channels==out_channels 55 | self.is_res = is_res 56 | self.conv1 = nn.Sequential( 57 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 58 | nn.BatchNorm2d(out_channels), 59 | nn.GELU(), 60 | ) 61 | self.conv2 = nn.Sequential( 62 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 63 | nn.BatchNorm2d(out_channels), 64 | nn.GELU(), 65 | ) 66 | 67 | def forward(self, x: torch.Tensor) -> torch.Tensor: 68 | if self.is_res: 69 | x1 = self.conv1(x) 70 | x2 = self.conv2(x1) 71 | if self.same_channels: 72 | out = x + x2 73 | else: 74 | out = x1 + x2 75 | return out 76 | else: 77 | x1 = self.conv1(x) 78 | x2 = self.conv2(x1) 79 | return x2 80 | 81 | 82 | def l2norm(t): 83 | return F.normalize(t, dim = -1) 84 | 85 | def exists(val): 86 | return val is not None 87 | 88 | class Residual(nn.Module): 89 | def __init__(self, fn): 90 | super().__init__() 91 | self.fn = fn 92 | 93 | def forward(self, x, **kwargs): 94 | return self.fn(x, **kwargs) + x 95 | 96 | class RearrangeToSequence(nn.Module): 97 | def __init__(self, fn): 98 | super().__init__() 99 | self.fn = fn 100 | 101 | def forward(self, x): 102 | x = rearrange(x, 'b c ... -> b ... c') 103 | x, ps = pack([x], 'b * c') 104 | 105 | x = self.fn(x) 106 | 107 | x, = unpack(x, ps, 'b * c') 108 | x = rearrange(x, 'b ... c -> b c ...') 109 | return x 110 | 111 | class LayerNorm(nn.Module): 112 | def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False): 113 | super().__init__() 114 | self.eps = eps 115 | self.fp16_eps = fp16_eps 116 | self.stable = stable 117 | self.g = nn.Parameter(torch.ones(dim)) 118 | 119 | def forward(self, x): 120 | eps = self.eps if x.dtype == torch.float32 else self.fp16_eps 121 | 122 | if self.stable: 123 | x = x / x.amax(dim = -1, keepdim = True).detach() 124 | 125 | var = torch.var(x, dim = -1, unbiased = False, keepdim = True) 126 | mean = torch.mean(x, dim = -1, keepdim = True) 127 | return (x - mean) * (var + eps).rsqrt() * self.g 128 | 129 | class Attention(nn.Module): 130 | def __init__( 131 | self, 132 | dim, 133 | *, 134 | dim_head = 64, 135 | heads = 8, 136 | dropout = 0., 137 | causal = False, 138 | rotary_emb = None, 139 | cosine_sim = True, 140 | cosine_sim_scale = 16 141 | ): 142 | super().__init__() 143 | self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5) 144 | self.cosine_sim = cosine_sim 145 | 146 | self.heads = heads 147 | inner_dim = dim_head * heads 148 | 149 | self.causal = causal 150 | self.norm = LayerNorm(dim) 151 | self.dropout = nn.Dropout(dropout) 152 | 153 | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) 154 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 155 | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) 156 | 157 | self.rotary_emb = rotary_emb 158 | 159 | self.to_out = nn.Sequential( 160 | nn.Linear(inner_dim, dim, bias = False), 161 | LayerNorm(dim) 162 | ) 163 | 164 | def forward(self, x, mask = None, attn_bias = None): 165 | b, n, device = *x.shape[:2], x.device 166 | 167 | x = self.norm(x) 168 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) 169 | 170 | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) 171 | q = q * self.scale 172 | 173 | # rotary embeddings 174 | if exists(self.rotary_emb): 175 | q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k)) 176 | 177 | # add null key / value for classifier free guidance in prior net 178 | nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2)) 179 | k = torch.cat((nk, k), dim = -2) 180 | v = torch.cat((nv, v), dim = -2) 181 | 182 | # whether to use cosine sim 183 | if self.cosine_sim: 184 | q, k = map(l2norm, (q, k)) 185 | 186 | q, k = map(lambda t: t * math.sqrt(self.scale), (q, k)) 187 | 188 | # calculate query / key similarities 189 | sim = torch.einsum('b h i d, b j d -> b h i j', q, k) 190 | 191 | # relative positional encoding (T5 style) 192 | 193 | if exists(attn_bias): 194 | sim = sim + attn_bias 195 | 196 | # masking 197 | max_neg_value = -torch.finfo(sim.dtype).max 198 | 199 | if exists(mask): 200 | mask = F.pad(mask, (1, 0), value = True) 201 | mask = rearrange(mask, 'b j -> b 1 1 j') 202 | sim = sim.masked_fill(~mask, max_neg_value) 203 | 204 | if self.causal: 205 | i, j = sim.shape[-2:] 206 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) 207 | sim = sim.masked_fill(causal_mask, max_neg_value) 208 | 209 | # attention 210 | attn = sim.softmax(dim = -1, dtype = torch.float32) 211 | attn = attn.type(sim.dtype) 212 | 213 | attn = self.dropout(attn) 214 | 215 | # aggregate values 216 | out = torch.einsum('b h i j, b j d -> b h i d', attn, v) 217 | 218 | out = rearrange(out, 'b h n d -> b n (h d)') 219 | return self.to_out(out) 220 | 221 | 222 | # Self and Cross Attention mechanism (Checked) 223 | class CrossAttention(nn.Module): 224 | '''General implementation of Cross & Self Attention multi-head 225 | ''' 226 | def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=8, ): 227 | super(CrossAttention, self).__init__() 228 | self.hidden_dim = hidden_dim 229 | self.context_dim = context_dim 230 | self.embed_dim = embed_dim 231 | self.num_heads = num_heads 232 | self.head_dim = embed_dim // num_heads 233 | self.to_q = nn.Linear(hidden_dim, embed_dim, bias=False) 234 | if context_dim is None: 235 | # Self Attention 236 | self.to_k = nn.Linear(hidden_dim, embed_dim, bias=False) 237 | self.to_v = nn.Linear(hidden_dim, embed_dim, bias=False) 238 | self.self_attn = True 239 | else: 240 | # Cross Attention 241 | self.to_k = nn.Linear(context_dim, embed_dim, bias=False) 242 | self.to_v = nn.Linear(context_dim, embed_dim, bias=False) 243 | self.self_attn = False 244 | self.to_out = nn.Sequential( 245 | nn.Linear(embed_dim, hidden_dim, bias=True) 246 | ) # this could be omitted 247 | 248 | def forward(self, tokens, context=None): 249 | Q = self.to_q(tokens) 250 | K = self.to_k(tokens) if self.self_attn else self.to_k(context) 251 | V = self.to_v(tokens) if self.self_attn else self.to_v(context) 252 | # print(Q.shape, K.shape, V.shape) 253 | # transform heads onto batch dimension 254 | Q = rearrange(Q, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim) 255 | K = rearrange(K, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim) 256 | V = rearrange(V, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim) 257 | # print(Q.shape, K.shape, V.shape) 258 | scoremats = torch.einsum("BTD,BSD->BTS", Q, K) 259 | attnmats = F.softmax(scoremats / math.sqrt(self.head_dim), dim=-1) 260 | # print(scoremats.shape, attnmats.shape, ) 261 | ctx_vecs = torch.einsum("BTS,BSD->BTD", attnmats, V) 262 | # split the heads transform back to hidden. 263 | ctx_vecs = rearrange(ctx_vecs, '(B H) T D -> B T (H D)', H=self.num_heads, D=self.head_dim) 264 | # TODO: note this `to_out` is also a linear layer, could be in principle merged into the to_value layer. 265 | return self.to_out(ctx_vecs) 266 | 267 | 268 | 269 | # Define the U-Net downsampling and upsampling components 270 | class UnetDown(nn.Module): 271 | def __init__(self, in_channels, out_channels, type_attention): 272 | super(UnetDown, self).__init__() 273 | ''' 274 | process and downscale the image feature maps 275 | ''' 276 | layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)] 277 | attention = nn.Identity() 278 | if type_attention=='self': 279 | create_self_attn = lambda dim: RearrangeToSequence(Residual(Attention(dim))) 280 | attention = create_self_attn(out_channels) 281 | if type_attention=='cross': 282 | create_self_attn = lambda dim: RearrangeToSequence(Residual(CrossAttention(dim, dim))) 283 | attention = create_self_attn(out_channels) 284 | self.model = nn.Sequential(*[ResidualConvBlock(in_channels, out_channels), attention, nn.MaxPool2d(2)]) 285 | 286 | def forward(self, x): 287 | return self.model(x) 288 | 289 | 290 | class UnetUp(nn.Module): 291 | def __init__(self, in_channels, out_channels, type_attention): 292 | super(UnetUp, self).__init__() 293 | ''' 294 | process and upscale the image feature maps 295 | ''' 296 | attention = nn.Identity() 297 | if type_attention=="self": 298 | create_self_attn = lambda dim: RearrangeToSequence(Residual(Attention(dim))) 299 | attention = create_self_attn(out_channels) 300 | if type_attention=="cross": 301 | create_self_attn = lambda dim: RearrangeToSequence(Residual(CrossAttention(dim, dim))) 302 | attention = create_self_attn(out_channels) 303 | layers = [ 304 | nn.ConvTranspose2d(in_channels, out_channels, 2, 2), 305 | ResidualConvBlock(out_channels, out_channels), 306 | attention, 307 | ResidualConvBlock(out_channels, out_channels), 308 | ] 309 | self.model = nn.Sequential(*layers) 310 | 311 | def forward(self, x, skip): 312 | x = torch.cat((x, skip), 1) 313 | x = self.model(x) 314 | return x 315 | 316 | 317 | class EmbedFC(nn.Module): 318 | def __init__(self, input_dim, emb_dim): 319 | super(EmbedFC, self).__init__() 320 | ''' 321 | generic one layer FC NN for embedding things 322 | ''' 323 | self.input_dim = input_dim 324 | layers = [ 325 | nn.Linear(input_dim, emb_dim), 326 | nn.GELU(), 327 | nn.Linear(emb_dim, emb_dim), 328 | ] 329 | self.model = nn.Sequential(*layers) 330 | 331 | def forward(self, x): 332 | x = x.view(-1, self.input_dim) 333 | return self.model(x) 334 | 335 | 336 | class ContextUnet(nn.Module): 337 | def __init__(self, in_channels, n_feat = 256, n_classes=10, dataset="", type_attention=""): 338 | super(ContextUnet, self).__init__() 339 | 340 | self.in_channels = in_channels 341 | self.n_contexts = len(n_classes) 342 | self.n_feat = 2 * n_feat 343 | self.n_classes = n_classes 344 | 345 | self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True) 346 | 347 | self.down1 = UnetDown(n_feat, n_feat, type_attention) 348 | self.down2 = UnetDown(n_feat, 2 * n_feat, type_attention) 349 | 350 | self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU()) 351 | 352 | self.timeembed1 = EmbedFC(1, 2*n_feat) 353 | self.timeembed2 = EmbedFC(1, 1*n_feat) 354 | 355 | ### embedding shape 356 | self.dataset = dataset 357 | self.n_out1 = 2*n_feat 358 | self.n_out2 = n_feat 359 | 360 | self.contextembed1 = nn.ModuleList([EmbedFC(self.n_classes[iclass], self.n_out1) for iclass in range(len(self.n_classes))]) 361 | self.contextembed2 = nn.ModuleList([EmbedFC(self.n_classes[iclass], self.n_out2) for iclass in range(len(self.n_classes))]) 362 | 363 | 364 | n_conv = 7 365 | self.up0 = nn.Sequential( 366 | nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, n_conv, n_conv), 367 | nn.GroupNorm(8, 2 * n_feat), 368 | nn.ReLU(), 369 | ) 370 | 371 | self.up1 = UnetUp(4 * n_feat, n_feat, type_attention) 372 | self.up2 = UnetUp(2 * n_feat, n_feat, type_attention) 373 | self.out = nn.Sequential( 374 | nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), 375 | nn.GroupNorm(8, n_feat), 376 | nn.ReLU(), 377 | nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), 378 | ) 379 | 380 | def forward(self, x, c, t, context_mask=None): 381 | # x is (noisy) image, c is context label, t is timestep, 382 | 383 | x = self.init_conv(x) 384 | down1 = self.down1(x) 385 | down2 = self.down2(down1) 386 | hiddenvec = self.to_vec(down2) 387 | 388 | temb1 = self.timeembed1(t).view(-1, int(self.n_feat), 1, 1) 389 | temb2 = self.timeembed2(t).view(-1, int(self.n_feat/2), 1, 1) 390 | 391 | # embed context, time step 392 | cemb1 = 0 393 | cemb2 = 0 394 | for ic in range(len(self.n_classes)): 395 | tmpc = c[ic] 396 | if tmpc.dtype==torch.int64: 397 | tmpc = nn.functional.one_hot(tmpc, num_classes=self.n_classes[ic]).type(torch.float) 398 | cemb1 += self.contextembed1[ic](tmpc).view(-1, int(self.n_out1/1.), 1, 1) 399 | cemb2 += self.contextembed2[ic](tmpc).view(-1, int(self.n_out2/1.), 1, 1) 400 | 401 | up1 = self.up0(hiddenvec) 402 | up2 = self.up1(cemb1*up1 + temb1, down2) 403 | up3 = self.up2(cemb2*up2 + temb2, down1) 404 | out = self.out(torch.cat((up3, x), 1)) 405 | return out 406 | 407 | 408 | def ddpm_schedules(beta1, beta2, T): 409 | """ 410 | Returns pre-computed schedules for DDPM sampling, training process. 411 | """ 412 | assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)" 413 | 414 | beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1 415 | sqrt_beta_t = torch.sqrt(beta_t) 416 | alpha_t = 1 - beta_t 417 | log_alpha_t = torch.log(alpha_t) 418 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() 419 | 420 | sqrtab = torch.sqrt(alphabar_t) 421 | oneover_sqrta = 1 / torch.sqrt(alpha_t) 422 | 423 | sqrtmab = torch.sqrt(1 - alphabar_t) 424 | mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab 425 | 426 | return { 427 | "alpha_t": alpha_t, # \alpha_t 428 | "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t} 429 | "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t} 430 | "alphabar_t": alphabar_t, # \bar{\alpha_t} 431 | "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} 432 | "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} 433 | "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}} 434 | } 435 | 436 | 437 | class DDPM(nn.Module): 438 | def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1, n_classes=None, flag_weight=0): 439 | super(DDPM, self).__init__() 440 | self.nn_model = nn_model.to(device) 441 | self.n_classes = n_classes 442 | 443 | for k, v in ddpm_schedules(betas[0], betas[1], n_T).items(): 444 | self.register_buffer(k, v) 445 | 446 | self.betas = torch.linspace(betas[0], betas[1], n_T).to(device) 447 | self.alphas = 1. - self.betas 448 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 449 | 450 | self.n_T = n_T 451 | self.device = device 452 | self.flag_weight = flag_weight 453 | self.drop_prob = drop_prob 454 | self.loss_mse = nn.MSELoss() 455 | 456 | def forward(self, x, c): 457 | """ 458 | this method is used in training, so samples t and noise randomly 459 | """ 460 | 461 | _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device) # t ~ Uniform(0, n_T) 462 | noise = torch.randn_like(x) # eps ~ N(0, 1) 463 | 464 | x_t = ( 465 | self.sqrtab[_ts, None, None, None] * x 466 | + self.sqrtmab[_ts, None, None, None] * noise 467 | ) 468 | 469 | return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T)) #, context_mask)) 470 | 471 | def sample(self, n_sample, c_gen, size, device, guide_w = 0.0): 472 | 473 | x_i = torch.randn(n_sample, *size).to(device) # x_T ~ N(0, 1), sample initial noise 474 | _c_gen = [tmpc_gen[:n_sample].to(device) for tmpc_gen in c_gen.values()] 475 | 476 | #context_mask = torch.zeros_like(_c_gen[0]).to(device) 477 | 478 | x_i_store = [] 479 | print() 480 | for i in range(self.n_T, 0, -1): 481 | print(f'sampling timestep {i}',end='\r') 482 | t_is = torch.tensor([i / self.n_T]).to(device) 483 | t_is = t_is.repeat(n_sample,1,1,1) 484 | 485 | z = torch.randn(n_sample, *size).to(device) if i > 1 else 0 486 | eps = self.nn_model(x_i, _c_gen, t_is) #, context_mask) 487 | x_i = ( 488 | self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) 489 | + self.sqrt_beta_t[i] * z 490 | ) 491 | if i%20==0: 492 | x_i_store.append(x_i.detach().cpu().numpy()) 493 | 494 | x_i_store = np.array(x_i_store) 495 | return x_i, x_i_store 496 | 497 | 498 | def ddim_step(self, x_t, t, noise_pred): 499 | """ 500 | DDIM step to predict the next state of the image. 501 | """ 502 | alpha_t = self.alphas_cumprod[t] 503 | alpha_t_1 = torch.where(t > 0, self.alphas_cumprod[t-1], torch.tensor(1.0).to(self.device)) 504 | sigma_t = torch.sqrt((1 - alpha_t_1) / (1 - alpha_t) * (1 - alpha_t / alpha_t_1)) 505 | alpha_t = alpha_t.view(-1,1,1,1) 506 | sigma_t = sigma_t.view(-1,1,1,1) 507 | alpha_t_1 = alpha_t_1.view(-1,1,1,1) 508 | 509 | x_0_pred = (x_t - sigma_t * noise_pred) / torch.sqrt(alpha_t) 510 | x_t_1 = torch.sqrt(alpha_t_1) * x_0_pred + sigma_t * torch.randn_like(x_t) 511 | return x_t_1 512 | 513 | def sample_ddim(self, n_sample, c_gen, size, device): 514 | """ 515 | Sample using the DDIM scheduler. 516 | """ 517 | x_t = torch.randn(n_sample, *size).to(device) # Initialize with noise 518 | 519 | _c_gen = {k: v.to(device) for k, v in c_gen.items()} 520 | 521 | x_i_store = [] 522 | for i in reversed(range(0, self.n_T)): 523 | print(f'sampling timestep {i}',end='\r') 524 | t = torch.full((n_sample,), i, device=device, dtype=torch.long) 525 | noise_pred = self.nn_model(x_t, _c_gen, t.float() / self.n_T) 526 | x_t = self.ddim_step(x_t, t, noise_pred) 527 | 528 | if i%20==0: 529 | x_i_store.append(x_t.detach().cpu().numpy()) 530 | 531 | x_i_store = np.array(x_i_store) 532 | return x_t, x_i_store 533 | 534 | 535 | 536 | def training(args): 537 | 538 | n_epoch = args.n_epoch 539 | batch_size = args.batch_size 540 | n_T = args.n_T 541 | n_feat = args.n_feat 542 | lrate = args.lrate 543 | alpha = args.alpha 544 | beta = args.beta 545 | test_size = args.test_size 546 | dataset = args.dataset 547 | num_samples = args.num_samples 548 | pixel_size = args.pixel_size 549 | experiment = args.experiment 550 | n_sample = args.n_sample 551 | type_attention = args.type_attention 552 | remove_node = args.remove_node 553 | seed = args.seed 554 | scheduler = args.scheduler 555 | in_channels = 3 if "celeba" in dataset else 4 556 | 557 | 558 | torch.manual_seed(seed) 559 | if torch.cuda.is_available(): 560 | torch.cuda.manual_seed(seed) 561 | torch.cuda.manual_seed_all(seed) 562 | torch.backends.cudnn.deterministic = True 563 | torch.backends.cudnn.benchmark = False 564 | np.random.seed(seed) 565 | random.seed(seed) 566 | 567 | 568 | with open("config_category.json", 'r') as f: 569 | configs = json.load(f)[experiment] 570 | 571 | 572 | experiment_classes = { 573 | "H42-train1": [2, 3, 1, 1], 574 | "H22-train1": [2, 2], 575 | "default": [2, 3, 1] 576 | } 577 | n_classes = experiment_classes.get(experiment, experiment_classes["default"]) 578 | 579 | if "celeba" in dataset: 580 | n_classes = [2,2,2] 581 | 582 | tf = transforms.Compose([transforms.Resize((pixel_size,pixel_size)), transforms.ToTensor()]) 583 | 584 | 585 | save_dir = './output/'+dataset+'/'+experiment+'/' 586 | if not os.path.isdir(save_dir): os.makedirs(save_dir) 587 | save_dir = save_dir + str(num_samples) + "_" + str(test_size) + "_" + str(n_feat) + "_" + str(n_T) + "_" + str(n_epoch) \ 588 | + "_" + str(lrate) + "_" + remove_node + "_" + str(alpha) + "_" + str(beta) + "_" + str(seed) + "/" #+ str(type_attention) + "/" 589 | if not os.path.isdir(save_dir): os.makedirs(save_dir) 590 | 591 | ddpm = DDPM(nn_model=ContextUnet(in_channels=in_channels, n_feat=n_feat, n_classes=n_classes, dataset=dataset, type_attention=type_attention), 592 | betas=(lrate, 0.02), n_T=n_T, device=device, drop_prob=0.1, n_classes=n_classes) 593 | ddpm.to(device) 594 | 595 | 596 | train_dataset = load_dataset.my_dataset(tf, num_samples, dataset, configs=configs["train"], training=True, alpha=alpha, remove_node=remove_node) 597 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1) 598 | 599 | 600 | test_dataloaders = {} 601 | log_dict = {'train_loss_per_batch': [], 602 | 'test_loss_per_batch': {key: [] for key in configs["test"]}} 603 | output_configs = list(set(configs["test"] + configs["train"])) 604 | for config in output_configs: 605 | test_dataset = load_dataset.my_dataset(tf, n_sample, dataset, configs=config, training=False, test_size=test_size) 606 | test_dataloaders[config] = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1) 607 | 608 | optim = torch.optim.Adam(ddpm.parameters(), lr=lrate) 609 | 610 | for ep in range(n_epoch): 611 | print(f'epoch {ep}') 612 | 613 | ddpm.train() 614 | 615 | # linear lrate decay 616 | optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch) 617 | 618 | pbar = tqdm(train_dataloader) 619 | for x, c in pbar: 620 | optim.zero_grad() 621 | x = x.to(device) 622 | _c = [tmpc.to(device) for tmpc in c.values()] 623 | loss = ddpm(x, _c) 624 | log_dict['train_loss_per_batch'].append(loss.item()) 625 | loss.backward() 626 | loss_ema = loss.item() 627 | pbar.set_description(f"loss: {loss_ema:.4f}") 628 | optim.step() 629 | 630 | 631 | ddpm.eval() 632 | with torch.no_grad(): 633 | 634 | for test_config in configs["test"]: 635 | for test_x, test_c in test_dataloaders[test_config]: 636 | test_x = test_x.to(device) 637 | _test_c = [tmptest_c.to(device) for tmptest_c in test_c.values()] 638 | test_loss = ddpm(test_x, _test_c) 639 | log_dict['test_loss_per_batch'][test_config].append(test_loss.item()) 640 | 641 | if (ep + 1) % 100 == 0 or ep >= (n_epoch - 5): 642 | for test_config in output_configs: 643 | x_real, c_gen = next(iter(test_dataloaders[test_config])) 644 | x_real = x_real[:n_sample].to(device) 645 | if scheduler=="DDIM": 646 | x_gen, x_gen_store = ddpm.sample_ddim(n_sample, c_gen, (in_channels, pixel_size, pixel_size), device) 647 | else: 648 | x_gen, x_gen_store = ddpm.sample(n_sample, c_gen, (in_channels, pixel_size, pixel_size), device, guide_w=0.0) 649 | np.savez_compressed(save_dir + f"image_"+test_config+"_ep"+str(ep)+".npz", x_gen=x_gen.detach().cpu().numpy()) 650 | print('saved image at ' + save_dir + f"image_"+test_config+"_ep"+str(ep)+".png") 651 | 652 | if ep + 1 == n_epoch: 653 | np.savez_compressed(save_dir + f"gen_store_"+test_config+"_ep"+str(ep)+".npz", x_gen_store=x_gen_store) 654 | print('saved image file at ' + save_dir + f"gen_store_"+test_config+"_ep"+str(ep)+".npz") 655 | 656 | 657 | if (ep + 1) == n_epoch: 658 | with open(save_dir + f"training_log_"+str(ep)+".json", "w") as outfile: 659 | json.dump(log_dict, outfile) 660 | 661 | 662 | 663 | if __name__ == "__main__": 664 | args = parser.parse_args() 665 | training(args) 666 | 667 | 668 | --------------------------------------------------------------------------------