├── dataloaders
├── __init__.py
├── utils.py
├── wrapper.py
├── datasetGen.py
└── base.py
├── imgs
└── teaser_imagenet.png
├── guide
├── __init__.py
├── fake_mpi.py
├── validation.py
├── losses.py
├── layers.py
├── dist_util.py
├── respace.py
├── nn.py
├── resample.py
├── script_args.py
├── resnet.py
├── fp16_util.py
├── script_util.py
├── logger.py
└── evaluator.py
├── .gitignore
├── setup.py
├── scripts
├── mrunner_train.py
├── image_sample.py
├── classifier_sample_universal.py
└── image_train.py
├── cl_methods
├── utils.py
├── base.py
├── generative_replay_disjoint_classifier_guidance.py
└── generative_replay.py
└── README.md
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imgs/teaser_imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cywinski/guide/HEAD/imgs/teaser_imagenet.png
--------------------------------------------------------------------------------
/guide/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Codebase for "Improved Denoising Diffusion Probabilistic Models".
3 | """
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | guide.egg-info/
2 | .DS_Store
3 | __pycache__/
4 | classify_image_graph_def.pb
5 | .idea
6 | data
7 | results
8 | wandb
9 | .vscode
10 | venv
11 | *.out
12 | slurm_out
13 | *.pt
14 | *.csv
15 | .ipynb_checkpoints
16 | *.ipynb
17 | *.gif
18 | *.npz
19 | results*/
20 | *.txt
21 | *.pdf
22 |
--------------------------------------------------------------------------------
/guide/fake_mpi.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace
2 |
3 |
4 | class CommWorld:
5 | def __init__(self):
6 | self.rank = 0
7 | self.size = 1
8 |
9 | def Get_rank(self):
10 | return 0
11 |
12 | def Get_size(self):
13 | return 1
14 |
15 | def bcast(self, value, root=0):
16 | return value
17 |
18 |
19 | MPI = Namespace(COMM_WORLD=CommWorld())
20 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="guide",
5 | py_modules=["guide"],
6 | install_requires=[
7 | "blobfile>=1.0.5",
8 | "tqdm==4.66.2",
9 | "scikit-learn==1.3.0",
10 | "torchmetrics==1.3.2",
11 | "kornia==0.7.1",
12 | "wandb==0.16.5",
13 | "matplotlib==3.7.5",
14 | "pytz==2024.1",
15 | ],
16 | )
17 |
--------------------------------------------------------------------------------
/scripts/mrunner_train.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace
2 |
3 | from mrunner.helpers.client_helper import get_configuration
4 |
5 | from scripts.image_train import run_training_with_args
6 |
7 |
8 | def main():
9 | params = get_configuration(print_diagnostics=True, with_neptune=False)
10 | run_training_with_args(Namespace(**params))
11 |
12 |
13 | if __name__ == "__main__":
14 | main()
15 |
--------------------------------------------------------------------------------
/cl_methods/utils.py:
--------------------------------------------------------------------------------
1 | from cl_methods.generative_replay import GenerativeReplay
2 | from cl_methods.generative_replay_disjoint_classifier_guidance import (
3 | GenerativeReplayDisjointClassifierGuidance,
4 | )
5 |
6 |
7 | def get_cl_method(args):
8 | if args.cl_method == "generative_replay":
9 | return GenerativeReplay(args)
10 | if args.cl_method == "generative_replay_disjoint_classifier_guidance":
11 | return GenerativeReplayDisjointClassifierGuidance(args)
12 | assert False, "bad cl method!"
13 |
--------------------------------------------------------------------------------
/cl_methods/base.py:
--------------------------------------------------------------------------------
1 | """CL-method-specific data preparation."""
2 |
3 | from torch.utils.data import DataLoader
4 |
5 | from dataloaders.utils import yielder
6 |
7 |
8 | class CLMethod:
9 | def __init__(self, args):
10 | self.args = args
11 |
12 | def get_data_for_task(self, dataset, task_id, train_loop):
13 | loader = DataLoader(
14 | dataset=dataset,
15 | batch_size=self.args.batch_size,
16 | shuffle=True,
17 | drop_last=True,
18 | )
19 | return yielder(loader), loader, None
20 |
--------------------------------------------------------------------------------
/cl_methods/generative_replay_disjoint_classifier_guidance.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 |
3 | from cl_methods.base import CLMethod
4 | from dataloaders.utils import yielder
5 |
6 |
7 | class GenerativeReplayDisjointClassifierGuidance(CLMethod):
8 | def get_data_for_task(
9 | self,
10 | dataset,
11 | task_id,
12 | train_loop,
13 | generator=None,
14 | step=None,
15 | ):
16 | loader = DataLoader(
17 | dataset=dataset,
18 | batch_size=self.args.batch_size // (task_id + 1),
19 | shuffle=True,
20 | drop_last=True,
21 | generator=generator,
22 | )
23 | return yielder(loader), loader, None
24 |
--------------------------------------------------------------------------------
/guide/validation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | @torch.no_grad()
6 | def calculate_accuracy_with_classifier(
7 | model,
8 | task_id,
9 | val_loader,
10 | device,
11 | train_loader=None,
12 | max_class=0,
13 | train_with_disjoint_classifier=False,
14 | ):
15 | model.eval()
16 | loader = (
17 | {"test": val_loader, "train": train_loader}
18 | if train_loader is not None
19 | else {"test": val_loader}
20 | )
21 | correct = {"test": 0.0, "train": 0.0}
22 | total = {"test": 0.0, "train": 0.0}
23 | loss = {"test": 0.0, "train": 0.0}
24 | print("Calculating accuracy:")
25 | for phase in loader.keys():
26 | for idx, batch in enumerate(loader[phase]):
27 | x, cond = batch
28 | x = x.to(device)
29 | y = cond["y"].to(device)
30 | out_classifier = model(x)
31 | preds = torch.argmax(out_classifier, 1)
32 | correct[phase] += (preds == torch.argmax(y, 1)).sum()
33 | total[phase] += len(y)
34 | loss[phase] += F.cross_entropy(
35 | out_classifier[:, : max_class + 1], y[:, : max_class + 1]
36 | )
37 | loss[phase] /= idx
38 | correct[phase] /= total[phase]
39 | model.train()
40 | return {
41 | "loss": loss,
42 | "accuracy": correct,
43 | }
44 |
--------------------------------------------------------------------------------
/dataloaders/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import ConcatDataset, DataLoader
4 |
5 |
6 | def yielder(loader):
7 | while True:
8 | yield from loader
9 |
10 |
11 | def concat_yielder(loaders):
12 | iterators = [iter(loader) for loader in loaders]
13 | while True:
14 | result = []
15 | for i in range(len(iterators)):
16 | try:
17 | x = next(iterators[i])
18 | except StopIteration:
19 | iterators[i] = iter(loaders[i])
20 | x = next(iterators[i])
21 | result.append(x)
22 |
23 | yield recursive_concat(result)
24 |
25 |
26 | def recursive_concat(l):
27 | """Concat elements in list l. Each element can be a tensor, dicts of tensors, tuple, etc."""
28 | if isinstance(l[0], torch.Tensor):
29 | return torch.cat(l)
30 | if isinstance(l[0], dict):
31 | keys = set(l[0].keys())
32 | for x in l[1:]:
33 | assert set(x.keys()) == keys
34 | return {k: recursive_concat([x[k] for x in l]) for k in keys}
35 | if isinstance(l[0], tuple) or isinstance(l[0], list):
36 | length = len(l[0])
37 | for x in l[1:]:
38 | assert len(x) == length
39 | return tuple([recursive_concat([x[i] for x in l]) for i in range(length)])
40 |
41 |
42 | def get_stratified_subset(frac_selected, labels, seed=0):
43 | """Returns indices of a subset with a given percentage of elements for each class."""
44 | labels = np.array(labels)
45 | rng = np.random.default_rng(seed=seed)
46 | res = []
47 | for l in np.unique(labels):
48 | all_indices = np.nonzero(labels == l)[0]
49 | num_selected = int(frac_selected * len(all_indices))
50 | res.append(rng.choice(all_indices, num_selected, replace=False))
51 | res = np.concatenate(res)
52 | return res
53 |
54 |
55 | def prepare_eval_loaders(
56 | train_dataset_splits, val_dataset_splits, args, include_train, generator=False
57 | ):
58 | eval_loaders = []
59 | for task_id in range(args.num_tasks):
60 | if include_train:
61 | eval_data = ConcatDataset(
62 | [train_dataset_splits[task_id], val_dataset_splits[task_id]]
63 | )
64 | else:
65 | eval_data = val_dataset_splits[task_id]
66 | eval_loader = DataLoader(
67 | dataset=eval_data,
68 | batch_size=args.batch_size,
69 | shuffle=False,
70 | generator=generator,
71 | )
72 | eval_loaders.append(eval_loader)
73 |
74 | return eval_loaders
75 |
--------------------------------------------------------------------------------
/guide/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 | import torch as th
9 |
10 |
11 | def normal_kl(mean1, logvar1, mean2, logvar2):
12 | """
13 | Compute the KL divergence between two gaussians.
14 |
15 | Shapes are automatically broadcasted, so batches can be compared to
16 | scalars, among other use cases.
17 | """
18 | tensor = None
19 | for obj in (mean1, logvar1, mean2, logvar2):
20 | if isinstance(obj, th.Tensor):
21 | tensor = obj
22 | break
23 | assert tensor is not None, "at least one argument must be a Tensor"
24 |
25 | # Force variances to be Tensors. Broadcasting helps convert scalars to
26 | # Tensors, but it does not work for th.exp().
27 | logvar1, logvar2 = [
28 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
29 | for x in (logvar1, logvar2)
30 | ]
31 |
32 | return 0.5 * (
33 | -1.0
34 | + logvar2
35 | - logvar1
36 | + th.exp(logvar1 - logvar2)
37 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
38 | )
39 |
40 |
41 | def approx_standard_normal_cdf(x):
42 | """
43 | A fast approximation of the cumulative distribution function of the
44 | standard normal.
45 | """
46 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
47 |
48 |
49 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
50 | """
51 | Compute the log-likelihood of a Gaussian distribution discretizing to a
52 | given image.
53 |
54 | :param x: the target images. It is assumed that this was uint8 values,
55 | rescaled to the range [-1, 1].
56 | :param means: the Gaussian mean Tensor.
57 | :param log_scales: the Gaussian log stddev Tensor.
58 | :return: a tensor like x of log probabilities (in nats).
59 | """
60 | assert x.shape == means.shape == log_scales.shape
61 | centered_x = x - means
62 | inv_stdv = th.exp(-log_scales)
63 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
64 | cdf_plus = approx_standard_normal_cdf(plus_in)
65 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
66 | cdf_min = approx_standard_normal_cdf(min_in)
67 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
68 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
69 | cdf_delta = cdf_plus - cdf_min
70 | log_probs = th.where(
71 | x < -0.999,
72 | log_cdf_plus,
73 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
74 | )
75 | assert log_probs.shape == x.shape
76 | return log_probs
77 |
--------------------------------------------------------------------------------
/cl_methods/generative_replay.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | from torch.utils.data import ConcatDataset, DataLoader, TensorDataset
3 |
4 | from cl_methods.base import CLMethod
5 | from dataloaders.utils import yielder
6 | from dataloaders.wrapper import AppendName
7 | from guide import dist_util
8 | from guide.logger import get_rank_without_mpi_import, log_generated_examples
9 |
10 |
11 | class GenerativeReplay(CLMethod):
12 | def get_data_for_task(
13 | self,
14 | dataset,
15 | task_id,
16 | train_loop,
17 | generator=None,
18 | step=None,
19 | ):
20 | if task_id == 0:
21 | train_dataset_loader = DataLoader(
22 | dataset=dataset,
23 | batch_size=self.args.batch_size,
24 | shuffle=True,
25 | drop_last=True,
26 | generator=generator,
27 | )
28 | dataset_yielder = yielder(train_dataset_loader)
29 | else:
30 | print("Preparing dataset for rehearsal...")
31 | if self.args.gr_n_generated_examples_per_task <= self.args.batch_size:
32 | batch_size = self.args.gr_n_generated_examples_per_task
33 | else:
34 | batch_size = self.args.batch_size
35 | (
36 | generated_previous_examples,
37 | generated_previous_examples_labels,
38 | generated_previous_examples_confidences,
39 | ) = train_loop.generate_examples(
40 | task_id - 1,
41 | self.args.gr_n_generated_examples_per_task,
42 | batch_size=batch_size,
43 | equal_n_examples_per_class=True,
44 | use_old_grad=False,
45 | use_new_grad=False,
46 | )
47 | generated_dataset = AppendName(
48 | TensorDataset(
49 | generated_previous_examples, generated_previous_examples_labels
50 | ),
51 | generated_previous_examples_labels.cpu().numpy(),
52 | True,
53 | False,
54 | )
55 | joined_dataset = ConcatDataset([dataset, generated_dataset])
56 | train_dataset_loader = DataLoader(
57 | dataset=joined_dataset,
58 | batch_size=self.args.batch_size,
59 | shuffle=True,
60 | drop_last=True,
61 | generator=generator,
62 | )
63 | dataset_yielder = yielder(train_dataset_loader)
64 | if get_rank_without_mpi_import() == 0:
65 | log_generated_examples(
66 | generated_previous_examples,
67 | th.argmax(generated_previous_examples_labels, 1),
68 | generated_previous_examples_confidences,
69 | task_id,
70 | step=step,
71 | )
72 |
73 | return (
74 | dataset_yielder,
75 | train_dataset_loader,
76 | generated_previous_examples if task_id != 0 else None,
77 | )
78 |
--------------------------------------------------------------------------------
/guide/layers.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | from torch import nn
4 |
5 |
6 | class ConvBlock(nn.Module):
7 | def __init__(
8 | self,
9 | opt,
10 | in_channels,
11 | out_channels,
12 | kernel_size,
13 | stride=1,
14 | padding=0,
15 | bias=False,
16 | groups=1,
17 | ):
18 | super(ConvBlock, self).__init__()
19 | self.in_channels = in_channels
20 | self.out_channels = out_channels
21 | self.kernel_size = kernel_size
22 | conv = nn.Conv2d(
23 | in_channels,
24 | out_channels,
25 | kernel_size=kernel_size,
26 | stride=stride,
27 | padding=padding,
28 | bias=bias,
29 | groups=groups,
30 | )
31 |
32 | layer = [conv]
33 | if opt.bn:
34 | if opt.preact:
35 | bn = getattr(nn, opt.normtype + "2d")(
36 | num_features=in_channels, affine=opt.affine_bn, eps=opt.bn_eps
37 | )
38 | layer = [bn]
39 | else:
40 | bn = getattr(nn, opt.normtype + "2d")(
41 | num_features=out_channels, affine=opt.affine_bn, eps=opt.bn_eps
42 | )
43 | layer = [conv, bn]
44 |
45 | if opt.activetype != "None":
46 | active = getattr(nn, opt.activetype)()
47 | layer.append(active)
48 |
49 | if opt.bn and opt.preact:
50 | layer.append(conv)
51 |
52 | self.block = nn.Sequential(*layer)
53 |
54 | def forward(self, input):
55 | return self.block.forward(input)
56 |
57 |
58 | class FCBlock(nn.Module):
59 | def __init__(self, opt, in_channels, out_channels, bias=False):
60 | super(FCBlock, self).__init__()
61 | self.in_channels = in_channels
62 | self.out_channels = out_channels
63 | lin = nn.Linear(in_channels, out_channels, bias=bias)
64 |
65 | layer = [lin]
66 | if opt.bn:
67 | if opt.preact:
68 | bn = getattr(nn, opt.normtype + "1d")(
69 | num_features=in_channels, affine=opt.affine_bn, eps=opt.bn_eps
70 | )
71 | layer = [bn]
72 | else:
73 | bn = getattr(nn, opt.normtype + "1d")(
74 | num_features=out_channels, affine=opt.affine_bn, eps=opt.bn_eps
75 | )
76 | layer = [lin, bn]
77 |
78 | if opt.activetype != "None":
79 | active = getattr(nn, opt.activetype)()
80 | layer.append(active)
81 |
82 | if opt.bn and opt.preact:
83 | layer.append(lin)
84 |
85 | self.block = nn.Sequential(*layer)
86 |
87 | def forward(self, input):
88 | return self.block.forward(input)
89 |
90 |
91 | def FinalBlock(opt, in_channels, bias=False):
92 | out_channels = opt.model_num_classes
93 | opt = copy.deepcopy(opt)
94 | if not opt.preact:
95 | opt.activetype = "None"
96 | return FCBlock(
97 | opt=opt, in_channels=in_channels, out_channels=out_channels, bias=bias
98 | )
99 |
100 |
101 | def InitialBlock(opt, out_channels, kernel_size, stride=1, padding=0, bias=False):
102 | in_channels = opt.in_channels
103 | opt = copy.deepcopy(opt)
104 | return ConvBlock(
105 | opt=opt,
106 | in_channels=in_channels,
107 | out_channels=out_channels,
108 | kernel_size=kernel_size,
109 | stride=stride,
110 | padding=padding,
111 | bias=bias,
112 | )
113 |
--------------------------------------------------------------------------------
/guide/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 | import warnings
9 |
10 | import blobfile as bf
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | try:
15 | from mpi4py import MPI
16 | except ImportError:
17 | from guide.fake_mpi import MPI
18 |
19 | warnings.warn("Using fake MPI!")
20 |
21 | # Change this to reflect your cluster layout.
22 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
23 | GPUS_PER_NODE = 8
24 | # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
25 | SETUP_RETRY_COUNT = 3
26 | # GPU_ID = "0"
27 |
28 |
29 | def setup_dist(args):
30 | """
31 | Setup a distributed process group.
32 | """
33 | # global GPU_ID
34 | # if args.gpu_id == -1:
35 | # os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
36 | # GPU_ID = ""
37 | # elif args.gpu_id != -2:
38 | # # GPU_ID = f":{args.gpu_id}"
39 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
40 |
41 | if dist.is_initialized():
42 | return
43 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
44 | print("visible devices:", os.environ["CUDA_VISIBLE_DEVICES"])
45 | comm = MPI.COMM_WORLD
46 | backend = "gloo" if not th.cuda.is_available() else "nccl"
47 |
48 | if backend == "gloo":
49 | hostname = "localhost"
50 | else:
51 | hostname = socket.gethostbyname(socket.getfqdn())
52 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
53 | os.environ["RANK"] = str(comm.rank)
54 | os.environ["WORLD_SIZE"] = str(comm.size)
55 | print("world size:", os.environ["WORLD_SIZE"])
56 |
57 | port = comm.bcast(_find_free_port(), root=0)
58 | os.environ["MASTER_PORT"] = str(port)
59 | dist.init_process_group(backend=backend, init_method="env://")
60 |
61 |
62 | def dev():
63 | """
64 | Get the device to use for torch.distributed.
65 | """
66 | # global GPU_ID
67 | if th.cuda.is_available():
68 | # return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
69 | return th.device("cuda")
70 | return th.device("cpu")
71 |
72 |
73 | def load_state_dict(path, **kwargs):
74 | """
75 | Load a PyTorch file without redundant fetches across MPI ranks.
76 | """
77 | chunk_size = 2**30 # MPI has a relatively small size limit
78 | if MPI.COMM_WORLD.Get_rank() == 0:
79 | with bf.BlobFile(path, "rb") as f:
80 | data = f.read()
81 | num_chunks = len(data) // chunk_size
82 | if len(data) % chunk_size:
83 | num_chunks += 1
84 | MPI.COMM_WORLD.bcast(num_chunks)
85 | for i in range(0, len(data), chunk_size):
86 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
87 | else:
88 | num_chunks = MPI.COMM_WORLD.bcast(None)
89 | data = bytes()
90 | for _ in range(num_chunks):
91 | data += MPI.COMM_WORLD.bcast(None)
92 |
93 | return th.load(io.BytesIO(data), **kwargs)
94 |
95 |
96 | def sync_params(params):
97 | """
98 | Synchronize a sequence of Tensors across ranks from rank 0.
99 | """
100 | # if GPU_ID!="":
101 | # rank = int(GPU_ID[-1:])
102 | # else:
103 | # rank = 0
104 | # if GPU_ID == "":
105 | for p in params:
106 | with th.no_grad():
107 | dist.broadcast(p, 0)
108 |
109 |
110 | def _find_free_port():
111 | try:
112 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
113 | s.bind(("", 0))
114 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
115 | return s.getsockname()[1]
116 | finally:
117 | s.close()
118 |
--------------------------------------------------------------------------------
/scripts/image_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import torch as th
12 | import torch.distributed as dist
13 | from torchvision.utils import make_grid
14 |
15 | from guide import dist_util, logger
16 | from guide.script_args import (
17 | add_dict_to_argparser,
18 | all_training_defaults,
19 | args_to_dict,
20 | preprocess_args,
21 | )
22 | from guide.script_util import create_model_and_diffusion, model_and_diffusion_defaults
23 |
24 |
25 | def main():
26 | args = create_argparser().parse_args()
27 | preprocess_args(args)
28 |
29 | os.environ["OPENAI_LOGDIR"] = f"sampled/{args.wandb_experiment_name}"
30 |
31 | dist_util.setup_dist(args)
32 | logger.configure()
33 |
34 | logger.log("creating model and diffusion...")
35 | model, diffusion = create_model_and_diffusion(
36 | **args_to_dict(args, model_and_diffusion_defaults().keys())
37 | )
38 | model.load_state_dict(
39 | dist_util.load_state_dict(args.model_path, map_location="cpu")
40 | )
41 | model.to(dist_util.dev())
42 | if args.use_fp16:
43 | model.convert_to_fp16()
44 | model.eval()
45 |
46 | logger.log("sampling...")
47 | all_images = []
48 | all_labels = []
49 | while len(all_images) < args.num_samples:
50 | model_kwargs = {}
51 | if args.class_cond:
52 | classes = th.randint(
53 | low=0,
54 | high=args.num_classes,
55 | size=(args.batch_size,),
56 | device=dist_util.dev(),
57 | )
58 | model_kwargs["y"] = classes
59 | sample_fn = (
60 | diffusion.ddim_sample_loop if args.use_ddim else diffusion.p_sample_loop
61 | )
62 | sample = sample_fn(
63 | model,
64 | (args.batch_size, args.in_channels, args.image_size, args.image_size),
65 | clip_denoised=args.clip_denoised,
66 | model_kwargs=model_kwargs,
67 | )
68 | all_images.extend(sample.cpu().numpy())
69 | # gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
70 | # dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
71 | # all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
72 | # if args.class_cond:
73 | # gathered_labels = [
74 | # th.zeros_like(classes) for _ in range(dist.get_world_size())
75 | # ]
76 | # dist.all_gather(gathered_labels, classes)
77 | # all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
78 | logger.log(f"created {len(all_images)} samples")
79 |
80 | arr = np.concatenate([all_images])
81 | arr = arr[: args.num_samples]
82 | if args.class_cond:
83 | label_arr = np.concatenate(all_labels, axis=0)
84 | label_arr = label_arr[: args.num_samples]
85 | if dist.get_rank() == 0:
86 | shape_str = "x".join([str(x) for x in arr.shape])
87 | out_path = os.path.join(
88 | logger.get_dir(), f"samples_{args.model_path.split('/')[-1][:-3]}.npz"
89 | )
90 | logger.log(f"saving to {out_path}")
91 | if args.class_cond:
92 | np.savez(out_path, arr, label_arr)
93 | else:
94 | np.savez(out_path, arr)
95 |
96 | dist.barrier()
97 | plt.figure()
98 | plt.axis("off")
99 | samples_grid = make_grid(th.from_numpy(arr[:16]), 4, normalize=True).permute(
100 | 1, 2, 0
101 | )
102 | plt.imshow(samples_grid)
103 | out_plot = os.path.join(
104 | logger.get_dir(), f"samples_{args.model_path.split('/')[-1][:-3]}"
105 | )
106 | plt.savefig(out_plot)
107 | logger.log("sampling complete")
108 |
109 |
110 | def create_argparser():
111 | defaults = all_training_defaults()
112 | defaults.update(model_and_diffusion_defaults())
113 | defaults.update(
114 | dict(
115 | model_path="",
116 | num_samples=2,
117 | image_size=28,
118 | in_channels=1,
119 | model_num_classes=10,
120 | batch_size=16,
121 | wandb_experiment_name="test",
122 | )
123 | )
124 | parser = argparse.ArgumentParser()
125 | add_dict_to_argparser(parser, defaults)
126 | return parser
127 |
128 |
129 | if __name__ == "__main__":
130 | main()
131 |
--------------------------------------------------------------------------------
/dataloaders/wrapper.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from os import path
3 |
4 | import numpy as np
5 | import torch
6 | import torch.utils.data as data
7 |
8 |
9 | class CacheClassLabel(data.Dataset):
10 | """
11 | A dataset wrapper that has a quick access to all labels of data.
12 | """
13 |
14 | def __init__(self, dataset, target_transform=None):
15 | super(CacheClassLabel, self).__init__()
16 | self.dataset = dataset
17 | self.labels = torch.zeros(len(dataset)).long()
18 | if target_transform:
19 | self.labels = target_transform(self.labels)
20 | label_cache_filename = path.join(
21 | dataset.root, str(type(dataset)) + "_" + str(len(dataset)) + ".pth"
22 | )
23 | if path.exists(label_cache_filename):
24 | self.labels = torch.load(label_cache_filename)
25 | else:
26 | for i, data in enumerate(dataset):
27 | self.labels[i] = data[1]
28 | torch.save(self.labels, label_cache_filename)
29 | self.number_classes = len(torch.unique(self.labels))
30 | if target_transform:
31 | self.number_classes = self.labels.shape[1]
32 |
33 | def __len__(self):
34 | return len(self.dataset)
35 |
36 | def __getitem__(self, index):
37 | img, target = self.dataset[index]
38 | return img, target
39 |
40 |
41 | class CacheClassLabelForTensor(CacheClassLabel):
42 | """
43 | A dataset wrapper that has a quick access to all labels of data.
44 | """
45 |
46 | def __init__(self, tensor_dataset, labels):
47 | super(super(CacheClassLabelForTensor, self)).__init__()
48 | self.dataset = tensor_dataset
49 | self.labels = labels
50 | self.number_classes = len(torch.unique(self.labels))
51 |
52 |
53 | class AppendName(data.Dataset):
54 | """
55 | A dataset wrapper that also return the name of the dataset/task
56 | """
57 |
58 | def __init__(
59 | self,
60 | dataset,
61 | task_ids,
62 | return_classes=False,
63 | return_task_as_class=False,
64 | first_class_ind=0,
65 | ):
66 | super(AppendName, self).__init__()
67 | self.dataset = dataset
68 | self.first_class_ind = first_class_ind
69 | self.task_ids = task_ids # For remapping the class index
70 | self.return_classes = return_classes
71 | self.return_task_as_class = return_task_as_class
72 |
73 | def __len__(self):
74 | return len(self.dataset)
75 |
76 | def __getitem__(self, index):
77 | img, target = self.dataset[index]
78 | target = target + self.first_class_ind
79 | out_dict = {}
80 | if self.return_task_as_class:
81 | out_dict["y"] = np.array(
82 | self.task_ids[index]
83 | ) # np.array(self.task_id, dtype=np.int64)
84 | elif self.return_classes:
85 | if isinstance(target, torch.Tensor):
86 | out_dict["y"] = target.cpu()
87 | elif isinstance(target, int):
88 | out_dict["y"] = torch.from_numpy(np.array(target, dtype=np.int64))
89 | return img, out_dict # target #, self.name
90 |
91 |
92 | class Subclass(data.Dataset):
93 | """
94 | A dataset wrapper that return the task name and remove the offset of labels (Let the labels start from 0)
95 | """
96 |
97 | def __init__(self, dataset, class_list, remap=True):
98 | """
99 | :param dataset: (CacheClassLabel)
100 | :param class_list: (list) A list of integers
101 | :param remap: (bool) Ex: remap class [2,4,6 ...] to [0,1,2 ...]
102 | """
103 | super(Subclass, self).__init__()
104 | assert isinstance(
105 | dataset, CacheClassLabel
106 | ), "dataset must be wrapped by CacheClassLabel"
107 | self.dataset = dataset
108 | self.class_list = deepcopy(class_list)
109 | self.remap = remap
110 | self.indices = []
111 |
112 | for c in class_list:
113 | self.indices.extend((dataset.labels == c).nonzero().flatten().tolist())
114 |
115 | if remap:
116 | self.class_mapping = {c: i for i, c in enumerate(class_list)}
117 |
118 | def __len__(self):
119 | return len(self.indices)
120 |
121 | def __getitem__(self, index):
122 | img, target = self.dataset[self.indices[index]]
123 | if self.remap:
124 | raw_target = target.item() if isinstance(target, torch.Tensor) else target
125 | target = self.class_mapping[raw_target]
126 | return img, target
127 |
128 |
129 | class Permutation(data.Dataset):
130 | """
131 | A dataset wrapper that permute the position of features
132 | """
133 |
134 | def __init__(self, dataset, permute_idx):
135 | super(Permutation, self).__init__()
136 | self.dataset = dataset
137 | self.permute_idx = permute_idx
138 |
139 | def __len__(self):
140 | return len(self.dataset)
141 |
142 | def __getitem__(self, index):
143 | img, target = self.dataset[index]
144 | shape = img.size()
145 | img = img.view(-1)[self.permute_idx].view(shape)
146 | return img, target
147 |
148 |
149 | class Storage(data.Subset):
150 | def reduce(self, m):
151 | self.indices = self.indices[:m]
152 |
--------------------------------------------------------------------------------
/guide/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def condition_mean(self, cond_fn, *args, **kwargs):
99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100 |
101 | def condition_score(self, cond_fn, *args, **kwargs):
102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def _wrap_model(self, model):
105 | if isinstance(model, _WrappedModel):
106 | return model
107 | return _WrappedModel(
108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 | )
110 |
111 | def _scale_timesteps(self, t):
112 | # Scaling is done by the wrapped model.
113 | return t
114 |
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | if self.rescale_timesteps:
127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/guide/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/dataloaders/datasetGen.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import Subset
6 |
7 | from .utils import get_stratified_subset
8 | from .wrapper import AppendName
9 |
10 |
11 | def data_split(
12 | dataset,
13 | return_classes=False,
14 | return_task_as_class=False,
15 | num_tasks=5,
16 | num_classes=10,
17 | limit_classes=-1,
18 | validation_frac=0.3,
19 | data_seed=0,
20 | shared_classes=False,
21 | first_task_num_classes=0,
22 | ):
23 | train_dataset_splits = {}
24 | val_dataset_splits = {}
25 |
26 | if not shared_classes:
27 | if limit_classes > 0:
28 | assert limit_classes <= num_classes
29 | num_classes = limit_classes
30 |
31 | # assert num_classes % num_tasks == 0
32 | if first_task_num_classes > 0:
33 | classes_per_task = (num_classes - first_task_num_classes) // (num_tasks - 1)
34 | class_split = {
35 | i: list(
36 | range(
37 | ((i - 1) * classes_per_task) + first_task_num_classes,
38 | (i * classes_per_task) + first_task_num_classes,
39 | )
40 | )
41 | for i in range(1, num_tasks)
42 | }
43 | class_split[0] = list(range(0, first_task_num_classes))
44 | class_split = dict(sorted(class_split.items()))
45 | else:
46 | classes_per_task = num_classes // num_tasks
47 | class_split = {
48 | i: list(range(i * classes_per_task, (i + 1) * classes_per_task))
49 | for i in range(num_tasks)
50 | }
51 | labels = (
52 | dataset.labels
53 | if dataset.labels.shape[1] == 1
54 | else torch.argmax(dataset.labels, 1)
55 | )
56 | class_indices = torch.LongTensor(labels)
57 | task_indices_1hot = torch.zeros(
58 | len(dataset), num_tasks
59 | ) # 1hot array describing to which tasks datapoints belong.
60 | for task, classes in class_split.items():
61 | task_indices_1hot[
62 | (class_indices[..., None] == torch.tensor(classes)).any(-1), task
63 | ] = 1
64 |
65 | train_set_indices_bitmask = torch.ones(len(dataset))
66 | validation_indices = get_stratified_subset(
67 | validation_frac, labels, seed=data_seed
68 | )
69 | train_set_indices_bitmask[validation_indices] = 0
70 |
71 | for task, classes in class_split.items():
72 | cur_task_indices_bitmask = task_indices_1hot[:, task] == 1
73 | cur_train_indices_bitmask = (
74 | train_set_indices_bitmask * cur_task_indices_bitmask
75 | )
76 | cur_val_indices_bitmask = (
77 | 1 - train_set_indices_bitmask
78 | ) * cur_task_indices_bitmask
79 |
80 | train_subset = Subset(
81 | dataset, torch.where(cur_train_indices_bitmask == 1)[0]
82 | )
83 | train_subset.class_list = classes
84 |
85 | val_subset = Subset(dataset, torch.where(cur_val_indices_bitmask == 1)[0])
86 | val_subset.class_list = classes
87 |
88 | train_dataset_splits[task] = AppendName(
89 | train_subset,
90 | [task] * len(train_subset),
91 | return_classes=return_classes,
92 | return_task_as_class=return_task_as_class,
93 | )
94 | val_dataset_splits[task] = AppendName(
95 | val_subset,
96 | [task] * len(train_subset),
97 | return_classes=return_classes,
98 | return_task_as_class=return_task_as_class,
99 | )
100 |
101 | else:
102 | # Each class in every task
103 | class_examples = defaultdict(list)
104 | for idx in range(len(dataset)):
105 | _, label = dataset[idx]
106 | class_examples[label].append(idx)
107 |
108 | classes_per_task = num_classes
109 |
110 | # Calculate the number of examples per class for each part and train/validation sets
111 | examples_per_class_per_part = len(class_examples[0]) // num_tasks
112 | train_examples = int(examples_per_class_per_part * (1 - validation_frac))
113 | val_examples = examples_per_class_per_part - train_examples
114 | for task in range(num_tasks):
115 | train_indices = []
116 | val_indices = []
117 |
118 | for class_idx in class_examples.keys():
119 | start_idx = task * examples_per_class_per_part
120 | train_end_idx = start_idx + train_examples
121 | val_end_idx = train_end_idx + val_examples
122 |
123 | train_indices.extend(class_examples[class_idx][start_idx:train_end_idx])
124 | val_indices.extend(class_examples[class_idx][train_end_idx:val_end_idx])
125 |
126 | train_subset = Subset(dataset, train_indices)
127 | train_subset.class_list = list(range(num_classes))
128 | val_subset = Subset(dataset, val_indices)
129 | val_subset.class_list = list(range(num_classes))
130 |
131 | train_dataset_splits[task] = AppendName(
132 | train_subset,
133 | [task] * len(train_subset),
134 | return_classes=return_classes,
135 | return_task_as_class=return_task_as_class,
136 | )
137 | val_dataset_splits[task] = AppendName(
138 | val_subset,
139 | [task] * len(train_subset),
140 | return_classes=return_classes,
141 | return_task_as_class=return_task_as_class,
142 | )
143 |
144 | print(
145 | f"Prepared dataset with splits: {[(idx, len(data)) for idx, data in enumerate(train_dataset_splits.values())]}"
146 | )
147 | print(
148 | f"Validation dataset with splits: {[(idx, len(data)) for idx, data in enumerate(val_dataset_splits.values())]}"
149 | )
150 | if hasattr(dataset.dataset, "classes"):
151 | print(
152 | f"Prepared class order: {[(idx, [np.array(dataset.dataset.classes)[data.dataset.class_list]]) for idx, data in enumerate(train_dataset_splits.values())]}"
153 | )
154 |
155 | return train_dataset_splits, val_dataset_splits, classes_per_task
156 |
--------------------------------------------------------------------------------
/scripts/classifier_sample_universal.py:
--------------------------------------------------------------------------------
1 | """
2 | Like image_sample.py, but use a noisy image classifier to guide the sampling
3 | process towards more realistic images.
4 |
5 | python scripts/classifier_sample_universal.py --num_samples=100 --use_ddim=True --timestep_respacing=ddim25 --model_path=results_new/bc_cifar100_ci5_class_cond_diffusion_long/ema_0.9999_100000_0.pt --classifier_path=models/resnet18_cifar100_task0.pt --num_classes_sample=20 --model_num_classes=100 --batch_size=50
6 | """
7 |
8 | import argparse
9 | import os
10 |
11 | import numpy as np
12 | import torch as th
13 | import torch.nn.functional as F
14 |
15 | from guide import dist_util, logger
16 | from guide.script_args import add_dict_to_argparser, args_to_dict, classifier_defaults
17 | from guide.script_util import (
18 | create_model_and_diffusion,
19 | create_resnet_classifier,
20 | model_and_diffusion_defaults,
21 | )
22 |
23 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
24 |
25 |
26 | def main():
27 | args = create_argparser().parse_args()
28 | print("Using manual seed = {}".format(args.seed))
29 | th.manual_seed(args.seed)
30 | th.cuda.manual_seed(args.seed)
31 | th.backends.cudnn.deterministic = True
32 | th.backends.cudnn.benchmark = False
33 | os.environ["OPENAI_LOGDIR"] = f"out/samples/{args.wandb_experiment_name}"
34 |
35 | assert args.num_samples % args.batch_size == 0
36 |
37 | logger.configure()
38 |
39 | logger.log("creating model and diffusion...")
40 | model, diffusion = create_model_and_diffusion(
41 | **args_to_dict(args, model_and_diffusion_defaults().keys())
42 | )
43 | model.load_state_dict(
44 | dist_util.load_state_dict(args.model_path, map_location="cpu")
45 | )
46 | model.to(dist_util.dev())
47 | if args.use_fp16:
48 | model.convert_to_fp16()
49 | model.eval()
50 |
51 | logger.log("loading classifier...")
52 | defaults = classifier_defaults()
53 | parser = argparse.ArgumentParser()
54 | add_dict_to_argparser(parser, defaults)
55 | opts = parser.parse_args([])
56 | opts.model_num_classes = int(args.model_num_classes)
57 | opts.in_channels = 3
58 | opts.depth = 18
59 | opts.noised = False
60 | classifier = create_resnet_classifier(opts)
61 | if args.classifier_path:
62 | classifier.load_state_dict(th.load(args.classifier_path, map_location="cpu"))
63 | classifier.to(dist_util.dev())
64 | print(classifier)
65 |
66 | # NOTE: Possible further improvements from http://arxiv.org/abs/2302.07121, but with them
67 | # sampling becomes very time-consuming.
68 | def cond_fn(x, t, y=None):
69 | assert y is not None
70 | with th.enable_grad():
71 | x = x.detach().requires_grad_(True)
72 | my_t = th.tensor(
73 | [diffusion.timestep_map.index(ts) for ts in t], device=dist_util.dev()
74 | )
75 | out = diffusion.p_mean_variance(
76 | model, x, my_t, clip_denoised=True, model_kwargs=model_kwargs
77 | )
78 | x_in = out["pred_xstart"]
79 |
80 | logit = classifier(x_in)
81 | if args.trim_logits:
82 | logit = logit[:, : int(args.num_classes_sample)]
83 |
84 | loss = -F.cross_entropy(logit, y, reduction="none")
85 |
86 | grad = th.autograd.grad(loss.sum(), x)[0]
87 | return grad * classfier_scale_vec.view(-1, 1, 1, 1)
88 |
89 | def model_fn(x, t, y=None):
90 | return model(x, t, y if args.class_cond else None)
91 |
92 | logger.log("sampling...")
93 | all_images = []
94 | all_labels = []
95 | batch_num = 0
96 | while len(all_images) < args.num_samples:
97 | model_kwargs = {}
98 | model_kwargs["y"] = th.randint(
99 | args.min_class_sample, args.max_class_sample, (args.batch_size,)
100 | ).to(dist_util.dev())
101 |
102 | sample_fn = (
103 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
104 | )
105 | classfier_scale_vec = (
106 | th.from_numpy(
107 | np.random.uniform(
108 | low=args.classifier_scale_min,
109 | high=args.classifier_scale_max,
110 | size=(len(model_kwargs["y"]),),
111 | )
112 | )
113 | .float()
114 | .to(dist_util.dev())
115 | )
116 | sample = sample_fn(
117 | model,
118 | (args.batch_size, args.in_channels, args.image_size, args.image_size),
119 | clip_denoised=True,
120 | model_kwargs=model_kwargs,
121 | cond_fn=(
122 | None
123 | if args.classifier_scale_min == 0.0 and args.classifier_scale_max == 0.0
124 | else cond_fn
125 | ),
126 | device=dist_util.dev(),
127 | )
128 | sample = ((sample + 1) * 127.5).clamp(0, 255)
129 | sample = sample.permute(0, 2, 3, 1)
130 | sample = sample.contiguous()
131 |
132 | all_images.extend([sample.cpu().numpy() for sample in sample])
133 | all_labels.extend([labels.cpu().numpy() for labels in model_kwargs["y"]])
134 | logger.log(f"created {len(all_images)} samples")
135 |
136 | batch_num += 1
137 |
138 | arr = all_images
139 | arr = arr[: args.num_samples]
140 | label_arr = all_labels
141 | label_arr = label_arr[: args.num_samples]
142 | out_path = os.path.join(
143 | os.path.dirname(args.model_path), f"{args.wandb_experiment_name}.npz"
144 | )
145 | logger.log(f"saving to {out_path}")
146 | np.savez(out_path, arr, label_arr)
147 |
148 | logger.log("sampling complete")
149 |
150 |
151 | def create_argparser():
152 | defaults = dict(
153 | clip_denoised=True,
154 | num_samples=10000,
155 | batch_size=16,
156 | use_ddim=False,
157 | timestep_respacing="",
158 | model_path="",
159 | classifier_path="",
160 | classifier_scale_min=0.0,
161 | classifier_scale_max=0.0,
162 | wandb_experiment_name="test",
163 | model_num_classes=10,
164 | trim_logits=True,
165 | min_class_sample=0,
166 | max_class_sample=0,
167 | seed=0,
168 | )
169 | defaults.update(model_and_diffusion_defaults())
170 | parser = argparse.ArgumentParser()
171 | add_dict_to_argparser(parser, defaults)
172 | return parser
173 |
174 |
175 | if __name__ == "__main__":
176 | main()
177 |
--------------------------------------------------------------------------------
/guide/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 | import wandb
8 |
9 |
10 | def create_named_schedule_sampler(name, diffusion, args):
11 | """
12 | Create a ScheduleSampler from a library of pre-defined samplers.
13 |
14 | :param name: the name of the sampler.
15 | :param diffusion: the diffusion object to sample for.
16 | """
17 | if name == "uniform":
18 | return UniformSampler(diffusion)
19 | elif name == "beta":
20 | return BetaSampler(diffusion, args.alpha, args.beta)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | elif name == "task_aware":
24 | return TaskAwareSampler(diffusion, args.alpha, args.beta)
25 | else:
26 | raise NotImplementedError(f"unknown schedule sampler: {name}")
27 |
28 |
29 | class ScheduleSampler(ABC):
30 | """
31 | A distribution over timesteps in the diffusion process, intended to reduce
32 | variance of the objective.
33 |
34 | By default, samplers perform unbiased importance sampling, in which the
35 | objective's mean is unchanged.
36 | However, subclasses may override sample() to change how the resampled
37 | terms are reweighted, allowing for actual changes in the objective.
38 | """
39 |
40 | @abstractmethod
41 | def weights(self):
42 | """
43 | Get a numpy array of weights, one per diffusion step.
44 |
45 | The weights needn't be normalized, but must be positive.
46 | """
47 |
48 | def sample(self, batch_size, device):
49 | """
50 | Importance-sample timesteps for a batch.
51 |
52 | :param batch_size: the number of timesteps.
53 | :param device: the torch device to save to.
54 | :return: a tuple (timesteps, weights):
55 | - timesteps: a tensor of timestep indices.
56 | - weights: a tensor of weights to scale the resulting losses.
57 | """
58 | w = self.weights()
59 | # wandb.log({f"w_{i}":w[i] for i in range(len(w))})
60 | p = w / np.sum(w)
61 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
62 | indices = th.from_numpy(indices_np).long().to(device)
63 | weights_np = 1 / (len(p) * p[indices_np])
64 | weights = th.from_numpy(weights_np).float().to(device)
65 | return indices, weights
66 |
67 |
68 | class UniformSampler(ScheduleSampler):
69 | def __init__(self, diffusion):
70 | self.diffusion = diffusion
71 | self._weights = np.ones([diffusion.num_timesteps])
72 |
73 | def weights(self):
74 | return self._weights
75 |
76 |
77 | class TaskAwareSampler:
78 | def __init__(self, diffusion, alfa=4, beta=1.2):
79 | self.diffusion = diffusion
80 | self.beta_sampler = BetaSampler(diffusion, alfa, beta, weights_smoothing=0)
81 | self.uniform_sampler = UniformSampler(diffusion)
82 |
83 | def sample(self, batch_size, device, task_ids, current_task_id):
84 | curr_task_indices, curr_task_weights = self.uniform_sampler.sample(
85 | (task_ids == current_task_id).sum().item(), device
86 | )
87 | prev_task_indices, prev_task_weights = self.beta_sampler.sample(
88 | (task_ids != current_task_id).sum().item(), device
89 | )
90 | indices = th.zeros(batch_size, device=device).long()
91 | weights = th.zeros(batch_size, device=device).float()
92 |
93 | indices[task_ids == current_task_id] = curr_task_indices
94 | indices[task_ids != current_task_id] = prev_task_indices
95 |
96 | weights[task_ids == current_task_id] = curr_task_weights
97 | weights[task_ids != current_task_id] = prev_task_weights
98 | return indices, weights
99 |
100 |
101 | class BetaSampler(ScheduleSampler):
102 | def __init__(self, diffusion, alfa=4, beta=1.2, weights_smoothing=1):
103 | beta_dist = th.distributions.beta.Beta(alfa, beta)
104 | w = th.exp(
105 | beta_dist.log_prob(
106 | (th.arange(0, diffusion.num_timesteps) / diffusion.num_timesteps)
107 | )
108 | )
109 | self.diffusion = diffusion
110 | self._weights = (
111 | w.numpy() + weights_smoothing
112 | ) # np.ones([diffusion.num_timesteps])
113 |
114 | def weights(self):
115 | return self._weights
116 |
117 |
118 | class LossAwareSampler(ScheduleSampler):
119 | def update_with_local_losses(self, local_ts, local_losses):
120 | """
121 | Update the reweighting using losses from a model.
122 |
123 | Call this method from each rank with a batch of timesteps and the
124 | corresponding losses for each of those timesteps.
125 | This method will perform synchronization to make sure all of the ranks
126 | maintain the exact same reweighting.
127 |
128 | :param local_ts: an integer Tensor of timesteps.
129 | :param local_losses: a 1D Tensor of losses.
130 | """
131 | batch_sizes = [
132 | th.tensor([0], dtype=th.int32, device=local_ts.device)
133 | for _ in range(dist.get_world_size())
134 | ]
135 | dist.all_gather(
136 | batch_sizes,
137 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
138 | )
139 |
140 | # Pad all_gather batches to be the maximum batch size.
141 | batch_sizes = [x.item() for x in batch_sizes]
142 | max_bs = max(batch_sizes)
143 |
144 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
145 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
146 | dist.all_gather(timestep_batches, local_ts)
147 | dist.all_gather(loss_batches, local_losses)
148 | timesteps = [
149 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
150 | ]
151 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
152 | self.update_with_all_losses(timesteps, losses)
153 |
154 | @abstractmethod
155 | def update_with_all_losses(self, ts, losses):
156 | """
157 | Update the reweighting using losses from a model.
158 |
159 | Sub-classes should override this method to update the reweighting
160 | using losses from the model.
161 |
162 | This method directly updates the reweighting without synchronizing
163 | between workers. It is called by update_with_local_losses from all
164 | ranks with identical arguments. Thus, it should have deterministic
165 | behavior to maintain state across workers.
166 |
167 | :param ts: a list of int timesteps.
168 | :param losses: a list of float losses, one per timestep.
169 | """
170 |
171 |
172 | class LossSecondMomentResampler(LossAwareSampler):
173 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
174 | self.diffusion = diffusion
175 | self.history_per_term = history_per_term
176 | self.uniform_prob = uniform_prob
177 | self._loss_history = np.zeros(
178 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
179 | )
180 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
181 |
182 | def weights(self):
183 | if not self._warmed_up():
184 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
185 | weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
186 | weights /= np.sum(weights)
187 | weights *= 1 - self.uniform_prob
188 | weights += self.uniform_prob / len(weights)
189 | return weights
190 |
191 | def update_with_all_losses(self, ts, losses):
192 | for t, loss in zip(ts, losses):
193 | if self._loss_counts[t] == self.history_per_term:
194 | # Shift out the oldest loss term.
195 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
196 | self._loss_history[t, -1] = loss
197 | else:
198 | self._loss_history[t, self._loss_counts[t]] = loss
199 | self._loss_counts[t] += 1
200 |
201 | def _warmed_up(self):
202 | return (self._loss_counts == self.history_per_term).all()
203 |
--------------------------------------------------------------------------------
/guide/script_args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 |
5 | def diffusion_defaults():
6 | """
7 | Defaults for image and classifier training.
8 | """
9 | return dict(
10 | learn_sigma=False,
11 | sigma_small=False,
12 | diffusion_steps=1000,
13 | noise_schedule="linear",
14 | timestep_respacing="",
15 | use_kl=False,
16 | predict_xstart=False,
17 | rescale_timesteps=False,
18 | rescale_learned_sigmas=False,
19 | )
20 |
21 |
22 | def classifier_defaults():
23 | """
24 | Defaults for classifier models.
25 | """
26 | return dict(
27 | image_size=64,
28 | classifier_use_fp16=False,
29 | classifier_width=128,
30 | classifier_depth=2,
31 | classifier_attention_resolutions="32,16,8", # 16
32 | classifier_use_scale_shift_norm=True, # False
33 | classifier_resblock_updown=True, # False
34 | classifier_pool="attention",
35 | # ResNet disjoint classifier parameters
36 | activetype="ReLU", # activation type
37 | pooltype="MaxPool2d", # Pooling type
38 | normtype="BatchNorm", # Batch norm type
39 | preact=True, # Places norms and activations before linear/conv layer.
40 | bn=True, # Apply Batchnorm.
41 | affine_bn=True, # Apply affine transform in BN.
42 | bn_eps=1e-6, # Affine transform for batch norm
43 | )
44 |
45 |
46 | def model_and_diffusion_defaults():
47 | """
48 | Defaults for image training.
49 | """
50 | res = dict(
51 | num_channels=128,
52 | num_res_blocks=3,
53 | num_heads=4,
54 | num_heads_upsample=-1,
55 | num_head_channels=-1,
56 | attention_resolutions="16,8",
57 | channel_mult="",
58 | dropout=0.1,
59 | use_checkpoint=False,
60 | use_scale_shift_norm=True,
61 | resblock_updown=False,
62 | use_fp16=False,
63 | use_new_attention_order=False,
64 | image_size=32,
65 | in_channels=3,
66 | model_switching_timestep=30,
67 | model_name="UNetModel",
68 | embedding_kind="concat_time_1hot", # embedding used for time and "class", possible values in EMBEDDING_KINDS
69 | model_num_classes=None,
70 | train_noised_classifier=False,
71 | )
72 | res.update(diffusion_defaults())
73 | return res
74 |
75 |
76 | def all_training_defaults():
77 | defaults = dict(
78 | seed=13,
79 | data_seed=0, # Seed used for data generation (mostly train/valid split). Typically, no need to set it.
80 | wandb_api_key="",
81 | wandb_experiment_name="test",
82 | wandb_project_name="project",
83 | wandb_entity="entity",
84 | dataroot="data/",
85 | dataset="CIFAR10",
86 | schedule_sampler="uniform",
87 | alpha=4,
88 | beta=1.2,
89 | lr=2e-4,
90 | disjoint_classifier_lr=1e-2,
91 | weight_decay=0.0,
92 | lr_anneal_steps=0,
93 | batch_size=64,
94 | microbatch=-1, # -1 disables microbatches
95 | ema_rate="0.9999", # comma-separated list of EMA values
96 | log_interval=500,
97 | skip_save=False,
98 | save_interval=5000,
99 | guid_generation_interval=1, # generate new examples from diffusion model every guid_generation_interval steps
100 | resume_checkpoint="",
101 | resume_checkpoint_classifier="",
102 | resume_checkpoint_classifier_noised="",
103 | use_fp16=False,
104 | fp16_scale_growth=1e-3,
105 | gpu_id=-1,
106 | reverse=False,
107 | num_tasks=5,
108 | limit_tasks=-1,
109 | limit_classes=-1,
110 | shared_classes=False,
111 | train_aug=False,
112 | skip_normalization=False,
113 | num_steps=20000,
114 | scheduler_rate=1.0,
115 | first_task_num_classes=0,
116 | first_task_num_steps=-1, # if -1, set to the same as num_steps.
117 | skip_gradient_thr=-1,
118 | log_gradient_stats=False,
119 | clip_denoised=True,
120 | cl_method="generative_replay_disjoint_classifier_guidance", # possible values are defined in CL_METHODS
121 | use_ddim=False,
122 | classifier_scale_min_old=None,
123 | classifier_scale_min_new=None,
124 | classifier_scale_max_old=None,
125 | classifier_scale_max_new=None,
126 | first_task=0,
127 | gr_n_generated_examples_per_task=32,
128 | use_old_grad=False,
129 | use_new_grad=False,
130 | guid_to_new_classes=False,
131 | trim_logits=True,
132 | train_with_disjoint_classifier=False,
133 | disjoint_classifier_init_num_steps=5000, # Steps in first task
134 | disjoint_classifier_num_steps=2000,
135 | classifier_init_lr=0.1, # First task learning rate
136 | classifier_lr=0.05, # Learning rate
137 | classifier_weight_decay=5e-4, # Weight decay
138 | depth=18, # ResNet depth
139 | classifier_augmentation=True,
140 | diffusion_pretrained_dir=None, # Directory contating trained diffusion models on each task. It effectively disables any training of the diffusion model.
141 | negate_old_grad=False, # negate old gradient
142 | classifier_first_task_dir=None,
143 | )
144 | defaults.update(model_and_diffusion_defaults())
145 | return defaults
146 |
147 |
148 | def classifier_and_diffusion_defaults():
149 | res = classifier_defaults()
150 | res.update(diffusion_defaults())
151 | return res
152 |
153 |
154 | def combine_with_defaults(config):
155 | res = all_training_defaults()
156 | for k, v in config.items():
157 | assert k in res.keys(), "{} not in default values".format(k)
158 | res[k] = v
159 | return res
160 |
161 |
162 | CL_METHODS = [
163 | "generative_replay",
164 | "generative_replay_disjoint_classifier_guidance", # GR with one disjoint classifier
165 | ]
166 |
167 | EMBEDDING_KINDS = [
168 | "none",
169 | "concat_time_1hot",
170 | "add_time_learned", # original from the paper
171 | ]
172 |
173 |
174 | def preprocess_args(args):
175 | """Perform simple validity checks and do a simple initial processing of training args."""
176 |
177 | assert args.cl_method in CL_METHODS
178 | assert args.embedding_kind in EMBEDDING_KINDS
179 |
180 | if args.first_task_num_steps == -1:
181 | args.first_task_num_steps = args.num_steps
182 |
183 | if not args.dataroot:
184 | args.dataroot = os.environ.get("DIFFUSION_DATA", "")
185 |
186 | if (
187 | args.classifier_scale_min_old is not None
188 | and args.classifier_scale_max_old is None
189 | ):
190 | args.classifier_scale_max_old = args.classifier_scale_min_old
191 |
192 | if (
193 | args.classifier_scale_min_old is not None
194 | and args.classifier_scale_max_old is None
195 | ):
196 | args.classifier_scale_max_old = args.classifier_scale_min_old
197 |
198 | if (
199 | args.classifier_scale_max_old is not None
200 | and args.classifier_scale_min_old is None
201 | ):
202 | args.classifier_scale_min_old = args.classifier_scale_max_old
203 |
204 | if (
205 | args.classifier_scale_min_new is not None
206 | and args.classifier_scale_max_new is None
207 | ):
208 | args.classifier_scale_max_new = args.classifier_scale_min_new
209 |
210 | if (
211 | args.classifier_scale_min_new is not None
212 | and args.classifier_scale_max_new is None
213 | ):
214 | args.classifier_scale_max_new = args.classifier_scale_min_new
215 |
216 | if (
217 | args.classifier_scale_max_new is not None
218 | and args.classifier_scale_min_new is None
219 | ):
220 | args.classifier_scale_min_new = args.classifier_scale_max_new
221 |
222 |
223 | def add_dict_to_argparser(parser, default_dict):
224 | for k, v in default_dict.items():
225 | v_type = type(v)
226 | if v is None:
227 | v_type = str
228 | elif isinstance(v, bool):
229 | v_type = str2bool
230 | parser.add_argument(f"--{k}", default=v, type=v_type)
231 |
232 |
233 | def args_to_dict(args, keys):
234 | return {k: getattr(args, k) for k in keys}
235 |
236 |
237 | def str2bool(v):
238 | """
239 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
240 | """
241 | if isinstance(v, bool):
242 | return v
243 | if v.lower() in ("yes", "true", "t", "y", "1"):
244 | return True
245 | elif v.lower() in ("no", "false", "f", "n", "0"):
246 | return False
247 | else:
248 | raise argparse.ArgumentTypeError("boolean value expected")
249 |
--------------------------------------------------------------------------------
/guide/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .layers import ConvBlock, FinalBlock, InitialBlock
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | expansion = 1
9 |
10 | def __init__(self, opt, inChannels, outChannels, stride=1, downsample=None):
11 | super(BasicBlock, self).__init__()
12 | self.downsample = downsample
13 | expansion = 1
14 | self.conv1 = ConvBlock(
15 | opt=opt,
16 | in_channels=inChannels,
17 | out_channels=outChannels,
18 | kernel_size=3,
19 | stride=stride,
20 | padding=1,
21 | bias=False,
22 | )
23 | self.conv2 = ConvBlock(
24 | opt=opt,
25 | in_channels=outChannels,
26 | out_channels=outChannels * expansion,
27 | kernel_size=3,
28 | stride=1,
29 | padding=1,
30 | bias=False,
31 | )
32 |
33 | def forward(self, x):
34 | _out = self.conv1(x)
35 | _out = self.conv2(_out)
36 | if self.downsample is not None:
37 | shortcut = self.downsample(x)
38 | else:
39 | shortcut = x
40 | _out = _out + shortcut
41 | return _out
42 |
43 |
44 | class BottleneckBlock(nn.Module):
45 | expansion = 4
46 |
47 | def __init__(self, opt, inChannels, outChannels, stride=1, downsample=None):
48 | super(BottleneckBlock, self).__init__()
49 | expansion = 4
50 | self.conv1 = ConvBlock(
51 | opt=opt,
52 | in_channels=inChannels,
53 | out_channels=outChannels,
54 | kernel_size=1,
55 | stride=1,
56 | padding=0,
57 | bias=False,
58 | )
59 | self.conv2 = ConvBlock(
60 | opt=opt,
61 | in_channels=outChannels,
62 | out_channels=outChannels,
63 | kernel_size=3,
64 | stride=stride,
65 | padding=1,
66 | bias=False,
67 | )
68 | self.conv3 = ConvBlock(
69 | opt=opt,
70 | in_channels=outChannels,
71 | out_channels=outChannels * expansion,
72 | kernel_size=1,
73 | stride=1,
74 | padding=0,
75 | bias=False,
76 | )
77 | self.downsample = downsample
78 |
79 | def forward(self, x):
80 | _out = self.conv1(x)
81 | _out = self.conv2(_out)
82 | _out = self.conv3(_out)
83 | if self.downsample is not None:
84 | shortcut = self.downsample(x)
85 | else:
86 | shortcut = x
87 | _out = _out + shortcut
88 | return _out
89 |
90 |
91 | class ResidualBlock(nn.Module):
92 | def __init__(self, opt, block, inChannels, outChannels, depth, stride=1):
93 | super(ResidualBlock, self).__init__()
94 | if stride != 1 or inChannels != outChannels * block.expansion:
95 | downsample = ConvBlock(
96 | opt=opt,
97 | in_channels=inChannels,
98 | out_channels=outChannels * block.expansion,
99 | kernel_size=1,
100 | stride=stride,
101 | padding=0,
102 | bias=False,
103 | )
104 | else:
105 | downsample = None
106 | self.blocks = nn.Sequential()
107 | self.blocks.add_module(
108 | "block0", block(opt, inChannels, outChannels, stride, downsample)
109 | )
110 | inChannels = outChannels * block.expansion
111 | for i in range(1, depth):
112 | self.blocks.add_module(
113 | "block{}".format(i), block(opt, inChannels, outChannels)
114 | )
115 |
116 | def forward(self, x):
117 | return self.blocks(x)
118 |
119 |
120 | class ResNet(nn.Module):
121 | def __init__(self, opt):
122 | super(ResNet, self).__init__()
123 | depth = opt.depth
124 | if depth in [20, 32, 44, 56, 110, 1202]:
125 | blocktype, self.nettype = "BasicBlock", "cifar"
126 | elif depth in [164, 1001]:
127 | blocktype, self.nettype = "BottleneckBlock", "cifar"
128 | elif depth in [18, 34]:
129 | blocktype, self.nettype = "BasicBlock", "imagenet"
130 | elif depth in [50, 101, 152]:
131 | blocktype, self.nettype = "BottleneckBlock", "imagenet"
132 | assert depth in [20, 32, 44, 56, 110, 1202, 164, 1001, 18, 34, 50, 101, 152]
133 |
134 | if blocktype == "BasicBlock" and self.nettype == "cifar":
135 | assert (
136 | depth - 2
137 | ) % 6 == 0, (
138 | "Depth should be 6n+2, and preferably one of 20, 32, 44, 56, 110, 1202"
139 | )
140 | n = (depth - 2) // 6
141 | block = BasicBlock
142 | in_planes, out_planes = 16, 64
143 | elif blocktype == "BottleneckBlock" and self.nettype == "cifar":
144 | assert (
145 | depth - 2
146 | ) % 9 == 0, "Depth should be 9n+2, and preferably one of 164 or 1001"
147 | n = (depth - 2) // 9
148 | block = BottleneckBlock
149 | in_planes, out_planes = 16, 64
150 | elif blocktype == "BasicBlock" and self.nettype == "imagenet":
151 | assert depth in [18, 34]
152 | num_blocks = [2, 2, 2, 2] if depth == 18 else [3, 4, 6, 3]
153 | block = BasicBlock
154 | in_planes, out_planes = 64, 512 # 20, 160
155 | elif blocktype == "BottleneckBlock" and self.nettype == "imagenet":
156 | assert depth in [50, 101, 152]
157 | if depth == 50:
158 | num_blocks = [3, 4, 6, 3]
159 | elif depth == 101:
160 | num_blocks = [3, 4, 23, 3]
161 | elif depth == 152:
162 | num_blocks = [3, 8, 36, 3]
163 | block = BottleneckBlock
164 | in_planes, out_planes = 64, 512
165 | else:
166 | assert 1 == 2
167 |
168 | self.num_classes = opt.model_num_classes
169 | self.initial = InitialBlock(
170 | opt=opt, out_channels=in_planes, kernel_size=3, stride=1, padding=1
171 | )
172 | if self.nettype == "cifar":
173 | self.group1 = ResidualBlock(opt, block, 16, 16, n, stride=1)
174 | self.group2 = ResidualBlock(
175 | opt, block, 16 * block.expansion, 32, n, stride=2
176 | )
177 | self.group3 = ResidualBlock(
178 | opt, block, 32 * block.expansion, 64, n, stride=2
179 | )
180 | elif self.nettype == "imagenet":
181 | self.group1 = ResidualBlock(
182 | opt, block, 64, 64, num_blocks[0], stride=1
183 | ) # For ResNet-S, convert this to 20,20
184 | self.group2 = ResidualBlock(
185 | opt, block, 64 * block.expansion, 128, num_blocks[1], stride=2
186 | ) # For ResNet-S, convert this to 20,40
187 | self.group3 = ResidualBlock(
188 | opt, block, 128 * block.expansion, 256, num_blocks[2], stride=2
189 | ) # For ResNet-S, convert this to 40,80
190 | self.group4 = ResidualBlock(
191 | opt, block, 256 * block.expansion, 512, num_blocks[3], stride=2
192 | ) # For ResNet-S, convert this to 80,160
193 | else:
194 | assert 1 == 2
195 | self.pool = nn.AdaptiveAvgPool2d(1)
196 | self.dim_out = out_planes * block.expansion
197 | self.noised = opt.noised
198 |
199 | in_size = (
200 | out_planes * block.expansion + 1
201 | if opt.noised
202 | else out_planes * block.expansion
203 | )
204 | self.final = FinalBlock(opt=opt, in_channels=in_size)
205 |
206 | for m in self.modules():
207 | if isinstance(m, nn.Conv2d):
208 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
209 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
210 | nn.init.constant_(m.weight, 1)
211 | nn.init.constant_(m.bias, 0)
212 |
213 | def forward(self, x, t=None, return_h=False):
214 | out = self.initial(x)
215 | out = self.group1(out)
216 | out = self.group2(out)
217 | out = self.group3(out)
218 | if self.nettype == "imagenet":
219 | out = self.group4(out)
220 | out = self.pool(out)
221 | out = out.view(x.size(0), -1)
222 |
223 | if return_h:
224 | return out
225 |
226 | if self.noised:
227 | if t is None:
228 | t = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)
229 | out = torch.cat([out, t.unsqueeze(1)], 1)
230 |
231 | out = self.final(out)
232 |
233 | return out
234 |
--------------------------------------------------------------------------------
/guide/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9 |
10 | from . import logger
11 |
12 | INITIAL_LOG_LOSS_SCALE = 20.0
13 |
14 |
15 | def convert_module_to_f16(l):
16 | """
17 | Convert primitive modules to float16.
18 | """
19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20 | l.weight.data = l.weight.data.half()
21 | if l.bias is not None:
22 | l.bias.data = l.bias.data.half()
23 |
24 |
25 | def convert_module_to_f32(l):
26 | """
27 | Convert primitive modules to float32, undoing convert_module_to_f16().
28 | """
29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | l.weight.data = l.weight.data.float()
31 | if l.bias is not None:
32 | l.bias.data = l.bias.data.float()
33 |
34 |
35 | def make_master_params(param_groups_and_shapes):
36 | """
37 | Copy model parameters into a (differently-shaped) list of full-precision
38 | parameters.
39 | """
40 | master_params = []
41 | for param_group, shape in param_groups_and_shapes:
42 | master_param = nn.Parameter(
43 | _flatten_dense_tensors(
44 | [param.detach().float() for (_, param) in param_group]
45 | ).view(shape)
46 | )
47 | master_param.requires_grad = True
48 | master_params.append(master_param)
49 | return master_params
50 |
51 |
52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53 | """
54 | Copy the gradients from the model parameters into the master parameters
55 | from make_master_params().
56 | """
57 | for master_param, (param_group, shape) in zip(
58 | master_params, param_groups_and_shapes
59 | ):
60 | master_param.grad = _flatten_dense_tensors(
61 | [param_grad_or_zeros(param) for (_, param) in param_group]
62 | ).view(shape)
63 |
64 |
65 | def master_params_to_model_params(param_groups_and_shapes, master_params):
66 | """
67 | Copy the master parameter data back into the model parameters.
68 | """
69 | # Without copying to a list, if a generator is passed, this will
70 | # silently not copy any parameters.
71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72 | for (_, param), unflat_master_param in zip(
73 | param_group, unflatten_master_params(param_group, master_param.view(-1))
74 | ):
75 | param.detach().copy_(unflat_master_param)
76 |
77 |
78 | def unflatten_master_params(param_group, master_param):
79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80 |
81 |
82 | def get_param_groups_and_shapes(named_model_params):
83 | named_model_params = list(named_model_params)
84 | scalar_vector_named_params = (
85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86 | (-1),
87 | )
88 | matrix_named_params = (
89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90 | (1, -1),
91 | )
92 | return [scalar_vector_named_params, matrix_named_params]
93 |
94 |
95 | def master_params_to_state_dict(
96 | model, param_groups_and_shapes, master_params, use_fp16
97 | ):
98 | if use_fp16:
99 | state_dict = model.state_dict()
100 | for master_param, (param_group, _) in zip(
101 | master_params, param_groups_and_shapes
102 | ):
103 | for (name, _), unflat_master_param in zip(
104 | param_group, unflatten_master_params(param_group, master_param.view(-1))
105 | ):
106 | assert name in state_dict
107 | state_dict[name] = unflat_master_param
108 | else:
109 | state_dict = model.state_dict()
110 | for i, (name, _value) in enumerate(model.named_parameters()):
111 | assert name in state_dict
112 | state_dict[name] = master_params[i]
113 | return state_dict
114 |
115 |
116 | def state_dict_to_master_params(model, state_dict, use_fp16):
117 | if use_fp16:
118 | named_model_params = [
119 | (name, state_dict[name]) for name, _ in model.named_parameters()
120 | ]
121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122 | master_params = make_master_params(param_groups_and_shapes)
123 | else:
124 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
125 | return master_params
126 |
127 |
128 | def zero_master_grads(master_params):
129 | for param in master_params:
130 | param.grad = None
131 |
132 |
133 | def zero_grad(model_params):
134 | for param in model_params:
135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136 | if param.grad is not None:
137 | param.grad.detach_()
138 | param.grad.zero_()
139 |
140 |
141 | def param_grad_or_zeros(param):
142 | if param.grad is not None:
143 | return param.grad.data.detach()
144 | else:
145 | return th.zeros_like(param)
146 |
147 |
148 | class MixedPrecisionTrainer:
149 | def __init__(
150 | self,
151 | *,
152 | model,
153 | use_fp16=False,
154 | fp16_scale_growth=1e-3,
155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 | skip_gradient_thr=-1,
157 | ):
158 | self.model = model
159 | self.use_fp16 = use_fp16
160 | self.fp16_scale_growth = fp16_scale_growth
161 | self.skip_gradient_thr = skip_gradient_thr
162 |
163 | self.model_params = list(self.model.parameters())
164 | self.master_params = self.model_params
165 | self.param_groups_and_shapes = None
166 | self.lg_loss_scale = initial_lg_loss_scale
167 |
168 | if self.use_fp16:
169 | self.param_groups_and_shapes = get_param_groups_and_shapes(
170 | self.model.named_parameters()
171 | )
172 | self.master_params = make_master_params(self.param_groups_and_shapes)
173 | self.model.convert_to_fp16()
174 |
175 | def zero_grad(self):
176 | zero_grad(self.model_params)
177 |
178 | def backward(self, loss: th.Tensor):
179 | if self.use_fp16:
180 | loss_scale = 2**self.lg_loss_scale
181 | (loss * loss_scale).backward()
182 | else:
183 | loss.backward()
184 |
185 | def optimize(self, opt: th.optim.Optimizer):
186 | if self.use_fp16:
187 | return self._optimize_fp16(opt)
188 | else:
189 | return self._optimize_normal(opt)
190 |
191 | def _optimize_fp16(self, opt: th.optim.Optimizer):
192 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
193 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
194 | grad_norm, param_norm = self._compute_norms(grad_scale=2**self.lg_loss_scale)
195 | if check_overflow(grad_norm):
196 | self.lg_loss_scale -= 1
197 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
198 | zero_master_grads(self.master_params)
199 | return False
200 |
201 | logger.logkv_mean("grad_norm", grad_norm)
202 | logger.logkv_mean("param_norm", param_norm)
203 |
204 | self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale))
205 | # TODO: add skip gradients here
206 | if self.skip_gradient_thr == -1.0 or grad_norm < self.skip_gradient_thr:
207 | logger.logkv_mean("skip_update", 0)
208 | opt.step()
209 | step_made = True
210 | else:
211 | logger.logkv_mean("skip_update", 1)
212 | step_made = False
213 | zero_master_grads(self.master_params)
214 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
215 | self.lg_loss_scale += self.fp16_scale_growth
216 | return step_made
217 |
218 | def _optimize_normal(self, opt: th.optim.Optimizer):
219 | grad_norm, param_norm = self._compute_norms()
220 | logger.logkv_mean("grad_norm", grad_norm)
221 | logger.logkv_mean("param_norm", param_norm)
222 | # TODO: add skip gradients here
223 | if self.skip_gradient_thr == -1.0 or grad_norm < self.skip_gradient_thr:
224 | logger.logkv_mean("skip_update", 0)
225 | opt.step()
226 | return True
227 | else:
228 | logger.logkv_mean("skip_update", 1)
229 | return False
230 |
231 | def _compute_norms(self, grad_scale=1.0):
232 | grad_norm = 0.0
233 | param_norm = 0.0
234 | for p in self.master_params:
235 | with th.no_grad():
236 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
237 | if p.grad is not None:
238 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
239 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
240 |
241 | def master_params_to_state_dict(self, master_params):
242 | return master_params_to_state_dict(
243 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
244 | )
245 |
246 | def state_dict_to_master_params(self, state_dict):
247 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
248 |
249 |
250 | def check_overflow(value):
251 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
252 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GUIDE: Guidance-based Incremental Learning with Diffusion Models
2 | [](https://arxiv.org/abs/2403.03938)
3 |
4 | This repository is the official implementation of [GUIDE: Guidance-based Incremental Learning with Diffusion Models](https://arxiv.org/abs/2403.03938).
5 |
6 |
7 |
8 | **Rehearsal sampling in GUIDE.**
9 | We guide the denoising process of a diffusion model trained on the previous task (blue) toward classes from the current task (orange).
10 | The replay samples, highlighted with **blue borders**, share features with the examples from the current task, which may be related to characteristics such as color or background (e.g., fishes on a snowy background when guided to *snowmobile*).
11 | Continual training of a classifier on such samples positioned near its decision boundary successfully mitigates catastrophic forgetting.
12 |
13 | ## Setup
14 |
15 | ### Clone repo
16 |
17 | ```bash
18 | git clone https://github.com/cywinski/guide.git
19 | cd guide
20 | ```
21 |
22 | ### Prepare Conda environment
23 |
24 | ```bash
25 | conda create -n guide_env python=3.8
26 | conda activate guide_env
27 | ```
28 |
29 | ### Install torch
30 |
31 | Install `torch` and `torchvision` according to instructions on [offical website](https://pytorch.org/).
32 |
33 | ### Install required **packages**
34 |
35 | ```
36 | pip install .
37 | ```
38 |
39 | ### Login to wandb
40 |
41 | ```bash
42 | wandb login
43 | ```
44 |
45 | ## Reproduction
46 |
47 | Below we present training commands for a single GPU setup. To run the training in distributed manner, run the same command with `mpiexec`:
48 |
49 | ```bash
50 | mpiexec -n $NUM_GPUS python scripts.image_train ...
51 | ```
52 |
53 | When training in a distributed manner, you must manually divide the `--batch_size` argument by the number of ranks. In lieu of distributed training, you may use `--microbatch 16` (or `--microbatch 1` in extreme memory-limited cases) to reduce memory usage.
54 |
55 | ### Diffusion models training
56 |
57 | Continual training of diffusion models with self-rehearsal. Trained models will be stored under `results/`
58 |
59 | **CIFAR-10/2**
60 |
61 | ```bash
62 | python -m scripts.image_train --wandb_experiment_name=c10_ci2_class_cond_diffusion --wandb_project_name=project --wandb_entity=entity --batch_size=256 --num_steps=100000 --dataset=CIFAR10 --num_tasks=2 --save_interval=100000 --gr_n_generated_examples_per_task=25000 --first_task_num_steps=100000 --seed=0 --timestep_respacing=1000 --use_ddim=False --log_interval=1000 --cl_method=generative_replay --train_with_disjoint_classifier=False --embedding_kind=concat_time_1hot --train_aug=True
63 | ```
64 |
65 | **CIFAR-10/5**
66 |
67 | ```bash
68 | python -m scripts.image_train --wandb_experiment_name=c10_ci5_class_cond_diffusion --wandb_project_name=project --wandb_entity=entity --batch_size=256 --num_steps=50000 --dataset=CIFAR10 --num_tasks=5 --save_interval=50000 --gr_n_generated_examples_per_task=10000 --first_task_num_steps=100000 --seed=0 --timestep_respacing=1000 --use_ddim=False --log_interval=1000 --cl_method=generative_replay --train_with_disjoint_classifier=False --embedding_kind=concat_time_1hot --train_aug=True
69 | ```
70 |
71 | **CIFAR-100/5**
72 |
73 | ```bash
74 | python -m scripts.image_train --wandb_experiment_name=c100_ci5_class_cond_diffusion --wandb_project_name=project --wandb_entity=entity --batch_size=256 --num_steps=50000 --dataset=CIFAR100 --num_tasks=5 --save_interval=50000 --gr_n_generated_examples_per_task=10000 --first_task_num_steps=100000 --seed=0 --timestep_respacing=1000 --use_ddim=False --log_interval=1000 --cl_method=generative_replay --train_with_disjoint_classifier=False --embedding_kind=concat_time_1hot --train_aug=True
75 | ```
76 |
77 | **CIFAR-100/10**
78 |
79 | ```bash
80 | python -m scripts.image_train --wandb_experiment_name=c100_ci10_class_cond_diffusion --wandb_project_name=project --wandb_entity=entity --batch_size=256 --num_steps=100000 --dataset=CIFAR100 --num_tasks=10 --save_interval=100000 --gr_n_generated_examples_per_task=5000 --first_task_num_steps=100000 --seed=0 --timestep_respacing=1000 --use_ddim=False --log_interval=1000 --cl_method=generative_replay --train_with_disjoint_classifier=False --embedding_kind=concat_time_1hot --train_aug=True
81 | ```
82 |
83 | **ImageNet100-64/5**
84 |
85 | ```bash
86 | python -m scripts.image_train --wandb_experiment_name=i100_ci5_class_cond_diffusion --wandb_project_name=project --wandb_entity=entity --batch_size=100 --num_steps=50000 --dataset=ImageNet100 --num_tasks=5 --save_interval=50000 --gr_n_generated_examples_per_task=26000 --first_task_num_steps=100000 --seed=0 --timestep_respacing=ddim250 --use_ddim=True --log_interval=1000 --cl_method=generative_replay --train_with_disjoint_classifier=False --embedding_kind=concat_time_1hot --train_aug=True --attention_resolutions 32,16,8 --lr 1e-4 --resblock_updown True --use_new_attention_order True --use_scale_shift_norm True --num_channels 192 --num_head_channels 64
87 | ```
88 |
89 | ### Classifier trainings
90 |
91 | Continual classifier trainings with generative replay according to GUIDE method. To run classifier trainings you first need to train the diffusion models (according to instructions presented above) and store `ema` checkpoints in `--diffusion_pretrained_dir`.
92 |
93 | **CIFAR-10/2**
94 |
95 | ```bash
96 | python -m scripts.image_train --wandb_experiment_name=c10_ci2_guide --wandb_project_name=project --wandb_entity=entity --batch_size=256 --dataset=CIFAR10 --num_tasks=2 --seed=0 --timestep_respacing=ddim50 --use_ddim=True --classifier_scale_min_new=0.2 --classifier_scale_max_new=0.2 --cl_method=generative_replay_disjoint_classifier_guidance --train_with_disjoint_classifier=True --use_old_grad=False --use_new_grad=True --guid_to_new_classes=True --embedding_kind=concat_time_1hot --classifier_init_lr=0.1 --classifier_lr=0.01 --disjoint_classifier_init_num_steps=5000 --disjoint_classifier_num_steps=2000 --classifier_augmentation=True --log_interval=200 --diffusion_pretrained_dir=results/c10_ci2_class_cond_diffusion
97 | ```
98 |
99 | **CIFAR-10/5**
100 |
101 | ```bash
102 | python -m scripts.image_train --wandb_experiment_name=c10_ci5_guide --wandb_project_name=project --wandb_entity=entity --batch_size=256 --dataset=CIFAR10 --num_tasks=5 --seed=0 --timestep_respacing=ddim50 --use_ddim=True --classifier_scale_min_new=0.5 --classifier_scale_max_new=0.5 --cl_method=generative_replay_disjoint_classifier_guidance --train_with_disjoint_classifier=True --use_old_grad=False --use_new_grad=True --guid_to_new_classes=True --embedding_kind=concat_time_1hot --classifier_init_lr=0.1 --classifier_lr=0.01 --disjoint_classifier_init_num_steps=5000 --disjoint_classifier_num_steps=2000 --classifier_augmentation=True --log_interval=200 --guid_generation_interval=5 --diffusion_pretrained_dir=results/c10_ci5_class_cond_diffusion
103 | ```
104 |
105 | **CIFAR-100/5**
106 |
107 | ```bash
108 | python -m scripts.image_train --wandb_experiment_name=c100_ci5_guide --wandb_project_name=project --wandb_entity=entity --batch_size=256 --dataset=CIFAR100 --num_tasks=5 --seed=0 --timestep_respacing=ddim100 --use_ddim=True --classifier_scale_min_new=0.5 --classifier_scale_max_new=0.5 --cl_method=generative_replay_disjoint_classifier_guidance --train_with_disjoint_classifier=True --use_old_grad=False --use_new_grad=True --guid_to_new_classes=True --embedding_kind=concat_time_1hot --classifier_init_lr=0.1 --classifier_lr=0.05 --disjoint_classifier_init_num_steps=10000 --disjoint_classifier_num_steps=2000 --classifier_augmentation=True --log_interval=200 --guid_generation_interval=10 --diffusion_pretrained_dir=results/c100_ci5_class_cond_diffusion
109 | ```
110 |
111 | **CIFAR-100/10**
112 |
113 | ```bash
114 | python -m scripts.image_train --wandb_experiment_name=c100_ci10_guide --wandb_project_name=project --wandb_entity=entity --batch_size=256 --dataset=CIFAR100 --num_tasks=10 --seed=0 --timestep_respacing=ddim100 --use_ddim=True --classifier_scale_min_new=1.0 --classifier_scale_max_new=1.0 --cl_method=generative_replay_disjoint_classifier_guidance --train_with_disjoint_classifier=True --use_old_grad=False --use_new_grad=True --guid_to_new_classes=True --embedding_kind=concat_time_1hot --classifier_init_lr=0.1 --classifier_lr=0.05 --disjoint_classifier_init_num_steps=10000 --disjoint_classifier_num_steps=2000 --classifier_augmentation=True --log_interval=200 --guid_generation_interval=10 --diffusion_pretrained_dir=results/c100_ci10_class_cond_diffusion
115 | ```
116 |
117 | **ImageNet100-64/5**
118 |
119 | ```bash
120 | python -m scripts.image_train --wandb_experiment_name=i100_ci5_guide --wandb_project_name=project --wandb_entity=entity --batch_size=100 --dataset=ImageNet100 --num_tasks=5 --seed=0 --timestep_respacing=ddim50 --use_ddim=True --classifier_scale_min_new=1.0 --classifier_scale_max_new=1.0 --cl_method=generative_replay_disjoint_classifier_guidance --train_with_disjoint_classifier=True --use_old_grad=False --use_new_grad=True --guid_to_new_classes=True --embedding_kind=concat_time_1hot --classifier_init_lr=0.1 --classifier_lr=0.001 --disjoint_classifier_init_num_steps=20000 --disjoint_classifier_num_steps=20000 --classifier_augmentation=False --log_interval=200 --guid_generation_interval=15 --attention_resolutions 32,16,8 --lr 1e-4 --resblock_updown True --use_new_attention_order True --use_scale_shift_norm True --num_channels 192 --num_head_channels 64 --diffusion_pretrained_dir=results/i100_ci5_class_cond_diffusion
121 | ```
122 |
123 |
124 | ## BibTeX
125 |
126 | If you find this work useful, please consider citing it:
127 |
128 | ```bibtex
129 | @article{cywinski2024guide,
130 | title={GUIDE: Guidance-based Incremental Learning with Diffusion Models},
131 | author={Cywi{\'n}ski, Bartosz and Deja, Kamil and Trzci{\'n}ski, Tomasz and Twardowski, Bart{\l}omiej and Kuci{\'n}ski, {\L}ukasz},
132 | journal={arXiv preprint arXiv:2403.03938},
133 | year={2024}
134 | }
135 | ```
136 |
137 | ## Acknowledgments
138 |
139 | This codebase borrows from [OpenAI's guided diffusion repo](https://github.com/openai/guided-diffusion) and [Continual-Learning-Benchmark repo](https://github.com/GT-RIPL/Continual-Learning-Benchmark).
140 |
--------------------------------------------------------------------------------
/scripts/image_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on images.
3 | """
4 |
5 | import sys
6 |
7 | sys.path.append(".")
8 | import argparse
9 | import copy
10 | import os
11 | import time
12 | from collections import OrderedDict
13 |
14 | import numpy as np
15 | import torch as th
16 |
17 | import wandb
18 | from cl_methods.utils import get_cl_method
19 | from dataloaders import base
20 | from dataloaders.datasetGen import *
21 | from dataloaders.utils import prepare_eval_loaders
22 | from guide import dist_util, logger
23 | from guide.logger import wandb_safe_log
24 | from guide.resample import create_named_schedule_sampler
25 | from guide.script_args import (
26 | add_dict_to_argparser,
27 | all_training_defaults,
28 | args_to_dict,
29 | classifier_defaults,
30 | preprocess_args,
31 | )
32 | from guide.script_util import (
33 | create_model_and_diffusion,
34 | create_resnet_classifier,
35 | model_and_diffusion_defaults,
36 | results_to_log,
37 | )
38 | from guide.train_util import TrainLoop
39 | from guide.validation import calculate_accuracy_with_classifier
40 |
41 | # os.environ["WANDB_MODE"] = "disabled"
42 |
43 |
44 | def main():
45 | args = create_argparser().parse_args()
46 | run_training_with_args(args)
47 |
48 |
49 | def run_training_with_args(args):
50 | preprocess_args(args)
51 |
52 | dist_util.setup_dist(args)
53 |
54 | if logger.get_rank_without_mpi_import() == 0:
55 | if args.wandb_api_key:
56 | os.environ["WANDB_API_KEY"] = args.wandb_api_key
57 | wandb.init(
58 | project=args.wandb_project_name,
59 | name=args.wandb_experiment_name,
60 | config=args,
61 | entity=args.wandb_entity,
62 | )
63 |
64 | args.seed = args.seed + logger.get_rank_without_mpi_import()
65 | random_generator = seed_everything(args.seed)
66 | os.environ["OPENAI_LOGDIR"] = f"results/{args.wandb_experiment_name}"
67 | os.makedirs(os.path.join(logger.get_dir(), "generated_examples"), exist_ok=True)
68 | logger.configure()
69 | logger.log("Using manual seed = {}".format(args.seed))
70 |
71 | (
72 | train_dataset,
73 | val_dataset,
74 | image_size,
75 | image_channels,
76 | train_transform_classifier,
77 | train_transform_diffusion,
78 | n_classes,
79 | ) = base.__dict__[args.dataset](
80 | args.dataroot,
81 | train_aug=args.train_aug,
82 | skip_normalization=args.skip_normalization,
83 | classifier_augmentation=args.classifier_augmentation,
84 | )
85 |
86 | args.image_size = image_size
87 | args.in_channels = image_channels
88 | args.model_num_classes = n_classes
89 |
90 | logger.log("creating model and diffusion...")
91 | model, diffusion = create_model_and_diffusion(
92 | **args_to_dict(args, model_and_diffusion_defaults().keys())
93 | )
94 | if args.log_gradient_stats and not os.environ.get("WANDB_MODE") == "disabled":
95 | wandb.watch(model, log_freq=10)
96 | # if we are not training diffusion, we will not need this model
97 | if not args.train_with_disjoint_classifier:
98 | model.to(dist_util.dev())
99 |
100 | classifier = None
101 | if args.train_with_disjoint_classifier:
102 | logger.log("creating disjoint classifier...")
103 | defaults = classifier_defaults()
104 | parser = argparse.ArgumentParser()
105 | add_dict_to_argparser(parser, defaults)
106 | opts = parser.parse_args([])
107 | opts.model_num_classes = n_classes
108 | opts.in_channels = args.in_channels
109 | opts.depth = args.depth
110 | opts.noised = args.train_noised_classifier
111 | classifier = create_resnet_classifier(opts)
112 | classifier.to(dist_util.dev())
113 | # Needed for creating correct EMAs and fp16 parameters.
114 | dist_util.sync_params(classifier.parameters())
115 |
116 | if args.resume_checkpoint_classifier:
117 | classifier.load_state_dict(
118 | dist_util.load_state_dict(
119 | args.resume_checkpoint_classifier, map_location=dist_util.dev()
120 | )
121 | )
122 | logger.log(
123 | f"loading classifier from {args.resume_checkpoint_classifier}..."
124 | )
125 |
126 | schedule_sampler = create_named_schedule_sampler(
127 | args.schedule_sampler, diffusion, args
128 | )
129 |
130 | logger.log("creating data loaders...")
131 | train_dataset_splits, _, classes_per_task = data_split(
132 | dataset=train_dataset,
133 | return_classes=True,
134 | return_task_as_class=False,
135 | num_tasks=args.num_tasks,
136 | num_classes=n_classes,
137 | limit_classes=args.limit_classes,
138 | data_seed=args.data_seed,
139 | shared_classes=args.shared_classes,
140 | first_task_num_classes=args.first_task_num_classes,
141 | validation_frac=0.0,
142 | )
143 |
144 | val_dataset_splits, _, classes_per_task = data_split(
145 | dataset=val_dataset,
146 | return_classes=True,
147 | return_task_as_class=False,
148 | num_tasks=args.num_tasks,
149 | num_classes=n_classes,
150 | limit_classes=args.limit_classes,
151 | data_seed=args.data_seed,
152 | shared_classes=args.shared_classes,
153 | first_task_num_classes=args.first_task_num_classes,
154 | validation_frac=0.0,
155 | )
156 |
157 | train_loaders = []
158 | validation_loaders = prepare_eval_loaders(
159 | train_dataset_splits=train_dataset_splits,
160 | val_dataset_splits=val_dataset_splits,
161 | args=args,
162 | include_train=False,
163 | generator=random_generator,
164 | )
165 | test_acc_table = OrderedDict()
166 | train_acc_table = OrderedDict()
167 |
168 | train_loop = None
169 | cl_method = get_cl_method(args)
170 | global_step = 0
171 | dataset_yielder = None
172 | train_loader = None
173 | generated_previous_examples = None
174 |
175 | if args.limit_tasks != -1:
176 | n_tasks = args.limit_tasks
177 | else:
178 | n_tasks = args.num_tasks
179 |
180 | for task_id in range(n_tasks):
181 | if args.first_task_num_classes > 0 and task_id == 0:
182 | max_class = args.first_task_num_classes
183 | else:
184 | max_class = (
185 | ((task_id + 1) * (n_classes // n_tasks)) - 1
186 | ) + args.first_task_num_classes
187 |
188 | if task_id == 0:
189 | if not args.train_with_disjoint_classifier:
190 | num_steps = args.first_task_num_steps
191 | else:
192 | num_steps = args.disjoint_classifier_init_num_steps
193 | else:
194 | if not args.train_with_disjoint_classifier:
195 | num_steps = args.num_steps
196 | else:
197 | num_steps = args.disjoint_classifier_num_steps
198 |
199 | train_loop = TrainLoop(
200 | params=args,
201 | model=model,
202 | prev_model=copy.deepcopy(model).to(dist_util.dev()),
203 | diffusion=diffusion,
204 | task_id=task_id,
205 | data=train_dataset_splits[task_id],
206 | data_yielder=None,
207 | data_loader=None,
208 | batch_size=args.batch_size,
209 | microbatch=args.microbatch,
210 | lr=args.lr,
211 | scheduler_rate=args.scheduler_rate,
212 | ema_rate=args.ema_rate,
213 | log_interval=args.log_interval,
214 | skip_save=args.skip_save,
215 | save_interval=args.save_interval,
216 | resume_checkpoint=(
217 | args.resume_checkpoint if task_id == args.first_task else None
218 | ),
219 | use_fp16=args.use_fp16,
220 | fp16_scale_growth=args.fp16_scale_growth,
221 | schedule_sampler=schedule_sampler,
222 | weight_decay=args.weight_decay,
223 | lr_anneal_steps=args.lr_anneal_steps,
224 | num_steps=num_steps,
225 | image_size=args.image_size,
226 | in_channels=args.in_channels,
227 | max_class=max_class,
228 | global_steps_before=global_step,
229 | cl_method=cl_method,
230 | classes_per_task=classes_per_task,
231 | use_ddim=args.use_ddim,
232 | classifier_scale_min_old=args.classifier_scale_min_old,
233 | classifier_scale_min_new=args.classifier_scale_min_new,
234 | classifier_scale_max_old=args.classifier_scale_max_old,
235 | classifier_scale_max_new=args.classifier_scale_max_new,
236 | guid_generation_interval=args.guid_generation_interval,
237 | use_old_grad=args.use_old_grad,
238 | use_new_grad=args.use_new_grad,
239 | guid_to_new_classes=args.guid_to_new_classes,
240 | trim_logits=args.trim_logits,
241 | disjoint_classifier=classifier,
242 | prev_disjoint_classifier=copy.deepcopy(classifier),
243 | diffusion_pretrained_dir=args.diffusion_pretrained_dir,
244 | train_transform_classifier=train_transform_classifier,
245 | train_transform_diffusion=train_transform_diffusion,
246 | n_classes=n_classes,
247 | random_generator=random_generator,
248 | classifier_first_task_dir=args.classifier_first_task_dir,
249 | train_noised_classifier=args.train_noised_classifier,
250 | )
251 |
252 | if task_id >= args.first_task:
253 | (
254 | dataset_yielder,
255 | train_loader,
256 | generated_previous_examples,
257 | ) = cl_method.get_data_for_task(
258 | dataset=train_dataset_splits[task_id],
259 | task_id=task_id,
260 | train_loop=train_loop,
261 | generator=random_generator,
262 | step=global_step,
263 | )
264 | train_loaders.append(train_loader)
265 | logger.log(f"training task {task_id}")
266 |
267 | train_loop.data_yielder = dataset_yielder
268 | train_loop.data_loader = train_loader
269 |
270 | train_loop_start_time = time.time()
271 | if task_id >= args.first_task:
272 | train_loop.run_loop()
273 | global_step += num_steps
274 | train_loop_time = time.time() - train_loop_start_time
275 | wandb_safe_log({"train_loop_time": train_loop_time}, step=global_step)
276 |
277 | logger.log("validation...")
278 | test_acc_table[task_id] = OrderedDict()
279 | train_acc_table[task_id] = OrderedDict()
280 | validation_start_time = time.time()
281 | for j in range(task_id + 1):
282 | if args.train_with_disjoint_classifier:
283 | clf_results = calculate_accuracy_with_classifier(
284 | model=(
285 | model if not args.train_with_disjoint_classifier else classifier
286 | ),
287 | task_id=j,
288 | val_loader=validation_loaders[j],
289 | device=dist_util.dev(),
290 | train_loader=(
291 | train_loaders[j - args.first_task]
292 | if j >= args.first_task
293 | else None
294 | ),
295 | max_class=max_class,
296 | train_with_disjoint_classifier=args.train_with_disjoint_classifier,
297 | )
298 | test_acc_table[j][task_id] = clf_results["accuracy"]["test"]
299 | train_acc_table[j][task_id] = clf_results["accuracy"]["train"]
300 | logger.log(f"Test accuracy task {j}: {test_acc_table[j][task_id]}")
301 | logger.log(f"Train accuracy task {j}: {train_acc_table[j][task_id]}")
302 | else:
303 | test_acc_table[j][task_id] = 0.0
304 | train_acc_table[j][task_id] = 0.0
305 |
306 | validation_time = time.time() - validation_start_time
307 | if logger.get_rank_without_mpi_import() == 0:
308 | results_to_log(
309 | test_acc_table,
310 | train_acc_table,
311 | validation_time=validation_time,
312 | step=global_step,
313 | task_id=task_id,
314 | )
315 | if generated_previous_examples is not None:
316 | th.save(
317 | generated_previous_examples,
318 | os.path.join(
319 | logger.get_dir(), f"generated_examples/task_{task_id:02d}.pt"
320 | ),
321 | )
322 | train_loop.prev_ddp_model = copy.deepcopy(model)
323 | logger.log("TEST ACCURACY TABLE:")
324 | logger.log(test_acc_table)
325 | logger.log("TRAIN ACCURACY TABLE:")
326 | logger.log(train_acc_table)
327 |
328 |
329 | def seed_everything(seed):
330 | th.manual_seed(seed)
331 | np.random.seed(seed)
332 | th.cuda.manual_seed(seed)
333 | th.backends.cudnn.deterministic = True
334 | th.backends.cudnn.benchmark = False
335 | random_generator = th.Generator()
336 | random_generator.manual_seed(seed)
337 | return random_generator
338 |
339 |
340 | def create_argparser():
341 | defaults = all_training_defaults()
342 | parser = argparse.ArgumentParser()
343 | add_dict_to_argparser(parser, defaults)
344 | return parser
345 |
346 |
347 | if __name__ == "__main__":
348 | main()
349 |
--------------------------------------------------------------------------------
/guide/script_util.py:
--------------------------------------------------------------------------------
1 | import inspect
2 |
3 | import matplotlib
4 | import numpy as np
5 |
6 | from . import gaussian_diffusion as gd
7 | from .logger import wandb_safe_log
8 | from .resnet import ResNet
9 | from .respace import SpacedDiffusion, space_timesteps
10 | from .script_args import model_and_diffusion_defaults
11 | from .unet import EncoderUNetModel, SuperResModel
12 |
13 | # NUM_CLASSES = 1000
14 |
15 |
16 | def create_model_and_diffusion(
17 | image_size,
18 | in_channels,
19 | learn_sigma,
20 | sigma_small,
21 | num_channels,
22 | num_res_blocks,
23 | channel_mult,
24 | num_heads,
25 | num_head_channels,
26 | num_heads_upsample,
27 | attention_resolutions,
28 | dropout,
29 | diffusion_steps,
30 | noise_schedule,
31 | timestep_respacing,
32 | use_kl,
33 | predict_xstart,
34 | rescale_timesteps,
35 | rescale_learned_sigmas,
36 | use_checkpoint,
37 | use_scale_shift_norm,
38 | resblock_updown,
39 | use_fp16,
40 | use_new_attention_order,
41 | model_name,
42 | model_switching_timestep,
43 | embedding_kind,
44 | model_num_classes=None,
45 | noise_marg_reg=False,
46 | train_noised_classifier=False,
47 | classifier_augmentation=True,
48 | ):
49 | model = create_model(
50 | image_size,
51 | in_channels,
52 | num_channels,
53 | num_res_blocks,
54 | model_name=model_name,
55 | model_switching_timestep=model_switching_timestep,
56 | embedding_kind=embedding_kind,
57 | channel_mult=channel_mult,
58 | learn_sigma=learn_sigma,
59 | use_checkpoint=use_checkpoint,
60 | attention_resolutions=attention_resolutions,
61 | num_heads=num_heads,
62 | num_head_channels=num_head_channels,
63 | num_heads_upsample=num_heads_upsample,
64 | use_scale_shift_norm=use_scale_shift_norm,
65 | dropout=dropout,
66 | resblock_updown=resblock_updown,
67 | use_fp16=use_fp16,
68 | use_new_attention_order=use_new_attention_order,
69 | num_classes=model_num_classes,
70 | classifier_augmentation=classifier_augmentation,
71 | )
72 | diffusion = create_gaussian_diffusion(
73 | steps=diffusion_steps,
74 | learn_sigma=learn_sigma,
75 | sigma_small=sigma_small,
76 | noise_schedule=noise_schedule,
77 | use_kl=use_kl,
78 | predict_xstart=predict_xstart,
79 | rescale_timesteps=rescale_timesteps,
80 | rescale_learned_sigmas=rescale_learned_sigmas,
81 | timestep_respacing=timestep_respacing,
82 | noise_marg_reg=noise_marg_reg,
83 | train_noised_classifier=train_noised_classifier,
84 | )
85 |
86 | return model, diffusion
87 |
88 |
89 | def create_model(
90 | image_size,
91 | in_channels,
92 | num_channels,
93 | num_res_blocks,
94 | model_name,
95 | model_switching_timestep,
96 | embedding_kind,
97 | channel_mult="",
98 | learn_sigma=False,
99 | use_checkpoint=False,
100 | attention_resolutions="16",
101 | num_heads=1,
102 | num_head_channels=-1,
103 | num_heads_upsample=-1,
104 | use_scale_shift_norm=False,
105 | dropout=0,
106 | resblock_updown=False,
107 | use_fp16=False,
108 | use_new_attention_order=False,
109 | num_classes=None,
110 | classifier_augmentation=False,
111 | ):
112 | if channel_mult == "":
113 | if image_size == 512:
114 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
115 | elif image_size == 256:
116 | channel_mult = (1, 1, 2, 2, 4, 4)
117 | elif image_size in {128, 224}:
118 | channel_mult = (1, 1, 2, 3, 4)
119 | elif image_size == 64:
120 | channel_mult = (1, 2, 3, 4)
121 | elif image_size == 32:
122 | channel_mult = (1, 2, 2, 2)
123 | elif image_size == 28:
124 | channel_mult = (1, 2, 2)
125 | elif image_size == 1:
126 | channel_mult = (1, 1, 1)
127 | else:
128 | raise ValueError(f"unsupported image size: {image_size}")
129 | else:
130 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
131 |
132 | attention_ds = []
133 | for res in attention_resolutions.split(","):
134 | attention_ds.append(image_size // int(res))
135 |
136 | if model_name == "UNetModel":
137 | print("Using single model")
138 | from .unet import UNetModel as Model
139 | elif model_name == "MLPModel":
140 | from .mlp import MLPModel as Model
141 | else:
142 | raise NotImplementedError
143 | return Model(
144 | image_size=image_size,
145 | in_channels=in_channels,
146 | model_channels=num_channels,
147 | out_channels=(in_channels if not learn_sigma else in_channels * 2),
148 | num_res_blocks=num_res_blocks,
149 | attention_resolutions=tuple(attention_ds),
150 | dropout=dropout,
151 | channel_mult=channel_mult,
152 | num_classes=num_classes,
153 | use_checkpoint=use_checkpoint,
154 | use_fp16=use_fp16,
155 | num_heads=num_heads,
156 | num_head_channels=num_head_channels,
157 | num_heads_upsample=num_heads_upsample,
158 | use_scale_shift_norm=use_scale_shift_norm,
159 | resblock_updown=resblock_updown,
160 | use_new_attention_order=use_new_attention_order,
161 | model_switching_timestep=model_switching_timestep,
162 | embedding_kind=embedding_kind,
163 | classifier_augmentation=classifier_augmentation,
164 | )
165 |
166 |
167 | def create_classifier_and_diffusion(
168 | image_size,
169 | classifier_use_fp16,
170 | classifier_width,
171 | classifier_depth,
172 | classifier_attention_resolutions,
173 | classifier_use_scale_shift_norm,
174 | classifier_resblock_updown,
175 | classifier_pool,
176 | learn_sigma,
177 | sigma_small,
178 | diffusion_steps,
179 | noise_schedule,
180 | timestep_respacing,
181 | use_kl,
182 | predict_xstart,
183 | rescale_timesteps,
184 | rescale_learned_sigmas,
185 | ):
186 | classifier = create_classifier(
187 | image_size,
188 | classifier_use_fp16,
189 | classifier_width,
190 | classifier_depth,
191 | classifier_attention_resolutions,
192 | classifier_use_scale_shift_norm,
193 | classifier_resblock_updown,
194 | classifier_pool,
195 | )
196 | diffusion = create_gaussian_diffusion(
197 | steps=diffusion_steps,
198 | learn_sigma=learn_sigma,
199 | sigma_small=sigma_small,
200 | noise_schedule=noise_schedule,
201 | use_kl=use_kl,
202 | predict_xstart=predict_xstart,
203 | rescale_timesteps=rescale_timesteps,
204 | rescale_learned_sigmas=rescale_learned_sigmas,
205 | timestep_respacing=timestep_respacing,
206 | )
207 | return classifier, diffusion
208 |
209 |
210 | def create_classifier(
211 | image_size,
212 | classifier_use_fp16,
213 | classifier_width,
214 | classifier_depth,
215 | classifier_attention_resolutions,
216 | classifier_use_scale_shift_norm,
217 | classifier_resblock_updown,
218 | classifier_pool,
219 | ):
220 | if image_size == 512:
221 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
222 | elif image_size == 256:
223 | channel_mult = (1, 1, 2, 2, 4, 4)
224 | elif image_size in {128, 224}:
225 | channel_mult = (1, 1, 2, 3, 4)
226 | elif image_size == 64:
227 | channel_mult = (1, 2, 3, 4)
228 | elif image_size == 32:
229 | channel_mult = (1, 2, 2, 2)
230 | else:
231 | raise ValueError(f"unsupported image size: {image_size}")
232 |
233 | attention_ds = []
234 | for res in classifier_attention_resolutions.split(","):
235 | attention_ds.append(image_size // int(res))
236 |
237 | return EncoderUNetModel(
238 | image_size=image_size,
239 | in_channels=3,
240 | model_channels=classifier_width,
241 | out_channels=1000,
242 | num_res_blocks=classifier_depth,
243 | attention_resolutions=tuple(attention_ds),
244 | channel_mult=channel_mult,
245 | use_fp16=classifier_use_fp16,
246 | num_head_channels=64,
247 | use_scale_shift_norm=classifier_use_scale_shift_norm,
248 | resblock_updown=classifier_resblock_updown,
249 | pool=classifier_pool,
250 | )
251 |
252 |
253 | def create_resnet_classifier(args):
254 | classifier = ResNet(args)
255 | return classifier
256 |
257 |
258 | def sr_model_and_diffusion_defaults():
259 | res = model_and_diffusion_defaults()
260 | res["large_size"] = 256
261 | res["small_size"] = 64
262 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
263 | for k in res.copy().keys():
264 | if k not in arg_names:
265 | del res[k]
266 | return res
267 |
268 |
269 | def sr_create_model_and_diffusion(
270 | large_size,
271 | small_size,
272 | learn_sigma,
273 | sigma_small,
274 | num_channels,
275 | num_res_blocks,
276 | num_heads,
277 | num_head_channels,
278 | num_heads_upsample,
279 | attention_resolutions,
280 | dropout,
281 | diffusion_steps,
282 | noise_schedule,
283 | timestep_respacing,
284 | use_kl,
285 | predict_xstart,
286 | rescale_timesteps,
287 | rescale_learned_sigmas,
288 | use_checkpoint,
289 | use_scale_shift_norm,
290 | resblock_updown,
291 | use_fp16,
292 | ):
293 | model = sr_create_model(
294 | large_size,
295 | small_size,
296 | num_channels,
297 | num_res_blocks,
298 | learn_sigma=learn_sigma,
299 | use_checkpoint=use_checkpoint,
300 | attention_resolutions=attention_resolutions,
301 | num_heads=num_heads,
302 | num_head_channels=num_head_channels,
303 | num_heads_upsample=num_heads_upsample,
304 | use_scale_shift_norm=use_scale_shift_norm,
305 | dropout=dropout,
306 | resblock_updown=resblock_updown,
307 | use_fp16=use_fp16,
308 | )
309 | diffusion = create_gaussian_diffusion(
310 | steps=diffusion_steps,
311 | learn_sigma=learn_sigma,
312 | sigma_small=sigma_small,
313 | noise_schedule=noise_schedule,
314 | use_kl=use_kl,
315 | predict_xstart=predict_xstart,
316 | rescale_timesteps=rescale_timesteps,
317 | rescale_learned_sigmas=rescale_learned_sigmas,
318 | timestep_respacing=timestep_respacing,
319 | )
320 | return model, diffusion
321 |
322 |
323 | def sr_create_model(
324 | large_size,
325 | small_size,
326 | num_channels,
327 | num_res_blocks,
328 | learn_sigma,
329 | use_checkpoint,
330 | attention_resolutions,
331 | num_heads,
332 | num_head_channels,
333 | num_heads_upsample,
334 | use_scale_shift_norm,
335 | dropout,
336 | resblock_updown,
337 | use_fp16,
338 | ):
339 | _ = small_size # hack to prevent unused variable
340 |
341 | if large_size == 512:
342 | channel_mult = (1, 1, 2, 2, 4, 4)
343 | elif large_size == 256:
344 | channel_mult = (1, 1, 2, 2, 4, 4)
345 | elif large_size == 64:
346 | channel_mult = (1, 2, 3, 4)
347 | else:
348 | raise ValueError(f"unsupported large size: {large_size}")
349 |
350 | attention_ds = []
351 | for res in attention_resolutions.split(","):
352 | attention_ds.append(large_size // int(res))
353 |
354 | return SuperResModel(
355 | image_size=large_size,
356 | in_channels=3,
357 | model_channels=num_channels,
358 | out_channels=(3 if not learn_sigma else 6),
359 | num_res_blocks=num_res_blocks,
360 | attention_resolutions=tuple(attention_ds),
361 | dropout=dropout,
362 | channel_mult=channel_mult,
363 | use_checkpoint=use_checkpoint,
364 | num_heads=num_heads,
365 | num_head_channels=num_head_channels,
366 | num_heads_upsample=num_heads_upsample,
367 | use_scale_shift_norm=use_scale_shift_norm,
368 | resblock_updown=resblock_updown,
369 | use_fp16=use_fp16,
370 | )
371 |
372 |
373 | def create_gaussian_diffusion(
374 | *,
375 | steps=1000,
376 | learn_sigma=False,
377 | sigma_small=False,
378 | noise_schedule="linear",
379 | use_kl=False,
380 | predict_xstart=False,
381 | predict_xprevious=False,
382 | rescale_timesteps=False,
383 | rescale_learned_sigmas=False,
384 | timestep_respacing="",
385 | noise_marg_reg=False,
386 | train_noised_classifier=False,
387 | ):
388 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
389 | if use_kl:
390 | loss_type = gd.LossType.RESCALED_KL
391 | elif rescale_learned_sigmas:
392 | loss_type = gd.LossType.RESCALED_MSE
393 | else:
394 | loss_type = gd.LossType.MSE
395 | if not timestep_respacing:
396 | timestep_respacing = [steps]
397 |
398 | if predict_xstart:
399 | model_mean_type = gd.ModelMeanType.START_X
400 | elif predict_xprevious:
401 | model_mean_type = gd.ModelMeanType.PREVIOUS_X
402 | else:
403 | model_mean_type = gd.ModelMeanType.EPSILON
404 |
405 | return SpacedDiffusion(
406 | use_timesteps=space_timesteps(steps, timestep_respacing),
407 | betas=betas,
408 | model_mean_type=model_mean_type,
409 | model_var_type=(
410 | (
411 | gd.ModelVarType.FIXED_LARGE
412 | if not sigma_small
413 | else gd.ModelVarType.FIXED_SMALL
414 | )
415 | if not learn_sigma
416 | else gd.ModelVarType.LEARNED_RANGE
417 | ),
418 | loss_type=loss_type,
419 | rescale_timesteps=rescale_timesteps,
420 | noise_marg_reg=noise_marg_reg,
421 | train_noised_classifier=train_noised_classifier,
422 | )
423 |
424 |
425 | def dict2array(results):
426 | tasks = len(results[0])
427 | array = np.zeros((tasks, tasks))
428 | for e, (key, val) in enumerate(reversed(results.items())):
429 | for e1, (k, v) in enumerate(reversed(val.items())):
430 | array[tasks - int(e1) - 1, tasks - int(e) - 1] = round(v, 3)
431 | return np.transpose(array, axes=(1, 0))
432 |
433 |
434 | def grid_plot(ax, array, type):
435 | if type == "fid":
436 | round = 1
437 | else:
438 | round = 2
439 | avg_array = np.around(array, round)
440 | num_tasks = array.shape[1]
441 | cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
442 | "", ["#287233", "#4c1c24"]
443 | )
444 | ax.imshow(avg_array, vmin=50, vmax=300, cmap=cmap)
445 | for i in range(len(avg_array)):
446 | for j in range(avg_array.shape[1]):
447 | if j >= i:
448 | ax.text(
449 | j,
450 | i,
451 | avg_array[i, j],
452 | va="center",
453 | ha="center",
454 | c="w",
455 | fontsize=70 / num_tasks,
456 | )
457 | ax.set_yticks(np.arange(num_tasks))
458 | ax.set_ylabel("Number of tasks")
459 | ax.set_xticks(np.arange(num_tasks))
460 | ax.set_xlabel("Tasks finished")
461 | ax.set_title(
462 | f"{type} -- {np.round(np.mean(array[:, -1]), 3)} -- std {np.round(np.std(array[:, -1]), 2)}"
463 | )
464 |
465 |
466 | def results_to_log(
467 | test_acc_table,
468 | train_acc_table,
469 | validation_time,
470 | step,
471 | task_id,
472 | ):
473 | log_dict = {"validation_time": validation_time}
474 | avg_acc_train = 0.0
475 | avg_acc_test = 0.0
476 | avg_forgetting_train = 0.0
477 | avg_forgetting_test = 0.0
478 | for j in range(task_id + 1):
479 | log_dict.update(
480 | {
481 | f"test/accuracy/{j}": test_acc_table[j][task_id],
482 | f"train/accuracy/{j}": train_acc_table[j][task_id],
483 | }
484 | )
485 | avg_acc_train += train_acc_table[j][task_id]
486 | avg_acc_test += test_acc_table[j][task_id]
487 |
488 | max_forgetting_train = 0.0 if (j == task_id) else -float("inf")
489 | max_forgetting_test = 0.0 if (j == task_id) else -float("inf")
490 | for k in range(j, task_id):
491 | max_forgetting_train = max(
492 | max_forgetting_train,
493 | train_acc_table[j][k] - train_acc_table[j][task_id],
494 | )
495 | max_forgetting_test = max(
496 | max_forgetting_test, test_acc_table[j][k] - test_acc_table[j][task_id]
497 | )
498 | avg_forgetting_train += max_forgetting_train
499 | avg_forgetting_test += max_forgetting_test
500 |
501 | log_dict.update(
502 | {
503 | "test/avg_accuracy": avg_acc_test / (task_id + 1),
504 | "train/avg_accuracy": avg_acc_train / (task_id + 1),
505 | "train/avg_forgetting": (
506 | (avg_forgetting_train / task_id) if task_id > 0 else 0.0
507 | ),
508 | "test/avg_forgetting": (
509 | (avg_forgetting_test / task_id) if task_id > 0 else 0.0
510 | ),
511 | }
512 | )
513 |
514 | wandb_safe_log(log_dict, step=step)
515 |
516 | # return PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
517 |
--------------------------------------------------------------------------------
/guide/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | import datetime
7 | import json
8 | import os
9 | import os.path as osp
10 | import sys
11 | import tempfile
12 | import time
13 | import warnings
14 | from collections import defaultdict
15 | from contextlib import contextmanager
16 |
17 | import numpy as np
18 | import torch as th
19 | from matplotlib import pyplot as plt
20 | from torchvision.utils import make_grid
21 |
22 | import wandb
23 |
24 | DEBUG = 10
25 | INFO = 20
26 | WARN = 30
27 | ERROR = 40
28 |
29 | DISABLED = 50
30 |
31 |
32 | class KVWriter(object):
33 | def writekvs(self, kvs):
34 | raise NotImplementedError
35 |
36 |
37 | class SeqWriter(object):
38 | def writeseq(self, seq):
39 | raise NotImplementedError
40 |
41 |
42 | class HumanOutputFormat(KVWriter, SeqWriter):
43 | def __init__(self, filename_or_file):
44 | if isinstance(filename_or_file, str):
45 | self.file = open(filename_or_file, "wt")
46 | self.own_file = True
47 | else:
48 | assert hasattr(filename_or_file, "read"), (
49 | "expected file or str, got %s" % filename_or_file
50 | )
51 | self.file = filename_or_file
52 | self.own_file = False
53 |
54 | def writekvs(self, kvs):
55 | # Create strings for printing
56 | key2str = {}
57 | for key, val in sorted(kvs.items()):
58 | if hasattr(val, "__float__"):
59 | valstr = "%-8.3g" % val
60 | else:
61 | valstr = str(val)
62 | key2str[self._truncate(key)] = self._truncate(valstr)
63 |
64 | # Find max widths
65 | if len(key2str) == 0:
66 | print("WARNING: tried to write empty key-value dict")
67 | return
68 | else:
69 | keywidth = max(map(len, key2str.keys()))
70 | valwidth = max(map(len, key2str.values()))
71 |
72 | # Write out the data
73 | dashes = "-" * (keywidth + valwidth + 7)
74 | lines = [dashes]
75 | for key, val in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
76 | lines.append(
77 | "| %s%s | %s%s |"
78 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
79 | )
80 | lines.append(dashes)
81 | self.file.write("\n".join(lines) + "\n")
82 |
83 | # Flush the output to the file
84 | self.file.flush()
85 |
86 | def _truncate(self, s):
87 | maxlen = 30
88 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
89 |
90 | def writeseq(self, seq):
91 | seq = list(seq)
92 | for i, elem in enumerate(seq):
93 | self.file.write(elem)
94 | if i < len(seq) - 1: # add space unless this is the last one
95 | self.file.write(" ")
96 | self.file.write("\n")
97 | self.file.flush()
98 |
99 | def close(self):
100 | if self.own_file:
101 | self.file.close()
102 |
103 |
104 | class JSONOutputFormat(KVWriter):
105 | def __init__(self, filename):
106 | self.file = open(filename, "wt")
107 |
108 | def writekvs(self, kvs):
109 | for k, v in sorted(kvs.items()):
110 | if hasattr(v, "dtype"):
111 | kvs[k] = float(v)
112 | self.file.write(json.dumps(kvs) + "\n")
113 | self.file.flush()
114 |
115 | def close(self):
116 | self.file.close()
117 |
118 |
119 | class CSVOutputFormat(KVWriter):
120 | def __init__(self, filename):
121 | self.file = open(filename, "w+t")
122 | self.keys = []
123 | self.sep = ","
124 |
125 | def writekvs(self, kvs):
126 | # Add our current row to the history
127 | extra_keys = list(kvs.keys() - self.keys)
128 | extra_keys.sort()
129 | if extra_keys:
130 | self.keys.extend(extra_keys)
131 | self.file.seek(0)
132 | lines = self.file.readlines()
133 | self.file.seek(0)
134 | for i, k in enumerate(self.keys):
135 | if i > 0:
136 | self.file.write(",")
137 | self.file.write(k)
138 | self.file.write("\n")
139 | for line in lines[1:]:
140 | self.file.write(line[:-1])
141 | self.file.write(self.sep * len(extra_keys))
142 | self.file.write("\n")
143 | for i, k in enumerate(self.keys):
144 | if i > 0:
145 | self.file.write(",")
146 | v = kvs.get(k)
147 | if v is not None:
148 | self.file.write(str(v))
149 | self.file.write("\n")
150 | self.file.flush()
151 |
152 | def close(self):
153 | self.file.close()
154 |
155 |
156 | class TensorBoardOutputFormat(KVWriter):
157 | """
158 | Dumps key/value pairs into TensorBoard's numeric format.
159 | """
160 |
161 | def __init__(self, dir):
162 | os.makedirs(dir, exist_ok=True)
163 | self.dir = dir
164 | self.step = 1
165 | prefix = "events"
166 | path = osp.join(osp.abspath(dir), prefix)
167 | import tensorflow as tf
168 | from tensorflow.core.util import event_pb2
169 | from tensorflow.python import pywrap_tensorflow
170 | from tensorflow.python.util import compat
171 |
172 | self.tf = tf
173 | self.event_pb2 = event_pb2
174 | self.pywrap_tensorflow = pywrap_tensorflow
175 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
176 |
177 | def writekvs(self, kvs):
178 | def summary_val(k, v):
179 | kwargs = {"tag": k, "simple_value": float(v)}
180 | return self.tf.Summary.Value(**kwargs)
181 |
182 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
183 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
184 | event.step = (
185 | self.step
186 | ) # is there any reason why you'd want to specify the step?
187 | self.writer.WriteEvent(event)
188 | self.writer.Flush()
189 | self.step += 1
190 |
191 | def close(self):
192 | if self.writer:
193 | self.writer.Close()
194 | self.writer = None
195 |
196 |
197 | def make_output_format(format, ev_dir, log_suffix=""):
198 | os.makedirs(ev_dir, exist_ok=True)
199 | if format == "stdout":
200 | return HumanOutputFormat(sys.stdout)
201 | elif format == "log":
202 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
203 | elif format == "json":
204 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
205 | elif format == "csv":
206 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
207 | elif format == "tensorboard":
208 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
209 | else:
210 | raise ValueError("Unknown format specified: %s" % (format,))
211 |
212 |
213 | # ================================================================
214 | # API
215 | # ================================================================
216 |
217 |
218 | def logkv(key, val):
219 | """
220 | Log a value of some diagnostic
221 | Call this once for each diagnostic quantity, each iteration
222 | If called many times, last value will be used.
223 | """
224 | get_current().logkv(key, val)
225 |
226 |
227 | def logkv_mean(key, val):
228 | """
229 | The same as logkv(), but if called many times, values averaged.
230 | """
231 | get_current().logkv_mean(key, val)
232 |
233 |
234 | def logkvs(d):
235 | """
236 | Log a dictionary of key-value pairs
237 | """
238 | for k, v in d.items():
239 | logkv(k, v)
240 |
241 |
242 | def dumpkvs():
243 | """
244 | Write all of the diagnostics from the current iteration
245 | """
246 | return get_current().dumpkvs()
247 |
248 |
249 | def getkvs():
250 | return get_current().name2val
251 |
252 |
253 | def log(*args, level=INFO):
254 | """
255 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
256 | """
257 | get_current().log(*args, level=level)
258 |
259 |
260 | def debug(*args):
261 | log(*args, level=DEBUG)
262 |
263 |
264 | def info(*args):
265 | log(*args, level=INFO)
266 |
267 |
268 | def warn(*args):
269 | log(*args, level=WARN)
270 |
271 |
272 | def error(*args):
273 | log(*args, level=ERROR)
274 |
275 |
276 | def set_level(level):
277 | """
278 | Set logging threshold on current logger.
279 | """
280 | get_current().set_level(level)
281 |
282 |
283 | def set_comm(comm):
284 | get_current().set_comm(comm)
285 |
286 |
287 | def get_dir():
288 | """
289 | Get directory that log files are being written to.
290 | will be None if there is no output directory (i.e., if you didn't call start)
291 | """
292 | return get_current().get_dir()
293 |
294 |
295 | record_tabular = logkv
296 | dump_tabular = dumpkvs
297 |
298 |
299 | @contextmanager
300 | def profile_kv(scopename):
301 | logkey = "wait_" + scopename
302 | tstart = time.time()
303 | try:
304 | yield
305 | finally:
306 | get_current().name2val[logkey] += time.time() - tstart
307 |
308 |
309 | def profile(n):
310 | """
311 | Usage:
312 | @profile("my_func")
313 | def my_func(): code
314 | """
315 |
316 | def decorator_with_name(func):
317 | def func_wrapper(*args, **kwargs):
318 | with profile_kv(n):
319 | return func(*args, **kwargs)
320 |
321 | return func_wrapper
322 |
323 | return decorator_with_name
324 |
325 |
326 | # ================================================================
327 | # Backend
328 | # ================================================================
329 |
330 |
331 | def get_current():
332 | if Logger.CURRENT is None:
333 | _configure_default_logger()
334 |
335 | return Logger.CURRENT
336 |
337 |
338 | class Logger(object):
339 | DEFAULT = None # A logger with no output files. (See right below class definition)
340 | # So that you can still log to the terminal without setting up any output files
341 | CURRENT = None # Current logger being used by the free functions above
342 |
343 | def __init__(self, dir, output_formats, comm=None):
344 | self.name2val = defaultdict(float) # values this iteration
345 | self.name2cnt = defaultdict(int)
346 | self.level = INFO
347 | self.dir = dir
348 | self.output_formats = output_formats
349 | self.comm = comm
350 |
351 | # Logging API, forwarded
352 | # ----------------------------------------
353 | def logkv(self, key, val):
354 | self.name2val[key] = val
355 |
356 | def logkv_mean(self, key, val):
357 | oldval, cnt = self.name2val[key], self.name2cnt[key]
358 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
359 | self.name2cnt[key] = cnt + 1
360 |
361 | def dumpkvs(self):
362 | if self.comm is None:
363 | d = self.name2val
364 | else:
365 | d = mpi_weighted_mean(
366 | self.comm,
367 | {
368 | name: (val, self.name2cnt.get(name, 1))
369 | for (name, val) in self.name2val.items()
370 | },
371 | )
372 | if self.comm.rank != 0:
373 | d["dummy"] = 1 # so we don't get a warning about empty dict
374 | out = d.copy() # Return the dict for unit testing purposes
375 | for fmt in self.output_formats:
376 | if isinstance(fmt, KVWriter):
377 | fmt.writekvs(d)
378 | self.name2val.clear()
379 | self.name2cnt.clear()
380 | return out
381 |
382 | def log(self, *args, level=INFO):
383 | if self.level <= level:
384 | self._do_log(args)
385 |
386 | # Configuration
387 | # ----------------------------------------
388 | def set_level(self, level):
389 | self.level = level
390 |
391 | def set_comm(self, comm):
392 | self.comm = comm
393 |
394 | def get_dir(self):
395 | return self.dir
396 |
397 | def close(self):
398 | for fmt in self.output_formats:
399 | fmt.close()
400 |
401 | # Misc
402 | # ----------------------------------------
403 | def _do_log(self, args):
404 | for fmt in self.output_formats:
405 | if isinstance(fmt, SeqWriter):
406 | fmt.writeseq(map(str, args))
407 |
408 |
409 | def get_rank_without_mpi_import():
410 | # check environment variables here instead of importing mpi4py
411 | # to avoid calling MPI_Init() when this module is imported
412 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
413 | if varname in os.environ:
414 | return int(os.environ[varname])
415 | return 0
416 |
417 |
418 | def mpi_weighted_mean(comm, local_name2valcount):
419 | """
420 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
421 | Perform a weighted average over dicts that are each on a different node
422 | Input: local_name2valcount: dict mapping key -> (value, count)
423 | Returns: key -> mean
424 | """
425 | all_name2valcount = comm.gather(local_name2valcount)
426 | if comm.rank == 0:
427 | name2sum = defaultdict(float)
428 | name2count = defaultdict(float)
429 | for n2vc in all_name2valcount:
430 | for name, (val, count) in n2vc.items():
431 | try:
432 | val = float(val)
433 | except ValueError:
434 | if comm.rank == 0:
435 | warnings.warn(
436 | "WARNING: tried to compute mean on non-float {}={}".format(
437 | name, val
438 | )
439 | )
440 | else:
441 | name2sum[name] += val * count
442 | name2count[name] += count
443 | return {name: name2sum[name] / name2count[name] for name in name2sum}
444 | else:
445 | return {}
446 |
447 |
448 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
449 | """
450 | If comm is provided, average all numerical stats across that comm
451 | """
452 | if dir is None:
453 | dir = os.getenv("OPENAI_LOGDIR")
454 | if dir is None:
455 | dir = osp.join(
456 | tempfile.gettempdir(),
457 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
458 | )
459 | assert isinstance(dir, str)
460 | dir = os.path.expanduser(dir)
461 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
462 |
463 | rank = get_rank_without_mpi_import()
464 | if rank > 0:
465 | log_suffix = log_suffix + "-rank%03i" % rank
466 |
467 | if format_strs is None:
468 | if rank == 0:
469 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
470 | else:
471 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
472 | format_strs = filter(None, format_strs)
473 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
474 |
475 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
476 | if output_formats:
477 | log("Logging to %s" % dir)
478 |
479 |
480 | def _configure_default_logger():
481 | configure()
482 | Logger.DEFAULT = Logger.CURRENT
483 |
484 |
485 | def reset():
486 | if Logger.CURRENT is not Logger.DEFAULT:
487 | Logger.CURRENT.close()
488 | Logger.CURRENT = Logger.DEFAULT
489 | log("Reset logger")
490 |
491 |
492 | @contextmanager
493 | def scoped_configure(dir=None, format_strs=None, comm=None):
494 | prevlogger = Logger.CURRENT
495 | configure(dir=dir, format_strs=format_strs, comm=comm)
496 | try:
497 | yield
498 | finally:
499 | Logger.CURRENT.close()
500 | Logger.CURRENT = prevlogger
501 |
502 |
503 | def wandb_safe_log(*args, **kwargs):
504 | # Try several times.
505 | for _ in range(10):
506 | try:
507 | wandb.log(*args, **kwargs)
508 | except:
509 | time.sleep(20)
510 | else:
511 | break
512 |
513 |
514 | def log_generated_examples(
515 | examples, labels, confidences, task_id, step, n_examples_to_log=4
516 | ):
517 | fig, ax = plt.subplots()
518 | _, _, bars = ax.hist(
519 | [c.item() for c in labels],
520 | bins=np.arange(labels.max().item() + 2) - 0.5,
521 | edgecolor="black",
522 | )
523 | ax.set_xticks(range(int(labels.max().item()) + 1))
524 | ax.set_xlabel("Generated class")
525 | ax.set_ylabel("Count")
526 | ax.bar_label(bars)
527 | ax.set_title(f"generated_classes/{task_id}")
528 | wandb.log({f"generated_classes/{task_id}": wandb.Image(fig)}, step=step)
529 |
530 | if confidences is not None:
531 | fig, ax = plt.subplots()
532 | _, _, bars = ax.hist(
533 | [c.item() for c in confidences],
534 | bins=np.linspace(0, 1, 6),
535 | edgecolor="black",
536 | )
537 | ax.set_xlabel("Confidence")
538 | ax.set_ylabel("Count")
539 | ax.bar_label(bars)
540 | ax.set_title(f"generated_examples_confidences/{task_id}")
541 | wandb.log(
542 | {f"generated_examples_confidences/{task_id}": wandb.Image(fig)}, step=step
543 | )
544 |
545 | # Additionaly plot the generated examples
546 | unique_generated_classes = labels.unique()
547 | min_n_of_samples_per_class = n_examples_to_log
548 | images_to_log = []
549 | for class_label in unique_generated_classes:
550 | class_indices = (labels == class_label).nonzero(as_tuple=True)[0]
551 | selected_indices = th.randperm(len(class_indices))[:min_n_of_samples_per_class]
552 | if len(class_indices) < min_n_of_samples_per_class:
553 | min_n_of_samples_per_class = len(class_indices)
554 | selected_images = examples[class_indices[selected_indices]].cpu()
555 | images_to_log.append(selected_images)
556 | if min_n_of_samples_per_class > 0:
557 | # assure the equal number of samples per class
558 | images_to_log = [
559 | samples[:min_n_of_samples_per_class] for samples in images_to_log
560 | ]
561 | samples_grid = make_grid(
562 | th.cat(images_to_log, dim=0),
563 | nrow=min_n_of_samples_per_class,
564 | normalize=True,
565 | )
566 | wandb.log(
567 | {f"generated_examples/{task_id}": wandb.Image(samples_grid)}, step=step
568 | )
569 |
--------------------------------------------------------------------------------
/dataloaders/base.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import kornia as K
4 | import torch
5 | import torchvision
6 | from torch.utils.data import DataLoader, Dataset
7 | from torchvision import transforms
8 |
9 | from .wrapper import CacheClassLabel
10 |
11 |
12 | class FastDataset(Dataset):
13 | def __init__(self, data, labels):
14 | self.dataset = data
15 | self.labels = labels
16 |
17 | def __len__(self):
18 | return len(self.dataset)
19 |
20 | def __getitem__(self, index):
21 | return self.dataset[index], self.labels[index]
22 |
23 |
24 | def CIFAR10(
25 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False
26 | ):
27 | train_transform_clf = None
28 | train_transform_diff = None
29 | # augmentation for diffusion training
30 | if train_aug:
31 | train_transform_diff = K.augmentation.ImageSequential(
32 | K.augmentation.RandomHorizontalFlip(),
33 | )
34 |
35 | # augmentation for classifier training
36 | if classifier_augmentation:
37 | train_transform_clf = K.augmentation.ImageSequential(
38 | K.augmentation.Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
39 | K.augmentation.RandomCrop((32, 32), padding=4),
40 | K.augmentation.RandomRotation(30),
41 | K.augmentation.RandomHorizontalFlip(),
42 | K.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1),
43 | K.augmentation.RandomErasing(scale=(0.1, 0.5)),
44 | K.augmentation.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
45 | )
46 |
47 | target_transform = transforms.Lambda(lambda y: torch.eye(10)[y])
48 |
49 | train_dataset = torchvision.datasets.CIFAR10(
50 | root=dataroot,
51 | train=True,
52 | download=True,
53 | transform=transforms.Compose(
54 | [
55 | transforms.ToTensor(),
56 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
57 | ]
58 | ),
59 | target_transform=target_transform,
60 | )
61 | train_dataset = CacheClassLabel(
62 | train_dataset,
63 | target_transform=target_transform,
64 | )
65 |
66 | val_dataset = torchvision.datasets.CIFAR10(
67 | root=dataroot,
68 | train=False,
69 | download=True,
70 | transform=transforms.Compose(
71 | [
72 | transforms.ToTensor(),
73 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
74 | ]
75 | ),
76 | target_transform=target_transform,
77 | )
78 | val_dataset = CacheClassLabel(
79 | val_dataset,
80 | target_transform=target_transform,
81 | )
82 | print("Loading data")
83 | save_path = f"{dataroot}/fast_cifar10_train"
84 | if os.path.exists(save_path):
85 | fast_cifar_train = torch.load(save_path)
86 | else:
87 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
88 | data = next(iter(train_loader))
89 | fast_cifar_train = FastDataset(data[0], data[1])
90 | torch.save(fast_cifar_train, save_path)
91 |
92 | save_path = f"{dataroot}/fast_cifar10_val"
93 | if os.path.exists(save_path):
94 | fast_cifar_val = torch.load(save_path)
95 | else:
96 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
97 | data = next(iter(val_loader))
98 | fast_cifar_val = FastDataset(data[0], data[1])
99 | torch.save(fast_cifar_val, save_path)
100 |
101 | return (
102 | fast_cifar_train,
103 | fast_cifar_val,
104 | 32,
105 | 3,
106 | train_transform_clf,
107 | train_transform_diff,
108 | 10,
109 | )
110 |
111 |
112 | def CIFAR100(
113 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False
114 | ):
115 | train_transform_clf = None
116 | train_transform_diff = None
117 | # augmentation for diffusion training
118 | if train_aug:
119 | train_transform_diff = K.augmentation.ImageSequential(
120 | K.augmentation.RandomHorizontalFlip(),
121 | )
122 |
123 | # augmentation for classifier training
124 | if classifier_augmentation:
125 | train_transform_clf = K.augmentation.ImageSequential(
126 | K.augmentation.Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
127 | K.augmentation.RandomCrop((32, 32), padding=4),
128 | K.augmentation.RandomRotation(30),
129 | K.augmentation.RandomHorizontalFlip(),
130 | K.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1),
131 | K.augmentation.RandomErasing(scale=(0.1, 0.5)),
132 | K.augmentation.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
133 | )
134 |
135 | target_transform = transforms.Lambda(lambda y: torch.eye(100)[y])
136 |
137 | train_dataset = torchvision.datasets.CIFAR100(
138 | root=dataroot,
139 | train=True,
140 | download=True,
141 | transform=transforms.Compose(
142 | [
143 | transforms.ToTensor(),
144 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
145 | ]
146 | ),
147 | target_transform=target_transform,
148 | )
149 | train_dataset = CacheClassLabel(
150 | train_dataset,
151 | target_transform=target_transform,
152 | )
153 |
154 | val_dataset = torchvision.datasets.CIFAR100(
155 | root=dataroot,
156 | train=False,
157 | download=True,
158 | transform=transforms.Compose(
159 | [
160 | transforms.ToTensor(),
161 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
162 | ]
163 | ),
164 | target_transform=target_transform,
165 | )
166 | val_dataset = CacheClassLabel(
167 | val_dataset,
168 | target_transform=target_transform,
169 | )
170 |
171 | print("Loading data")
172 | save_path = f"{dataroot}/fast_cifar100_train"
173 | if os.path.exists(save_path):
174 | fast_cifar_train = torch.load(save_path)
175 | else:
176 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
177 | data = next(iter(train_loader))
178 | fast_cifar_train = FastDataset(data[0], data[1])
179 | torch.save(fast_cifar_train, save_path)
180 |
181 | save_path = f"{dataroot}/fast_cifar100_val"
182 | if os.path.exists(save_path):
183 | fast_cifar_val = torch.load(save_path)
184 | else:
185 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
186 | data = next(iter(val_loader))
187 | fast_cifar_val = FastDataset(data[0], data[1])
188 | torch.save(fast_cifar_val, save_path)
189 |
190 | return (
191 | fast_cifar_train,
192 | fast_cifar_val,
193 | 32,
194 | 3,
195 | train_transform_clf,
196 | train_transform_diff,
197 | 100,
198 | )
199 |
200 |
201 | def ImageNet100(
202 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False
203 | ):
204 | train_transform_clf = None
205 | train_transform_diff = None
206 | # augmentation for diffusion training
207 | if train_aug:
208 | train_transform_diff = K.augmentation.ImageSequential(
209 | K.augmentation.RandomHorizontalFlip(),
210 | )
211 |
212 | print("Loading data")
213 | save_path = f"{dataroot}/fast_imagenet100_train"
214 | if os.path.exists(save_path):
215 | fast_imagenet_train = torch.load(save_path)
216 | else:
217 | target_transform = transforms.Lambda(lambda y: torch.eye(100)[y])
218 |
219 | train_dataset = torchvision.datasets.ImageFolder(
220 | root=os.path.join(dataroot, "imagenet100", "train"),
221 | transform=transforms.Compose(
222 | [
223 | transforms.Resize((76, 76)),
224 | transforms.CenterCrop((64, 64)),
225 | transforms.ToTensor(),
226 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
227 | ]
228 | ),
229 | target_transform=target_transform,
230 | )
231 | train_dataset = CacheClassLabel(
232 | train_dataset,
233 | target_transform=target_transform,
234 | )
235 |
236 | val_dataset = torchvision.datasets.ImageFolder(
237 | root=os.path.join(dataroot, "imagenet100", "val"),
238 | transform=transforms.Compose(
239 | [
240 | transforms.Resize((76, 76)),
241 | transforms.CenterCrop((64, 64)),
242 | transforms.ToTensor(),
243 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
244 | ]
245 | ),
246 | target_transform=target_transform,
247 | )
248 |
249 | val_dataset = CacheClassLabel(
250 | val_dataset,
251 | target_transform=target_transform,
252 | )
253 |
254 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
255 | data = next(iter(train_loader))
256 | fast_imagenet_train = FastDataset(data[0], data[1])
257 | torch.save(fast_imagenet_train, save_path)
258 |
259 | save_path = f"{dataroot}/fast_imagenet100_val"
260 | if os.path.exists(save_path):
261 | fast_imagenet_val = torch.load(save_path)
262 | else:
263 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
264 | data = next(iter(val_loader))
265 | fast_imagenet_val = FastDataset(data[0], data[1])
266 | torch.save(fast_imagenet_val, save_path)
267 |
268 | return (
269 | fast_imagenet_train,
270 | fast_imagenet_val,
271 | 64,
272 | 3,
273 | train_transform_clf,
274 | train_transform_diff,
275 | 100,
276 | )
277 |
278 |
279 | def ImageNet100128(
280 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False
281 | ):
282 | train_transform_clf = None
283 | train_transform_diff = None
284 | # augmentation for diffusion training
285 | if train_aug:
286 | train_transform_diff = K.augmentation.ImageSequential(
287 | K.augmentation.RandomHorizontalFlip(),
288 | )
289 |
290 | print("Loading data")
291 | save_path = f"{dataroot}/fast_imagenet100128_train"
292 | if os.path.exists(save_path):
293 | fast_imagenet_train = torch.load(save_path)
294 | else:
295 | target_transform = transforms.Lambda(lambda y: torch.eye(100)[y])
296 |
297 | train_dataset = torchvision.datasets.ImageFolder(
298 | root=os.path.join(dataroot, "imagenet100", "train"),
299 | transform=transforms.Compose(
300 | [
301 | transforms.Resize((152, 152)),
302 | transforms.CenterCrop((128, 128)),
303 | transforms.ToTensor(),
304 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
305 | ]
306 | ),
307 | target_transform=target_transform,
308 | )
309 | train_dataset = CacheClassLabel(
310 | train_dataset,
311 | target_transform=target_transform,
312 | )
313 |
314 | val_dataset = torchvision.datasets.ImageFolder(
315 | root=os.path.join(dataroot, "imagenet100", "val"),
316 | transform=transforms.Compose(
317 | [
318 | transforms.Resize((152, 152)),
319 | transforms.CenterCrop((128, 128)),
320 | transforms.ToTensor(),
321 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
322 | ]
323 | ),
324 | target_transform=target_transform,
325 | )
326 |
327 | val_dataset = CacheClassLabel(
328 | val_dataset,
329 | target_transform=target_transform,
330 | )
331 |
332 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
333 | data = next(iter(train_loader))
334 | fast_imagenet_train = FastDataset(data[0], data[1])
335 | torch.save(fast_imagenet_train, save_path)
336 |
337 | save_path = f"{dataroot}/fast_imagenet100128_val"
338 | if os.path.exists(save_path):
339 | fast_imagenet_val = torch.load(save_path)
340 | else:
341 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
342 | data = next(iter(val_loader))
343 | fast_imagenet_val = FastDataset(data[0], data[1])
344 | torch.save(fast_imagenet_val, save_path)
345 |
346 | return (
347 | fast_imagenet_train,
348 | fast_imagenet_val,
349 | 128,
350 | 3,
351 | train_transform_clf,
352 | train_transform_diff,
353 | 100,
354 | )
355 |
356 | def ImageNet100224(
357 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False
358 | ):
359 | train_transform_clf = None
360 | train_transform_diff = None
361 | # augmentation for diffusion training
362 | if train_aug:
363 | train_transform_diff = K.augmentation.ImageSequential(
364 | K.augmentation.RandomHorizontalFlip(),
365 | )
366 |
367 | print("Loading data")
368 | save_path = f"{dataroot}/fast_imagenet100224_train"
369 | if os.path.exists(save_path):
370 | fast_imagenet_train = torch.load(save_path)
371 | else:
372 | target_transform = transforms.Lambda(lambda y: torch.eye(100)[y])
373 |
374 | train_dataset = torchvision.datasets.ImageFolder(
375 | root=os.path.join(dataroot, "imagenet100", "train"),
376 | transform=transforms.Compose(
377 | [
378 | transforms.Resize((266, 266)),
379 | transforms.CenterCrop((224, 224)),
380 | transforms.ToTensor(),
381 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
382 | ]
383 | ),
384 | target_transform=target_transform,
385 | )
386 | train_dataset = CacheClassLabel(
387 | train_dataset,
388 | target_transform=target_transform,
389 | )
390 |
391 | val_dataset = torchvision.datasets.ImageFolder(
392 | root=os.path.join(dataroot, "imagenet100", "val"),
393 | transform=transforms.Compose(
394 | [
395 | transforms.Resize((266, 266)),
396 | transforms.CenterCrop((224, 224)),
397 | transforms.ToTensor(),
398 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
399 | ]
400 | ),
401 | target_transform=target_transform,
402 | )
403 |
404 | val_dataset = CacheClassLabel(
405 | val_dataset,
406 | target_transform=target_transform,
407 | )
408 |
409 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
410 | data = next(iter(train_loader))
411 | fast_imagenet_train = FastDataset(data[0], data[1])
412 | torch.save(fast_imagenet_train, save_path)
413 |
414 | save_path = f"{dataroot}/fast_imagenet100224_val"
415 | if os.path.exists(save_path):
416 | fast_imagenet_val = torch.load(save_path)
417 | else:
418 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
419 | data = next(iter(val_loader))
420 | fast_imagenet_val = FastDataset(data[0], data[1])
421 | torch.save(fast_imagenet_val, save_path)
422 |
423 | return (
424 | fast_imagenet_train,
425 | fast_imagenet_val,
426 | 224,
427 | 3,
428 | train_transform_clf,
429 | train_transform_diff,
430 | 100,
431 | )
432 |
433 | def Flowers102(
434 | dataroot, skip_normalization=False, train_aug=False, classifier_augmentation=False
435 | ):
436 | train_transform_clf = None
437 | train_transform_diff = None
438 | # augmentation for diffusion training
439 | if train_aug:
440 | train_transform_diff = K.augmentation.ImageSequential(
441 | K.augmentation.RandomHorizontalFlip(),
442 | )
443 |
444 | # augmentation for classifier training
445 | if classifier_augmentation:
446 | train_transform_clf = K.augmentation.ImageSequential(
447 | K.augmentation.Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
448 | K.augmentation.RandomRotation(30),
449 | K.augmentation.RandomHorizontalFlip(),
450 | K.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1),
451 | K.augmentation.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
452 | )
453 |
454 | target_transform = transforms.Lambda(lambda y: torch.eye(102)[y])
455 |
456 | train_dataset1 = torchvision.datasets.Flowers102(
457 | root=dataroot,
458 | split="train",
459 | download=True,
460 | transform=transforms.Compose(
461 | [
462 | transforms.Resize((224, 224)),
463 | transforms.ToTensor(),
464 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
465 | ]
466 | ),
467 | target_transform=target_transform,
468 | )
469 | train_dataset2 = torchvision.datasets.Flowers102(
470 | root=dataroot,
471 | split="val",
472 | download=True,
473 | transform=transforms.Compose(
474 | [
475 | transforms.Resize((224, 224)),
476 | transforms.ToTensor(),
477 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
478 | ]
479 | ),
480 | target_transform=target_transform,
481 | )
482 | train_dataset = torch.utils.data.ConcatDataset([train_dataset1, train_dataset2])
483 | train_dataset.root = dataroot
484 | train_dataset = CacheClassLabel(
485 | train_dataset,
486 | target_transform=target_transform,
487 | )
488 | val_dataset = torchvision.datasets.Flowers102(
489 | root=dataroot,
490 | split="test",
491 | download=True,
492 | transform=transforms.Compose(
493 | [
494 | transforms.Resize((224, 224)),
495 | transforms.ToTensor(),
496 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
497 | ]
498 | ),
499 | target_transform=target_transform,
500 | )
501 | val_dataset = CacheClassLabel(
502 | val_dataset,
503 | target_transform=target_transform,
504 | )
505 |
506 | print("Loading data")
507 | save_path = f"{dataroot}/fast_flowers_train"
508 | if os.path.exists(save_path):
509 | fast_flowers_train = torch.load(save_path)
510 | else:
511 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
512 | data = next(iter(train_loader))
513 | fast_flowers_train = FastDataset(data[0], data[1])
514 | torch.save(fast_flowers_train, save_path)
515 |
516 | save_path = f"{dataroot}/fast_flowers_val"
517 | if os.path.exists(save_path):
518 | fast_flowers_val = torch.load(save_path)
519 | else:
520 | val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
521 | data = next(iter(val_loader))
522 | fast_flowers_val = FastDataset(data[0], data[1])
523 | torch.save(fast_flowers_val, save_path)
524 |
525 | return (
526 | fast_flowers_train,
527 | fast_flowers_val,
528 | 224,
529 | 3,
530 | train_transform_clf,
531 | train_transform_diff,
532 | 102,
533 | )
534 |
--------------------------------------------------------------------------------
/guide/evaluator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import io
3 | import os
4 | import random
5 | import warnings
6 | import zipfile
7 | from abc import ABC, abstractmethod
8 | from contextlib import contextmanager
9 | from functools import partial
10 | from multiprocessing import cpu_count
11 | from multiprocessing.pool import ThreadPool
12 | from typing import Iterable, Optional, Tuple
13 |
14 | import numpy as np
15 | import requests
16 | import tensorflow.compat.v1 as tf
17 | from scipy import linalg
18 | from tqdm.auto import tqdm
19 |
20 | INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
21 | INCEPTION_V3_PATH = "classify_image_graph_def.pb"
22 |
23 | FID_POOL_NAME = "pool_3:0"
24 | FID_SPATIAL_NAME = "mixed_6/conv:0"
25 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
26 |
27 |
28 | def main():
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument(
31 | "--ref_batch", help="path to reference batch npz acts file", required=False
32 | )
33 | parser.add_argument(
34 | "--ref_acts", help="path to reference batch npz acts file", required=False
35 | )
36 | parser.add_argument("--sample_batch", help="path to sample batch npz file")
37 | args = parser.parse_args()
38 |
39 | config = tf.ConfigProto(
40 | allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
41 | )
42 | config.gpu_options.allow_growth = True
43 | evaluator = Evaluator(tf.Session(config=config))
44 |
45 | print("warming up TensorFlow...")
46 | # This will cause TF to print a bunch of verbose stuff now rather
47 | # than after the next print(), to help prevent confusion.
48 | evaluator.warmup()
49 |
50 | print("computing reference batch activations...")
51 | if args.ref_acts is None:
52 | ref_acts = evaluator.read_activations(args.ref_batch)
53 | # ref_acts = evaluator.read_activations(args.ref_batch)
54 | else:
55 | ref_acts = (
56 | np.load(args.ref_acts)["arr_0"],
57 | np.load(args.ref_acts)["arr_1"],
58 | )
59 |
60 | print("computing/reading reference batch statistics...")
61 | ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
62 |
63 | print("computing sample batch activations...")
64 | sample_acts = evaluator.read_activations(args.sample_batch)
65 | print("computing/reading sample batch statistics...")
66 | sample_stats, sample_stats_spatial = evaluator.read_statistics(
67 | args.sample_batch, sample_acts
68 | )
69 |
70 | print("Computing evaluations...")
71 | print("Inception Score:", evaluator.compute_inception_score(sample_acts[0]))
72 | print("FID:", sample_stats.frechet_distance(ref_stats))
73 | print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial))
74 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
75 | print("Precision:", prec)
76 | print("Recall:", recall)
77 |
78 |
79 | class InvalidFIDException(Exception):
80 | pass
81 |
82 |
83 | class FIDStatistics:
84 | def __init__(self, mu: np.ndarray, sigma: np.ndarray):
85 | self.mu = mu
86 | self.sigma = sigma
87 |
88 | def frechet_distance(self, other, eps=1e-6):
89 | """
90 | Compute the Frechet distance between two sets of statistics.
91 | """
92 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
93 | mu1, sigma1 = self.mu, self.sigma
94 | mu2, sigma2 = other.mu, other.sigma
95 |
96 | mu1 = np.atleast_1d(mu1)
97 | mu2 = np.atleast_1d(mu2)
98 |
99 | sigma1 = np.atleast_2d(sigma1)
100 | sigma2 = np.atleast_2d(sigma2)
101 |
102 | assert (
103 | mu1.shape == mu2.shape
104 | ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
105 | assert (
106 | sigma1.shape == sigma2.shape
107 | ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
108 |
109 | diff = mu1 - mu2
110 |
111 | # product might be almost singular
112 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
113 | if not np.isfinite(covmean).all():
114 | msg = (
115 | "fid calculation produces singular product; adding %s to diagonal of cov estimates"
116 | % eps
117 | )
118 | warnings.warn(msg)
119 | offset = np.eye(sigma1.shape[0]) * eps
120 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
121 |
122 | # numerical error might give slight imaginary component
123 | if np.iscomplexobj(covmean):
124 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
125 | m = np.max(np.abs(covmean.imag))
126 | raise ValueError("Imaginary component {}".format(m))
127 | covmean = covmean.real
128 |
129 | tr_covmean = np.trace(covmean)
130 |
131 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
132 |
133 |
134 | class Evaluator:
135 | def __init__(
136 | self,
137 | session,
138 | batch_size=64,
139 | softmax_batch_size=512,
140 | ):
141 | self.sess = session
142 | self.batch_size = batch_size
143 | self.softmax_batch_size = softmax_batch_size
144 | self.manifold_estimator = ManifoldEstimator(session)
145 | with self.sess.graph.as_default():
146 | self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
147 | self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
148 | self.pool_features, self.spatial_features = _create_feature_graph(
149 | self.image_input
150 | )
151 | self.softmax = _create_softmax_graph(self.softmax_input)
152 |
153 | def warmup(self):
154 | self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
155 |
156 | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
157 | with open_npz_array(npz_path, "arr_0") as reader:
158 | return self.compute_activations(reader.read_batches(self.batch_size))
159 |
160 | def compute_activations(
161 | self, batches: Iterable[np.ndarray]
162 | ) -> Tuple[np.ndarray, np.ndarray]:
163 | """
164 | Compute image features for downstream evals.
165 |
166 | :param batches: a iterator over NHWC numpy arrays in [0, 255].
167 | :return: a tuple of numpy arrays of shape [N x X], where X is a feature
168 | dimension. The tuple is (pool_3, spatial).
169 | """
170 | preds = []
171 | spatial_preds = []
172 | for batch in tqdm(batches):
173 | batch = batch.astype(np.float32)
174 | pred, spatial_pred = self.sess.run(
175 | [self.pool_features, self.spatial_features], {self.image_input: batch}
176 | )
177 | preds.append(pred.reshape([pred.shape[0], -1]))
178 | spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
179 | return (
180 | np.concatenate(preds, axis=0),
181 | np.concatenate(spatial_preds, axis=0),
182 | )
183 |
184 | def read_statistics(
185 | self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
186 | ) -> Tuple[FIDStatistics, FIDStatistics]:
187 | obj = np.load(npz_path)
188 | if "mu" in list(obj.keys()):
189 | return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
190 | obj["mu_s"], obj["sigma_s"]
191 | )
192 | return tuple(self.compute_statistics(x) for x in activations)
193 |
194 | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
195 | mu = np.mean(activations, axis=0)
196 | sigma = np.cov(activations, rowvar=False)
197 | return FIDStatistics(mu, sigma)
198 |
199 | def compute_inception_score(
200 | self, activations: np.ndarray, split_size: int = 5000
201 | ) -> float:
202 | softmax_out = []
203 | for i in range(0, len(activations), self.softmax_batch_size):
204 | acts = activations[i : i + self.softmax_batch_size]
205 | softmax_out.append(
206 | self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})
207 | )
208 | preds = np.concatenate(softmax_out, axis=0)
209 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
210 | scores = []
211 | for i in range(0, len(preds), split_size):
212 | part = preds[i : i + split_size]
213 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
214 | kl = np.mean(np.sum(kl, 1))
215 | scores.append(np.exp(kl))
216 | return float(np.mean(scores))
217 |
218 | def compute_prec_recall(
219 | self, activations_ref: np.ndarray, activations_sample: np.ndarray
220 | ) -> Tuple[float, float]:
221 | radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
222 | radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
223 | pr = self.manifold_estimator.evaluate_pr(
224 | activations_ref, radii_1, activations_sample, radii_2
225 | )
226 | return (float(pr[0][0]), float(pr[1][0]))
227 |
228 |
229 | class ManifoldEstimator:
230 | """
231 | A helper for comparing manifolds of feature vectors.
232 |
233 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
234 | """
235 |
236 | def __init__(
237 | self,
238 | session,
239 | row_batch_size=10000,
240 | col_batch_size=10000,
241 | nhood_sizes=(3,),
242 | clamp_to_percentile=None,
243 | eps=1e-5,
244 | ):
245 | """
246 | Estimate the manifold of given feature vectors.
247 |
248 | :param session: the TensorFlow session.
249 | :param row_batch_size: row batch size to compute pairwise distances
250 | (parameter to trade-off between memory usage and performance).
251 | :param col_batch_size: column batch size to compute pairwise distances.
252 | :param nhood_sizes: number of neighbors used to estimate the manifold.
253 | :param clamp_to_percentile: prune hyperspheres that have radius larger than
254 | the given percentile.
255 | :param eps: small number for numerical stability.
256 | """
257 | self.distance_block = DistanceBlock(session)
258 | self.row_batch_size = row_batch_size
259 | self.col_batch_size = col_batch_size
260 | self.nhood_sizes = nhood_sizes
261 | self.num_nhoods = len(nhood_sizes)
262 | self.clamp_to_percentile = clamp_to_percentile
263 | self.eps = eps
264 |
265 | def warmup(self):
266 | feats, radii = (
267 | np.zeros([1, 2048], dtype=np.float32),
268 | np.zeros([1, 1], dtype=np.float32),
269 | )
270 | self.evaluate_pr(feats, radii, feats, radii)
271 |
272 | def manifold_radii(self, features: np.ndarray) -> np.ndarray:
273 | num_images = len(features)
274 |
275 | # Estimate manifold of features by calculating distances to k-NN of each sample.
276 | radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
277 | distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
278 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
279 |
280 | for begin1 in range(0, num_images, self.row_batch_size):
281 | end1 = min(begin1 + self.row_batch_size, num_images)
282 | row_batch = features[begin1:end1]
283 |
284 | for begin2 in range(0, num_images, self.col_batch_size):
285 | end2 = min(begin2 + self.col_batch_size, num_images)
286 | col_batch = features[begin2:end2]
287 |
288 | # Compute distances between batches.
289 | distance_batch[0 : end1 - begin1, begin2:end2] = (
290 | self.distance_block.pairwise_distances(row_batch, col_batch)
291 | )
292 |
293 | # Find the k-nearest neighbor from the current batch.
294 | radii[begin1:end1, :] = np.concatenate(
295 | [
296 | x[:, self.nhood_sizes]
297 | for x in _numpy_partition(
298 | distance_batch[0 : end1 - begin1, :], seq, axis=1
299 | )
300 | ],
301 | axis=0,
302 | )
303 |
304 | if self.clamp_to_percentile is not None:
305 | max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
306 | radii[radii > max_distances] = 0
307 | return radii
308 |
309 | def evaluate(
310 | self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray
311 | ):
312 | """
313 | Evaluate if new feature vectors are at the manifold.
314 | """
315 | num_eval_images = eval_features.shape[0]
316 | num_ref_images = radii.shape[0]
317 | distance_batch = np.zeros(
318 | [self.row_batch_size, num_ref_images], dtype=np.float32
319 | )
320 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
321 | max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
322 | nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
323 |
324 | for begin1 in range(0, num_eval_images, self.row_batch_size):
325 | end1 = min(begin1 + self.row_batch_size, num_eval_images)
326 | feature_batch = eval_features[begin1:end1]
327 |
328 | for begin2 in range(0, num_ref_images, self.col_batch_size):
329 | end2 = min(begin2 + self.col_batch_size, num_ref_images)
330 | ref_batch = features[begin2:end2]
331 |
332 | distance_batch[0 : end1 - begin1, begin2:end2] = (
333 | self.distance_block.pairwise_distances(feature_batch, ref_batch)
334 | )
335 |
336 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
337 | # If a feature vector is inside a hypersphere of some reference sample, then
338 | # the new sample lies at the estimated manifold.
339 | # The radii of the hyperspheres are determined from distances of neighborhood size k.
340 | samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
341 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(
342 | np.int32
343 | )
344 |
345 | max_realism_score[begin1:end1] = np.max(
346 | radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
347 | )
348 | nearest_indices[begin1:end1] = np.argmin(
349 | distance_batch[0 : end1 - begin1, :], axis=1
350 | )
351 |
352 | return {
353 | "fraction": float(np.mean(batch_predictions)),
354 | "batch_predictions": batch_predictions,
355 | "max_realisim_score": max_realism_score,
356 | "nearest_indices": nearest_indices,
357 | }
358 |
359 | def evaluate_pr(
360 | self,
361 | features_1: np.ndarray,
362 | radii_1: np.ndarray,
363 | features_2: np.ndarray,
364 | radii_2: np.ndarray,
365 | ) -> Tuple[np.ndarray, np.ndarray]:
366 | """
367 | Evaluate precision and recall efficiently.
368 |
369 | :param features_1: [N1 x D] feature vectors for reference batch.
370 | :param radii_1: [N1 x K1] radii for reference vectors.
371 | :param features_2: [N2 x D] feature vectors for the other batch.
372 | :param radii_2: [N x K2] radii for other vectors.
373 | :return: a tuple of arrays for (precision, recall):
374 | - precision: an np.ndarray of length K1
375 | - recall: an np.ndarray of length K2
376 | """
377 | features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
378 | features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
379 | for begin_1 in range(0, len(features_1), self.row_batch_size):
380 | end_1 = begin_1 + self.row_batch_size
381 | batch_1 = features_1[begin_1:end_1]
382 | for begin_2 in range(0, len(features_2), self.col_batch_size):
383 | end_2 = begin_2 + self.col_batch_size
384 | batch_2 = features_2[begin_2:end_2]
385 | batch_1_in, batch_2_in = self.distance_block.less_thans(
386 | batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
387 | )
388 | features_1_status[begin_1:end_1] |= batch_1_in
389 | features_2_status[begin_2:end_2] |= batch_2_in
390 | return (
391 | np.mean(features_2_status.astype(np.float64), axis=0),
392 | np.mean(features_1_status.astype(np.float64), axis=0),
393 | )
394 |
395 |
396 | class DistanceBlock:
397 | """
398 | Calculate pairwise distances between vectors.
399 |
400 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
401 | """
402 |
403 | def __init__(self, session):
404 | self.session = session
405 |
406 | # Initialize TF graph to calculate pairwise distances.
407 | with session.graph.as_default():
408 | self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
409 | self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
410 | distance_block_16 = _batch_pairwise_distances(
411 | tf.cast(self._features_batch1, tf.float16),
412 | tf.cast(self._features_batch2, tf.float16),
413 | )
414 | self.distance_block = tf.cond(
415 | tf.reduce_all(tf.math.is_finite(distance_block_16)),
416 | lambda: tf.cast(distance_block_16, tf.float32),
417 | lambda: _batch_pairwise_distances(
418 | self._features_batch1, self._features_batch2
419 | ),
420 | )
421 |
422 | # Extra logic for less thans.
423 | self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
424 | self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
425 | dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
426 | self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
427 | self._batch_2_in = tf.math.reduce_any(
428 | dist32 <= self._radii1[:, None], axis=0
429 | )
430 |
431 | def pairwise_distances(self, U, V):
432 | """
433 | Evaluate pairwise distances between two batches of feature vectors.
434 | """
435 | return self.session.run(
436 | self.distance_block,
437 | feed_dict={self._features_batch1: U, self._features_batch2: V},
438 | )
439 |
440 | def less_thans(self, batch_1, radii_1, batch_2, radii_2):
441 | return self.session.run(
442 | [self._batch_1_in, self._batch_2_in],
443 | feed_dict={
444 | self._features_batch1: batch_1,
445 | self._features_batch2: batch_2,
446 | self._radii1: radii_1,
447 | self._radii2: radii_2,
448 | },
449 | )
450 |
451 |
452 | def _batch_pairwise_distances(U, V):
453 | """
454 | Compute pairwise distances between two batches of feature vectors.
455 | """
456 | with tf.variable_scope("pairwise_dist_block"):
457 | # Squared norms of each row in U and V.
458 | norm_u = tf.reduce_sum(tf.square(U), 1)
459 | norm_v = tf.reduce_sum(tf.square(V), 1)
460 |
461 | # norm_u as a column and norm_v as a row vectors.
462 | norm_u = tf.reshape(norm_u, [-1, 1])
463 | norm_v = tf.reshape(norm_v, [1, -1])
464 |
465 | # Pairwise squared Euclidean distances.
466 | D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
467 |
468 | return D
469 |
470 |
471 | class NpzArrayReader(ABC):
472 | @abstractmethod
473 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
474 | pass
475 |
476 | @abstractmethod
477 | def remaining(self) -> int:
478 | pass
479 |
480 | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
481 | def gen_fn():
482 | while True:
483 | batch = self.read_batch(batch_size)
484 | if batch is None:
485 | break
486 | yield batch
487 |
488 | rem = self.remaining()
489 | num_batches = rem // batch_size + int(rem % batch_size != 0)
490 | return BatchIterator(gen_fn, num_batches)
491 |
492 |
493 | class BatchIterator:
494 | def __init__(self, gen_fn, length):
495 | self.gen_fn = gen_fn
496 | self.length = length
497 |
498 | def __len__(self):
499 | return self.length
500 |
501 | def __iter__(self):
502 | return self.gen_fn()
503 |
504 |
505 | class StreamingNpzArrayReader(NpzArrayReader):
506 | def __init__(self, arr_f, shape, dtype):
507 | self.arr_f = arr_f
508 | self.shape = shape
509 | self.dtype = dtype
510 | self.idx = 0
511 |
512 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
513 | if self.idx >= self.shape[0]:
514 | return None
515 |
516 | bs = min(batch_size, self.shape[0] - self.idx)
517 | self.idx += bs
518 |
519 | if self.dtype.itemsize == 0:
520 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
521 |
522 | read_count = bs * np.prod(self.shape[1:])
523 | read_size = int(read_count * self.dtype.itemsize)
524 | data = _read_bytes(self.arr_f, read_size, "array data")
525 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
526 |
527 | def remaining(self) -> int:
528 | return max(0, self.shape[0] - self.idx)
529 |
530 |
531 | class MemoryNpzArrayReader(NpzArrayReader):
532 | def __init__(self, arr):
533 | self.arr = arr
534 | self.idx = 0
535 |
536 | @classmethod
537 | def load(cls, path: str, arr_name: str):
538 | with open(path, "rb") as f:
539 | arr = np.load(f)[arr_name]
540 | return cls(arr)
541 |
542 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
543 | if self.idx >= self.arr.shape[0]:
544 | return None
545 |
546 | res = self.arr[self.idx : self.idx + batch_size]
547 | self.idx += batch_size
548 | return res
549 |
550 | def remaining(self) -> int:
551 | return max(0, self.arr.shape[0] - self.idx)
552 |
553 |
554 | @contextmanager
555 | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
556 | with _open_npy_file(path, arr_name) as arr_f:
557 | version = np.lib.format.read_magic(arr_f)
558 | if version == (1, 0):
559 | header = np.lib.format.read_array_header_1_0(arr_f)
560 | elif version == (2, 0):
561 | header = np.lib.format.read_array_header_2_0(arr_f)
562 | else:
563 | yield MemoryNpzArrayReader.load(path, arr_name)
564 | return
565 | shape, fortran, dtype = header
566 | if fortran or dtype.hasobject:
567 | yield MemoryNpzArrayReader.load(path, arr_name)
568 | else:
569 | yield StreamingNpzArrayReader(arr_f, shape, dtype)
570 |
571 |
572 | def _read_bytes(fp, size, error_template="ran out of data"):
573 | """
574 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
575 |
576 | Read from file-like object until size bytes are read.
577 | Raises ValueError if not EOF is encountered before size bytes are read.
578 | Non-blocking objects only supported if they derive from io objects.
579 | Required as e.g. ZipExtFile in python 2.6 can return less data than
580 | requested.
581 | """
582 | data = bytes()
583 | while True:
584 | # io files (default in python3) return None or raise on
585 | # would-block, python2 file will truncate, probably nothing can be
586 | # done about that. note that regular files can't be non-blocking
587 | try:
588 | r = fp.read(size - len(data))
589 | data += r
590 | if len(r) == 0 or len(data) == size:
591 | break
592 | except io.BlockingIOError:
593 | pass
594 | if len(data) != size:
595 | msg = "EOF: reading %s, expected %d bytes got %d"
596 | raise ValueError(msg % (error_template, size, len(data)))
597 | else:
598 | return data
599 |
600 |
601 | @contextmanager
602 | def _open_npy_file(path: str, arr_name: str):
603 | with open(path, "rb") as f:
604 | with zipfile.ZipFile(f, "r") as zip_f:
605 | if f"{arr_name}.npy" not in zip_f.namelist():
606 | raise ValueError(f"missing {arr_name} in npz file")
607 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
608 | yield arr_f
609 |
610 |
611 | def _download_inception_model():
612 | if os.path.exists(INCEPTION_V3_PATH):
613 | return
614 | print("downloading InceptionV3 model...")
615 | with requests.get(INCEPTION_V3_URL, stream=True) as r:
616 | r.raise_for_status()
617 | tmp_path = INCEPTION_V3_PATH + ".tmp"
618 | with open(tmp_path, "wb") as f:
619 | for chunk in tqdm(r.iter_content(chunk_size=8192)):
620 | f.write(chunk)
621 | os.rename(tmp_path, INCEPTION_V3_PATH)
622 |
623 |
624 | def _create_feature_graph(input_batch):
625 | _download_inception_model()
626 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
627 | with open(INCEPTION_V3_PATH, "rb") as f:
628 | graph_def = tf.GraphDef()
629 | graph_def.ParseFromString(f.read())
630 | pool3, spatial = tf.import_graph_def(
631 | graph_def,
632 | input_map={f"ExpandDims:0": input_batch},
633 | return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
634 | name=prefix,
635 | )
636 | _update_shapes(pool3)
637 | spatial = spatial[..., :7]
638 | return pool3, spatial
639 |
640 |
641 | def _create_softmax_graph(input_batch):
642 | _download_inception_model()
643 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
644 | with open(INCEPTION_V3_PATH, "rb") as f:
645 | graph_def = tf.GraphDef()
646 | graph_def.ParseFromString(f.read())
647 | (matmul,) = tf.import_graph_def(
648 | graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
649 | )
650 | w = matmul.inputs[1]
651 | logits = tf.matmul(input_batch, w)
652 | return tf.nn.softmax(logits)
653 |
654 |
655 | def _update_shapes(pool3):
656 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
657 | ops = pool3.graph.get_operations()
658 | for op in ops:
659 | for o in op.outputs:
660 | shape = o.get_shape()
661 | if shape._dims is not None: # pylint: disable=protected-access
662 | # shape = [s.value for s in shape] TF 1.x
663 | shape = [s for s in shape] # TF 2.x
664 | new_shape = []
665 | for j, s in enumerate(shape):
666 | if s == 1 and j == 0:
667 | new_shape.append(None)
668 | else:
669 | new_shape.append(s)
670 | o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
671 | return pool3
672 |
673 |
674 | def _numpy_partition(arr, kth, **kwargs):
675 | num_workers = min(cpu_count(), len(arr))
676 | chunk_size = len(arr) // num_workers
677 | extra = len(arr) % num_workers
678 |
679 | start_idx = 0
680 | batches = []
681 | for i in range(num_workers):
682 | size = chunk_size + (1 if i < extra else 0)
683 | batches.append(arr[start_idx : start_idx + size])
684 | start_idx += size
685 |
686 | with ThreadPool(num_workers) as pool:
687 | return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
688 |
689 |
690 | if __name__ == "__main__":
691 | main()
692 |
--------------------------------------------------------------------------------