├── dataloaders ├── __init__.py ├── utils.py ├── wrapper.py ├── datasetGen.py └── base.py ├── imgs └── teaser_imagenet.png ├── guide ├── __init__.py ├── fake_mpi.py ├── validation.py ├── losses.py ├── layers.py ├── dist_util.py ├── respace.py ├── nn.py ├── resample.py ├── script_args.py ├── resnet.py ├── fp16_util.py ├── script_util.py ├── logger.py └── evaluator.py ├── .gitignore ├── setup.py ├── scripts ├── mrunner_train.py ├── image_sample.py ├── classifier_sample_universal.py └── image_train.py ├── cl_methods ├── utils.py ├── base.py ├── generative_replay_disjoint_classifier_guidance.py └── generative_replay.py └── README.md /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/teaser_imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cywinski/guide/HEAD/imgs/teaser_imagenet.png -------------------------------------------------------------------------------- /guide/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | guide.egg-info/ 2 | .DS_Store 3 | __pycache__/ 4 | classify_image_graph_def.pb 5 | .idea 6 | data 7 | results 8 | wandb 9 | .vscode 10 | venv 11 | *.out 12 | slurm_out 13 | *.pt 14 | *.csv 15 | .ipynb_checkpoints 16 | *.ipynb 17 | *.gif 18 | *.npz 19 | results*/ 20 | *.txt 21 | *.pdf 22 | -------------------------------------------------------------------------------- /guide/fake_mpi.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | 4 | class CommWorld: 5 | def __init__(self): 6 | self.rank = 0 7 | self.size = 1 8 | 9 | def Get_rank(self): 10 | return 0 11 | 12 | def Get_size(self): 13 | return 1 14 | 15 | def bcast(self, value, root=0): 16 | return value 17 | 18 | 19 | MPI = Namespace(COMM_WORLD=CommWorld()) 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="guide", 5 | py_modules=["guide"], 6 | install_requires=[ 7 | "blobfile>=1.0.5", 8 | "tqdm==4.66.2", 9 | "scikit-learn==1.3.0", 10 | "torchmetrics==1.3.2", 11 | "kornia==0.7.1", 12 | "wandb==0.16.5", 13 | "matplotlib==3.7.5", 14 | "pytz==2024.1", 15 | ], 16 | ) 17 | -------------------------------------------------------------------------------- /scripts/mrunner_train.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | from mrunner.helpers.client_helper import get_configuration 4 | 5 | from scripts.image_train import run_training_with_args 6 | 7 | 8 | def main(): 9 | params = get_configuration(print_diagnostics=True, with_neptune=False) 10 | run_training_with_args(Namespace(**params)) 11 | 12 | 13 | if __name__ == "__main__": 14 | main() 15 | -------------------------------------------------------------------------------- /cl_methods/utils.py: -------------------------------------------------------------------------------- 1 | from cl_methods.generative_replay import GenerativeReplay 2 | from cl_methods.generative_replay_disjoint_classifier_guidance import ( 3 | GenerativeReplayDisjointClassifierGuidance, 4 | ) 5 | 6 | 7 | def get_cl_method(args): 8 | if args.cl_method == "generative_replay": 9 | return GenerativeReplay(args) 10 | if args.cl_method == "generative_replay_disjoint_classifier_guidance": 11 | return GenerativeReplayDisjointClassifierGuidance(args) 12 | assert False, "bad cl method!" 13 | -------------------------------------------------------------------------------- /cl_methods/base.py: -------------------------------------------------------------------------------- 1 | """CL-method-specific data preparation.""" 2 | 3 | from torch.utils.data import DataLoader 4 | 5 | from dataloaders.utils import yielder 6 | 7 | 8 | class CLMethod: 9 | def __init__(self, args): 10 | self.args = args 11 | 12 | def get_data_for_task(self, dataset, task_id, train_loop): 13 | loader = DataLoader( 14 | dataset=dataset, 15 | batch_size=self.args.batch_size, 16 | shuffle=True, 17 | drop_last=True, 18 | ) 19 | return yielder(loader), loader, None 20 | -------------------------------------------------------------------------------- /cl_methods/generative_replay_disjoint_classifier_guidance.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from cl_methods.base import CLMethod 4 | from dataloaders.utils import yielder 5 | 6 | 7 | class GenerativeReplayDisjointClassifierGuidance(CLMethod): 8 | def get_data_for_task( 9 | self, 10 | dataset, 11 | task_id, 12 | train_loop, 13 | generator=None, 14 | step=None, 15 | ): 16 | loader = DataLoader( 17 | dataset=dataset, 18 | batch_size=self.args.batch_size // (task_id + 1), 19 | shuffle=True, 20 | drop_last=True, 21 | generator=generator, 22 | ) 23 | return yielder(loader), loader, None 24 | -------------------------------------------------------------------------------- /guide/validation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | @torch.no_grad() 6 | def calculate_accuracy_with_classifier( 7 | model, 8 | task_id, 9 | val_loader, 10 | device, 11 | train_loader=None, 12 | max_class=0, 13 | train_with_disjoint_classifier=False, 14 | ): 15 | model.eval() 16 | loader = ( 17 | {"test": val_loader, "train": train_loader} 18 | if train_loader is not None 19 | else {"test": val_loader} 20 | ) 21 | correct = {"test": 0.0, "train": 0.0} 22 | total = {"test": 0.0, "train": 0.0} 23 | loss = {"test": 0.0, "train": 0.0} 24 | print("Calculating accuracy:") 25 | for phase in loader.keys(): 26 | for idx, batch in enumerate(loader[phase]): 27 | x, cond = batch 28 | x = x.to(device) 29 | y = cond["y"].to(device) 30 | out_classifier = model(x) 31 | preds = torch.argmax(out_classifier, 1) 32 | correct[phase] += (preds == torch.argmax(y, 1)).sum() 33 | total[phase] += len(y) 34 | loss[phase] += F.cross_entropy( 35 | out_classifier[:, : max_class + 1], y[:, : max_class + 1] 36 | ) 37 | loss[phase] /= idx 38 | correct[phase] /= total[phase] 39 | model.train() 40 | return { 41 | "loss": loss, 42 | "accuracy": correct, 43 | } 44 | -------------------------------------------------------------------------------- /dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import ConcatDataset, DataLoader 4 | 5 | 6 | def yielder(loader): 7 | while True: 8 | yield from loader 9 | 10 | 11 | def concat_yielder(loaders): 12 | iterators = [iter(loader) for loader in loaders] 13 | while True: 14 | result = [] 15 | for i in range(len(iterators)): 16 | try: 17 | x = next(iterators[i]) 18 | except StopIteration: 19 | iterators[i] = iter(loaders[i]) 20 | x = next(iterators[i]) 21 | result.append(x) 22 | 23 | yield recursive_concat(result) 24 | 25 | 26 | def recursive_concat(l): 27 | """Concat elements in list l. Each element can be a tensor, dicts of tensors, tuple, etc.""" 28 | if isinstance(l[0], torch.Tensor): 29 | return torch.cat(l) 30 | if isinstance(l[0], dict): 31 | keys = set(l[0].keys()) 32 | for x in l[1:]: 33 | assert set(x.keys()) == keys 34 | return {k: recursive_concat([x[k] for x in l]) for k in keys} 35 | if isinstance(l[0], tuple) or isinstance(l[0], list): 36 | length = len(l[0]) 37 | for x in l[1:]: 38 | assert len(x) == length 39 | return tuple([recursive_concat([x[i] for x in l]) for i in range(length)]) 40 | 41 | 42 | def get_stratified_subset(frac_selected, labels, seed=0): 43 | """Returns indices of a subset with a given percentage of elements for each class.""" 44 | labels = np.array(labels) 45 | rng = np.random.default_rng(seed=seed) 46 | res = [] 47 | for l in np.unique(labels): 48 | all_indices = np.nonzero(labels == l)[0] 49 | num_selected = int(frac_selected * len(all_indices)) 50 | res.append(rng.choice(all_indices, num_selected, replace=False)) 51 | res = np.concatenate(res) 52 | return res 53 | 54 | 55 | def prepare_eval_loaders( 56 | train_dataset_splits, val_dataset_splits, args, include_train, generator=False 57 | ): 58 | eval_loaders = [] 59 | for task_id in range(args.num_tasks): 60 | if include_train: 61 | eval_data = ConcatDataset( 62 | [train_dataset_splits[task_id], val_dataset_splits[task_id]] 63 | ) 64 | else: 65 | eval_data = val_dataset_splits[task_id] 66 | eval_loader = DataLoader( 67 | dataset=eval_data, 68 | batch_size=args.batch_size, 69 | shuffle=False, 70 | generator=generator, 71 | ) 72 | eval_loaders.append(eval_loader) 73 | 74 | return eval_loaders 75 | -------------------------------------------------------------------------------- /guide/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | import torch as th 9 | 10 | 11 | def normal_kl(mean1, logvar1, mean2, logvar2): 12 | """ 13 | Compute the KL divergence between two gaussians. 14 | 15 | Shapes are automatically broadcasted, so batches can be compared to 16 | scalars, among other use cases. 17 | """ 18 | tensor = None 19 | for obj in (mean1, logvar1, mean2, logvar2): 20 | if isinstance(obj, th.Tensor): 21 | tensor = obj 22 | break 23 | assert tensor is not None, "at least one argument must be a Tensor" 24 | 25 | # Force variances to be Tensors. Broadcasting helps convert scalars to 26 | # Tensors, but it does not work for th.exp(). 27 | logvar1, logvar2 = [ 28 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 29 | for x in (logvar1, logvar2) 30 | ] 31 | 32 | return 0.5 * ( 33 | -1.0 34 | + logvar2 35 | - logvar1 36 | + th.exp(logvar1 - logvar2) 37 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 38 | ) 39 | 40 | 41 | def approx_standard_normal_cdf(x): 42 | """ 43 | A fast approximation of the cumulative distribution function of the 44 | standard normal. 45 | """ 46 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 47 | 48 | 49 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 50 | """ 51 | Compute the log-likelihood of a Gaussian distribution discretizing to a 52 | given image. 53 | 54 | :param x: the target images. It is assumed that this was uint8 values, 55 | rescaled to the range [-1, 1]. 56 | :param means: the Gaussian mean Tensor. 57 | :param log_scales: the Gaussian log stddev Tensor. 58 | :return: a tensor like x of log probabilities (in nats). 59 | """ 60 | assert x.shape == means.shape == log_scales.shape 61 | centered_x = x - means 62 | inv_stdv = th.exp(-log_scales) 63 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 64 | cdf_plus = approx_standard_normal_cdf(plus_in) 65 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 66 | cdf_min = approx_standard_normal_cdf(min_in) 67 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 68 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 69 | cdf_delta = cdf_plus - cdf_min 70 | log_probs = th.where( 71 | x < -0.999, 72 | log_cdf_plus, 73 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 74 | ) 75 | assert log_probs.shape == x.shape 76 | return log_probs 77 | -------------------------------------------------------------------------------- /cl_methods/generative_replay.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch.utils.data import ConcatDataset, DataLoader, TensorDataset 3 | 4 | from cl_methods.base import CLMethod 5 | from dataloaders.utils import yielder 6 | from dataloaders.wrapper import AppendName 7 | from guide import dist_util 8 | from guide.logger import get_rank_without_mpi_import, log_generated_examples 9 | 10 | 11 | class GenerativeReplay(CLMethod): 12 | def get_data_for_task( 13 | self, 14 | dataset, 15 | task_id, 16 | train_loop, 17 | generator=None, 18 | step=None, 19 | ): 20 | if task_id == 0: 21 | train_dataset_loader = DataLoader( 22 | dataset=dataset, 23 | batch_size=self.args.batch_size, 24 | shuffle=True, 25 | drop_last=True, 26 | generator=generator, 27 | ) 28 | dataset_yielder = yielder(train_dataset_loader) 29 | else: 30 | print("Preparing dataset for rehearsal...") 31 | if self.args.gr_n_generated_examples_per_task <= self.args.batch_size: 32 | batch_size = self.args.gr_n_generated_examples_per_task 33 | else: 34 | batch_size = self.args.batch_size 35 | ( 36 | generated_previous_examples, 37 | generated_previous_examples_labels, 38 | generated_previous_examples_confidences, 39 | ) = train_loop.generate_examples( 40 | task_id - 1, 41 | self.args.gr_n_generated_examples_per_task, 42 | batch_size=batch_size, 43 | equal_n_examples_per_class=True, 44 | use_old_grad=False, 45 | use_new_grad=False, 46 | ) 47 | generated_dataset = AppendName( 48 | TensorDataset( 49 | generated_previous_examples, generated_previous_examples_labels 50 | ), 51 | generated_previous_examples_labels.cpu().numpy(), 52 | True, 53 | False, 54 | ) 55 | joined_dataset = ConcatDataset([dataset, generated_dataset]) 56 | train_dataset_loader = DataLoader( 57 | dataset=joined_dataset, 58 | batch_size=self.args.batch_size, 59 | shuffle=True, 60 | drop_last=True, 61 | generator=generator, 62 | ) 63 | dataset_yielder = yielder(train_dataset_loader) 64 | if get_rank_without_mpi_import() == 0: 65 | log_generated_examples( 66 | generated_previous_examples, 67 | th.argmax(generated_previous_examples_labels, 1), 68 | generated_previous_examples_confidences, 69 | task_id, 70 | step=step, 71 | ) 72 | 73 | return ( 74 | dataset_yielder, 75 | train_dataset_loader, 76 | generated_previous_examples if task_id != 0 else None, 77 | ) 78 | -------------------------------------------------------------------------------- /guide/layers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from torch import nn 4 | 5 | 6 | class ConvBlock(nn.Module): 7 | def __init__( 8 | self, 9 | opt, 10 | in_channels, 11 | out_channels, 12 | kernel_size, 13 | stride=1, 14 | padding=0, 15 | bias=False, 16 | groups=1, 17 | ): 18 | super(ConvBlock, self).__init__() 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | self.kernel_size = kernel_size 22 | conv = nn.Conv2d( 23 | in_channels, 24 | out_channels, 25 | kernel_size=kernel_size, 26 | stride=stride, 27 | padding=padding, 28 | bias=bias, 29 | groups=groups, 30 | ) 31 | 32 | layer = [conv] 33 | if opt.bn: 34 | if opt.preact: 35 | bn = getattr(nn, opt.normtype + "2d")( 36 | num_features=in_channels, affine=opt.affine_bn, eps=opt.bn_eps 37 | ) 38 | layer = [bn] 39 | else: 40 | bn = getattr(nn, opt.normtype + "2d")( 41 | num_features=out_channels, affine=opt.affine_bn, eps=opt.bn_eps 42 | ) 43 | layer = [conv, bn] 44 | 45 | if opt.activetype != "None": 46 | active = getattr(nn, opt.activetype)() 47 | layer.append(active) 48 | 49 | if opt.bn and opt.preact: 50 | layer.append(conv) 51 | 52 | self.block = nn.Sequential(*layer) 53 | 54 | def forward(self, input): 55 | return self.block.forward(input) 56 | 57 | 58 | class FCBlock(nn.Module): 59 | def __init__(self, opt, in_channels, out_channels, bias=False): 60 | super(FCBlock, self).__init__() 61 | self.in_channels = in_channels 62 | self.out_channels = out_channels 63 | lin = nn.Linear(in_channels, out_channels, bias=bias) 64 | 65 | layer = [lin] 66 | if opt.bn: 67 | if opt.preact: 68 | bn = getattr(nn, opt.normtype + "1d")( 69 | num_features=in_channels, affine=opt.affine_bn, eps=opt.bn_eps 70 | ) 71 | layer = [bn] 72 | else: 73 | bn = getattr(nn, opt.normtype + "1d")( 74 | num_features=out_channels, affine=opt.affine_bn, eps=opt.bn_eps 75 | ) 76 | layer = [lin, bn] 77 | 78 | if opt.activetype != "None": 79 | active = getattr(nn, opt.activetype)() 80 | layer.append(active) 81 | 82 | if opt.bn and opt.preact: 83 | layer.append(lin) 84 | 85 | self.block = nn.Sequential(*layer) 86 | 87 | def forward(self, input): 88 | return self.block.forward(input) 89 | 90 | 91 | def FinalBlock(opt, in_channels, bias=False): 92 | out_channels = opt.model_num_classes 93 | opt = copy.deepcopy(opt) 94 | if not opt.preact: 95 | opt.activetype = "None" 96 | return FCBlock( 97 | opt=opt, in_channels=in_channels, out_channels=out_channels, bias=bias 98 | ) 99 | 100 | 101 | def InitialBlock(opt, out_channels, kernel_size, stride=1, padding=0, bias=False): 102 | in_channels = opt.in_channels 103 | opt = copy.deepcopy(opt) 104 | return ConvBlock( 105 | opt=opt, 106 | in_channels=in_channels, 107 | out_channels=out_channels, 108 | kernel_size=kernel_size, 109 | stride=stride, 110 | padding=padding, 111 | bias=bias, 112 | ) 113 | -------------------------------------------------------------------------------- /guide/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | import warnings 9 | 10 | import blobfile as bf 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | try: 15 | from mpi4py import MPI 16 | except ImportError: 17 | from guide.fake_mpi import MPI 18 | 19 | warnings.warn("Using fake MPI!") 20 | 21 | # Change this to reflect your cluster layout. 22 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 23 | GPUS_PER_NODE = 8 24 | # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 25 | SETUP_RETRY_COUNT = 3 26 | # GPU_ID = "0" 27 | 28 | 29 | def setup_dist(args): 30 | """ 31 | Setup a distributed process group. 32 | """ 33 | # global GPU_ID 34 | # if args.gpu_id == -1: 35 | # os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 36 | # GPU_ID = "" 37 | # elif args.gpu_id != -2: 38 | # # GPU_ID = f":{args.gpu_id}" 39 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) 40 | 41 | if dist.is_initialized(): 42 | return 43 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 44 | print("visible devices:", os.environ["CUDA_VISIBLE_DEVICES"]) 45 | comm = MPI.COMM_WORLD 46 | backend = "gloo" if not th.cuda.is_available() else "nccl" 47 | 48 | if backend == "gloo": 49 | hostname = "localhost" 50 | else: 51 | hostname = socket.gethostbyname(socket.getfqdn()) 52 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 53 | os.environ["RANK"] = str(comm.rank) 54 | os.environ["WORLD_SIZE"] = str(comm.size) 55 | print("world size:", os.environ["WORLD_SIZE"]) 56 | 57 | port = comm.bcast(_find_free_port(), root=0) 58 | os.environ["MASTER_PORT"] = str(port) 59 | dist.init_process_group(backend=backend, init_method="env://") 60 | 61 | 62 | def dev(): 63 | """ 64 | Get the device to use for torch.distributed. 65 | """ 66 | # global GPU_ID 67 | if th.cuda.is_available(): 68 | # return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}") 69 | return th.device("cuda") 70 | return th.device("cpu") 71 | 72 | 73 | def load_state_dict(path, **kwargs): 74 | """ 75 | Load a PyTorch file without redundant fetches across MPI ranks. 76 | """ 77 | chunk_size = 2**30 # MPI has a relatively small size limit 78 | if MPI.COMM_WORLD.Get_rank() == 0: 79 | with bf.BlobFile(path, "rb") as f: 80 | data = f.read() 81 | num_chunks = len(data) // chunk_size 82 | if len(data) % chunk_size: 83 | num_chunks += 1 84 | MPI.COMM_WORLD.bcast(num_chunks) 85 | for i in range(0, len(data), chunk_size): 86 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 87 | else: 88 | num_chunks = MPI.COMM_WORLD.bcast(None) 89 | data = bytes() 90 | for _ in range(num_chunks): 91 | data += MPI.COMM_WORLD.bcast(None) 92 | 93 | return th.load(io.BytesIO(data), **kwargs) 94 | 95 | 96 | def sync_params(params): 97 | """ 98 | Synchronize a sequence of Tensors across ranks from rank 0. 99 | """ 100 | # if GPU_ID!="": 101 | # rank = int(GPU_ID[-1:]) 102 | # else: 103 | # rank = 0 104 | # if GPU_ID == "": 105 | for p in params: 106 | with th.no_grad(): 107 | dist.broadcast(p, 0) 108 | 109 | 110 | def _find_free_port(): 111 | try: 112 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 113 | s.bind(("", 0)) 114 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 115 | return s.getsockname()[1] 116 | finally: 117 | s.close() 118 | -------------------------------------------------------------------------------- /scripts/image_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch as th 12 | import torch.distributed as dist 13 | from torchvision.utils import make_grid 14 | 15 | from guide import dist_util, logger 16 | from guide.script_args import ( 17 | add_dict_to_argparser, 18 | all_training_defaults, 19 | args_to_dict, 20 | preprocess_args, 21 | ) 22 | from guide.script_util import create_model_and_diffusion, model_and_diffusion_defaults 23 | 24 | 25 | def main(): 26 | args = create_argparser().parse_args() 27 | preprocess_args(args) 28 | 29 | os.environ["OPENAI_LOGDIR"] = f"sampled/{args.wandb_experiment_name}" 30 | 31 | dist_util.setup_dist(args) 32 | logger.configure() 33 | 34 | logger.log("creating model and diffusion...") 35 | model, diffusion = create_model_and_diffusion( 36 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 37 | ) 38 | model.load_state_dict( 39 | dist_util.load_state_dict(args.model_path, map_location="cpu") 40 | ) 41 | model.to(dist_util.dev()) 42 | if args.use_fp16: 43 | model.convert_to_fp16() 44 | model.eval() 45 | 46 | logger.log("sampling...") 47 | all_images = [] 48 | all_labels = [] 49 | while len(all_images) < args.num_samples: 50 | model_kwargs = {} 51 | if args.class_cond: 52 | classes = th.randint( 53 | low=0, 54 | high=args.num_classes, 55 | size=(args.batch_size,), 56 | device=dist_util.dev(), 57 | ) 58 | model_kwargs["y"] = classes 59 | sample_fn = ( 60 | diffusion.ddim_sample_loop if args.use_ddim else diffusion.p_sample_loop 61 | ) 62 | sample = sample_fn( 63 | model, 64 | (args.batch_size, args.in_channels, args.image_size, args.image_size), 65 | clip_denoised=args.clip_denoised, 66 | model_kwargs=model_kwargs, 67 | ) 68 | all_images.extend(sample.cpu().numpy()) 69 | # gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 70 | # dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 71 | # all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 72 | # if args.class_cond: 73 | # gathered_labels = [ 74 | # th.zeros_like(classes) for _ in range(dist.get_world_size()) 75 | # ] 76 | # dist.all_gather(gathered_labels, classes) 77 | # all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 78 | logger.log(f"created {len(all_images)} samples") 79 | 80 | arr = np.concatenate([all_images]) 81 | arr = arr[: args.num_samples] 82 | if args.class_cond: 83 | label_arr = np.concatenate(all_labels, axis=0) 84 | label_arr = label_arr[: args.num_samples] 85 | if dist.get_rank() == 0: 86 | shape_str = "x".join([str(x) for x in arr.shape]) 87 | out_path = os.path.join( 88 | logger.get_dir(), f"samples_{args.model_path.split('/')[-1][:-3]}.npz" 89 | ) 90 | logger.log(f"saving to {out_path}") 91 | if args.class_cond: 92 | np.savez(out_path, arr, label_arr) 93 | else: 94 | np.savez(out_path, arr) 95 | 96 | dist.barrier() 97 | plt.figure() 98 | plt.axis("off") 99 | samples_grid = make_grid(th.from_numpy(arr[:16]), 4, normalize=True).permute( 100 | 1, 2, 0 101 | ) 102 | plt.imshow(samples_grid) 103 | out_plot = os.path.join( 104 | logger.get_dir(), f"samples_{args.model_path.split('/')[-1][:-3]}" 105 | ) 106 | plt.savefig(out_plot) 107 | logger.log("sampling complete") 108 | 109 | 110 | def create_argparser(): 111 | defaults = all_training_defaults() 112 | defaults.update(model_and_diffusion_defaults()) 113 | defaults.update( 114 | dict( 115 | model_path="", 116 | num_samples=2, 117 | image_size=28, 118 | in_channels=1, 119 | model_num_classes=10, 120 | batch_size=16, 121 | wandb_experiment_name="test", 122 | ) 123 | ) 124 | parser = argparse.ArgumentParser() 125 | add_dict_to_argparser(parser, defaults) 126 | return parser 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /dataloaders/wrapper.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from os import path 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | 8 | 9 | class CacheClassLabel(data.Dataset): 10 | """ 11 | A dataset wrapper that has a quick access to all labels of data. 12 | """ 13 | 14 | def __init__(self, dataset, target_transform=None): 15 | super(CacheClassLabel, self).__init__() 16 | self.dataset = dataset 17 | self.labels = torch.zeros(len(dataset)).long() 18 | if target_transform: 19 | self.labels = target_transform(self.labels) 20 | label_cache_filename = path.join( 21 | dataset.root, str(type(dataset)) + "_" + str(len(dataset)) + ".pth" 22 | ) 23 | if path.exists(label_cache_filename): 24 | self.labels = torch.load(label_cache_filename) 25 | else: 26 | for i, data in enumerate(dataset): 27 | self.labels[i] = data[1] 28 | torch.save(self.labels, label_cache_filename) 29 | self.number_classes = len(torch.unique(self.labels)) 30 | if target_transform: 31 | self.number_classes = self.labels.shape[1] 32 | 33 | def __len__(self): 34 | return len(self.dataset) 35 | 36 | def __getitem__(self, index): 37 | img, target = self.dataset[index] 38 | return img, target 39 | 40 | 41 | class CacheClassLabelForTensor(CacheClassLabel): 42 | """ 43 | A dataset wrapper that has a quick access to all labels of data. 44 | """ 45 | 46 | def __init__(self, tensor_dataset, labels): 47 | super(super(CacheClassLabelForTensor, self)).__init__() 48 | self.dataset = tensor_dataset 49 | self.labels = labels 50 | self.number_classes = len(torch.unique(self.labels)) 51 | 52 | 53 | class AppendName(data.Dataset): 54 | """ 55 | A dataset wrapper that also return the name of the dataset/task 56 | """ 57 | 58 | def __init__( 59 | self, 60 | dataset, 61 | task_ids, 62 | return_classes=False, 63 | return_task_as_class=False, 64 | first_class_ind=0, 65 | ): 66 | super(AppendName, self).__init__() 67 | self.dataset = dataset 68 | self.first_class_ind = first_class_ind 69 | self.task_ids = task_ids # For remapping the class index 70 | self.return_classes = return_classes 71 | self.return_task_as_class = return_task_as_class 72 | 73 | def __len__(self): 74 | return len(self.dataset) 75 | 76 | def __getitem__(self, index): 77 | img, target = self.dataset[index] 78 | target = target + self.first_class_ind 79 | out_dict = {} 80 | if self.return_task_as_class: 81 | out_dict["y"] = np.array( 82 | self.task_ids[index] 83 | ) # np.array(self.task_id, dtype=np.int64) 84 | elif self.return_classes: 85 | if isinstance(target, torch.Tensor): 86 | out_dict["y"] = target.cpu() 87 | elif isinstance(target, int): 88 | out_dict["y"] = torch.from_numpy(np.array(target, dtype=np.int64)) 89 | return img, out_dict # target #, self.name 90 | 91 | 92 | class Subclass(data.Dataset): 93 | """ 94 | A dataset wrapper that return the task name and remove the offset of labels (Let the labels start from 0) 95 | """ 96 | 97 | def __init__(self, dataset, class_list, remap=True): 98 | """ 99 | :param dataset: (CacheClassLabel) 100 | :param class_list: (list) A list of integers 101 | :param remap: (bool) Ex: remap class [2,4,6 ...] to [0,1,2 ...] 102 | """ 103 | super(Subclass, self).__init__() 104 | assert isinstance( 105 | dataset, CacheClassLabel 106 | ), "dataset must be wrapped by CacheClassLabel" 107 | self.dataset = dataset 108 | self.class_list = deepcopy(class_list) 109 | self.remap = remap 110 | self.indices = [] 111 | 112 | for c in class_list: 113 | self.indices.extend((dataset.labels == c).nonzero().flatten().tolist()) 114 | 115 | if remap: 116 | self.class_mapping = {c: i for i, c in enumerate(class_list)} 117 | 118 | def __len__(self): 119 | return len(self.indices) 120 | 121 | def __getitem__(self, index): 122 | img, target = self.dataset[self.indices[index]] 123 | if self.remap: 124 | raw_target = target.item() if isinstance(target, torch.Tensor) else target 125 | target = self.class_mapping[raw_target] 126 | return img, target 127 | 128 | 129 | class Permutation(data.Dataset): 130 | """ 131 | A dataset wrapper that permute the position of features 132 | """ 133 | 134 | def __init__(self, dataset, permute_idx): 135 | super(Permutation, self).__init__() 136 | self.dataset = dataset 137 | self.permute_idx = permute_idx 138 | 139 | def __len__(self): 140 | return len(self.dataset) 141 | 142 | def __getitem__(self, index): 143 | img, target = self.dataset[index] 144 | shape = img.size() 145 | img = img.view(-1)[self.permute_idx].view(shape) 146 | return img, target 147 | 148 | 149 | class Storage(data.Subset): 150 | def reduce(self, m): 151 | self.indices = self.indices[:m] 152 | -------------------------------------------------------------------------------- /guide/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /guide/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /dataloaders/datasetGen.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Subset 6 | 7 | from .utils import get_stratified_subset 8 | from .wrapper import AppendName 9 | 10 | 11 | def data_split( 12 | dataset, 13 | return_classes=False, 14 | return_task_as_class=False, 15 | num_tasks=5, 16 | num_classes=10, 17 | limit_classes=-1, 18 | validation_frac=0.3, 19 | data_seed=0, 20 | shared_classes=False, 21 | first_task_num_classes=0, 22 | ): 23 | train_dataset_splits = {} 24 | val_dataset_splits = {} 25 | 26 | if not shared_classes: 27 | if limit_classes > 0: 28 | assert limit_classes <= num_classes 29 | num_classes = limit_classes 30 | 31 | # assert num_classes % num_tasks == 0 32 | if first_task_num_classes > 0: 33 | classes_per_task = (num_classes - first_task_num_classes) // (num_tasks - 1) 34 | class_split = { 35 | i: list( 36 | range( 37 | ((i - 1) * classes_per_task) + first_task_num_classes, 38 | (i * classes_per_task) + first_task_num_classes, 39 | ) 40 | ) 41 | for i in range(1, num_tasks) 42 | } 43 | class_split[0] = list(range(0, first_task_num_classes)) 44 | class_split = dict(sorted(class_split.items())) 45 | else: 46 | classes_per_task = num_classes // num_tasks 47 | class_split = { 48 | i: list(range(i * classes_per_task, (i + 1) * classes_per_task)) 49 | for i in range(num_tasks) 50 | } 51 | labels = ( 52 | dataset.labels 53 | if dataset.labels.shape[1] == 1 54 | else torch.argmax(dataset.labels, 1) 55 | ) 56 | class_indices = torch.LongTensor(labels) 57 | task_indices_1hot = torch.zeros( 58 | len(dataset), num_tasks 59 | ) # 1hot array describing to which tasks datapoints belong. 60 | for task, classes in class_split.items(): 61 | task_indices_1hot[ 62 | (class_indices[..., None] == torch.tensor(classes)).any(-1), task 63 | ] = 1 64 | 65 | train_set_indices_bitmask = torch.ones(len(dataset)) 66 | validation_indices = get_stratified_subset( 67 | validation_frac, labels, seed=data_seed 68 | ) 69 | train_set_indices_bitmask[validation_indices] = 0 70 | 71 | for task, classes in class_split.items(): 72 | cur_task_indices_bitmask = task_indices_1hot[:, task] == 1 73 | cur_train_indices_bitmask = ( 74 | train_set_indices_bitmask * cur_task_indices_bitmask 75 | ) 76 | cur_val_indices_bitmask = ( 77 | 1 - train_set_indices_bitmask 78 | ) * cur_task_indices_bitmask 79 | 80 | train_subset = Subset( 81 | dataset, torch.where(cur_train_indices_bitmask == 1)[0] 82 | ) 83 | train_subset.class_list = classes 84 | 85 | val_subset = Subset(dataset, torch.where(cur_val_indices_bitmask == 1)[0]) 86 | val_subset.class_list = classes 87 | 88 | train_dataset_splits[task] = AppendName( 89 | train_subset, 90 | [task] * len(train_subset), 91 | return_classes=return_classes, 92 | return_task_as_class=return_task_as_class, 93 | ) 94 | val_dataset_splits[task] = AppendName( 95 | val_subset, 96 | [task] * len(train_subset), 97 | return_classes=return_classes, 98 | return_task_as_class=return_task_as_class, 99 | ) 100 | 101 | else: 102 | # Each class in every task 103 | class_examples = defaultdict(list) 104 | for idx in range(len(dataset)): 105 | _, label = dataset[idx] 106 | class_examples[label].append(idx) 107 | 108 | classes_per_task = num_classes 109 | 110 | # Calculate the number of examples per class for each part and train/validation sets 111 | examples_per_class_per_part = len(class_examples[0]) // num_tasks 112 | train_examples = int(examples_per_class_per_part * (1 - validation_frac)) 113 | val_examples = examples_per_class_per_part - train_examples 114 | for task in range(num_tasks): 115 | train_indices = [] 116 | val_indices = [] 117 | 118 | for class_idx in class_examples.keys(): 119 | start_idx = task * examples_per_class_per_part 120 | train_end_idx = start_idx + train_examples 121 | val_end_idx = train_end_idx + val_examples 122 | 123 | train_indices.extend(class_examples[class_idx][start_idx:train_end_idx]) 124 | val_indices.extend(class_examples[class_idx][train_end_idx:val_end_idx]) 125 | 126 | train_subset = Subset(dataset, train_indices) 127 | train_subset.class_list = list(range(num_classes)) 128 | val_subset = Subset(dataset, val_indices) 129 | val_subset.class_list = list(range(num_classes)) 130 | 131 | train_dataset_splits[task] = AppendName( 132 | train_subset, 133 | [task] * len(train_subset), 134 | return_classes=return_classes, 135 | return_task_as_class=return_task_as_class, 136 | ) 137 | val_dataset_splits[task] = AppendName( 138 | val_subset, 139 | [task] * len(train_subset), 140 | return_classes=return_classes, 141 | return_task_as_class=return_task_as_class, 142 | ) 143 | 144 | print( 145 | f"Prepared dataset with splits: {[(idx, len(data)) for idx, data in enumerate(train_dataset_splits.values())]}" 146 | ) 147 | print( 148 | f"Validation dataset with splits: {[(idx, len(data)) for idx, data in enumerate(val_dataset_splits.values())]}" 149 | ) 150 | if hasattr(dataset.dataset, "classes"): 151 | print( 152 | f"Prepared class order: {[(idx, [np.array(dataset.dataset.classes)[data.dataset.class_list]]) for idx, data in enumerate(train_dataset_splits.values())]}" 153 | ) 154 | 155 | return train_dataset_splits, val_dataset_splits, classes_per_task 156 | -------------------------------------------------------------------------------- /scripts/classifier_sample_universal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Like image_sample.py, but use a noisy image classifier to guide the sampling 3 | process towards more realistic images. 4 | 5 | python scripts/classifier_sample_universal.py --num_samples=100 --use_ddim=True --timestep_respacing=ddim25 --model_path=results_new/bc_cifar100_ci5_class_cond_diffusion_long/ema_0.9999_100000_0.pt --classifier_path=models/resnet18_cifar100_task0.pt --num_classes_sample=20 --model_num_classes=100 --batch_size=50 6 | """ 7 | 8 | import argparse 9 | import os 10 | 11 | import numpy as np 12 | import torch as th 13 | import torch.nn.functional as F 14 | 15 | from guide import dist_util, logger 16 | from guide.script_args import add_dict_to_argparser, args_to_dict, classifier_defaults 17 | from guide.script_util import ( 18 | create_model_and_diffusion, 19 | create_resnet_classifier, 20 | model_and_diffusion_defaults, 21 | ) 22 | 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 24 | 25 | 26 | def main(): 27 | args = create_argparser().parse_args() 28 | print("Using manual seed = {}".format(args.seed)) 29 | th.manual_seed(args.seed) 30 | th.cuda.manual_seed(args.seed) 31 | th.backends.cudnn.deterministic = True 32 | th.backends.cudnn.benchmark = False 33 | os.environ["OPENAI_LOGDIR"] = f"out/samples/{args.wandb_experiment_name}" 34 | 35 | assert args.num_samples % args.batch_size == 0 36 | 37 | logger.configure() 38 | 39 | logger.log("creating model and diffusion...") 40 | model, diffusion = create_model_and_diffusion( 41 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 42 | ) 43 | model.load_state_dict( 44 | dist_util.load_state_dict(args.model_path, map_location="cpu") 45 | ) 46 | model.to(dist_util.dev()) 47 | if args.use_fp16: 48 | model.convert_to_fp16() 49 | model.eval() 50 | 51 | logger.log("loading classifier...") 52 | defaults = classifier_defaults() 53 | parser = argparse.ArgumentParser() 54 | add_dict_to_argparser(parser, defaults) 55 | opts = parser.parse_args([]) 56 | opts.model_num_classes = int(args.model_num_classes) 57 | opts.in_channels = 3 58 | opts.depth = 18 59 | opts.noised = False 60 | classifier = create_resnet_classifier(opts) 61 | if args.classifier_path: 62 | classifier.load_state_dict(th.load(args.classifier_path, map_location="cpu")) 63 | classifier.to(dist_util.dev()) 64 | print(classifier) 65 | 66 | # NOTE: Possible further improvements from http://arxiv.org/abs/2302.07121, but with them 67 | # sampling becomes very time-consuming. 68 | def cond_fn(x, t, y=None): 69 | assert y is not None 70 | with th.enable_grad(): 71 | x = x.detach().requires_grad_(True) 72 | my_t = th.tensor( 73 | [diffusion.timestep_map.index(ts) for ts in t], device=dist_util.dev() 74 | ) 75 | out = diffusion.p_mean_variance( 76 | model, x, my_t, clip_denoised=True, model_kwargs=model_kwargs 77 | ) 78 | x_in = out["pred_xstart"] 79 | 80 | logit = classifier(x_in) 81 | if args.trim_logits: 82 | logit = logit[:, : int(args.num_classes_sample)] 83 | 84 | loss = -F.cross_entropy(logit, y, reduction="none") 85 | 86 | grad = th.autograd.grad(loss.sum(), x)[0] 87 | return grad * classfier_scale_vec.view(-1, 1, 1, 1) 88 | 89 | def model_fn(x, t, y=None): 90 | return model(x, t, y if args.class_cond else None) 91 | 92 | logger.log("sampling...") 93 | all_images = [] 94 | all_labels = [] 95 | batch_num = 0 96 | while len(all_images) < args.num_samples: 97 | model_kwargs = {} 98 | model_kwargs["y"] = th.randint( 99 | args.min_class_sample, args.max_class_sample, (args.batch_size,) 100 | ).to(dist_util.dev()) 101 | 102 | sample_fn = ( 103 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 104 | ) 105 | classfier_scale_vec = ( 106 | th.from_numpy( 107 | np.random.uniform( 108 | low=args.classifier_scale_min, 109 | high=args.classifier_scale_max, 110 | size=(len(model_kwargs["y"]),), 111 | ) 112 | ) 113 | .float() 114 | .to(dist_util.dev()) 115 | ) 116 | sample = sample_fn( 117 | model, 118 | (args.batch_size, args.in_channels, args.image_size, args.image_size), 119 | clip_denoised=True, 120 | model_kwargs=model_kwargs, 121 | cond_fn=( 122 | None 123 | if args.classifier_scale_min == 0.0 and args.classifier_scale_max == 0.0 124 | else cond_fn 125 | ), 126 | device=dist_util.dev(), 127 | ) 128 | sample = ((sample + 1) * 127.5).clamp(0, 255) 129 | sample = sample.permute(0, 2, 3, 1) 130 | sample = sample.contiguous() 131 | 132 | all_images.extend([sample.cpu().numpy() for sample in sample]) 133 | all_labels.extend([labels.cpu().numpy() for labels in model_kwargs["y"]]) 134 | logger.log(f"created {len(all_images)} samples") 135 | 136 | batch_num += 1 137 | 138 | arr = all_images 139 | arr = arr[: args.num_samples] 140 | label_arr = all_labels 141 | label_arr = label_arr[: args.num_samples] 142 | out_path = os.path.join( 143 | os.path.dirname(args.model_path), f"{args.wandb_experiment_name}.npz" 144 | ) 145 | logger.log(f"saving to {out_path}") 146 | np.savez(out_path, arr, label_arr) 147 | 148 | logger.log("sampling complete") 149 | 150 | 151 | def create_argparser(): 152 | defaults = dict( 153 | clip_denoised=True, 154 | num_samples=10000, 155 | batch_size=16, 156 | use_ddim=False, 157 | timestep_respacing="", 158 | model_path="", 159 | classifier_path="", 160 | classifier_scale_min=0.0, 161 | classifier_scale_max=0.0, 162 | wandb_experiment_name="test", 163 | model_num_classes=10, 164 | trim_logits=True, 165 | min_class_sample=0, 166 | max_class_sample=0, 167 | seed=0, 168 | ) 169 | defaults.update(model_and_diffusion_defaults()) 170 | parser = argparse.ArgumentParser() 171 | add_dict_to_argparser(parser, defaults) 172 | return parser 173 | 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /guide/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | import wandb 8 | 9 | 10 | def create_named_schedule_sampler(name, diffusion, args): 11 | """ 12 | Create a ScheduleSampler from a library of pre-defined samplers. 13 | 14 | :param name: the name of the sampler. 15 | :param diffusion: the diffusion object to sample for. 16 | """ 17 | if name == "uniform": 18 | return UniformSampler(diffusion) 19 | elif name == "beta": 20 | return BetaSampler(diffusion, args.alpha, args.beta) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | elif name == "task_aware": 24 | return TaskAwareSampler(diffusion, args.alpha, args.beta) 25 | else: 26 | raise NotImplementedError(f"unknown schedule sampler: {name}") 27 | 28 | 29 | class ScheduleSampler(ABC): 30 | """ 31 | A distribution over timesteps in the diffusion process, intended to reduce 32 | variance of the objective. 33 | 34 | By default, samplers perform unbiased importance sampling, in which the 35 | objective's mean is unchanged. 36 | However, subclasses may override sample() to change how the resampled 37 | terms are reweighted, allowing for actual changes in the objective. 38 | """ 39 | 40 | @abstractmethod 41 | def weights(self): 42 | """ 43 | Get a numpy array of weights, one per diffusion step. 44 | 45 | The weights needn't be normalized, but must be positive. 46 | """ 47 | 48 | def sample(self, batch_size, device): 49 | """ 50 | Importance-sample timesteps for a batch. 51 | 52 | :param batch_size: the number of timesteps. 53 | :param device: the torch device to save to. 54 | :return: a tuple (timesteps, weights): 55 | - timesteps: a tensor of timestep indices. 56 | - weights: a tensor of weights to scale the resulting losses. 57 | """ 58 | w = self.weights() 59 | # wandb.log({f"w_{i}":w[i] for i in range(len(w))}) 60 | p = w / np.sum(w) 61 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 62 | indices = th.from_numpy(indices_np).long().to(device) 63 | weights_np = 1 / (len(p) * p[indices_np]) 64 | weights = th.from_numpy(weights_np).float().to(device) 65 | return indices, weights 66 | 67 | 68 | class UniformSampler(ScheduleSampler): 69 | def __init__(self, diffusion): 70 | self.diffusion = diffusion 71 | self._weights = np.ones([diffusion.num_timesteps]) 72 | 73 | def weights(self): 74 | return self._weights 75 | 76 | 77 | class TaskAwareSampler: 78 | def __init__(self, diffusion, alfa=4, beta=1.2): 79 | self.diffusion = diffusion 80 | self.beta_sampler = BetaSampler(diffusion, alfa, beta, weights_smoothing=0) 81 | self.uniform_sampler = UniformSampler(diffusion) 82 | 83 | def sample(self, batch_size, device, task_ids, current_task_id): 84 | curr_task_indices, curr_task_weights = self.uniform_sampler.sample( 85 | (task_ids == current_task_id).sum().item(), device 86 | ) 87 | prev_task_indices, prev_task_weights = self.beta_sampler.sample( 88 | (task_ids != current_task_id).sum().item(), device 89 | ) 90 | indices = th.zeros(batch_size, device=device).long() 91 | weights = th.zeros(batch_size, device=device).float() 92 | 93 | indices[task_ids == current_task_id] = curr_task_indices 94 | indices[task_ids != current_task_id] = prev_task_indices 95 | 96 | weights[task_ids == current_task_id] = curr_task_weights 97 | weights[task_ids != current_task_id] = prev_task_weights 98 | return indices, weights 99 | 100 | 101 | class BetaSampler(ScheduleSampler): 102 | def __init__(self, diffusion, alfa=4, beta=1.2, weights_smoothing=1): 103 | beta_dist = th.distributions.beta.Beta(alfa, beta) 104 | w = th.exp( 105 | beta_dist.log_prob( 106 | (th.arange(0, diffusion.num_timesteps) / diffusion.num_timesteps) 107 | ) 108 | ) 109 | self.diffusion = diffusion 110 | self._weights = ( 111 | w.numpy() + weights_smoothing 112 | ) # np.ones([diffusion.num_timesteps]) 113 | 114 | def weights(self): 115 | return self._weights 116 | 117 | 118 | class LossAwareSampler(ScheduleSampler): 119 | def update_with_local_losses(self, local_ts, local_losses): 120 | """ 121 | Update the reweighting using losses from a model. 122 | 123 | Call this method from each rank with a batch of timesteps and the 124 | corresponding losses for each of those timesteps. 125 | This method will perform synchronization to make sure all of the ranks 126 | maintain the exact same reweighting. 127 | 128 | :param local_ts: an integer Tensor of timesteps. 129 | :param local_losses: a 1D Tensor of losses. 130 | """ 131 | batch_sizes = [ 132 | th.tensor([0], dtype=th.int32, device=local_ts.device) 133 | for _ in range(dist.get_world_size()) 134 | ] 135 | dist.all_gather( 136 | batch_sizes, 137 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 138 | ) 139 | 140 | # Pad all_gather batches to be the maximum batch size. 141 | batch_sizes = [x.item() for x in batch_sizes] 142 | max_bs = max(batch_sizes) 143 | 144 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 145 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 146 | dist.all_gather(timestep_batches, local_ts) 147 | dist.all_gather(loss_batches, local_losses) 148 | timesteps = [ 149 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 150 | ] 151 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 152 | self.update_with_all_losses(timesteps, losses) 153 | 154 | @abstractmethod 155 | def update_with_all_losses(self, ts, losses): 156 | """ 157 | Update the reweighting using losses from a model. 158 | 159 | Sub-classes should override this method to update the reweighting 160 | using losses from the model. 161 | 162 | This method directly updates the reweighting without synchronizing 163 | between workers. It is called by update_with_local_losses from all 164 | ranks with identical arguments. Thus, it should have deterministic 165 | behavior to maintain state across workers. 166 | 167 | :param ts: a list of int timesteps. 168 | :param losses: a list of float losses, one per timestep. 169 | """ 170 | 171 | 172 | class LossSecondMomentResampler(LossAwareSampler): 173 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 174 | self.diffusion = diffusion 175 | self.history_per_term = history_per_term 176 | self.uniform_prob = uniform_prob 177 | self._loss_history = np.zeros( 178 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 179 | ) 180 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 181 | 182 | def weights(self): 183 | if not self._warmed_up(): 184 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 185 | weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) 186 | weights /= np.sum(weights) 187 | weights *= 1 - self.uniform_prob 188 | weights += self.uniform_prob / len(weights) 189 | return weights 190 | 191 | def update_with_all_losses(self, ts, losses): 192 | for t, loss in zip(ts, losses): 193 | if self._loss_counts[t] == self.history_per_term: 194 | # Shift out the oldest loss term. 195 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 196 | self._loss_history[t, -1] = loss 197 | else: 198 | self._loss_history[t, self._loss_counts[t]] = loss 199 | self._loss_counts[t] += 1 200 | 201 | def _warmed_up(self): 202 | return (self._loss_counts == self.history_per_term).all() 203 | -------------------------------------------------------------------------------- /guide/script_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def diffusion_defaults(): 6 | """ 7 | Defaults for image and classifier training. 8 | """ 9 | return dict( 10 | learn_sigma=False, 11 | sigma_small=False, 12 | diffusion_steps=1000, 13 | noise_schedule="linear", 14 | timestep_respacing="", 15 | use_kl=False, 16 | predict_xstart=False, 17 | rescale_timesteps=False, 18 | rescale_learned_sigmas=False, 19 | ) 20 | 21 | 22 | def classifier_defaults(): 23 | """ 24 | Defaults for classifier models. 25 | """ 26 | return dict( 27 | image_size=64, 28 | classifier_use_fp16=False, 29 | classifier_width=128, 30 | classifier_depth=2, 31 | classifier_attention_resolutions="32,16,8", # 16 32 | classifier_use_scale_shift_norm=True, # False 33 | classifier_resblock_updown=True, # False 34 | classifier_pool="attention", 35 | # ResNet disjoint classifier parameters 36 | activetype="ReLU", # activation type 37 | pooltype="MaxPool2d", # Pooling type 38 | normtype="BatchNorm", # Batch norm type 39 | preact=True, # Places norms and activations before linear/conv layer. 40 | bn=True, # Apply Batchnorm. 41 | affine_bn=True, # Apply affine transform in BN. 42 | bn_eps=1e-6, # Affine transform for batch norm 43 | ) 44 | 45 | 46 | def model_and_diffusion_defaults(): 47 | """ 48 | Defaults for image training. 49 | """ 50 | res = dict( 51 | num_channels=128, 52 | num_res_blocks=3, 53 | num_heads=4, 54 | num_heads_upsample=-1, 55 | num_head_channels=-1, 56 | attention_resolutions="16,8", 57 | channel_mult="", 58 | dropout=0.1, 59 | use_checkpoint=False, 60 | use_scale_shift_norm=True, 61 | resblock_updown=False, 62 | use_fp16=False, 63 | use_new_attention_order=False, 64 | image_size=32, 65 | in_channels=3, 66 | model_switching_timestep=30, 67 | model_name="UNetModel", 68 | embedding_kind="concat_time_1hot", # embedding used for time and "class", possible values in EMBEDDING_KINDS 69 | model_num_classes=None, 70 | train_noised_classifier=False, 71 | ) 72 | res.update(diffusion_defaults()) 73 | return res 74 | 75 | 76 | def all_training_defaults(): 77 | defaults = dict( 78 | seed=13, 79 | data_seed=0, # Seed used for data generation (mostly train/valid split). Typically, no need to set it. 80 | wandb_api_key="", 81 | wandb_experiment_name="test", 82 | wandb_project_name="project", 83 | wandb_entity="entity", 84 | dataroot="data/", 85 | dataset="CIFAR10", 86 | schedule_sampler="uniform", 87 | alpha=4, 88 | beta=1.2, 89 | lr=2e-4, 90 | disjoint_classifier_lr=1e-2, 91 | weight_decay=0.0, 92 | lr_anneal_steps=0, 93 | batch_size=64, 94 | microbatch=-1, # -1 disables microbatches 95 | ema_rate="0.9999", # comma-separated list of EMA values 96 | log_interval=500, 97 | skip_save=False, 98 | save_interval=5000, 99 | guid_generation_interval=1, # generate new examples from diffusion model every guid_generation_interval steps 100 | resume_checkpoint="", 101 | resume_checkpoint_classifier="", 102 | resume_checkpoint_classifier_noised="", 103 | use_fp16=False, 104 | fp16_scale_growth=1e-3, 105 | gpu_id=-1, 106 | reverse=False, 107 | num_tasks=5, 108 | limit_tasks=-1, 109 | limit_classes=-1, 110 | shared_classes=False, 111 | train_aug=False, 112 | skip_normalization=False, 113 | num_steps=20000, 114 | scheduler_rate=1.0, 115 | first_task_num_classes=0, 116 | first_task_num_steps=-1, # if -1, set to the same as num_steps. 117 | skip_gradient_thr=-1, 118 | log_gradient_stats=False, 119 | clip_denoised=True, 120 | cl_method="generative_replay_disjoint_classifier_guidance", # possible values are defined in CL_METHODS 121 | use_ddim=False, 122 | classifier_scale_min_old=None, 123 | classifier_scale_min_new=None, 124 | classifier_scale_max_old=None, 125 | classifier_scale_max_new=None, 126 | first_task=0, 127 | gr_n_generated_examples_per_task=32, 128 | use_old_grad=False, 129 | use_new_grad=False, 130 | guid_to_new_classes=False, 131 | trim_logits=True, 132 | train_with_disjoint_classifier=False, 133 | disjoint_classifier_init_num_steps=5000, # Steps in first task 134 | disjoint_classifier_num_steps=2000, 135 | classifier_init_lr=0.1, # First task learning rate 136 | classifier_lr=0.05, # Learning rate 137 | classifier_weight_decay=5e-4, # Weight decay 138 | depth=18, # ResNet depth 139 | classifier_augmentation=True, 140 | diffusion_pretrained_dir=None, # Directory contating trained diffusion models on each task. It effectively disables any training of the diffusion model. 141 | negate_old_grad=False, # negate old gradient 142 | classifier_first_task_dir=None, 143 | ) 144 | defaults.update(model_and_diffusion_defaults()) 145 | return defaults 146 | 147 | 148 | def classifier_and_diffusion_defaults(): 149 | res = classifier_defaults() 150 | res.update(diffusion_defaults()) 151 | return res 152 | 153 | 154 | def combine_with_defaults(config): 155 | res = all_training_defaults() 156 | for k, v in config.items(): 157 | assert k in res.keys(), "{} not in default values".format(k) 158 | res[k] = v 159 | return res 160 | 161 | 162 | CL_METHODS = [ 163 | "generative_replay", 164 | "generative_replay_disjoint_classifier_guidance", # GR with one disjoint classifier 165 | ] 166 | 167 | EMBEDDING_KINDS = [ 168 | "none", 169 | "concat_time_1hot", 170 | "add_time_learned", # original from the paper 171 | ] 172 | 173 | 174 | def preprocess_args(args): 175 | """Perform simple validity checks and do a simple initial processing of training args.""" 176 | 177 | assert args.cl_method in CL_METHODS 178 | assert args.embedding_kind in EMBEDDING_KINDS 179 | 180 | if args.first_task_num_steps == -1: 181 | args.first_task_num_steps = args.num_steps 182 | 183 | if not args.dataroot: 184 | args.dataroot = os.environ.get("DIFFUSION_DATA", "") 185 | 186 | if ( 187 | args.classifier_scale_min_old is not None 188 | and args.classifier_scale_max_old is None 189 | ): 190 | args.classifier_scale_max_old = args.classifier_scale_min_old 191 | 192 | if ( 193 | args.classifier_scale_min_old is not None 194 | and args.classifier_scale_max_old is None 195 | ): 196 | args.classifier_scale_max_old = args.classifier_scale_min_old 197 | 198 | if ( 199 | args.classifier_scale_max_old is not None 200 | and args.classifier_scale_min_old is None 201 | ): 202 | args.classifier_scale_min_old = args.classifier_scale_max_old 203 | 204 | if ( 205 | args.classifier_scale_min_new is not None 206 | and args.classifier_scale_max_new is None 207 | ): 208 | args.classifier_scale_max_new = args.classifier_scale_min_new 209 | 210 | if ( 211 | args.classifier_scale_min_new is not None 212 | and args.classifier_scale_max_new is None 213 | ): 214 | args.classifier_scale_max_new = args.classifier_scale_min_new 215 | 216 | if ( 217 | args.classifier_scale_max_new is not None 218 | and args.classifier_scale_min_new is None 219 | ): 220 | args.classifier_scale_min_new = args.classifier_scale_max_new 221 | 222 | 223 | def add_dict_to_argparser(parser, default_dict): 224 | for k, v in default_dict.items(): 225 | v_type = type(v) 226 | if v is None: 227 | v_type = str 228 | elif isinstance(v, bool): 229 | v_type = str2bool 230 | parser.add_argument(f"--{k}", default=v, type=v_type) 231 | 232 | 233 | def args_to_dict(args, keys): 234 | return {k: getattr(args, k) for k in keys} 235 | 236 | 237 | def str2bool(v): 238 | """ 239 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 240 | """ 241 | if isinstance(v, bool): 242 | return v 243 | if v.lower() in ("yes", "true", "t", "y", "1"): 244 | return True 245 | elif v.lower() in ("no", "false", "f", "n", "0"): 246 | return False 247 | else: 248 | raise argparse.ArgumentTypeError("boolean value expected") 249 | -------------------------------------------------------------------------------- /guide/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .layers import ConvBlock, FinalBlock, InitialBlock 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | expansion = 1 9 | 10 | def __init__(self, opt, inChannels, outChannels, stride=1, downsample=None): 11 | super(BasicBlock, self).__init__() 12 | self.downsample = downsample 13 | expansion = 1 14 | self.conv1 = ConvBlock( 15 | opt=opt, 16 | in_channels=inChannels, 17 | out_channels=outChannels, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=1, 21 | bias=False, 22 | ) 23 | self.conv2 = ConvBlock( 24 | opt=opt, 25 | in_channels=outChannels, 26 | out_channels=outChannels * expansion, 27 | kernel_size=3, 28 | stride=1, 29 | padding=1, 30 | bias=False, 31 | ) 32 | 33 | def forward(self, x): 34 | _out = self.conv1(x) 35 | _out = self.conv2(_out) 36 | if self.downsample is not None: 37 | shortcut = self.downsample(x) 38 | else: 39 | shortcut = x 40 | _out = _out + shortcut 41 | return _out 42 | 43 | 44 | class BottleneckBlock(nn.Module): 45 | expansion = 4 46 | 47 | def __init__(self, opt, inChannels, outChannels, stride=1, downsample=None): 48 | super(BottleneckBlock, self).__init__() 49 | expansion = 4 50 | self.conv1 = ConvBlock( 51 | opt=opt, 52 | in_channels=inChannels, 53 | out_channels=outChannels, 54 | kernel_size=1, 55 | stride=1, 56 | padding=0, 57 | bias=False, 58 | ) 59 | self.conv2 = ConvBlock( 60 | opt=opt, 61 | in_channels=outChannels, 62 | out_channels=outChannels, 63 | kernel_size=3, 64 | stride=stride, 65 | padding=1, 66 | bias=False, 67 | ) 68 | self.conv3 = ConvBlock( 69 | opt=opt, 70 | in_channels=outChannels, 71 | out_channels=outChannels * expansion, 72 | kernel_size=1, 73 | stride=1, 74 | padding=0, 75 | bias=False, 76 | ) 77 | self.downsample = downsample 78 | 79 | def forward(self, x): 80 | _out = self.conv1(x) 81 | _out = self.conv2(_out) 82 | _out = self.conv3(_out) 83 | if self.downsample is not None: 84 | shortcut = self.downsample(x) 85 | else: 86 | shortcut = x 87 | _out = _out + shortcut 88 | return _out 89 | 90 | 91 | class ResidualBlock(nn.Module): 92 | def __init__(self, opt, block, inChannels, outChannels, depth, stride=1): 93 | super(ResidualBlock, self).__init__() 94 | if stride != 1 or inChannels != outChannels * block.expansion: 95 | downsample = ConvBlock( 96 | opt=opt, 97 | in_channels=inChannels, 98 | out_channels=outChannels * block.expansion, 99 | kernel_size=1, 100 | stride=stride, 101 | padding=0, 102 | bias=False, 103 | ) 104 | else: 105 | downsample = None 106 | self.blocks = nn.Sequential() 107 | self.blocks.add_module( 108 | "block0", block(opt, inChannels, outChannels, stride, downsample) 109 | ) 110 | inChannels = outChannels * block.expansion 111 | for i in range(1, depth): 112 | self.blocks.add_module( 113 | "block{}".format(i), block(opt, inChannels, outChannels) 114 | ) 115 | 116 | def forward(self, x): 117 | return self.blocks(x) 118 | 119 | 120 | class ResNet(nn.Module): 121 | def __init__(self, opt): 122 | super(ResNet, self).__init__() 123 | depth = opt.depth 124 | if depth in [20, 32, 44, 56, 110, 1202]: 125 | blocktype, self.nettype = "BasicBlock", "cifar" 126 | elif depth in [164, 1001]: 127 | blocktype, self.nettype = "BottleneckBlock", "cifar" 128 | elif depth in [18, 34]: 129 | blocktype, self.nettype = "BasicBlock", "imagenet" 130 | elif depth in [50, 101, 152]: 131 | blocktype, self.nettype = "BottleneckBlock", "imagenet" 132 | assert depth in [20, 32, 44, 56, 110, 1202, 164, 1001, 18, 34, 50, 101, 152] 133 | 134 | if blocktype == "BasicBlock" and self.nettype == "cifar": 135 | assert ( 136 | depth - 2 137 | ) % 6 == 0, ( 138 | "Depth should be 6n+2, and preferably one of 20, 32, 44, 56, 110, 1202" 139 | ) 140 | n = (depth - 2) // 6 141 | block = BasicBlock 142 | in_planes, out_planes = 16, 64 143 | elif blocktype == "BottleneckBlock" and self.nettype == "cifar": 144 | assert ( 145 | depth - 2 146 | ) % 9 == 0, "Depth should be 9n+2, and preferably one of 164 or 1001" 147 | n = (depth - 2) // 9 148 | block = BottleneckBlock 149 | in_planes, out_planes = 16, 64 150 | elif blocktype == "BasicBlock" and self.nettype == "imagenet": 151 | assert depth in [18, 34] 152 | num_blocks = [2, 2, 2, 2] if depth == 18 else [3, 4, 6, 3] 153 | block = BasicBlock 154 | in_planes, out_planes = 64, 512 # 20, 160 155 | elif blocktype == "BottleneckBlock" and self.nettype == "imagenet": 156 | assert depth in [50, 101, 152] 157 | if depth == 50: 158 | num_blocks = [3, 4, 6, 3] 159 | elif depth == 101: 160 | num_blocks = [3, 4, 23, 3] 161 | elif depth == 152: 162 | num_blocks = [3, 8, 36, 3] 163 | block = BottleneckBlock 164 | in_planes, out_planes = 64, 512 165 | else: 166 | assert 1 == 2 167 | 168 | self.num_classes = opt.model_num_classes 169 | self.initial = InitialBlock( 170 | opt=opt, out_channels=in_planes, kernel_size=3, stride=1, padding=1 171 | ) 172 | if self.nettype == "cifar": 173 | self.group1 = ResidualBlock(opt, block, 16, 16, n, stride=1) 174 | self.group2 = ResidualBlock( 175 | opt, block, 16 * block.expansion, 32, n, stride=2 176 | ) 177 | self.group3 = ResidualBlock( 178 | opt, block, 32 * block.expansion, 64, n, stride=2 179 | ) 180 | elif self.nettype == "imagenet": 181 | self.group1 = ResidualBlock( 182 | opt, block, 64, 64, num_blocks[0], stride=1 183 | ) # For ResNet-S, convert this to 20,20 184 | self.group2 = ResidualBlock( 185 | opt, block, 64 * block.expansion, 128, num_blocks[1], stride=2 186 | ) # For ResNet-S, convert this to 20,40 187 | self.group3 = ResidualBlock( 188 | opt, block, 128 * block.expansion, 256, num_blocks[2], stride=2 189 | ) # For ResNet-S, convert this to 40,80 190 | self.group4 = ResidualBlock( 191 | opt, block, 256 * block.expansion, 512, num_blocks[3], stride=2 192 | ) # For ResNet-S, convert this to 80,160 193 | else: 194 | assert 1 == 2 195 | self.pool = nn.AdaptiveAvgPool2d(1) 196 | self.dim_out = out_planes * block.expansion 197 | self.noised = opt.noised 198 | 199 | in_size = ( 200 | out_planes * block.expansion + 1 201 | if opt.noised 202 | else out_planes * block.expansion 203 | ) 204 | self.final = FinalBlock(opt=opt, in_channels=in_size) 205 | 206 | for m in self.modules(): 207 | if isinstance(m, nn.Conv2d): 208 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 209 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 210 | nn.init.constant_(m.weight, 1) 211 | nn.init.constant_(m.bias, 0) 212 | 213 | def forward(self, x, t=None, return_h=False): 214 | out = self.initial(x) 215 | out = self.group1(out) 216 | out = self.group2(out) 217 | out = self.group3(out) 218 | if self.nettype == "imagenet": 219 | out = self.group4(out) 220 | out = self.pool(out) 221 | out = out.view(x.size(0), -1) 222 | 223 | if return_h: 224 | return out 225 | 226 | if self.noised: 227 | if t is None: 228 | t = torch.zeros(x.shape[0], dtype=torch.long, device=x.device) 229 | out = torch.cat([out, t.unsqueeze(1)], 1) 230 | 231 | out = self.final(out) 232 | 233 | return out 234 | -------------------------------------------------------------------------------- /guide/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | skip_gradient_thr=-1, 157 | ): 158 | self.model = model 159 | self.use_fp16 = use_fp16 160 | self.fp16_scale_growth = fp16_scale_growth 161 | self.skip_gradient_thr = skip_gradient_thr 162 | 163 | self.model_params = list(self.model.parameters()) 164 | self.master_params = self.model_params 165 | self.param_groups_and_shapes = None 166 | self.lg_loss_scale = initial_lg_loss_scale 167 | 168 | if self.use_fp16: 169 | self.param_groups_and_shapes = get_param_groups_and_shapes( 170 | self.model.named_parameters() 171 | ) 172 | self.master_params = make_master_params(self.param_groups_and_shapes) 173 | self.model.convert_to_fp16() 174 | 175 | def zero_grad(self): 176 | zero_grad(self.model_params) 177 | 178 | def backward(self, loss: th.Tensor): 179 | if self.use_fp16: 180 | loss_scale = 2**self.lg_loss_scale 181 | (loss * loss_scale).backward() 182 | else: 183 | loss.backward() 184 | 185 | def optimize(self, opt: th.optim.Optimizer): 186 | if self.use_fp16: 187 | return self._optimize_fp16(opt) 188 | else: 189 | return self._optimize_normal(opt) 190 | 191 | def _optimize_fp16(self, opt: th.optim.Optimizer): 192 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 193 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 194 | grad_norm, param_norm = self._compute_norms(grad_scale=2**self.lg_loss_scale) 195 | if check_overflow(grad_norm): 196 | self.lg_loss_scale -= 1 197 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 198 | zero_master_grads(self.master_params) 199 | return False 200 | 201 | logger.logkv_mean("grad_norm", grad_norm) 202 | logger.logkv_mean("param_norm", param_norm) 203 | 204 | self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale)) 205 | # TODO: add skip gradients here 206 | if self.skip_gradient_thr == -1.0 or grad_norm < self.skip_gradient_thr: 207 | logger.logkv_mean("skip_update", 0) 208 | opt.step() 209 | step_made = True 210 | else: 211 | logger.logkv_mean("skip_update", 1) 212 | step_made = False 213 | zero_master_grads(self.master_params) 214 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 215 | self.lg_loss_scale += self.fp16_scale_growth 216 | return step_made 217 | 218 | def _optimize_normal(self, opt: th.optim.Optimizer): 219 | grad_norm, param_norm = self._compute_norms() 220 | logger.logkv_mean("grad_norm", grad_norm) 221 | logger.logkv_mean("param_norm", param_norm) 222 | # TODO: add skip gradients here 223 | if self.skip_gradient_thr == -1.0 or grad_norm < self.skip_gradient_thr: 224 | logger.logkv_mean("skip_update", 0) 225 | opt.step() 226 | return True 227 | else: 228 | logger.logkv_mean("skip_update", 1) 229 | return False 230 | 231 | def _compute_norms(self, grad_scale=1.0): 232 | grad_norm = 0.0 233 | param_norm = 0.0 234 | for p in self.master_params: 235 | with th.no_grad(): 236 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 237 | if p.grad is not None: 238 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 239 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 240 | 241 | def master_params_to_state_dict(self, master_params): 242 | return master_params_to_state_dict( 243 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 244 | ) 245 | 246 | def state_dict_to_master_params(self, state_dict): 247 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 248 | 249 | 250 | def check_overflow(value): 251 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 252 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GUIDE: Guidance-based Incremental Learning with Diffusion Models 2 | [![arXiv](https://img.shields.io/badge/arXiv-2403.03938-b31b1b.svg?style=flat)](https://arxiv.org/abs/2403.03938) 3 | 4 | This repository is the official implementation of [GUIDE: Guidance-based Incremental Learning with Diffusion Models](https://arxiv.org/abs/2403.03938). 5 | 6 | drawing 7 | 8 | **Rehearsal sampling in GUIDE.** 9 | We guide the denoising process of a diffusion model trained on the previous task (blue) toward classes from the current task (orange). 10 | The replay samples, highlighted with **blue borders**, share features with the examples from the current task, which may be related to characteristics such as color or background (e.g., fishes on a snowy background when guided to *snowmobile*). 11 | Continual training of a classifier on such samples positioned near its decision boundary successfully mitigates catastrophic forgetting. 12 | 13 | ## Setup 14 | 15 | ### Clone repo 16 | 17 | ```bash 18 | git clone https://github.com/cywinski/guide.git 19 | cd guide 20 | ``` 21 | 22 | ### Prepare Conda environment 23 | 24 | ```bash 25 | conda create -n guide_env python=3.8 26 | conda activate guide_env 27 | ``` 28 | 29 | ### Install torch 30 | 31 | Install `torch` and `torchvision` according to instructions on [offical website](https://pytorch.org/). 32 | 33 | ### Install required **packages** 34 | 35 | ``` 36 | pip install . 37 | ``` 38 | 39 | ### Login to wandb 40 | 41 | ```bash 42 | wandb login 43 | ``` 44 | 45 | ## Reproduction 46 | 47 | Below we present training commands for a single GPU setup. To run the training in distributed manner, run the same command with `mpiexec`: 48 | 49 | ```bash 50 | mpiexec -n $NUM_GPUS python scripts.image_train ... 51 | ``` 52 | 53 | When training in a distributed manner, you must manually divide the `--batch_size` argument by the number of ranks. In lieu of distributed training, you may use `--microbatch 16` (or `--microbatch 1` in extreme memory-limited cases) to reduce memory usage. 54 | 55 | ### Diffusion models training 56 | 57 | Continual training of diffusion models with self-rehearsal. Trained models will be stored under `results/` 58 | 59 | **CIFAR-10/2** 60 | 61 | ```bash 62 | python -m scripts.image_train --wandb_experiment_name=c10_ci2_class_cond_diffusion --wandb_project_name=project --wandb_entity=entity --batch_size=256 --num_steps=100000 --dataset=CIFAR10 --num_tasks=2 --save_interval=100000 --gr_n_generated_examples_per_task=25000 --first_task_num_steps=100000 --seed=0 --timestep_respacing=1000 --use_ddim=False --log_interval=1000 --cl_method=generative_replay --train_with_disjoint_classifier=False --embedding_kind=concat_time_1hot --train_aug=True 63 | ``` 64 | 65 | **CIFAR-10/5** 66 | 67 | ```bash 68 | python -m scripts.image_train --wandb_experiment_name=c10_ci5_class_cond_diffusion --wandb_project_name=project --wandb_entity=entity --batch_size=256 --num_steps=50000 --dataset=CIFAR10 --num_tasks=5 --save_interval=50000 --gr_n_generated_examples_per_task=10000 --first_task_num_steps=100000 --seed=0 --timestep_respacing=1000 --use_ddim=False --log_interval=1000 --cl_method=generative_replay --train_with_disjoint_classifier=False --embedding_kind=concat_time_1hot --train_aug=True 69 | ``` 70 | 71 | **CIFAR-100/5** 72 | 73 | ```bash 74 | python -m scripts.image_train --wandb_experiment_name=c100_ci5_class_cond_diffusion --wandb_project_name=project --wandb_entity=entity --batch_size=256 --num_steps=50000 --dataset=CIFAR100 --num_tasks=5 --save_interval=50000 --gr_n_generated_examples_per_task=10000 --first_task_num_steps=100000 --seed=0 --timestep_respacing=1000 --use_ddim=False --log_interval=1000 --cl_method=generative_replay --train_with_disjoint_classifier=False --embedding_kind=concat_time_1hot --train_aug=True 75 | ``` 76 | 77 | **CIFAR-100/10** 78 | 79 | ```bash 80 | python -m scripts.image_train --wandb_experiment_name=c100_ci10_class_cond_diffusion --wandb_project_name=project --wandb_entity=entity --batch_size=256 --num_steps=100000 --dataset=CIFAR100 --num_tasks=10 --save_interval=100000 --gr_n_generated_examples_per_task=5000 --first_task_num_steps=100000 --seed=0 --timestep_respacing=1000 --use_ddim=False --log_interval=1000 --cl_method=generative_replay --train_with_disjoint_classifier=False --embedding_kind=concat_time_1hot --train_aug=True 81 | ``` 82 | 83 | **ImageNet100-64/5** 84 | 85 | ```bash 86 | python -m scripts.image_train --wandb_experiment_name=i100_ci5_class_cond_diffusion --wandb_project_name=project --wandb_entity=entity --batch_size=100 --num_steps=50000 --dataset=ImageNet100 --num_tasks=5 --save_interval=50000 --gr_n_generated_examples_per_task=26000 --first_task_num_steps=100000 --seed=0 --timestep_respacing=ddim250 --use_ddim=True --log_interval=1000 --cl_method=generative_replay --train_with_disjoint_classifier=False --embedding_kind=concat_time_1hot --train_aug=True --attention_resolutions 32,16,8 --lr 1e-4 --resblock_updown True --use_new_attention_order True --use_scale_shift_norm True --num_channels 192 --num_head_channels 64 87 | ``` 88 | 89 | ### Classifier trainings 90 | 91 | Continual classifier trainings with generative replay according to GUIDE method. To run classifier trainings you first need to train the diffusion models (according to instructions presented above) and store `ema` checkpoints in `--diffusion_pretrained_dir`. 92 | 93 | **CIFAR-10/2** 94 | 95 | ```bash 96 | python -m scripts.image_train --wandb_experiment_name=c10_ci2_guide --wandb_project_name=project --wandb_entity=entity --batch_size=256 --dataset=CIFAR10 --num_tasks=2 --seed=0 --timestep_respacing=ddim50 --use_ddim=True --classifier_scale_min_new=0.2 --classifier_scale_max_new=0.2 --cl_method=generative_replay_disjoint_classifier_guidance --train_with_disjoint_classifier=True --use_old_grad=False --use_new_grad=True --guid_to_new_classes=True --embedding_kind=concat_time_1hot --classifier_init_lr=0.1 --classifier_lr=0.01 --disjoint_classifier_init_num_steps=5000 --disjoint_classifier_num_steps=2000 --classifier_augmentation=True --log_interval=200 --diffusion_pretrained_dir=results/c10_ci2_class_cond_diffusion 97 | ``` 98 | 99 | **CIFAR-10/5** 100 | 101 | ```bash 102 | python -m scripts.image_train --wandb_experiment_name=c10_ci5_guide --wandb_project_name=project --wandb_entity=entity --batch_size=256 --dataset=CIFAR10 --num_tasks=5 --seed=0 --timestep_respacing=ddim50 --use_ddim=True --classifier_scale_min_new=0.5 --classifier_scale_max_new=0.5 --cl_method=generative_replay_disjoint_classifier_guidance --train_with_disjoint_classifier=True --use_old_grad=False --use_new_grad=True --guid_to_new_classes=True --embedding_kind=concat_time_1hot --classifier_init_lr=0.1 --classifier_lr=0.01 --disjoint_classifier_init_num_steps=5000 --disjoint_classifier_num_steps=2000 --classifier_augmentation=True --log_interval=200 --guid_generation_interval=5 --diffusion_pretrained_dir=results/c10_ci5_class_cond_diffusion 103 | ``` 104 | 105 | **CIFAR-100/5** 106 | 107 | ```bash 108 | python -m scripts.image_train --wandb_experiment_name=c100_ci5_guide --wandb_project_name=project --wandb_entity=entity --batch_size=256 --dataset=CIFAR100 --num_tasks=5 --seed=0 --timestep_respacing=ddim100 --use_ddim=True --classifier_scale_min_new=0.5 --classifier_scale_max_new=0.5 --cl_method=generative_replay_disjoint_classifier_guidance --train_with_disjoint_classifier=True --use_old_grad=False --use_new_grad=True --guid_to_new_classes=True --embedding_kind=concat_time_1hot --classifier_init_lr=0.1 --classifier_lr=0.05 --disjoint_classifier_init_num_steps=10000 --disjoint_classifier_num_steps=2000 --classifier_augmentation=True --log_interval=200 --guid_generation_interval=10 --diffusion_pretrained_dir=results/c100_ci5_class_cond_diffusion 109 | ``` 110 | 111 | **CIFAR-100/10** 112 | 113 | ```bash 114 | python -m scripts.image_train --wandb_experiment_name=c100_ci10_guide --wandb_project_name=project --wandb_entity=entity --batch_size=256 --dataset=CIFAR100 --num_tasks=10 --seed=0 --timestep_respacing=ddim100 --use_ddim=True --classifier_scale_min_new=1.0 --classifier_scale_max_new=1.0 --cl_method=generative_replay_disjoint_classifier_guidance --train_with_disjoint_classifier=True --use_old_grad=False --use_new_grad=True --guid_to_new_classes=True --embedding_kind=concat_time_1hot --classifier_init_lr=0.1 --classifier_lr=0.05 --disjoint_classifier_init_num_steps=10000 --disjoint_classifier_num_steps=2000 --classifier_augmentation=True --log_interval=200 --guid_generation_interval=10 --diffusion_pretrained_dir=results/c100_ci10_class_cond_diffusion 115 | ``` 116 | 117 | **ImageNet100-64/5** 118 | 119 | ```bash 120 | python -m scripts.image_train --wandb_experiment_name=i100_ci5_guide --wandb_project_name=project --wandb_entity=entity --batch_size=100 --dataset=ImageNet100 --num_tasks=5 --seed=0 --timestep_respacing=ddim50 --use_ddim=True --classifier_scale_min_new=1.0 --classifier_scale_max_new=1.0 --cl_method=generative_replay_disjoint_classifier_guidance --train_with_disjoint_classifier=True --use_old_grad=False --use_new_grad=True --guid_to_new_classes=True --embedding_kind=concat_time_1hot --classifier_init_lr=0.1 --classifier_lr=0.001 --disjoint_classifier_init_num_steps=20000 --disjoint_classifier_num_steps=20000 --classifier_augmentation=False --log_interval=200 --guid_generation_interval=15 --attention_resolutions 32,16,8 --lr 1e-4 --resblock_updown True --use_new_attention_order True --use_scale_shift_norm True --num_channels 192 --num_head_channels 64 --diffusion_pretrained_dir=results/i100_ci5_class_cond_diffusion 121 | ``` 122 | 123 | 124 | ## BibTeX 125 | 126 | If you find this work useful, please consider citing it: 127 | 128 | ```bibtex 129 | @article{cywinski2024guide, 130 | title={GUIDE: Guidance-based Incremental Learning with Diffusion Models}, 131 | author={Cywi{\'n}ski, Bartosz and Deja, Kamil and Trzci{\'n}ski, Tomasz and Twardowski, Bart{\l}omiej and Kuci{\'n}ski, {\L}ukasz}, 132 | journal={arXiv preprint arXiv:2403.03938}, 133 | year={2024} 134 | } 135 | ``` 136 | 137 | ## Acknowledgments 138 | 139 | This codebase borrows from [OpenAI's guided diffusion repo](https://github.com/openai/guided-diffusion) and [Continual-Learning-Benchmark repo](https://github.com/GT-RIPL/Continual-Learning-Benchmark). 140 | -------------------------------------------------------------------------------- /scripts/image_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import sys 6 | 7 | sys.path.append(".") 8 | import argparse 9 | import copy 10 | import os 11 | import time 12 | from collections import OrderedDict 13 | 14 | import numpy as np 15 | import torch as th 16 | 17 | import wandb 18 | from cl_methods.utils import get_cl_method 19 | from dataloaders import base 20 | from dataloaders.datasetGen import * 21 | from dataloaders.utils import prepare_eval_loaders 22 | from guide import dist_util, logger 23 | from guide.logger import wandb_safe_log 24 | from guide.resample import create_named_schedule_sampler 25 | from guide.script_args import ( 26 | add_dict_to_argparser, 27 | all_training_defaults, 28 | args_to_dict, 29 | classifier_defaults, 30 | preprocess_args, 31 | ) 32 | from guide.script_util import ( 33 | create_model_and_diffusion, 34 | create_resnet_classifier, 35 | model_and_diffusion_defaults, 36 | results_to_log, 37 | ) 38 | from guide.train_util import TrainLoop 39 | from guide.validation import calculate_accuracy_with_classifier 40 | 41 | # os.environ["WANDB_MODE"] = "disabled" 42 | 43 | 44 | def main(): 45 | args = create_argparser().parse_args() 46 | run_training_with_args(args) 47 | 48 | 49 | def run_training_with_args(args): 50 | preprocess_args(args) 51 | 52 | dist_util.setup_dist(args) 53 | 54 | if logger.get_rank_without_mpi_import() == 0: 55 | if args.wandb_api_key: 56 | os.environ["WANDB_API_KEY"] = args.wandb_api_key 57 | wandb.init( 58 | project=args.wandb_project_name, 59 | name=args.wandb_experiment_name, 60 | config=args, 61 | entity=args.wandb_entity, 62 | ) 63 | 64 | args.seed = args.seed + logger.get_rank_without_mpi_import() 65 | random_generator = seed_everything(args.seed) 66 | os.environ["OPENAI_LOGDIR"] = f"results/{args.wandb_experiment_name}" 67 | os.makedirs(os.path.join(logger.get_dir(), "generated_examples"), exist_ok=True) 68 | logger.configure() 69 | logger.log("Using manual seed = {}".format(args.seed)) 70 | 71 | ( 72 | train_dataset, 73 | val_dataset, 74 | image_size, 75 | image_channels, 76 | train_transform_classifier, 77 | train_transform_diffusion, 78 | n_classes, 79 | ) = base.__dict__[args.dataset]( 80 | args.dataroot, 81 | train_aug=args.train_aug, 82 | skip_normalization=args.skip_normalization, 83 | classifier_augmentation=args.classifier_augmentation, 84 | ) 85 | 86 | args.image_size = image_size 87 | args.in_channels = image_channels 88 | args.model_num_classes = n_classes 89 | 90 | logger.log("creating model and diffusion...") 91 | model, diffusion = create_model_and_diffusion( 92 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 93 | ) 94 | if args.log_gradient_stats and not os.environ.get("WANDB_MODE") == "disabled": 95 | wandb.watch(model, log_freq=10) 96 | # if we are not training diffusion, we will not need this model 97 | if not args.train_with_disjoint_classifier: 98 | model.to(dist_util.dev()) 99 | 100 | classifier = None 101 | if args.train_with_disjoint_classifier: 102 | logger.log("creating disjoint classifier...") 103 | defaults = classifier_defaults() 104 | parser = argparse.ArgumentParser() 105 | add_dict_to_argparser(parser, defaults) 106 | opts = parser.parse_args([]) 107 | opts.model_num_classes = n_classes 108 | opts.in_channels = args.in_channels 109 | opts.depth = args.depth 110 | opts.noised = args.train_noised_classifier 111 | classifier = create_resnet_classifier(opts) 112 | classifier.to(dist_util.dev()) 113 | # Needed for creating correct EMAs and fp16 parameters. 114 | dist_util.sync_params(classifier.parameters()) 115 | 116 | if args.resume_checkpoint_classifier: 117 | classifier.load_state_dict( 118 | dist_util.load_state_dict( 119 | args.resume_checkpoint_classifier, map_location=dist_util.dev() 120 | ) 121 | ) 122 | logger.log( 123 | f"loading classifier from {args.resume_checkpoint_classifier}..." 124 | ) 125 | 126 | schedule_sampler = create_named_schedule_sampler( 127 | args.schedule_sampler, diffusion, args 128 | ) 129 | 130 | logger.log("creating data loaders...") 131 | train_dataset_splits, _, classes_per_task = data_split( 132 | dataset=train_dataset, 133 | return_classes=True, 134 | return_task_as_class=False, 135 | num_tasks=args.num_tasks, 136 | num_classes=n_classes, 137 | limit_classes=args.limit_classes, 138 | data_seed=args.data_seed, 139 | shared_classes=args.shared_classes, 140 | first_task_num_classes=args.first_task_num_classes, 141 | validation_frac=0.0, 142 | ) 143 | 144 | val_dataset_splits, _, classes_per_task = data_split( 145 | dataset=val_dataset, 146 | return_classes=True, 147 | return_task_as_class=False, 148 | num_tasks=args.num_tasks, 149 | num_classes=n_classes, 150 | limit_classes=args.limit_classes, 151 | data_seed=args.data_seed, 152 | shared_classes=args.shared_classes, 153 | first_task_num_classes=args.first_task_num_classes, 154 | validation_frac=0.0, 155 | ) 156 | 157 | train_loaders = [] 158 | validation_loaders = prepare_eval_loaders( 159 | train_dataset_splits=train_dataset_splits, 160 | val_dataset_splits=val_dataset_splits, 161 | args=args, 162 | include_train=False, 163 | generator=random_generator, 164 | ) 165 | test_acc_table = OrderedDict() 166 | train_acc_table = OrderedDict() 167 | 168 | train_loop = None 169 | cl_method = get_cl_method(args) 170 | global_step = 0 171 | dataset_yielder = None 172 | train_loader = None 173 | generated_previous_examples = None 174 | 175 | if args.limit_tasks != -1: 176 | n_tasks = args.limit_tasks 177 | else: 178 | n_tasks = args.num_tasks 179 | 180 | for task_id in range(n_tasks): 181 | if args.first_task_num_classes > 0 and task_id == 0: 182 | max_class = args.first_task_num_classes 183 | else: 184 | max_class = ( 185 | ((task_id + 1) * (n_classes // n_tasks)) - 1 186 | ) + args.first_task_num_classes 187 | 188 | if task_id == 0: 189 | if not args.train_with_disjoint_classifier: 190 | num_steps = args.first_task_num_steps 191 | else: 192 | num_steps = args.disjoint_classifier_init_num_steps 193 | else: 194 | if not args.train_with_disjoint_classifier: 195 | num_steps = args.num_steps 196 | else: 197 | num_steps = args.disjoint_classifier_num_steps 198 | 199 | train_loop = TrainLoop( 200 | params=args, 201 | model=model, 202 | prev_model=copy.deepcopy(model).to(dist_util.dev()), 203 | diffusion=diffusion, 204 | task_id=task_id, 205 | data=train_dataset_splits[task_id], 206 | data_yielder=None, 207 | data_loader=None, 208 | batch_size=args.batch_size, 209 | microbatch=args.microbatch, 210 | lr=args.lr, 211 | scheduler_rate=args.scheduler_rate, 212 | ema_rate=args.ema_rate, 213 | log_interval=args.log_interval, 214 | skip_save=args.skip_save, 215 | save_interval=args.save_interval, 216 | resume_checkpoint=( 217 | args.resume_checkpoint if task_id == args.first_task else None 218 | ), 219 | use_fp16=args.use_fp16, 220 | fp16_scale_growth=args.fp16_scale_growth, 221 | schedule_sampler=schedule_sampler, 222 | weight_decay=args.weight_decay, 223 | lr_anneal_steps=args.lr_anneal_steps, 224 | num_steps=num_steps, 225 | image_size=args.image_size, 226 | in_channels=args.in_channels, 227 | max_class=max_class, 228 | global_steps_before=global_step, 229 | cl_method=cl_method, 230 | classes_per_task=classes_per_task, 231 | use_ddim=args.use_ddim, 232 | classifier_scale_min_old=args.classifier_scale_min_old, 233 | classifier_scale_min_new=args.classifier_scale_min_new, 234 | classifier_scale_max_old=args.classifier_scale_max_old, 235 | classifier_scale_max_new=args.classifier_scale_max_new, 236 | guid_generation_interval=args.guid_generation_interval, 237 | use_old_grad=args.use_old_grad, 238 | use_new_grad=args.use_new_grad, 239 | guid_to_new_classes=args.guid_to_new_classes, 240 | trim_logits=args.trim_logits, 241 | disjoint_classifier=classifier, 242 | prev_disjoint_classifier=copy.deepcopy(classifier), 243 | diffusion_pretrained_dir=args.diffusion_pretrained_dir, 244 | train_transform_classifier=train_transform_classifier, 245 | train_transform_diffusion=train_transform_diffusion, 246 | n_classes=n_classes, 247 | random_generator=random_generator, 248 | classifier_first_task_dir=args.classifier_first_task_dir, 249 | train_noised_classifier=args.train_noised_classifier, 250 | ) 251 | 252 | if task_id >= args.first_task: 253 | ( 254 | dataset_yielder, 255 | train_loader, 256 | generated_previous_examples, 257 | ) = cl_method.get_data_for_task( 258 | dataset=train_dataset_splits[task_id], 259 | task_id=task_id, 260 | train_loop=train_loop, 261 | generator=random_generator, 262 | step=global_step, 263 | ) 264 | train_loaders.append(train_loader) 265 | logger.log(f"training task {task_id}") 266 | 267 | train_loop.data_yielder = dataset_yielder 268 | train_loop.data_loader = train_loader 269 | 270 | train_loop_start_time = time.time() 271 | if task_id >= args.first_task: 272 | train_loop.run_loop() 273 | global_step += num_steps 274 | train_loop_time = time.time() - train_loop_start_time 275 | wandb_safe_log({"train_loop_time": train_loop_time}, step=global_step) 276 | 277 | logger.log("validation...") 278 | test_acc_table[task_id] = OrderedDict() 279 | train_acc_table[task_id] = OrderedDict() 280 | validation_start_time = time.time() 281 | for j in range(task_id + 1): 282 | if args.train_with_disjoint_classifier: 283 | clf_results = calculate_accuracy_with_classifier( 284 | model=( 285 | model if not args.train_with_disjoint_classifier else classifier 286 | ), 287 | task_id=j, 288 | val_loader=validation_loaders[j], 289 | device=dist_util.dev(), 290 | train_loader=( 291 | train_loaders[j - args.first_task] 292 | if j >= args.first_task 293 | else None 294 | ), 295 | max_class=max_class, 296 | train_with_disjoint_classifier=args.train_with_disjoint_classifier, 297 | ) 298 | test_acc_table[j][task_id] = clf_results["accuracy"]["test"] 299 | train_acc_table[j][task_id] = clf_results["accuracy"]["train"] 300 | logger.log(f"Test accuracy task {j}: {test_acc_table[j][task_id]}") 301 | logger.log(f"Train accuracy task {j}: {train_acc_table[j][task_id]}") 302 | else: 303 | test_acc_table[j][task_id] = 0.0 304 | train_acc_table[j][task_id] = 0.0 305 | 306 | validation_time = time.time() - validation_start_time 307 | if logger.get_rank_without_mpi_import() == 0: 308 | results_to_log( 309 | test_acc_table, 310 | train_acc_table, 311 | validation_time=validation_time, 312 | step=global_step, 313 | task_id=task_id, 314 | ) 315 | if generated_previous_examples is not None: 316 | th.save( 317 | generated_previous_examples, 318 | os.path.join( 319 | logger.get_dir(), f"generated_examples/task_{task_id:02d}.pt" 320 | ), 321 | ) 322 | train_loop.prev_ddp_model = copy.deepcopy(model) 323 | logger.log("TEST ACCURACY TABLE:") 324 | logger.log(test_acc_table) 325 | logger.log("TRAIN ACCURACY TABLE:") 326 | logger.log(train_acc_table) 327 | 328 | 329 | def seed_everything(seed): 330 | th.manual_seed(seed) 331 | np.random.seed(seed) 332 | th.cuda.manual_seed(seed) 333 | th.backends.cudnn.deterministic = True 334 | th.backends.cudnn.benchmark = False 335 | random_generator = th.Generator() 336 | random_generator.manual_seed(seed) 337 | return random_generator 338 | 339 | 340 | def create_argparser(): 341 | defaults = all_training_defaults() 342 | parser = argparse.ArgumentParser() 343 | add_dict_to_argparser(parser, defaults) 344 | return parser 345 | 346 | 347 | if __name__ == "__main__": 348 | main() 349 | -------------------------------------------------------------------------------- /guide/script_util.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import matplotlib 4 | import numpy as np 5 | 6 | from . import gaussian_diffusion as gd 7 | from .logger import wandb_safe_log 8 | from .resnet import ResNet 9 | from .respace import SpacedDiffusion, space_timesteps 10 | from .script_args import model_and_diffusion_defaults 11 | from .unet import EncoderUNetModel, SuperResModel 12 | 13 | # NUM_CLASSES = 1000 14 | 15 | 16 | def create_model_and_diffusion( 17 | image_size, 18 | in_channels, 19 | learn_sigma, 20 | sigma_small, 21 | num_channels, 22 | num_res_blocks, 23 | channel_mult, 24 | num_heads, 25 | num_head_channels, 26 | num_heads_upsample, 27 | attention_resolutions, 28 | dropout, 29 | diffusion_steps, 30 | noise_schedule, 31 | timestep_respacing, 32 | use_kl, 33 | predict_xstart, 34 | rescale_timesteps, 35 | rescale_learned_sigmas, 36 | use_checkpoint, 37 | use_scale_shift_norm, 38 | resblock_updown, 39 | use_fp16, 40 | use_new_attention_order, 41 | model_name, 42 | model_switching_timestep, 43 | embedding_kind, 44 | model_num_classes=None, 45 | noise_marg_reg=False, 46 | train_noised_classifier=False, 47 | classifier_augmentation=True, 48 | ): 49 | model = create_model( 50 | image_size, 51 | in_channels, 52 | num_channels, 53 | num_res_blocks, 54 | model_name=model_name, 55 | model_switching_timestep=model_switching_timestep, 56 | embedding_kind=embedding_kind, 57 | channel_mult=channel_mult, 58 | learn_sigma=learn_sigma, 59 | use_checkpoint=use_checkpoint, 60 | attention_resolutions=attention_resolutions, 61 | num_heads=num_heads, 62 | num_head_channels=num_head_channels, 63 | num_heads_upsample=num_heads_upsample, 64 | use_scale_shift_norm=use_scale_shift_norm, 65 | dropout=dropout, 66 | resblock_updown=resblock_updown, 67 | use_fp16=use_fp16, 68 | use_new_attention_order=use_new_attention_order, 69 | num_classes=model_num_classes, 70 | classifier_augmentation=classifier_augmentation, 71 | ) 72 | diffusion = create_gaussian_diffusion( 73 | steps=diffusion_steps, 74 | learn_sigma=learn_sigma, 75 | sigma_small=sigma_small, 76 | noise_schedule=noise_schedule, 77 | use_kl=use_kl, 78 | predict_xstart=predict_xstart, 79 | rescale_timesteps=rescale_timesteps, 80 | rescale_learned_sigmas=rescale_learned_sigmas, 81 | timestep_respacing=timestep_respacing, 82 | noise_marg_reg=noise_marg_reg, 83 | train_noised_classifier=train_noised_classifier, 84 | ) 85 | 86 | return model, diffusion 87 | 88 | 89 | def create_model( 90 | image_size, 91 | in_channels, 92 | num_channels, 93 | num_res_blocks, 94 | model_name, 95 | model_switching_timestep, 96 | embedding_kind, 97 | channel_mult="", 98 | learn_sigma=False, 99 | use_checkpoint=False, 100 | attention_resolutions="16", 101 | num_heads=1, 102 | num_head_channels=-1, 103 | num_heads_upsample=-1, 104 | use_scale_shift_norm=False, 105 | dropout=0, 106 | resblock_updown=False, 107 | use_fp16=False, 108 | use_new_attention_order=False, 109 | num_classes=None, 110 | classifier_augmentation=False, 111 | ): 112 | if channel_mult == "": 113 | if image_size == 512: 114 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 115 | elif image_size == 256: 116 | channel_mult = (1, 1, 2, 2, 4, 4) 117 | elif image_size in {128, 224}: 118 | channel_mult = (1, 1, 2, 3, 4) 119 | elif image_size == 64: 120 | channel_mult = (1, 2, 3, 4) 121 | elif image_size == 32: 122 | channel_mult = (1, 2, 2, 2) 123 | elif image_size == 28: 124 | channel_mult = (1, 2, 2) 125 | elif image_size == 1: 126 | channel_mult = (1, 1, 1) 127 | else: 128 | raise ValueError(f"unsupported image size: {image_size}") 129 | else: 130 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 131 | 132 | attention_ds = [] 133 | for res in attention_resolutions.split(","): 134 | attention_ds.append(image_size // int(res)) 135 | 136 | if model_name == "UNetModel": 137 | print("Using single model") 138 | from .unet import UNetModel as Model 139 | elif model_name == "MLPModel": 140 | from .mlp import MLPModel as Model 141 | else: 142 | raise NotImplementedError 143 | return Model( 144 | image_size=image_size, 145 | in_channels=in_channels, 146 | model_channels=num_channels, 147 | out_channels=(in_channels if not learn_sigma else in_channels * 2), 148 | num_res_blocks=num_res_blocks, 149 | attention_resolutions=tuple(attention_ds), 150 | dropout=dropout, 151 | channel_mult=channel_mult, 152 | num_classes=num_classes, 153 | use_checkpoint=use_checkpoint, 154 | use_fp16=use_fp16, 155 | num_heads=num_heads, 156 | num_head_channels=num_head_channels, 157 | num_heads_upsample=num_heads_upsample, 158 | use_scale_shift_norm=use_scale_shift_norm, 159 | resblock_updown=resblock_updown, 160 | use_new_attention_order=use_new_attention_order, 161 | model_switching_timestep=model_switching_timestep, 162 | embedding_kind=embedding_kind, 163 | classifier_augmentation=classifier_augmentation, 164 | ) 165 | 166 | 167 | def create_classifier_and_diffusion( 168 | image_size, 169 | classifier_use_fp16, 170 | classifier_width, 171 | classifier_depth, 172 | classifier_attention_resolutions, 173 | classifier_use_scale_shift_norm, 174 | classifier_resblock_updown, 175 | classifier_pool, 176 | learn_sigma, 177 | sigma_small, 178 | diffusion_steps, 179 | noise_schedule, 180 | timestep_respacing, 181 | use_kl, 182 | predict_xstart, 183 | rescale_timesteps, 184 | rescale_learned_sigmas, 185 | ): 186 | classifier = create_classifier( 187 | image_size, 188 | classifier_use_fp16, 189 | classifier_width, 190 | classifier_depth, 191 | classifier_attention_resolutions, 192 | classifier_use_scale_shift_norm, 193 | classifier_resblock_updown, 194 | classifier_pool, 195 | ) 196 | diffusion = create_gaussian_diffusion( 197 | steps=diffusion_steps, 198 | learn_sigma=learn_sigma, 199 | sigma_small=sigma_small, 200 | noise_schedule=noise_schedule, 201 | use_kl=use_kl, 202 | predict_xstart=predict_xstart, 203 | rescale_timesteps=rescale_timesteps, 204 | rescale_learned_sigmas=rescale_learned_sigmas, 205 | timestep_respacing=timestep_respacing, 206 | ) 207 | return classifier, diffusion 208 | 209 | 210 | def create_classifier( 211 | image_size, 212 | classifier_use_fp16, 213 | classifier_width, 214 | classifier_depth, 215 | classifier_attention_resolutions, 216 | classifier_use_scale_shift_norm, 217 | classifier_resblock_updown, 218 | classifier_pool, 219 | ): 220 | if image_size == 512: 221 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 222 | elif image_size == 256: 223 | channel_mult = (1, 1, 2, 2, 4, 4) 224 | elif image_size in {128, 224}: 225 | channel_mult = (1, 1, 2, 3, 4) 226 | elif image_size == 64: 227 | channel_mult = (1, 2, 3, 4) 228 | elif image_size == 32: 229 | channel_mult = (1, 2, 2, 2) 230 | else: 231 | raise ValueError(f"unsupported image size: {image_size}") 232 | 233 | attention_ds = [] 234 | for res in classifier_attention_resolutions.split(","): 235 | attention_ds.append(image_size // int(res)) 236 | 237 | return EncoderUNetModel( 238 | image_size=image_size, 239 | in_channels=3, 240 | model_channels=classifier_width, 241 | out_channels=1000, 242 | num_res_blocks=classifier_depth, 243 | attention_resolutions=tuple(attention_ds), 244 | channel_mult=channel_mult, 245 | use_fp16=classifier_use_fp16, 246 | num_head_channels=64, 247 | use_scale_shift_norm=classifier_use_scale_shift_norm, 248 | resblock_updown=classifier_resblock_updown, 249 | pool=classifier_pool, 250 | ) 251 | 252 | 253 | def create_resnet_classifier(args): 254 | classifier = ResNet(args) 255 | return classifier 256 | 257 | 258 | def sr_model_and_diffusion_defaults(): 259 | res = model_and_diffusion_defaults() 260 | res["large_size"] = 256 261 | res["small_size"] = 64 262 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 263 | for k in res.copy().keys(): 264 | if k not in arg_names: 265 | del res[k] 266 | return res 267 | 268 | 269 | def sr_create_model_and_diffusion( 270 | large_size, 271 | small_size, 272 | learn_sigma, 273 | sigma_small, 274 | num_channels, 275 | num_res_blocks, 276 | num_heads, 277 | num_head_channels, 278 | num_heads_upsample, 279 | attention_resolutions, 280 | dropout, 281 | diffusion_steps, 282 | noise_schedule, 283 | timestep_respacing, 284 | use_kl, 285 | predict_xstart, 286 | rescale_timesteps, 287 | rescale_learned_sigmas, 288 | use_checkpoint, 289 | use_scale_shift_norm, 290 | resblock_updown, 291 | use_fp16, 292 | ): 293 | model = sr_create_model( 294 | large_size, 295 | small_size, 296 | num_channels, 297 | num_res_blocks, 298 | learn_sigma=learn_sigma, 299 | use_checkpoint=use_checkpoint, 300 | attention_resolutions=attention_resolutions, 301 | num_heads=num_heads, 302 | num_head_channels=num_head_channels, 303 | num_heads_upsample=num_heads_upsample, 304 | use_scale_shift_norm=use_scale_shift_norm, 305 | dropout=dropout, 306 | resblock_updown=resblock_updown, 307 | use_fp16=use_fp16, 308 | ) 309 | diffusion = create_gaussian_diffusion( 310 | steps=diffusion_steps, 311 | learn_sigma=learn_sigma, 312 | sigma_small=sigma_small, 313 | noise_schedule=noise_schedule, 314 | use_kl=use_kl, 315 | predict_xstart=predict_xstart, 316 | rescale_timesteps=rescale_timesteps, 317 | rescale_learned_sigmas=rescale_learned_sigmas, 318 | timestep_respacing=timestep_respacing, 319 | ) 320 | return model, diffusion 321 | 322 | 323 | def sr_create_model( 324 | large_size, 325 | small_size, 326 | num_channels, 327 | num_res_blocks, 328 | learn_sigma, 329 | use_checkpoint, 330 | attention_resolutions, 331 | num_heads, 332 | num_head_channels, 333 | num_heads_upsample, 334 | use_scale_shift_norm, 335 | dropout, 336 | resblock_updown, 337 | use_fp16, 338 | ): 339 | _ = small_size # hack to prevent unused variable 340 | 341 | if large_size == 512: 342 | channel_mult = (1, 1, 2, 2, 4, 4) 343 | elif large_size == 256: 344 | channel_mult = (1, 1, 2, 2, 4, 4) 345 | elif large_size == 64: 346 | channel_mult = (1, 2, 3, 4) 347 | else: 348 | raise ValueError(f"unsupported large size: {large_size}") 349 | 350 | attention_ds = [] 351 | for res in attention_resolutions.split(","): 352 | attention_ds.append(large_size // int(res)) 353 | 354 | return SuperResModel( 355 | image_size=large_size, 356 | in_channels=3, 357 | model_channels=num_channels, 358 | out_channels=(3 if not learn_sigma else 6), 359 | num_res_blocks=num_res_blocks, 360 | attention_resolutions=tuple(attention_ds), 361 | dropout=dropout, 362 | channel_mult=channel_mult, 363 | use_checkpoint=use_checkpoint, 364 | num_heads=num_heads, 365 | num_head_channels=num_head_channels, 366 | num_heads_upsample=num_heads_upsample, 367 | use_scale_shift_norm=use_scale_shift_norm, 368 | resblock_updown=resblock_updown, 369 | use_fp16=use_fp16, 370 | ) 371 | 372 | 373 | def create_gaussian_diffusion( 374 | *, 375 | steps=1000, 376 | learn_sigma=False, 377 | sigma_small=False, 378 | noise_schedule="linear", 379 | use_kl=False, 380 | predict_xstart=False, 381 | predict_xprevious=False, 382 | rescale_timesteps=False, 383 | rescale_learned_sigmas=False, 384 | timestep_respacing="", 385 | noise_marg_reg=False, 386 | train_noised_classifier=False, 387 | ): 388 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 389 | if use_kl: 390 | loss_type = gd.LossType.RESCALED_KL 391 | elif rescale_learned_sigmas: 392 | loss_type = gd.LossType.RESCALED_MSE 393 | else: 394 | loss_type = gd.LossType.MSE 395 | if not timestep_respacing: 396 | timestep_respacing = [steps] 397 | 398 | if predict_xstart: 399 | model_mean_type = gd.ModelMeanType.START_X 400 | elif predict_xprevious: 401 | model_mean_type = gd.ModelMeanType.PREVIOUS_X 402 | else: 403 | model_mean_type = gd.ModelMeanType.EPSILON 404 | 405 | return SpacedDiffusion( 406 | use_timesteps=space_timesteps(steps, timestep_respacing), 407 | betas=betas, 408 | model_mean_type=model_mean_type, 409 | model_var_type=( 410 | ( 411 | gd.ModelVarType.FIXED_LARGE 412 | if not sigma_small 413 | else gd.ModelVarType.FIXED_SMALL 414 | ) 415 | if not learn_sigma 416 | else gd.ModelVarType.LEARNED_RANGE 417 | ), 418 | loss_type=loss_type, 419 | rescale_timesteps=rescale_timesteps, 420 | noise_marg_reg=noise_marg_reg, 421 | train_noised_classifier=train_noised_classifier, 422 | ) 423 | 424 | 425 | def dict2array(results): 426 | tasks = len(results[0]) 427 | array = np.zeros((tasks, tasks)) 428 | for e, (key, val) in enumerate(reversed(results.items())): 429 | for e1, (k, v) in enumerate(reversed(val.items())): 430 | array[tasks - int(e1) - 1, tasks - int(e) - 1] = round(v, 3) 431 | return np.transpose(array, axes=(1, 0)) 432 | 433 | 434 | def grid_plot(ax, array, type): 435 | if type == "fid": 436 | round = 1 437 | else: 438 | round = 2 439 | avg_array = np.around(array, round) 440 | num_tasks = array.shape[1] 441 | cmap = matplotlib.colors.LinearSegmentedColormap.from_list( 442 | "", ["#287233", "#4c1c24"] 443 | ) 444 | ax.imshow(avg_array, vmin=50, vmax=300, cmap=cmap) 445 | for i in range(len(avg_array)): 446 | for j in range(avg_array.shape[1]): 447 | if j >= i: 448 | ax.text( 449 | j, 450 | i, 451 | avg_array[i, j], 452 | va="center", 453 | ha="center", 454 | c="w", 455 | fontsize=70 / num_tasks, 456 | ) 457 | ax.set_yticks(np.arange(num_tasks)) 458 | ax.set_ylabel("Number of tasks") 459 | ax.set_xticks(np.arange(num_tasks)) 460 | ax.set_xlabel("Tasks finished") 461 | ax.set_title( 462 | f"{type} -- {np.round(np.mean(array[:, -1]), 3)} -- std {np.round(np.std(array[:, -1]), 2)}" 463 | ) 464 | 465 | 466 | def results_to_log( 467 | test_acc_table, 468 | train_acc_table, 469 | validation_time, 470 | step, 471 | task_id, 472 | ): 473 | log_dict = {"validation_time": validation_time} 474 | avg_acc_train = 0.0 475 | avg_acc_test = 0.0 476 | avg_forgetting_train = 0.0 477 | avg_forgetting_test = 0.0 478 | for j in range(task_id + 1): 479 | log_dict.update( 480 | { 481 | f"test/accuracy/{j}": test_acc_table[j][task_id], 482 | f"train/accuracy/{j}": train_acc_table[j][task_id], 483 | } 484 | ) 485 | avg_acc_train += train_acc_table[j][task_id] 486 | avg_acc_test += test_acc_table[j][task_id] 487 | 488 | max_forgetting_train = 0.0 if (j == task_id) else -float("inf") 489 | max_forgetting_test = 0.0 if (j == task_id) else -float("inf") 490 | for k in range(j, task_id): 491 | max_forgetting_train = max( 492 | max_forgetting_train, 493 | train_acc_table[j][k] - train_acc_table[j][task_id], 494 | ) 495 | max_forgetting_test = max( 496 | max_forgetting_test, test_acc_table[j][k] - test_acc_table[j][task_id] 497 | ) 498 | avg_forgetting_train += max_forgetting_train 499 | avg_forgetting_test += max_forgetting_test 500 | 501 | log_dict.update( 502 | { 503 | "test/avg_accuracy": avg_acc_test / (task_id + 1), 504 | "train/avg_accuracy": avg_acc_train / (task_id + 1), 505 | "train/avg_forgetting": ( 506 | (avg_forgetting_train / task_id) if task_id > 0 else 0.0 507 | ), 508 | "test/avg_forgetting": ( 509 | (avg_forgetting_test / task_id) if task_id > 0 else 0.0 510 | ), 511 | } 512 | ) 513 | 514 | wandb_safe_log(log_dict, step=step) 515 | 516 | # return PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) 517 | -------------------------------------------------------------------------------- /guide/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import datetime 7 | import json 8 | import os 9 | import os.path as osp 10 | import sys 11 | import tempfile 12 | import time 13 | import warnings 14 | from collections import defaultdict 15 | from contextlib import contextmanager 16 | 17 | import numpy as np 18 | import torch as th 19 | from matplotlib import pyplot as plt 20 | from torchvision.utils import make_grid 21 | 22 | import wandb 23 | 24 | DEBUG = 10 25 | INFO = 20 26 | WARN = 30 27 | ERROR = 40 28 | 29 | DISABLED = 50 30 | 31 | 32 | class KVWriter(object): 33 | def writekvs(self, kvs): 34 | raise NotImplementedError 35 | 36 | 37 | class SeqWriter(object): 38 | def writeseq(self, seq): 39 | raise NotImplementedError 40 | 41 | 42 | class HumanOutputFormat(KVWriter, SeqWriter): 43 | def __init__(self, filename_or_file): 44 | if isinstance(filename_or_file, str): 45 | self.file = open(filename_or_file, "wt") 46 | self.own_file = True 47 | else: 48 | assert hasattr(filename_or_file, "read"), ( 49 | "expected file or str, got %s" % filename_or_file 50 | ) 51 | self.file = filename_or_file 52 | self.own_file = False 53 | 54 | def writekvs(self, kvs): 55 | # Create strings for printing 56 | key2str = {} 57 | for key, val in sorted(kvs.items()): 58 | if hasattr(val, "__float__"): 59 | valstr = "%-8.3g" % val 60 | else: 61 | valstr = str(val) 62 | key2str[self._truncate(key)] = self._truncate(valstr) 63 | 64 | # Find max widths 65 | if len(key2str) == 0: 66 | print("WARNING: tried to write empty key-value dict") 67 | return 68 | else: 69 | keywidth = max(map(len, key2str.keys())) 70 | valwidth = max(map(len, key2str.values())) 71 | 72 | # Write out the data 73 | dashes = "-" * (keywidth + valwidth + 7) 74 | lines = [dashes] 75 | for key, val in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 76 | lines.append( 77 | "| %s%s | %s%s |" 78 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 79 | ) 80 | lines.append(dashes) 81 | self.file.write("\n".join(lines) + "\n") 82 | 83 | # Flush the output to the file 84 | self.file.flush() 85 | 86 | def _truncate(self, s): 87 | maxlen = 30 88 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 89 | 90 | def writeseq(self, seq): 91 | seq = list(seq) 92 | for i, elem in enumerate(seq): 93 | self.file.write(elem) 94 | if i < len(seq) - 1: # add space unless this is the last one 95 | self.file.write(" ") 96 | self.file.write("\n") 97 | self.file.flush() 98 | 99 | def close(self): 100 | if self.own_file: 101 | self.file.close() 102 | 103 | 104 | class JSONOutputFormat(KVWriter): 105 | def __init__(self, filename): 106 | self.file = open(filename, "wt") 107 | 108 | def writekvs(self, kvs): 109 | for k, v in sorted(kvs.items()): 110 | if hasattr(v, "dtype"): 111 | kvs[k] = float(v) 112 | self.file.write(json.dumps(kvs) + "\n") 113 | self.file.flush() 114 | 115 | def close(self): 116 | self.file.close() 117 | 118 | 119 | class CSVOutputFormat(KVWriter): 120 | def __init__(self, filename): 121 | self.file = open(filename, "w+t") 122 | self.keys = [] 123 | self.sep = "," 124 | 125 | def writekvs(self, kvs): 126 | # Add our current row to the history 127 | extra_keys = list(kvs.keys() - self.keys) 128 | extra_keys.sort() 129 | if extra_keys: 130 | self.keys.extend(extra_keys) 131 | self.file.seek(0) 132 | lines = self.file.readlines() 133 | self.file.seek(0) 134 | for i, k in enumerate(self.keys): 135 | if i > 0: 136 | self.file.write(",") 137 | self.file.write(k) 138 | self.file.write("\n") 139 | for line in lines[1:]: 140 | self.file.write(line[:-1]) 141 | self.file.write(self.sep * len(extra_keys)) 142 | self.file.write("\n") 143 | for i, k in enumerate(self.keys): 144 | if i > 0: 145 | self.file.write(",") 146 | v = kvs.get(k) 147 | if v is not None: 148 | self.file.write(str(v)) 149 | self.file.write("\n") 150 | self.file.flush() 151 | 152 | def close(self): 153 | self.file.close() 154 | 155 | 156 | class TensorBoardOutputFormat(KVWriter): 157 | """ 158 | Dumps key/value pairs into TensorBoard's numeric format. 159 | """ 160 | 161 | def __init__(self, dir): 162 | os.makedirs(dir, exist_ok=True) 163 | self.dir = dir 164 | self.step = 1 165 | prefix = "events" 166 | path = osp.join(osp.abspath(dir), prefix) 167 | import tensorflow as tf 168 | from tensorflow.core.util import event_pb2 169 | from tensorflow.python import pywrap_tensorflow 170 | from tensorflow.python.util import compat 171 | 172 | self.tf = tf 173 | self.event_pb2 = event_pb2 174 | self.pywrap_tensorflow = pywrap_tensorflow 175 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 176 | 177 | def writekvs(self, kvs): 178 | def summary_val(k, v): 179 | kwargs = {"tag": k, "simple_value": float(v)} 180 | return self.tf.Summary.Value(**kwargs) 181 | 182 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 183 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 184 | event.step = ( 185 | self.step 186 | ) # is there any reason why you'd want to specify the step? 187 | self.writer.WriteEvent(event) 188 | self.writer.Flush() 189 | self.step += 1 190 | 191 | def close(self): 192 | if self.writer: 193 | self.writer.Close() 194 | self.writer = None 195 | 196 | 197 | def make_output_format(format, ev_dir, log_suffix=""): 198 | os.makedirs(ev_dir, exist_ok=True) 199 | if format == "stdout": 200 | return HumanOutputFormat(sys.stdout) 201 | elif format == "log": 202 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 203 | elif format == "json": 204 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 205 | elif format == "csv": 206 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 207 | elif format == "tensorboard": 208 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 209 | else: 210 | raise ValueError("Unknown format specified: %s" % (format,)) 211 | 212 | 213 | # ================================================================ 214 | # API 215 | # ================================================================ 216 | 217 | 218 | def logkv(key, val): 219 | """ 220 | Log a value of some diagnostic 221 | Call this once for each diagnostic quantity, each iteration 222 | If called many times, last value will be used. 223 | """ 224 | get_current().logkv(key, val) 225 | 226 | 227 | def logkv_mean(key, val): 228 | """ 229 | The same as logkv(), but if called many times, values averaged. 230 | """ 231 | get_current().logkv_mean(key, val) 232 | 233 | 234 | def logkvs(d): 235 | """ 236 | Log a dictionary of key-value pairs 237 | """ 238 | for k, v in d.items(): 239 | logkv(k, v) 240 | 241 | 242 | def dumpkvs(): 243 | """ 244 | Write all of the diagnostics from the current iteration 245 | """ 246 | return get_current().dumpkvs() 247 | 248 | 249 | def getkvs(): 250 | return get_current().name2val 251 | 252 | 253 | def log(*args, level=INFO): 254 | """ 255 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 256 | """ 257 | get_current().log(*args, level=level) 258 | 259 | 260 | def debug(*args): 261 | log(*args, level=DEBUG) 262 | 263 | 264 | def info(*args): 265 | log(*args, level=INFO) 266 | 267 | 268 | def warn(*args): 269 | log(*args, level=WARN) 270 | 271 | 272 | def error(*args): 273 | log(*args, level=ERROR) 274 | 275 | 276 | def set_level(level): 277 | """ 278 | Set logging threshold on current logger. 279 | """ 280 | get_current().set_level(level) 281 | 282 | 283 | def set_comm(comm): 284 | get_current().set_comm(comm) 285 | 286 | 287 | def get_dir(): 288 | """ 289 | Get directory that log files are being written to. 290 | will be None if there is no output directory (i.e., if you didn't call start) 291 | """ 292 | return get_current().get_dir() 293 | 294 | 295 | record_tabular = logkv 296 | dump_tabular = dumpkvs 297 | 298 | 299 | @contextmanager 300 | def profile_kv(scopename): 301 | logkey = "wait_" + scopename 302 | tstart = time.time() 303 | try: 304 | yield 305 | finally: 306 | get_current().name2val[logkey] += time.time() - tstart 307 | 308 | 309 | def profile(n): 310 | """ 311 | Usage: 312 | @profile("my_func") 313 | def my_func(): code 314 | """ 315 | 316 | def decorator_with_name(func): 317 | def func_wrapper(*args, **kwargs): 318 | with profile_kv(n): 319 | return func(*args, **kwargs) 320 | 321 | return func_wrapper 322 | 323 | return decorator_with_name 324 | 325 | 326 | # ================================================================ 327 | # Backend 328 | # ================================================================ 329 | 330 | 331 | def get_current(): 332 | if Logger.CURRENT is None: 333 | _configure_default_logger() 334 | 335 | return Logger.CURRENT 336 | 337 | 338 | class Logger(object): 339 | DEFAULT = None # A logger with no output files. (See right below class definition) 340 | # So that you can still log to the terminal without setting up any output files 341 | CURRENT = None # Current logger being used by the free functions above 342 | 343 | def __init__(self, dir, output_formats, comm=None): 344 | self.name2val = defaultdict(float) # values this iteration 345 | self.name2cnt = defaultdict(int) 346 | self.level = INFO 347 | self.dir = dir 348 | self.output_formats = output_formats 349 | self.comm = comm 350 | 351 | # Logging API, forwarded 352 | # ---------------------------------------- 353 | def logkv(self, key, val): 354 | self.name2val[key] = val 355 | 356 | def logkv_mean(self, key, val): 357 | oldval, cnt = self.name2val[key], self.name2cnt[key] 358 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 359 | self.name2cnt[key] = cnt + 1 360 | 361 | def dumpkvs(self): 362 | if self.comm is None: 363 | d = self.name2val 364 | else: 365 | d = mpi_weighted_mean( 366 | self.comm, 367 | { 368 | name: (val, self.name2cnt.get(name, 1)) 369 | for (name, val) in self.name2val.items() 370 | }, 371 | ) 372 | if self.comm.rank != 0: 373 | d["dummy"] = 1 # so we don't get a warning about empty dict 374 | out = d.copy() # Return the dict for unit testing purposes 375 | for fmt in self.output_formats: 376 | if isinstance(fmt, KVWriter): 377 | fmt.writekvs(d) 378 | self.name2val.clear() 379 | self.name2cnt.clear() 380 | return out 381 | 382 | def log(self, *args, level=INFO): 383 | if self.level <= level: 384 | self._do_log(args) 385 | 386 | # Configuration 387 | # ---------------------------------------- 388 | def set_level(self, level): 389 | self.level = level 390 | 391 | def set_comm(self, comm): 392 | self.comm = comm 393 | 394 | def get_dir(self): 395 | return self.dir 396 | 397 | def close(self): 398 | for fmt in self.output_formats: 399 | fmt.close() 400 | 401 | # Misc 402 | # ---------------------------------------- 403 | def _do_log(self, args): 404 | for fmt in self.output_formats: 405 | if isinstance(fmt, SeqWriter): 406 | fmt.writeseq(map(str, args)) 407 | 408 | 409 | def get_rank_without_mpi_import(): 410 | # check environment variables here instead of importing mpi4py 411 | # to avoid calling MPI_Init() when this module is imported 412 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 413 | if varname in os.environ: 414 | return int(os.environ[varname]) 415 | return 0 416 | 417 | 418 | def mpi_weighted_mean(comm, local_name2valcount): 419 | """ 420 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 421 | Perform a weighted average over dicts that are each on a different node 422 | Input: local_name2valcount: dict mapping key -> (value, count) 423 | Returns: key -> mean 424 | """ 425 | all_name2valcount = comm.gather(local_name2valcount) 426 | if comm.rank == 0: 427 | name2sum = defaultdict(float) 428 | name2count = defaultdict(float) 429 | for n2vc in all_name2valcount: 430 | for name, (val, count) in n2vc.items(): 431 | try: 432 | val = float(val) 433 | except ValueError: 434 | if comm.rank == 0: 435 | warnings.warn( 436 | "WARNING: tried to compute mean on non-float {}={}".format( 437 | name, val 438 | ) 439 | ) 440 | else: 441 | name2sum[name] += val * count 442 | name2count[name] += count 443 | return {name: name2sum[name] / name2count[name] for name in name2sum} 444 | else: 445 | return {} 446 | 447 | 448 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 449 | """ 450 | If comm is provided, average all numerical stats across that comm 451 | """ 452 | if dir is None: 453 | dir = os.getenv("OPENAI_LOGDIR") 454 | if dir is None: 455 | dir = osp.join( 456 | tempfile.gettempdir(), 457 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 458 | ) 459 | assert isinstance(dir, str) 460 | dir = os.path.expanduser(dir) 461 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 462 | 463 | rank = get_rank_without_mpi_import() 464 | if rank > 0: 465 | log_suffix = log_suffix + "-rank%03i" % rank 466 | 467 | if format_strs is None: 468 | if rank == 0: 469 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 470 | else: 471 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 472 | format_strs = filter(None, format_strs) 473 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 474 | 475 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 476 | if output_formats: 477 | log("Logging to %s" % dir) 478 | 479 | 480 | def _configure_default_logger(): 481 | configure() 482 | Logger.DEFAULT = Logger.CURRENT 483 | 484 | 485 | def reset(): 486 | if Logger.CURRENT is not Logger.DEFAULT: 487 | Logger.CURRENT.close() 488 | Logger.CURRENT = Logger.DEFAULT 489 | log("Reset logger") 490 | 491 | 492 | @contextmanager 493 | def scoped_configure(dir=None, format_strs=None, comm=None): 494 | prevlogger = Logger.CURRENT 495 | configure(dir=dir, format_strs=format_strs, comm=comm) 496 | try: 497 | yield 498 | finally: 499 | Logger.CURRENT.close() 500 | Logger.CURRENT = prevlogger 501 | 502 | 503 | def wandb_safe_log(*args, **kwargs): 504 | # Try several times. 505 | for _ in range(10): 506 | try: 507 | wandb.log(*args, **kwargs) 508 | except: 509 | time.sleep(20) 510 | else: 511 | break 512 | 513 | 514 | def log_generated_examples( 515 | examples, labels, confidences, task_id, step, n_examples_to_log=4 516 | ): 517 | fig, ax = plt.subplots() 518 | _, _, bars = ax.hist( 519 | [c.item() for c in labels], 520 | bins=np.arange(labels.max().item() + 2) - 0.5, 521 | edgecolor="black", 522 | ) 523 | ax.set_xticks(range(int(labels.max().item()) + 1)) 524 | ax.set_xlabel("Generated class") 525 | ax.set_ylabel("Count") 526 | ax.bar_label(bars) 527 | ax.set_title(f"generated_classes/{task_id}") 528 | wandb.log({f"generated_classes/{task_id}": wandb.Image(fig)}, step=step) 529 | 530 | if confidences is not None: 531 | fig, ax = plt.subplots() 532 | _, _, bars = ax.hist( 533 | [c.item() for c in confidences], 534 | bins=np.linspace(0, 1, 6), 535 | edgecolor="black", 536 | ) 537 | ax.set_xlabel("Confidence") 538 | ax.set_ylabel("Count") 539 | ax.bar_label(bars) 540 | ax.set_title(f"generated_examples_confidences/{task_id}") 541 | wandb.log( 542 | {f"generated_examples_confidences/{task_id}": wandb.Image(fig)}, step=step 543 | ) 544 | 545 | # Additionaly plot the generated examples 546 | unique_generated_classes = labels.unique() 547 | min_n_of_samples_per_class = n_examples_to_log 548 | images_to_log = [] 549 | for class_label in unique_generated_classes: 550 | class_indices = (labels == class_label).nonzero(as_tuple=True)[0] 551 | selected_indices = th.randperm(len(class_indices))[:min_n_of_samples_per_class] 552 | if len(class_indices) < min_n_of_samples_per_class: 553 | min_n_of_samples_per_class = len(class_indices) 554 | selected_images = examples[class_indices[selected_indices]].cpu() 555 | images_to_log.append(selected_images) 556 | if min_n_of_samples_per_class > 0: 557 | # assure the equal number of samples per class 558 | images_to_log = [ 559 | samples[:min_n_of_samples_per_class] for samples in images_to_log 560 | ] 561 | samples_grid = make_grid( 562 | th.cat(images_to_log, dim=0), 563 | nrow=min_n_of_samples_per_class, 564 | normalize=True, 565 | ) 566 | wandb.log( 567 | {f"generated_examples/{task_id}": wandb.Image(samples_grid)}, step=step 568 | ) 569 | -------------------------------------------------------------------------------- /dataloaders/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import kornia as K 4 | import torch 5 | import torchvision 6 | from torch.utils.data import DataLoader, Dataset 7 | from torchvision import transforms 8 | 9 | from .wrapper import CacheClassLabel 10 | 11 | 12 | class FastDataset(Dataset): 13 | def __init__(self, data, labels): 14 | self.dataset = data 15 | self.labels = labels 16 | 17 | def __len__(self): 18 | return len(self.dataset) 19 | 20 | def __getitem__(self, index): 21 | return self.dataset[index], self.labels[index] 22 | 23 | 24 | def CIFAR10( 25 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False 26 | ): 27 | train_transform_clf = None 28 | train_transform_diff = None 29 | # augmentation for diffusion training 30 | if train_aug: 31 | train_transform_diff = K.augmentation.ImageSequential( 32 | K.augmentation.RandomHorizontalFlip(), 33 | ) 34 | 35 | # augmentation for classifier training 36 | if classifier_augmentation: 37 | train_transform_clf = K.augmentation.ImageSequential( 38 | K.augmentation.Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 39 | K.augmentation.RandomCrop((32, 32), padding=4), 40 | K.augmentation.RandomRotation(30), 41 | K.augmentation.RandomHorizontalFlip(), 42 | K.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1), 43 | K.augmentation.RandomErasing(scale=(0.1, 0.5)), 44 | K.augmentation.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 45 | ) 46 | 47 | target_transform = transforms.Lambda(lambda y: torch.eye(10)[y]) 48 | 49 | train_dataset = torchvision.datasets.CIFAR10( 50 | root=dataroot, 51 | train=True, 52 | download=True, 53 | transform=transforms.Compose( 54 | [ 55 | transforms.ToTensor(), 56 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 57 | ] 58 | ), 59 | target_transform=target_transform, 60 | ) 61 | train_dataset = CacheClassLabel( 62 | train_dataset, 63 | target_transform=target_transform, 64 | ) 65 | 66 | val_dataset = torchvision.datasets.CIFAR10( 67 | root=dataroot, 68 | train=False, 69 | download=True, 70 | transform=transforms.Compose( 71 | [ 72 | transforms.ToTensor(), 73 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 74 | ] 75 | ), 76 | target_transform=target_transform, 77 | ) 78 | val_dataset = CacheClassLabel( 79 | val_dataset, 80 | target_transform=target_transform, 81 | ) 82 | print("Loading data") 83 | save_path = f"{dataroot}/fast_cifar10_train" 84 | if os.path.exists(save_path): 85 | fast_cifar_train = torch.load(save_path) 86 | else: 87 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset)) 88 | data = next(iter(train_loader)) 89 | fast_cifar_train = FastDataset(data[0], data[1]) 90 | torch.save(fast_cifar_train, save_path) 91 | 92 | save_path = f"{dataroot}/fast_cifar10_val" 93 | if os.path.exists(save_path): 94 | fast_cifar_val = torch.load(save_path) 95 | else: 96 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset)) 97 | data = next(iter(val_loader)) 98 | fast_cifar_val = FastDataset(data[0], data[1]) 99 | torch.save(fast_cifar_val, save_path) 100 | 101 | return ( 102 | fast_cifar_train, 103 | fast_cifar_val, 104 | 32, 105 | 3, 106 | train_transform_clf, 107 | train_transform_diff, 108 | 10, 109 | ) 110 | 111 | 112 | def CIFAR100( 113 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False 114 | ): 115 | train_transform_clf = None 116 | train_transform_diff = None 117 | # augmentation for diffusion training 118 | if train_aug: 119 | train_transform_diff = K.augmentation.ImageSequential( 120 | K.augmentation.RandomHorizontalFlip(), 121 | ) 122 | 123 | # augmentation for classifier training 124 | if classifier_augmentation: 125 | train_transform_clf = K.augmentation.ImageSequential( 126 | K.augmentation.Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 127 | K.augmentation.RandomCrop((32, 32), padding=4), 128 | K.augmentation.RandomRotation(30), 129 | K.augmentation.RandomHorizontalFlip(), 130 | K.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1), 131 | K.augmentation.RandomErasing(scale=(0.1, 0.5)), 132 | K.augmentation.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 133 | ) 134 | 135 | target_transform = transforms.Lambda(lambda y: torch.eye(100)[y]) 136 | 137 | train_dataset = torchvision.datasets.CIFAR100( 138 | root=dataroot, 139 | train=True, 140 | download=True, 141 | transform=transforms.Compose( 142 | [ 143 | transforms.ToTensor(), 144 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 145 | ] 146 | ), 147 | target_transform=target_transform, 148 | ) 149 | train_dataset = CacheClassLabel( 150 | train_dataset, 151 | target_transform=target_transform, 152 | ) 153 | 154 | val_dataset = torchvision.datasets.CIFAR100( 155 | root=dataroot, 156 | train=False, 157 | download=True, 158 | transform=transforms.Compose( 159 | [ 160 | transforms.ToTensor(), 161 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 162 | ] 163 | ), 164 | target_transform=target_transform, 165 | ) 166 | val_dataset = CacheClassLabel( 167 | val_dataset, 168 | target_transform=target_transform, 169 | ) 170 | 171 | print("Loading data") 172 | save_path = f"{dataroot}/fast_cifar100_train" 173 | if os.path.exists(save_path): 174 | fast_cifar_train = torch.load(save_path) 175 | else: 176 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset)) 177 | data = next(iter(train_loader)) 178 | fast_cifar_train = FastDataset(data[0], data[1]) 179 | torch.save(fast_cifar_train, save_path) 180 | 181 | save_path = f"{dataroot}/fast_cifar100_val" 182 | if os.path.exists(save_path): 183 | fast_cifar_val = torch.load(save_path) 184 | else: 185 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset)) 186 | data = next(iter(val_loader)) 187 | fast_cifar_val = FastDataset(data[0], data[1]) 188 | torch.save(fast_cifar_val, save_path) 189 | 190 | return ( 191 | fast_cifar_train, 192 | fast_cifar_val, 193 | 32, 194 | 3, 195 | train_transform_clf, 196 | train_transform_diff, 197 | 100, 198 | ) 199 | 200 | 201 | def ImageNet100( 202 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False 203 | ): 204 | train_transform_clf = None 205 | train_transform_diff = None 206 | # augmentation for diffusion training 207 | if train_aug: 208 | train_transform_diff = K.augmentation.ImageSequential( 209 | K.augmentation.RandomHorizontalFlip(), 210 | ) 211 | 212 | print("Loading data") 213 | save_path = f"{dataroot}/fast_imagenet100_train" 214 | if os.path.exists(save_path): 215 | fast_imagenet_train = torch.load(save_path) 216 | else: 217 | target_transform = transforms.Lambda(lambda y: torch.eye(100)[y]) 218 | 219 | train_dataset = torchvision.datasets.ImageFolder( 220 | root=os.path.join(dataroot, "imagenet100", "train"), 221 | transform=transforms.Compose( 222 | [ 223 | transforms.Resize((76, 76)), 224 | transforms.CenterCrop((64, 64)), 225 | transforms.ToTensor(), 226 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 227 | ] 228 | ), 229 | target_transform=target_transform, 230 | ) 231 | train_dataset = CacheClassLabel( 232 | train_dataset, 233 | target_transform=target_transform, 234 | ) 235 | 236 | val_dataset = torchvision.datasets.ImageFolder( 237 | root=os.path.join(dataroot, "imagenet100", "val"), 238 | transform=transforms.Compose( 239 | [ 240 | transforms.Resize((76, 76)), 241 | transforms.CenterCrop((64, 64)), 242 | transforms.ToTensor(), 243 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 244 | ] 245 | ), 246 | target_transform=target_transform, 247 | ) 248 | 249 | val_dataset = CacheClassLabel( 250 | val_dataset, 251 | target_transform=target_transform, 252 | ) 253 | 254 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset)) 255 | data = next(iter(train_loader)) 256 | fast_imagenet_train = FastDataset(data[0], data[1]) 257 | torch.save(fast_imagenet_train, save_path) 258 | 259 | save_path = f"{dataroot}/fast_imagenet100_val" 260 | if os.path.exists(save_path): 261 | fast_imagenet_val = torch.load(save_path) 262 | else: 263 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset)) 264 | data = next(iter(val_loader)) 265 | fast_imagenet_val = FastDataset(data[0], data[1]) 266 | torch.save(fast_imagenet_val, save_path) 267 | 268 | return ( 269 | fast_imagenet_train, 270 | fast_imagenet_val, 271 | 64, 272 | 3, 273 | train_transform_clf, 274 | train_transform_diff, 275 | 100, 276 | ) 277 | 278 | 279 | def ImageNet100128( 280 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False 281 | ): 282 | train_transform_clf = None 283 | train_transform_diff = None 284 | # augmentation for diffusion training 285 | if train_aug: 286 | train_transform_diff = K.augmentation.ImageSequential( 287 | K.augmentation.RandomHorizontalFlip(), 288 | ) 289 | 290 | print("Loading data") 291 | save_path = f"{dataroot}/fast_imagenet100128_train" 292 | if os.path.exists(save_path): 293 | fast_imagenet_train = torch.load(save_path) 294 | else: 295 | target_transform = transforms.Lambda(lambda y: torch.eye(100)[y]) 296 | 297 | train_dataset = torchvision.datasets.ImageFolder( 298 | root=os.path.join(dataroot, "imagenet100", "train"), 299 | transform=transforms.Compose( 300 | [ 301 | transforms.Resize((152, 152)), 302 | transforms.CenterCrop((128, 128)), 303 | transforms.ToTensor(), 304 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 305 | ] 306 | ), 307 | target_transform=target_transform, 308 | ) 309 | train_dataset = CacheClassLabel( 310 | train_dataset, 311 | target_transform=target_transform, 312 | ) 313 | 314 | val_dataset = torchvision.datasets.ImageFolder( 315 | root=os.path.join(dataroot, "imagenet100", "val"), 316 | transform=transforms.Compose( 317 | [ 318 | transforms.Resize((152, 152)), 319 | transforms.CenterCrop((128, 128)), 320 | transforms.ToTensor(), 321 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 322 | ] 323 | ), 324 | target_transform=target_transform, 325 | ) 326 | 327 | val_dataset = CacheClassLabel( 328 | val_dataset, 329 | target_transform=target_transform, 330 | ) 331 | 332 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset)) 333 | data = next(iter(train_loader)) 334 | fast_imagenet_train = FastDataset(data[0], data[1]) 335 | torch.save(fast_imagenet_train, save_path) 336 | 337 | save_path = f"{dataroot}/fast_imagenet100128_val" 338 | if os.path.exists(save_path): 339 | fast_imagenet_val = torch.load(save_path) 340 | else: 341 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset)) 342 | data = next(iter(val_loader)) 343 | fast_imagenet_val = FastDataset(data[0], data[1]) 344 | torch.save(fast_imagenet_val, save_path) 345 | 346 | return ( 347 | fast_imagenet_train, 348 | fast_imagenet_val, 349 | 128, 350 | 3, 351 | train_transform_clf, 352 | train_transform_diff, 353 | 100, 354 | ) 355 | 356 | def ImageNet100224( 357 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False 358 | ): 359 | train_transform_clf = None 360 | train_transform_diff = None 361 | # augmentation for diffusion training 362 | if train_aug: 363 | train_transform_diff = K.augmentation.ImageSequential( 364 | K.augmentation.RandomHorizontalFlip(), 365 | ) 366 | 367 | print("Loading data") 368 | save_path = f"{dataroot}/fast_imagenet100224_train" 369 | if os.path.exists(save_path): 370 | fast_imagenet_train = torch.load(save_path) 371 | else: 372 | target_transform = transforms.Lambda(lambda y: torch.eye(100)[y]) 373 | 374 | train_dataset = torchvision.datasets.ImageFolder( 375 | root=os.path.join(dataroot, "imagenet100", "train"), 376 | transform=transforms.Compose( 377 | [ 378 | transforms.Resize((266, 266)), 379 | transforms.CenterCrop((224, 224)), 380 | transforms.ToTensor(), 381 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 382 | ] 383 | ), 384 | target_transform=target_transform, 385 | ) 386 | train_dataset = CacheClassLabel( 387 | train_dataset, 388 | target_transform=target_transform, 389 | ) 390 | 391 | val_dataset = torchvision.datasets.ImageFolder( 392 | root=os.path.join(dataroot, "imagenet100", "val"), 393 | transform=transforms.Compose( 394 | [ 395 | transforms.Resize((266, 266)), 396 | transforms.CenterCrop((224, 224)), 397 | transforms.ToTensor(), 398 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 399 | ] 400 | ), 401 | target_transform=target_transform, 402 | ) 403 | 404 | val_dataset = CacheClassLabel( 405 | val_dataset, 406 | target_transform=target_transform, 407 | ) 408 | 409 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset)) 410 | data = next(iter(train_loader)) 411 | fast_imagenet_train = FastDataset(data[0], data[1]) 412 | torch.save(fast_imagenet_train, save_path) 413 | 414 | save_path = f"{dataroot}/fast_imagenet100224_val" 415 | if os.path.exists(save_path): 416 | fast_imagenet_val = torch.load(save_path) 417 | else: 418 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset)) 419 | data = next(iter(val_loader)) 420 | fast_imagenet_val = FastDataset(data[0], data[1]) 421 | torch.save(fast_imagenet_val, save_path) 422 | 423 | return ( 424 | fast_imagenet_train, 425 | fast_imagenet_val, 426 | 224, 427 | 3, 428 | train_transform_clf, 429 | train_transform_diff, 430 | 100, 431 | ) 432 | 433 | def Flowers102( 434 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False 435 | ): 436 | train_transform_clf = None 437 | train_transform_diff = None 438 | # augmentation for diffusion training 439 | if train_aug: 440 | train_transform_diff = K.augmentation.ImageSequential( 441 | K.augmentation.RandomHorizontalFlip(), 442 | ) 443 | 444 | # augmentation for classifier training 445 | if classifier_augmentation: 446 | train_transform_clf = K.augmentation.ImageSequential( 447 | K.augmentation.Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 448 | K.augmentation.RandomRotation(30), 449 | K.augmentation.RandomHorizontalFlip(), 450 | K.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1), 451 | K.augmentation.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 452 | ) 453 | 454 | target_transform = transforms.Lambda(lambda y: torch.eye(102)[y]) 455 | 456 | train_dataset1 = torchvision.datasets.Flowers102( 457 | root=dataroot, 458 | split="train", 459 | download=True, 460 | transform=transforms.Compose( 461 | [ 462 | transforms.Resize((224, 224)), 463 | transforms.ToTensor(), 464 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 465 | ] 466 | ), 467 | target_transform=target_transform, 468 | ) 469 | train_dataset2 = torchvision.datasets.Flowers102( 470 | root=dataroot, 471 | split="val", 472 | download=True, 473 | transform=transforms.Compose( 474 | [ 475 | transforms.Resize((224, 224)), 476 | transforms.ToTensor(), 477 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 478 | ] 479 | ), 480 | target_transform=target_transform, 481 | ) 482 | train_dataset = torch.utils.data.ConcatDataset([train_dataset1, train_dataset2]) 483 | train_dataset.root = dataroot 484 | train_dataset = CacheClassLabel( 485 | train_dataset, 486 | target_transform=target_transform, 487 | ) 488 | val_dataset = torchvision.datasets.Flowers102( 489 | root=dataroot, 490 | split="test", 491 | download=True, 492 | transform=transforms.Compose( 493 | [ 494 | transforms.Resize((224, 224)), 495 | transforms.ToTensor(), 496 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 497 | ] 498 | ), 499 | target_transform=target_transform, 500 | ) 501 | val_dataset = CacheClassLabel( 502 | val_dataset, 503 | target_transform=target_transform, 504 | ) 505 | 506 | print("Loading data") 507 | save_path = f"{dataroot}/fast_flowers_train" 508 | if os.path.exists(save_path): 509 | fast_flowers_train = torch.load(save_path) 510 | else: 511 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset)) 512 | data = next(iter(train_loader)) 513 | fast_flowers_train = FastDataset(data[0], data[1]) 514 | torch.save(fast_flowers_train, save_path) 515 | 516 | save_path = f"{dataroot}/fast_flowers_val" 517 | if os.path.exists(save_path): 518 | fast_flowers_val = torch.load(save_path) 519 | else: 520 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset)) 521 | data = next(iter(val_loader)) 522 | fast_flowers_val = FastDataset(data[0], data[1]) 523 | torch.save(fast_flowers_val, save_path) 524 | 525 | return ( 526 | fast_flowers_train, 527 | fast_flowers_val, 528 | 224, 529 | 3, 530 | train_transform_clf, 531 | train_transform_diff, 532 | 102, 533 | ) 534 | -------------------------------------------------------------------------------- /guide/evaluator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | import os 4 | import random 5 | import warnings 6 | import zipfile 7 | from abc import ABC, abstractmethod 8 | from contextlib import contextmanager 9 | from functools import partial 10 | from multiprocessing import cpu_count 11 | from multiprocessing.pool import ThreadPool 12 | from typing import Iterable, Optional, Tuple 13 | 14 | import numpy as np 15 | import requests 16 | import tensorflow.compat.v1 as tf 17 | from scipy import linalg 18 | from tqdm.auto import tqdm 19 | 20 | INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" 21 | INCEPTION_V3_PATH = "classify_image_graph_def.pb" 22 | 23 | FID_POOL_NAME = "pool_3:0" 24 | FID_SPATIAL_NAME = "mixed_6/conv:0" 25 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument( 31 | "--ref_batch", help="path to reference batch npz acts file", required=False 32 | ) 33 | parser.add_argument( 34 | "--ref_acts", help="path to reference batch npz acts file", required=False 35 | ) 36 | parser.add_argument("--sample_batch", help="path to sample batch npz file") 37 | args = parser.parse_args() 38 | 39 | config = tf.ConfigProto( 40 | allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph 41 | ) 42 | config.gpu_options.allow_growth = True 43 | evaluator = Evaluator(tf.Session(config=config)) 44 | 45 | print("warming up TensorFlow...") 46 | # This will cause TF to print a bunch of verbose stuff now rather 47 | # than after the next print(), to help prevent confusion. 48 | evaluator.warmup() 49 | 50 | print("computing reference batch activations...") 51 | if args.ref_acts is None: 52 | ref_acts = evaluator.read_activations(args.ref_batch) 53 | # ref_acts = evaluator.read_activations(args.ref_batch) 54 | else: 55 | ref_acts = ( 56 | np.load(args.ref_acts)["arr_0"], 57 | np.load(args.ref_acts)["arr_1"], 58 | ) 59 | 60 | print("computing/reading reference batch statistics...") 61 | ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts) 62 | 63 | print("computing sample batch activations...") 64 | sample_acts = evaluator.read_activations(args.sample_batch) 65 | print("computing/reading sample batch statistics...") 66 | sample_stats, sample_stats_spatial = evaluator.read_statistics( 67 | args.sample_batch, sample_acts 68 | ) 69 | 70 | print("Computing evaluations...") 71 | print("Inception Score:", evaluator.compute_inception_score(sample_acts[0])) 72 | print("FID:", sample_stats.frechet_distance(ref_stats)) 73 | print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial)) 74 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) 75 | print("Precision:", prec) 76 | print("Recall:", recall) 77 | 78 | 79 | class InvalidFIDException(Exception): 80 | pass 81 | 82 | 83 | class FIDStatistics: 84 | def __init__(self, mu: np.ndarray, sigma: np.ndarray): 85 | self.mu = mu 86 | self.sigma = sigma 87 | 88 | def frechet_distance(self, other, eps=1e-6): 89 | """ 90 | Compute the Frechet distance between two sets of statistics. 91 | """ 92 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 93 | mu1, sigma1 = self.mu, self.sigma 94 | mu2, sigma2 = other.mu, other.sigma 95 | 96 | mu1 = np.atleast_1d(mu1) 97 | mu2 = np.atleast_1d(mu2) 98 | 99 | sigma1 = np.atleast_2d(sigma1) 100 | sigma2 = np.atleast_2d(sigma2) 101 | 102 | assert ( 103 | mu1.shape == mu2.shape 104 | ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" 105 | assert ( 106 | sigma1.shape == sigma2.shape 107 | ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" 108 | 109 | diff = mu1 - mu2 110 | 111 | # product might be almost singular 112 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 113 | if not np.isfinite(covmean).all(): 114 | msg = ( 115 | "fid calculation produces singular product; adding %s to diagonal of cov estimates" 116 | % eps 117 | ) 118 | warnings.warn(msg) 119 | offset = np.eye(sigma1.shape[0]) * eps 120 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 121 | 122 | # numerical error might give slight imaginary component 123 | if np.iscomplexobj(covmean): 124 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 125 | m = np.max(np.abs(covmean.imag)) 126 | raise ValueError("Imaginary component {}".format(m)) 127 | covmean = covmean.real 128 | 129 | tr_covmean = np.trace(covmean) 130 | 131 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 132 | 133 | 134 | class Evaluator: 135 | def __init__( 136 | self, 137 | session, 138 | batch_size=64, 139 | softmax_batch_size=512, 140 | ): 141 | self.sess = session 142 | self.batch_size = batch_size 143 | self.softmax_batch_size = softmax_batch_size 144 | self.manifold_estimator = ManifoldEstimator(session) 145 | with self.sess.graph.as_default(): 146 | self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3]) 147 | self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048]) 148 | self.pool_features, self.spatial_features = _create_feature_graph( 149 | self.image_input 150 | ) 151 | self.softmax = _create_softmax_graph(self.softmax_input) 152 | 153 | def warmup(self): 154 | self.compute_activations(np.zeros([1, 8, 64, 64, 3])) 155 | 156 | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]: 157 | with open_npz_array(npz_path, "arr_0") as reader: 158 | return self.compute_activations(reader.read_batches(self.batch_size)) 159 | 160 | def compute_activations( 161 | self, batches: Iterable[np.ndarray] 162 | ) -> Tuple[np.ndarray, np.ndarray]: 163 | """ 164 | Compute image features for downstream evals. 165 | 166 | :param batches: a iterator over NHWC numpy arrays in [0, 255]. 167 | :return: a tuple of numpy arrays of shape [N x X], where X is a feature 168 | dimension. The tuple is (pool_3, spatial). 169 | """ 170 | preds = [] 171 | spatial_preds = [] 172 | for batch in tqdm(batches): 173 | batch = batch.astype(np.float32) 174 | pred, spatial_pred = self.sess.run( 175 | [self.pool_features, self.spatial_features], {self.image_input: batch} 176 | ) 177 | preds.append(pred.reshape([pred.shape[0], -1])) 178 | spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) 179 | return ( 180 | np.concatenate(preds, axis=0), 181 | np.concatenate(spatial_preds, axis=0), 182 | ) 183 | 184 | def read_statistics( 185 | self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray] 186 | ) -> Tuple[FIDStatistics, FIDStatistics]: 187 | obj = np.load(npz_path) 188 | if "mu" in list(obj.keys()): 189 | return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics( 190 | obj["mu_s"], obj["sigma_s"] 191 | ) 192 | return tuple(self.compute_statistics(x) for x in activations) 193 | 194 | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics: 195 | mu = np.mean(activations, axis=0) 196 | sigma = np.cov(activations, rowvar=False) 197 | return FIDStatistics(mu, sigma) 198 | 199 | def compute_inception_score( 200 | self, activations: np.ndarray, split_size: int = 5000 201 | ) -> float: 202 | softmax_out = [] 203 | for i in range(0, len(activations), self.softmax_batch_size): 204 | acts = activations[i : i + self.softmax_batch_size] 205 | softmax_out.append( 206 | self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}) 207 | ) 208 | preds = np.concatenate(softmax_out, axis=0) 209 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 210 | scores = [] 211 | for i in range(0, len(preds), split_size): 212 | part = preds[i : i + split_size] 213 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 214 | kl = np.mean(np.sum(kl, 1)) 215 | scores.append(np.exp(kl)) 216 | return float(np.mean(scores)) 217 | 218 | def compute_prec_recall( 219 | self, activations_ref: np.ndarray, activations_sample: np.ndarray 220 | ) -> Tuple[float, float]: 221 | radii_1 = self.manifold_estimator.manifold_radii(activations_ref) 222 | radii_2 = self.manifold_estimator.manifold_radii(activations_sample) 223 | pr = self.manifold_estimator.evaluate_pr( 224 | activations_ref, radii_1, activations_sample, radii_2 225 | ) 226 | return (float(pr[0][0]), float(pr[1][0])) 227 | 228 | 229 | class ManifoldEstimator: 230 | """ 231 | A helper for comparing manifolds of feature vectors. 232 | 233 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57 234 | """ 235 | 236 | def __init__( 237 | self, 238 | session, 239 | row_batch_size=10000, 240 | col_batch_size=10000, 241 | nhood_sizes=(3,), 242 | clamp_to_percentile=None, 243 | eps=1e-5, 244 | ): 245 | """ 246 | Estimate the manifold of given feature vectors. 247 | 248 | :param session: the TensorFlow session. 249 | :param row_batch_size: row batch size to compute pairwise distances 250 | (parameter to trade-off between memory usage and performance). 251 | :param col_batch_size: column batch size to compute pairwise distances. 252 | :param nhood_sizes: number of neighbors used to estimate the manifold. 253 | :param clamp_to_percentile: prune hyperspheres that have radius larger than 254 | the given percentile. 255 | :param eps: small number for numerical stability. 256 | """ 257 | self.distance_block = DistanceBlock(session) 258 | self.row_batch_size = row_batch_size 259 | self.col_batch_size = col_batch_size 260 | self.nhood_sizes = nhood_sizes 261 | self.num_nhoods = len(nhood_sizes) 262 | self.clamp_to_percentile = clamp_to_percentile 263 | self.eps = eps 264 | 265 | def warmup(self): 266 | feats, radii = ( 267 | np.zeros([1, 2048], dtype=np.float32), 268 | np.zeros([1, 1], dtype=np.float32), 269 | ) 270 | self.evaluate_pr(feats, radii, feats, radii) 271 | 272 | def manifold_radii(self, features: np.ndarray) -> np.ndarray: 273 | num_images = len(features) 274 | 275 | # Estimate manifold of features by calculating distances to k-NN of each sample. 276 | radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32) 277 | distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32) 278 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) 279 | 280 | for begin1 in range(0, num_images, self.row_batch_size): 281 | end1 = min(begin1 + self.row_batch_size, num_images) 282 | row_batch = features[begin1:end1] 283 | 284 | for begin2 in range(0, num_images, self.col_batch_size): 285 | end2 = min(begin2 + self.col_batch_size, num_images) 286 | col_batch = features[begin2:end2] 287 | 288 | # Compute distances between batches. 289 | distance_batch[0 : end1 - begin1, begin2:end2] = ( 290 | self.distance_block.pairwise_distances(row_batch, col_batch) 291 | ) 292 | 293 | # Find the k-nearest neighbor from the current batch. 294 | radii[begin1:end1, :] = np.concatenate( 295 | [ 296 | x[:, self.nhood_sizes] 297 | for x in _numpy_partition( 298 | distance_batch[0 : end1 - begin1, :], seq, axis=1 299 | ) 300 | ], 301 | axis=0, 302 | ) 303 | 304 | if self.clamp_to_percentile is not None: 305 | max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0) 306 | radii[radii > max_distances] = 0 307 | return radii 308 | 309 | def evaluate( 310 | self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray 311 | ): 312 | """ 313 | Evaluate if new feature vectors are at the manifold. 314 | """ 315 | num_eval_images = eval_features.shape[0] 316 | num_ref_images = radii.shape[0] 317 | distance_batch = np.zeros( 318 | [self.row_batch_size, num_ref_images], dtype=np.float32 319 | ) 320 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) 321 | max_realism_score = np.zeros([num_eval_images], dtype=np.float32) 322 | nearest_indices = np.zeros([num_eval_images], dtype=np.int32) 323 | 324 | for begin1 in range(0, num_eval_images, self.row_batch_size): 325 | end1 = min(begin1 + self.row_batch_size, num_eval_images) 326 | feature_batch = eval_features[begin1:end1] 327 | 328 | for begin2 in range(0, num_ref_images, self.col_batch_size): 329 | end2 = min(begin2 + self.col_batch_size, num_ref_images) 330 | ref_batch = features[begin2:end2] 331 | 332 | distance_batch[0 : end1 - begin1, begin2:end2] = ( 333 | self.distance_block.pairwise_distances(feature_batch, ref_batch) 334 | ) 335 | 336 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold. 337 | # If a feature vector is inside a hypersphere of some reference sample, then 338 | # the new sample lies at the estimated manifold. 339 | # The radii of the hyperspheres are determined from distances of neighborhood size k. 340 | samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii 341 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype( 342 | np.int32 343 | ) 344 | 345 | max_realism_score[begin1:end1] = np.max( 346 | radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1 347 | ) 348 | nearest_indices[begin1:end1] = np.argmin( 349 | distance_batch[0 : end1 - begin1, :], axis=1 350 | ) 351 | 352 | return { 353 | "fraction": float(np.mean(batch_predictions)), 354 | "batch_predictions": batch_predictions, 355 | "max_realisim_score": max_realism_score, 356 | "nearest_indices": nearest_indices, 357 | } 358 | 359 | def evaluate_pr( 360 | self, 361 | features_1: np.ndarray, 362 | radii_1: np.ndarray, 363 | features_2: np.ndarray, 364 | radii_2: np.ndarray, 365 | ) -> Tuple[np.ndarray, np.ndarray]: 366 | """ 367 | Evaluate precision and recall efficiently. 368 | 369 | :param features_1: [N1 x D] feature vectors for reference batch. 370 | :param radii_1: [N1 x K1] radii for reference vectors. 371 | :param features_2: [N2 x D] feature vectors for the other batch. 372 | :param radii_2: [N x K2] radii for other vectors. 373 | :return: a tuple of arrays for (precision, recall): 374 | - precision: an np.ndarray of length K1 375 | - recall: an np.ndarray of length K2 376 | """ 377 | features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool) 378 | features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool) 379 | for begin_1 in range(0, len(features_1), self.row_batch_size): 380 | end_1 = begin_1 + self.row_batch_size 381 | batch_1 = features_1[begin_1:end_1] 382 | for begin_2 in range(0, len(features_2), self.col_batch_size): 383 | end_2 = begin_2 + self.col_batch_size 384 | batch_2 = features_2[begin_2:end_2] 385 | batch_1_in, batch_2_in = self.distance_block.less_thans( 386 | batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2] 387 | ) 388 | features_1_status[begin_1:end_1] |= batch_1_in 389 | features_2_status[begin_2:end_2] |= batch_2_in 390 | return ( 391 | np.mean(features_2_status.astype(np.float64), axis=0), 392 | np.mean(features_1_status.astype(np.float64), axis=0), 393 | ) 394 | 395 | 396 | class DistanceBlock: 397 | """ 398 | Calculate pairwise distances between vectors. 399 | 400 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34 401 | """ 402 | 403 | def __init__(self, session): 404 | self.session = session 405 | 406 | # Initialize TF graph to calculate pairwise distances. 407 | with session.graph.as_default(): 408 | self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None]) 409 | self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None]) 410 | distance_block_16 = _batch_pairwise_distances( 411 | tf.cast(self._features_batch1, tf.float16), 412 | tf.cast(self._features_batch2, tf.float16), 413 | ) 414 | self.distance_block = tf.cond( 415 | tf.reduce_all(tf.math.is_finite(distance_block_16)), 416 | lambda: tf.cast(distance_block_16, tf.float32), 417 | lambda: _batch_pairwise_distances( 418 | self._features_batch1, self._features_batch2 419 | ), 420 | ) 421 | 422 | # Extra logic for less thans. 423 | self._radii1 = tf.placeholder(tf.float32, shape=[None, None]) 424 | self._radii2 = tf.placeholder(tf.float32, shape=[None, None]) 425 | dist32 = tf.cast(self.distance_block, tf.float32)[..., None] 426 | self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1) 427 | self._batch_2_in = tf.math.reduce_any( 428 | dist32 <= self._radii1[:, None], axis=0 429 | ) 430 | 431 | def pairwise_distances(self, U, V): 432 | """ 433 | Evaluate pairwise distances between two batches of feature vectors. 434 | """ 435 | return self.session.run( 436 | self.distance_block, 437 | feed_dict={self._features_batch1: U, self._features_batch2: V}, 438 | ) 439 | 440 | def less_thans(self, batch_1, radii_1, batch_2, radii_2): 441 | return self.session.run( 442 | [self._batch_1_in, self._batch_2_in], 443 | feed_dict={ 444 | self._features_batch1: batch_1, 445 | self._features_batch2: batch_2, 446 | self._radii1: radii_1, 447 | self._radii2: radii_2, 448 | }, 449 | ) 450 | 451 | 452 | def _batch_pairwise_distances(U, V): 453 | """ 454 | Compute pairwise distances between two batches of feature vectors. 455 | """ 456 | with tf.variable_scope("pairwise_dist_block"): 457 | # Squared norms of each row in U and V. 458 | norm_u = tf.reduce_sum(tf.square(U), 1) 459 | norm_v = tf.reduce_sum(tf.square(V), 1) 460 | 461 | # norm_u as a column and norm_v as a row vectors. 462 | norm_u = tf.reshape(norm_u, [-1, 1]) 463 | norm_v = tf.reshape(norm_v, [1, -1]) 464 | 465 | # Pairwise squared Euclidean distances. 466 | D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0) 467 | 468 | return D 469 | 470 | 471 | class NpzArrayReader(ABC): 472 | @abstractmethod 473 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 474 | pass 475 | 476 | @abstractmethod 477 | def remaining(self) -> int: 478 | pass 479 | 480 | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: 481 | def gen_fn(): 482 | while True: 483 | batch = self.read_batch(batch_size) 484 | if batch is None: 485 | break 486 | yield batch 487 | 488 | rem = self.remaining() 489 | num_batches = rem // batch_size + int(rem % batch_size != 0) 490 | return BatchIterator(gen_fn, num_batches) 491 | 492 | 493 | class BatchIterator: 494 | def __init__(self, gen_fn, length): 495 | self.gen_fn = gen_fn 496 | self.length = length 497 | 498 | def __len__(self): 499 | return self.length 500 | 501 | def __iter__(self): 502 | return self.gen_fn() 503 | 504 | 505 | class StreamingNpzArrayReader(NpzArrayReader): 506 | def __init__(self, arr_f, shape, dtype): 507 | self.arr_f = arr_f 508 | self.shape = shape 509 | self.dtype = dtype 510 | self.idx = 0 511 | 512 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 513 | if self.idx >= self.shape[0]: 514 | return None 515 | 516 | bs = min(batch_size, self.shape[0] - self.idx) 517 | self.idx += bs 518 | 519 | if self.dtype.itemsize == 0: 520 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) 521 | 522 | read_count = bs * np.prod(self.shape[1:]) 523 | read_size = int(read_count * self.dtype.itemsize) 524 | data = _read_bytes(self.arr_f, read_size, "array data") 525 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) 526 | 527 | def remaining(self) -> int: 528 | return max(0, self.shape[0] - self.idx) 529 | 530 | 531 | class MemoryNpzArrayReader(NpzArrayReader): 532 | def __init__(self, arr): 533 | self.arr = arr 534 | self.idx = 0 535 | 536 | @classmethod 537 | def load(cls, path: str, arr_name: str): 538 | with open(path, "rb") as f: 539 | arr = np.load(f)[arr_name] 540 | return cls(arr) 541 | 542 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 543 | if self.idx >= self.arr.shape[0]: 544 | return None 545 | 546 | res = self.arr[self.idx : self.idx + batch_size] 547 | self.idx += batch_size 548 | return res 549 | 550 | def remaining(self) -> int: 551 | return max(0, self.arr.shape[0] - self.idx) 552 | 553 | 554 | @contextmanager 555 | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader: 556 | with _open_npy_file(path, arr_name) as arr_f: 557 | version = np.lib.format.read_magic(arr_f) 558 | if version == (1, 0): 559 | header = np.lib.format.read_array_header_1_0(arr_f) 560 | elif version == (2, 0): 561 | header = np.lib.format.read_array_header_2_0(arr_f) 562 | else: 563 | yield MemoryNpzArrayReader.load(path, arr_name) 564 | return 565 | shape, fortran, dtype = header 566 | if fortran or dtype.hasobject: 567 | yield MemoryNpzArrayReader.load(path, arr_name) 568 | else: 569 | yield StreamingNpzArrayReader(arr_f, shape, dtype) 570 | 571 | 572 | def _read_bytes(fp, size, error_template="ran out of data"): 573 | """ 574 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 575 | 576 | Read from file-like object until size bytes are read. 577 | Raises ValueError if not EOF is encountered before size bytes are read. 578 | Non-blocking objects only supported if they derive from io objects. 579 | Required as e.g. ZipExtFile in python 2.6 can return less data than 580 | requested. 581 | """ 582 | data = bytes() 583 | while True: 584 | # io files (default in python3) return None or raise on 585 | # would-block, python2 file will truncate, probably nothing can be 586 | # done about that. note that regular files can't be non-blocking 587 | try: 588 | r = fp.read(size - len(data)) 589 | data += r 590 | if len(r) == 0 or len(data) == size: 591 | break 592 | except io.BlockingIOError: 593 | pass 594 | if len(data) != size: 595 | msg = "EOF: reading %s, expected %d bytes got %d" 596 | raise ValueError(msg % (error_template, size, len(data))) 597 | else: 598 | return data 599 | 600 | 601 | @contextmanager 602 | def _open_npy_file(path: str, arr_name: str): 603 | with open(path, "rb") as f: 604 | with zipfile.ZipFile(f, "r") as zip_f: 605 | if f"{arr_name}.npy" not in zip_f.namelist(): 606 | raise ValueError(f"missing {arr_name} in npz file") 607 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f: 608 | yield arr_f 609 | 610 | 611 | def _download_inception_model(): 612 | if os.path.exists(INCEPTION_V3_PATH): 613 | return 614 | print("downloading InceptionV3 model...") 615 | with requests.get(INCEPTION_V3_URL, stream=True) as r: 616 | r.raise_for_status() 617 | tmp_path = INCEPTION_V3_PATH + ".tmp" 618 | with open(tmp_path, "wb") as f: 619 | for chunk in tqdm(r.iter_content(chunk_size=8192)): 620 | f.write(chunk) 621 | os.rename(tmp_path, INCEPTION_V3_PATH) 622 | 623 | 624 | def _create_feature_graph(input_batch): 625 | _download_inception_model() 626 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" 627 | with open(INCEPTION_V3_PATH, "rb") as f: 628 | graph_def = tf.GraphDef() 629 | graph_def.ParseFromString(f.read()) 630 | pool3, spatial = tf.import_graph_def( 631 | graph_def, 632 | input_map={f"ExpandDims:0": input_batch}, 633 | return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], 634 | name=prefix, 635 | ) 636 | _update_shapes(pool3) 637 | spatial = spatial[..., :7] 638 | return pool3, spatial 639 | 640 | 641 | def _create_softmax_graph(input_batch): 642 | _download_inception_model() 643 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" 644 | with open(INCEPTION_V3_PATH, "rb") as f: 645 | graph_def = tf.GraphDef() 646 | graph_def.ParseFromString(f.read()) 647 | (matmul,) = tf.import_graph_def( 648 | graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix 649 | ) 650 | w = matmul.inputs[1] 651 | logits = tf.matmul(input_batch, w) 652 | return tf.nn.softmax(logits) 653 | 654 | 655 | def _update_shapes(pool3): 656 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63 657 | ops = pool3.graph.get_operations() 658 | for op in ops: 659 | for o in op.outputs: 660 | shape = o.get_shape() 661 | if shape._dims is not None: # pylint: disable=protected-access 662 | # shape = [s.value for s in shape] TF 1.x 663 | shape = [s for s in shape] # TF 2.x 664 | new_shape = [] 665 | for j, s in enumerate(shape): 666 | if s == 1 and j == 0: 667 | new_shape.append(None) 668 | else: 669 | new_shape.append(s) 670 | o.__dict__["_shape_val"] = tf.TensorShape(new_shape) 671 | return pool3 672 | 673 | 674 | def _numpy_partition(arr, kth, **kwargs): 675 | num_workers = min(cpu_count(), len(arr)) 676 | chunk_size = len(arr) // num_workers 677 | extra = len(arr) % num_workers 678 | 679 | start_idx = 0 680 | batches = [] 681 | for i in range(num_workers): 682 | size = chunk_size + (1 if i < extra else 0) 683 | batches.append(arr[start_idx : start_idx + size]) 684 | start_idx += size 685 | 686 | with ThreadPool(num_workers) as pool: 687 | return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches)) 688 | 689 | 690 | if __name__ == "__main__": 691 | main() 692 | --------------------------------------------------------------------------------