├── .gitignore
├── cifar10_experiment
├── TorchSSL
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── DistributedProxySampler.py
│ │ ├── dataset.py
│ │ ├── augmentation
│ │ │ └── randaugment.py
│ │ └── data_utils.py
│ ├── models
│ │ ├── nets
│ │ │ ├── __init__.py
│ │ │ ├── wrn.py
│ │ │ ├── wrn_var.py
│ │ │ └── resnet50.py
│ │ └── freematch
│ │ │ ├── __init__.py
│ │ │ └── freematch_utils.py
│ ├── config
│ │ ├── freematch_cifar10_40_1.yaml
│ │ └── freematch_cifar10_40_1_aug1000.yaml
│ ├── pseudo_dataset.py
│ ├── eval.py
│ ├── utils.py
│ ├── get_labels.py
│ └── custom_writer.py
├── edm
│ ├── training
│ │ ├── __init__.py
│ │ ├── loss.py
│ │ └── dataset.py
│ ├── torch_utils
│ │ ├── __init__.py
│ │ └── distributed.py
│ ├── dnnlib
│ │ └── __init__.py
│ ├── generate.sh
│ └── fid.py
└── README.md
├── libs
├── __init__.py
├── clip.py
├── timm.py
├── uvit_t2i.py
└── uvit.py
├── dpt.png
├── idx_to_class.pkl
├── scripts
├── convert_fid_stats.py
├── extract_empty_feature.py
├── extract_test_prompt_feature.py
├── extract_imagenet_feature.py
├── extract_imagenet_features.py
├── extract_mscoco_feature.py
└── sweep_sample.py
├── LICENSE
├── src
├── sgd.py
├── losses.py
├── utils.py
└── data_manager.py
├── configs
├── accelerate_b4_subset1_2img_k_128_large.py
├── accelerate_l7_subset2_zimg_k_128_large.py
└── accelerate_b4_subset1_2img_k_128_huge.py
├── extract_imagenet_features_semi.py
├── grid_sample.py
├── utils.py
├── sample_ldm_all.py
├── sample_ldm_discrete_all.py
├── sde.py
└── tools
└── fid_score.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__*
2 | *cluster*/
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/libs/__init__.py:
--------------------------------------------------------------------------------
1 | # codes from third party
2 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/models/nets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/models/freematch/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dpt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ML-GSAI/DPT/HEAD/dpt.png
--------------------------------------------------------------------------------
/idx_to_class.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ML-GSAI/DPT/HEAD/idx_to_class.pkl
--------------------------------------------------------------------------------
/cifar10_experiment/edm/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/cifar10_experiment/edm/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/cifar10_experiment/edm/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | from .util import EasyDict, make_cache_dir_path
9 |
--------------------------------------------------------------------------------
/scripts/convert_fid_stats.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def convert(resolution):
5 | # download ImageNet reference batch from https://github.com/openai/guided-diffusion/tree/main/evaluations
6 | if resolution == 512:
7 | obj = np.load(f'VIRTUAL_imagenet{resolution}.npz')
8 | else:
9 | obj = np.load(f'VIRTUAL_imagenet{resolution}_labeled.npz')
10 | np.savez(f'fid_stats_imagenet{resolution}_guided_diffusion.npz', mu=obj['mu'], sigma=obj['sigma'])
11 |
12 |
13 | convert(resolution=512)
14 |
--------------------------------------------------------------------------------
/scripts/extract_empty_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import libs.autoencoder
5 | import libs.clip
6 | from datasets import MSCOCODatabase
7 | import argparse
8 | from tqdm import tqdm
9 |
10 |
11 | def main():
12 | prompts = [
13 | '',
14 | ]
15 |
16 | device = 'cuda'
17 | clip = libs.clip.FrozenCLIPEmbedder()
18 | clip.eval()
19 | clip.to(device)
20 |
21 | save_dir = f'assets/datasets/coco256_features'
22 | latent = clip.encode(prompts)
23 | print(latent.shape)
24 | c = latent[0].detach().cpu().numpy()
25 | np.save(os.path.join(save_dir, f'empty_context.npy'), c)
26 |
27 |
28 | if __name__ == '__main__':
29 | main()
30 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/config/freematch_cifar10_40_1.yaml:
--------------------------------------------------------------------------------
1 | save_dir: ./saved_models
2 | save_name: freematch_cifar10_40_1
3 | resume: False
4 | load_path: None
5 | overwrite: True
6 | use_tensorboard: True
7 | epoch: 1
8 | num_train_iter: 1048576
9 | num_eval_iter: 5000
10 | num_labels: 40
11 | batch_size: 64
12 | eval_batch_size: 1024
13 | hard_label: True
14 | T: 0.5
15 | ulb_loss_ratio: 1.0
16 | ent_loss_ratio: 0.0
17 | uratio: 7
18 | ema_m: 0.999
19 | optim: SGD
20 | lr: 0.03
21 | momentum: 0.9
22 | weight_decay: 0.0005
23 | amp: False
24 | net: WideResNet
25 | net_from_name: False
26 | depth: 28
27 | widen_factor: 2
28 | leaky_slope: 0.1
29 | dropout: 0.0
30 | data_dir: ./data
31 | dataset: cifar10
32 | train_sampler: RandomSampler
33 | num_classes: 10
34 | num_workers: 1
35 | alg: freematch
36 | seed: 1
37 | world_size: 1
38 | rank: 0
39 | multiprocessing_distributed: True
40 | dist_url: tcp://127.0.0.1:10043
41 | dist_backend: nccl
42 | gpu: None
43 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/config/freematch_cifar10_40_1_aug1000.yaml:
--------------------------------------------------------------------------------
1 | save_dir: ./saved_models
2 | save_name: freematch_cifar10_40_1_aug1000
3 | resume: False
4 | load_path: None
5 | overwrite: True
6 | use_tensorboard: True
7 | epoch: 1
8 | num_train_iter: 1048576
9 | num_eval_iter: 5000
10 | num_labels: 40
11 | batch_size: 64
12 | eval_batch_size: 1024
13 | hard_label: True
14 | T: 0.5
15 | ulb_loss_ratio: 1.0
16 | ent_loss_ratio: 0.0
17 | uratio: 7
18 | ema_m: 0.999
19 | optim: SGD
20 | lr: 0.03
21 | momentum: 0.9
22 | weight_decay: 0.0005
23 | amp: False
24 | net: WideResNet
25 | net_from_name: False
26 | depth: 28
27 | widen_factor: 2
28 | leaky_slope: 0.1
29 | dropout: 0.0
30 | data_dir: ./data
31 | dataset: cifar10
32 | train_sampler: RandomSampler
33 | num_classes: 10
34 | num_workers: 1
35 | alg: freematch
36 | seed: 1
37 | world_size: 1
38 | rank: 0
39 | multiprocessing_distributed: True
40 | dist_url: tcp://127.0.0.1:10043
41 | dist_backend: nccl
42 | gpu: None
43 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 yyyouy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/cifar10_experiment/edm/generate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | num=$1
3 | save_dir=$2
4 | model=$3
5 | nproc_per_node=$4
6 |
7 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/0 --seeds=0-$num --class=0 --network=$model
8 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/1 --seeds=0-$num --class=1 --network=$model
9 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/2 --seeds=0-$num --class=2 --network=$model
10 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/3 --seeds=0-$num --class=3 --network=$model
11 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/4 --seeds=0-$num --class=4 --network=$model
12 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/5 --seeds=0-$num --class=5 --network=$model
13 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/6 --seeds=0-$num --class=6 --network=$model
14 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/7 --seeds=0-$num --class=7 --network=$model
15 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/8 --seeds=0-$num --class=8 --network=$model
16 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/9 --seeds=0-$num --class=9 --network=$model
--------------------------------------------------------------------------------
/libs/clip.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import CLIPTokenizer, CLIPTextModel
3 |
4 |
5 | class AbstractEncoder(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 |
9 | def encode(self, *args, **kwargs):
10 | raise NotImplementedError
11 |
12 |
13 | class FrozenCLIPEmbedder(AbstractEncoder):
14 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
15 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
16 | super().__init__()
17 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
18 | self.transformer = CLIPTextModel.from_pretrained(version)
19 | self.device = device
20 | self.max_length = max_length
21 | self.freeze()
22 |
23 | def freeze(self):
24 | self.transformer = self.transformer.eval()
25 | for param in self.parameters():
26 | param.requires_grad = False
27 |
28 | def forward(self, text):
29 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
30 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
31 | tokens = batch_encoding["input_ids"].to(self.device)
32 | outputs = self.transformer(input_ids=tokens)
33 |
34 | z = outputs.last_hidden_state
35 | return z
36 |
37 | def encode(self, text):
38 | return self(text)
39 |
--------------------------------------------------------------------------------
/scripts/extract_test_prompt_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import libs.autoencoder
5 | import libs.clip
6 | from datasets import MSCOCODatabase
7 | import argparse
8 | from tqdm import tqdm
9 |
10 |
11 | def main():
12 | prompts = [
13 | 'A green train is coming down the tracks.',
14 | 'A group of skiers are preparing to ski down a mountain.',
15 | 'A small kitchen with a low ceiling.',
16 | 'A group of elephants walking in muddy water.',
17 | 'A living area with a television and a table.',
18 | 'A road with traffic lights, street lights and cars.',
19 | 'A bus driving in a city area with traffic signs.',
20 | 'A bus pulls over to the curb close to an intersection.',
21 | 'A group of people are walking and one is holding an umbrella.',
22 | 'A baseball player taking a swing at an incoming ball.',
23 | 'A city street line with brick buildings and trees.',
24 | 'A close up of a plate of broccoli and sauce.',
25 | ]
26 |
27 | device = 'cuda'
28 | clip = libs.clip.FrozenCLIPEmbedder()
29 | clip.eval()
30 | clip.to(device)
31 |
32 | save_dir = f'assets/datasets/coco256_features/run_vis'
33 | latent = clip.encode(prompts)
34 | for i in range(len(latent)):
35 | c = latent[i].detach().cpu().numpy()
36 | np.save(os.path.join(save_dir, f'{i}.npy'), (prompts[i], c))
37 |
38 |
39 | if __name__ == '__main__':
40 | main()
41 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/models/freematch/freematch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from train_utils import ce_loss
4 |
5 |
6 | class Get_Scalar:
7 | def __init__(self, value):
8 | self.value = value
9 |
10 | def get_value(self, iter):
11 | return self.value
12 |
13 | def __call__(self, iter):
14 | return self.value
15 |
16 |
17 |
18 | def consistency_loss(dataset,logits_s, logits_w,time_p,p_model, name='ce', use_hard_labels=True):
19 | assert name in ['ce', 'L2']
20 | logits_w = logits_w.detach()
21 | if name == 'L2':
22 | assert logits_w.size() == logits_s.size()
23 | return F.mse_loss(logits_s, logits_w, reduction='mean')
24 |
25 | elif name == 'L2_mask':
26 | pass
27 |
28 | elif name == 'ce':
29 | pseudo_label = torch.softmax(logits_w, dim=-1)
30 | max_probs, max_idx = torch.max(pseudo_label, dim=-1)
31 | p_cutoff = time_p
32 | p_model_cutoff = p_model / torch.max(p_model,dim=-1)[0]
33 | threshold = p_cutoff * p_model_cutoff[max_idx]
34 | if dataset == 'svhn':
35 | threshold = torch.clamp(threshold, min=0.9, max=0.95)
36 | mask = max_probs.ge(threshold)
37 | if use_hard_labels:
38 | masked_loss = ce_loss(logits_s, max_idx, use_hard_labels, reduction='none') * mask.float()
39 | else:
40 | pseudo_label = torch.softmax(logits_w / T, dim=-1)
41 | masked_loss = ce_loss(logits_s, pseudo_label, use_hard_labels) * mask.float()
42 | return masked_loss.mean(), mask
43 |
44 | else:
45 | assert Exception('Not Implemented consistency_loss')
46 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/datasets/DistributedProxySampler.py:
--------------------------------------------------------------------------------
1 | # copyright: https://github.com/pytorch/pytorch/issues/23430#issuecomment-562350407
2 |
3 | import math
4 | import torch
5 | from torch.utils.data.distributed import DistributedSampler
6 |
7 |
8 | class DistributedProxySampler(DistributedSampler):
9 | """Sampler that restricts data loading to a subset of input sampler indices.
10 |
11 | It is especially useful in conjunction with
12 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
13 | process can pass a DistributedSampler instance as a DataLoader sampler,
14 | and load a subset of the original dataset that is exclusive to it.
15 |
16 | .. note::
17 | Input sampler is assumed to be of constant size.
18 |
19 | Arguments:
20 | sampler: Input data sampler.
21 | num_replicas (optional): Number of processes participating in
22 | distributed training.
23 | rank (optional): Rank of the current process within num_replicas.
24 | """
25 |
26 | def __init__(self, sampler, num_replicas=None, rank=None):
27 | super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False)
28 | self.sampler = sampler
29 |
30 | def __iter__(self):
31 | # deterministically shuffle based on epoch
32 | torch.manual_seed(self.epoch)
33 | indices = list(self.sampler)
34 |
35 | # add extra samples to make it evenly divisible
36 | indices += indices[:(self.total_size - len(indices))]
37 | if len(indices) != self.total_size:
38 | raise RuntimeError("{} vs {}".format(len(indices), self.total_size))
39 |
40 | # subsample
41 | indices = indices[self.rank:self.total_size:self.num_replicas]
42 | if len(indices) != self.num_samples:
43 | raise RuntimeError("{} vs {}".format(len(indices), self.num_samples))
44 |
45 | return iter(indices)
--------------------------------------------------------------------------------
/src/sgd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | from torch.optim import Optimizer
10 |
11 |
12 | class SGD(Optimizer):
13 |
14 | def __init__(self, params, lr, momentum=0, weight_decay=0, nesterov=False):
15 | if lr < 0.0:
16 | raise ValueError(f'Invalid learning rate: {lr}')
17 | if weight_decay < 0.0:
18 | raise ValueError(f'Invalid weight_decay value: {weight_decay}')
19 |
20 | defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay,
21 | nesterov=nesterov)
22 | super(SGD, self).__init__(params, defaults)
23 |
24 | @torch.no_grad()
25 | def step(self):
26 | for group in self.param_groups:
27 | weight_decay = group['weight_decay']
28 | momentum = group['momentum']
29 | nesterov = group['nesterov']
30 |
31 | for p in group['params']:
32 | if p.grad is None:
33 | continue
34 |
35 | d_p = p.grad
36 | if weight_decay != 0:
37 | d_p = d_p.add(p, alpha=weight_decay)
38 | d_p.mul_(-group['lr'])
39 |
40 | if momentum != 0:
41 | param_state = self.state[p]
42 | if 'momentum_buffer' not in param_state:
43 | buf = param_state['momentum_buffer'] = d_p.clone().detach()
44 | else:
45 | buf = param_state['momentum_buffer']
46 | buf.mul_(momentum).add_(d_p)
47 |
48 | if nesterov:
49 | d_p.add_(buf, alpha=momentum)
50 | else:
51 | d_p = buf
52 |
53 | p.add_(d_p)
54 |
55 | return None
56 |
--------------------------------------------------------------------------------
/scripts/extract_imagenet_feature.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import torch
4 | from datasets import ImageNet
5 | from torch.utils.data import DataLoader
6 | from libs.autoencoder import get_model
7 | import argparse
8 | from tqdm import tqdm
9 | torch.manual_seed(0)
10 | np.random.seed(0)
11 |
12 |
13 | def main(resolution=256):
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('path')
16 | args = parser.parse_args()
17 |
18 | dataset = ImageNet(path=args.path, resolution=resolution, random_flip=False)
19 | train_dataset = dataset.get_split(split='train', labeled=True)
20 | train_dataset_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, drop_last=False,
21 | num_workers=8, pin_memory=True, persistent_workers=True)
22 |
23 | model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
24 | model = nn.DataParallel(model)
25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
26 | model.to(device)
27 |
28 | # features = []
29 | # labels = []
30 |
31 | idx = 0
32 | for batch in tqdm(train_dataset_loader):
33 | img, label = batch
34 | img = torch.cat([img, img.flip(dims=[-1])], dim=0)
35 | img = img.to(device)
36 | moments = model(img, fn='encode_moments')
37 | moments = moments.detach().cpu().numpy()
38 |
39 | label = torch.cat([label, label], dim=0)
40 | label = label.detach().cpu().numpy()
41 |
42 | for moment, lb in zip(moments, label):
43 | np.save(f'assets/datasets/imagenet{resolution}_features/{idx}.npy', (moment, lb))
44 | idx += 1
45 |
46 | print(f'save {idx} files')
47 |
48 | # features = np.concatenate(features, axis=0)
49 | # labels = np.concatenate(labels, axis=0)
50 | # print(f'features.shape={features.shape}')
51 | # print(f'labels.shape={labels.shape}')
52 | # np.save(f'imagenet{resolution}_features.npy', features)
53 | # np.save(f'imagenet{resolution}_labels.npy', labels)
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/scripts/extract_imagenet_features.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import torch
4 | from datasets import ImageNet
5 | from torch.utils.data import DataLoader
6 | from libs.autoencoder import get_model
7 | import argparse
8 | from tqdm import tqdm
9 | torch.manual_seed(0)
10 | np.random.seed(0)
11 |
12 |
13 | def main(resolution=256):
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('path')
16 | args = parser.parse_args()
17 |
18 | dataset = ImageNet(path=args.path, resolution=resolution, random_flip=False)
19 | train_dataset = dataset.get_split(split='train', labeled=True)
20 | train_dataset_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, drop_last=False,
21 | num_workers=8, pin_memory=True, persistent_workers=True)
22 |
23 | model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
24 | model = nn.DataParallel(model)
25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
26 | model.to(device)
27 |
28 | # features = []
29 | # labels = []
30 |
31 | idx = 0
32 | for batch in tqdm(train_dataset_loader):
33 | img, label = batch
34 | img = torch.cat([img, img.flip(dims=[-1])], dim=0)
35 | img = img.to(device)
36 | moments = model(img, fn='encode_moments')
37 | moments = moments.detach().cpu().numpy()
38 |
39 | label = torch.cat([label, label], dim=0)
40 | label = label.detach().cpu().numpy()
41 |
42 | for moment, lb in zip(moments, label):
43 | np.save(f'assets/datasets/imagenet{resolution}_features/{idx}.npy', (moment, lb))
44 | idx += 1
45 |
46 | print(f'save {idx} files')
47 |
48 | # features = np.concatenate(features, axis=0)
49 | # labels = np.concatenate(labels, axis=0)
50 | # print(f'features.shape={features.shape}')
51 | # print(f'labels.shape={labels.shape}')
52 | # np.save(f'imagenet{resolution}_features.npy', features)
53 | # np.save(f'imagenet{resolution}_labels.npy', labels)
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/scripts/extract_mscoco_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import libs.autoencoder
5 | import libs.clip
6 | from datasets import MSCOCODatabase
7 | import argparse
8 | from tqdm import tqdm
9 |
10 |
11 | def main(resolution=256):
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--split', default='train')
14 | args = parser.parse_args()
15 | print(args)
16 |
17 |
18 | if args.split == "train":
19 | datas = MSCOCODatabase(root='assets/datasets/coco/train2014',
20 | annFile='assets/datasets/coco/annotations/captions_train2014.json',
21 | size=resolution)
22 | save_dir = f'assets/datasets/coco{resolution}_features/train'
23 | elif args.split == "val":
24 | datas = MSCOCODatabase(root='assets/datasets/coco/val2014',
25 | annFile='assets/datasets/coco/annotations/captions_val2014.json',
26 | size=resolution)
27 | save_dir = f'assets/datasets/coco{resolution}_features/val'
28 | else:
29 | raise NotImplementedError("ERROR!")
30 |
31 | device = "cuda"
32 | os.makedirs(save_dir)
33 |
34 | autoencoder = libs.autoencoder.get_model('assets/stable-diffusion/autoencoder_kl.pth')
35 | autoencoder.to(device)
36 | clip = libs.clip.FrozenCLIPEmbedder()
37 | clip.eval()
38 | clip.to(device)
39 |
40 | with torch.no_grad():
41 | for idx, data in tqdm(enumerate(datas)):
42 | x, captions = data
43 |
44 | if len(x.shape) == 3:
45 | x = x[None, ...]
46 | x = torch.tensor(x, device=device)
47 | moments = autoencoder(x, fn='encode_moments').squeeze(0)
48 | moments = moments.detach().cpu().numpy()
49 | np.save(os.path.join(save_dir, f'{idx}.npy'), moments)
50 |
51 | latent = clip.encode(captions)
52 | for i in range(len(latent)):
53 | c = latent[i].detach().cpu().numpy()
54 | np.save(os.path.join(save_dir, f'{idx}_{i}.npy'), c)
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/pseudo_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torchvision
3 | import torch
4 | import os
5 | import pickle
6 | import sys
7 | import pickle
8 | import numpy as np
9 | from torch.utils.data import DataLoader
10 | # 随机打乱targets
11 | import random
12 | import tarfile
13 |
14 | def unpickle(file):
15 | with open(file, 'rb') as fo:
16 | dict = pickle.load(fo, encoding='latin1')
17 | return dict
18 |
19 | if __name__ == "__main__":
20 | import argparse
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--data_path', type=str, default='./data/cifar-10-batches-py')
23 | parser.add_argument('--save_path', type=str, default='./saved_models/freematch_cifar10_40_1')
24 |
25 | args = parser.parse_args()
26 |
27 | labels = [] # save labels
28 | data_dict = {} # data_1, data_2, data_3, data_4, data_5
29 |
30 | for i in range(5):
31 | batch = unpickle(os.path.join(args.data_path, 'data_batch_{}'.format(i+1)))
32 | data_dict['data_{}'.format(i+1)] = batch['data']
33 | labels.extend(batch['labels'])
34 |
35 | pseudo_path = os.path.join(args.save_path, 'pseudo_label.npy')
36 | pseudo_labels = np.load(pseudo_path)
37 |
38 | # print true labels distribution and pseudo labels distribution
39 | true_distribution = torch.bincount(torch.tensor(labels))
40 | pseudo_distribution = torch.bincount(torch.tensor(pseudo_labels))
41 | print("true distribution:", true_distribution)
42 | print("pseudo distribution:", pseudo_distribution)
43 | # print accuracy
44 | print(np.sum(labels == pseudo_labels) / 50000)
45 |
46 | for i in range(5):
47 | data = data_dict['data_{}'.format(i+1)]
48 | targets = pseudo_labels[i*10000:(i+1)*10000]
49 | batch_path = os.path.join(args.save_path, 'cifar-10-batches-py/data_batch_{}'.format(i+1))
50 | os.makedirs(os.path.dirname(batch_path), exist_ok=True)
51 | with open(batch_path, 'wb') as f:
52 | pickle.dump({'data': data, 'labels': targets}, f)
53 |
54 | with tarfile.open(os.path.join(args.save_path, 'cifar-10-python.tar.gz'), 'w:gz') as f:
55 | f.add(os.path.join(args.save_path, 'cifar-10-batches-py'), arcname='cifar-10-batches-py')
56 |
57 |
58 |
--------------------------------------------------------------------------------
/cifar10_experiment/edm/torch_utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import os
9 | import torch
10 | from . import training_stats
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | def init():
15 | if 'MASTER_ADDR' not in os.environ:
16 | os.environ['MASTER_ADDR'] = 'localhost'
17 | if 'MASTER_PORT' not in os.environ:
18 | os.environ['MASTER_PORT'] = '29500'
19 | if 'RANK' not in os.environ:
20 | os.environ['RANK'] = '0'
21 | if 'LOCAL_RANK' not in os.environ:
22 | os.environ['LOCAL_RANK'] = '0'
23 | if 'WORLD_SIZE' not in os.environ:
24 | os.environ['WORLD_SIZE'] = '1'
25 |
26 | backend = 'gloo' if os.name == 'nt' else 'nccl'
27 | torch.distributed.init_process_group(backend=backend, init_method='env://')
28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
29 |
30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None
31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def get_rank():
36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | def get_world_size():
41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | def should_stop():
46 | return False
47 |
48 | #----------------------------------------------------------------------------
49 |
50 | def update_progress(cur, total):
51 | _ = cur, total
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def print0(*args, **kwargs):
56 | if get_rank() == 0:
57 | print(*args, **kwargs)
58 |
59 | #----------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/scripts/sweep_sample.py:
--------------------------------------------------------------------------------
1 | import time
2 | import random
3 | import os
4 | import argparse
5 | import subprocess
6 | from pathlib import Path
7 | import datetime
8 |
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--workdir', required=True)
13 | parser.add_argument('--port', type=int)
14 | parser.add_argument('--ckpts', type=str)
15 |
16 | args, unknown = parser.parse_known_args()
17 | args.ckpt_root = os.path.join(args.workdir, 'ckpts')
18 | if args.port is None:
19 | args.port = random.randint(10000, 30000)
20 |
21 | for x in unknown:
22 | assert '=' in x
23 |
24 | return args, ' '.join(unknown)
25 |
26 |
27 | def valid_str(unknown):
28 | items = unknown.split(' ')
29 | res = []
30 | for item in items:
31 | assert item.startswith('--')
32 | res.append(item[2:].replace('/', '_'))
33 | return '_'.join(res)
34 |
35 |
36 | def main():
37 | args, unknown = parse_args()
38 | ckpts = [f'{int(ckpt)}.ckpt' for ckpt in args.ckpts.split(',')]
39 | n_devices = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
40 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
41 | output_path = os.path.join(args.workdir, f'{now}_sweep_sample.log')
42 |
43 | for ckpt in ckpts:
44 | print(f'try sample {os.path.join(args.ckpt_root, ckpt)}')
45 |
46 | while not os.path.exists(os.path.join(args.ckpt_root, ckpt)):
47 | time.sleep(5 + 5 * random.random())
48 | time.sleep(5 + 5 * random.random())
49 |
50 | if os.path.exists(os.path.join(args.ckpt_root, '.state', ckpt, valid_str(unknown)[:100])):
51 | print(f'{ckpt} already evaluated, skip')
52 | continue
53 |
54 | os.makedirs(os.path.join(args.ckpt_root, '.state', ckpt), exist_ok=True) # mark as running
55 | Path(os.path.join(args.ckpt_root, '.state', ckpt, valid_str(unknown)[:100])).touch()
56 |
57 | dct = dict()
58 | nnet_path = os.path.join(args.ckpt_root, ckpt, 'nnet_ema.pth')
59 | dct['nnet_path'] = nnet_path
60 | dct['output_path'] = output_path
61 |
62 | dct_str = ' '.join([f'--{key}={val}' for key, val in dct.items()])
63 |
64 | accelerate_args = f'--multi_gpu --main_process_port {args.port} --num_processes {n_devices} --mixed_precision fp16'
65 | cmd = f'accelerate launch {accelerate_args} sample.py {dct_str} {unknown}'
66 | cmd = list(filter(lambda x: x != '', cmd.split(' ')))
67 | print(cmd)
68 | subprocess.Popen(cmd).wait()
69 |
70 |
71 | if __name__ == "__main__":
72 | main()
73 |
--------------------------------------------------------------------------------
/configs/accelerate_b4_subset1_2img_k_128_large.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.lambd = 0.0025
13 | config.penalty = 'l2'
14 | config.mask = 0.0
15 | config.preload = True
16 | config.fname = 'vitb4_300ep.pth.tar'
17 | config.model_name = 'deit_base_p4'
18 | config.pretrained = 'pretrained/'
19 |
20 | config.normalize = True
21 | config.root_path = '/cache/datasets/ILSVRC/Data/'
22 | config.image_folder = 'CLS-LOC/'
23 | config.image_path = '/cache/datasets/ILSVRC/Data/CLS-LOC'
24 |
25 | config.subset_path = 'imagenet_subsets1/2imgs_class.txt'
26 | config.blocks = 1
27 |
28 | config.seed = 1234
29 | config.pred = 'noise_pred'
30 | config.ema_rate = 0.9999
31 | config.z_shape = (4, 32, 32)
32 | config.resolution = 256
33 |
34 | config.autoencoder = d(
35 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth'
36 | )
37 |
38 | config.dpm_path = 'assets/DPM'
39 |
40 | # augmentation
41 | config.augmentation_K = 128 # add 128 generated images per class
42 | config.using_true_label = True
43 | config.output_path = ''
44 |
45 | config.train = d(
46 | n_steps=300000,
47 | batch_size=1024,
48 | mode='cond',
49 | log_interval=10,
50 | eval_interval=5000,
51 | save_interval=50000,
52 | )
53 |
54 | config.optimizer = d(
55 | name='adamw',
56 | lr=0.0002,
57 | weight_decay=0.03,
58 | betas=(0.99, 0.99),
59 | )
60 |
61 | config.lr_scheduler = d(
62 | name='customized',
63 | warmup_steps=5000
64 | )
65 |
66 | config.nnet = d(
67 | name='uvit',
68 | img_size=32,
69 | patch_size=2,
70 | in_chans=4,
71 | embed_dim=1024,
72 | depth=20,
73 | num_heads=16,
74 | mlp_ratio=4,
75 | qkv_bias=False,
76 | mlp_time_embed=False,
77 | num_classes=1001,
78 | use_checkpoint=True
79 | )
80 |
81 | config.dataset = d(
82 | name='imagenet256_features',
83 | path='',
84 | cfg=True,
85 | p_uncond=0.15
86 | )
87 |
88 | config.sample = d(
89 | sample_steps=50,
90 | n_samples=50000,
91 | mini_batch_size=20, # the decoder is large
92 | algorithm='dpm_solver',
93 | cfg=True,
94 | scale=0.4,
95 | path=''
96 | )
97 |
98 | return config
99 |
100 | config = get_config()
101 |
--------------------------------------------------------------------------------
/configs/accelerate_l7_subset2_zimg_k_128_large.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.lambd = 0.0025
13 | config.penalty = 'l2'
14 | config.mask = 0.0
15 | config.preload = True
16 | config.fname = 'vitl7_200ep.pth.tar'
17 | config.model_name = 'deit_large_p7'
18 | config.pretrained = 'pretrained/'
19 |
20 | config.normalize = True
21 | config.root_path = '/cache/datasets/ILSVRC/Data/'
22 | config.image_folder = 'CLS-LOC/'
23 | config.image_path = '/cache/datasets/ILSVRC/Data/CLS-LOC'
24 |
25 | config.subset_path = 'imagenet_subsets2/1imgs_class.txt'
26 | config.blocks = 1
27 |
28 | config.seed = 1234
29 | config.pred = 'noise_pred'
30 | config.ema_rate = 0.9999
31 | config.z_shape = (4, 32, 32)
32 | config.resolution = 256
33 |
34 | config.autoencoder = d(
35 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth'
36 | )
37 |
38 | config.dpm_path = 'assets/DPM'
39 |
40 | # augmentation
41 | config.augmentation_K = 128 # add 128 generated images per class
42 | config.using_true_label = True
43 | config.output_path = ''
44 |
45 | config.train = d(
46 | n_steps=300000,
47 | batch_size=1024,
48 | mode='cond',
49 | log_interval=10,
50 | eval_interval=5000,
51 | save_interval=50000,
52 | )
53 |
54 | config.optimizer = d(
55 | name='adamw',
56 | lr=0.0002,
57 | weight_decay=0.03,
58 | betas=(0.99, 0.99),
59 | )
60 |
61 | config.lr_scheduler = d(
62 | name='customized',
63 | warmup_steps=5000
64 | )
65 |
66 | config.nnet = d(
67 | name='uvit',
68 | img_size=32,
69 | patch_size=2,
70 | in_chans=4,
71 | embed_dim=1024,
72 | depth=20,
73 | num_heads=16,
74 | mlp_ratio=4,
75 | qkv_bias=False,
76 | mlp_time_embed=False,
77 | num_classes=1001,
78 | use_checkpoint=True
79 | )
80 |
81 | config.dataset = d(
82 | name='imagenet256_features',
83 | path='',
84 | cfg=True,
85 | p_uncond=0.15
86 | )
87 |
88 | config.sample = d(
89 | sample_steps=50,
90 | n_samples=50000,
91 | mini_batch_size=20, # the decoder is large
92 | algorithm='dpm_solver',
93 | cfg=True,
94 | scale=0.4,
95 | path=''
96 | )
97 |
98 | return config
99 |
100 | config = get_config()
101 |
--------------------------------------------------------------------------------
/configs/accelerate_b4_subset1_2img_k_128_huge.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.lambd = 0.0025
13 | config.penalty = 'l2'
14 | config.mask = 0.0
15 | config.preload = True
16 | config.fname = 'vitb4_300ep.pth.tar'
17 | config.model_name = 'deit_base_p4'
18 | config.pretrained = 'pretrained/'
19 |
20 | config.normalize = True
21 | config.root_path = '/cache/datasets/ILSVRC/Data/'
22 | config.image_folder = 'CLS-LOC/'
23 | config.image_path = '/cache/datasets/ILSVRC/Data/CLS-LOC'
24 |
25 | config.subset_path = 'imagenet_subsets1/2imgs_class.txt'
26 | config.blocks = 1
27 |
28 | config.seed = 1234
29 | config.pred = 'noise_pred'
30 | config.ema_rate = 0.9999
31 | config.z_shape = (4, 32, 32)
32 | config.resolution = 256
33 |
34 | config.autoencoder = d(
35 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth'
36 | )
37 |
38 | config.dpm_path = 'assets/DPM'
39 |
40 | # augmentation
41 | config.augmentation_K = 128 # add 128 generated images per class
42 | config.using_true_label = True
43 | config.output_path = ''
44 |
45 | config.train = d(
46 | n_steps=500000,
47 | batch_size=1024,
48 | mode='cond',
49 | log_interval=10,
50 | eval_interval=5000,
51 | save_interval=50000,
52 | )
53 |
54 | config.optimizer = d(
55 | name='adamw',
56 | lr=0.0002,
57 | weight_decay=0.03,
58 | betas=(0.99, 0.99),
59 | )
60 |
61 | config.lr_scheduler = d(
62 | name='customized',
63 | warmup_steps=5000
64 | )
65 |
66 | config.nnet = d(
67 | name='uvit',
68 | img_size=32,
69 | patch_size=2,
70 | in_chans=4,
71 | embed_dim=1152,
72 | depth=28,
73 | num_heads=16,
74 | mlp_ratio=4,
75 | qkv_bias=False,
76 | mlp_time_embed=False,
77 | num_classes=1001,
78 | use_checkpoint=True,
79 | conv=False
80 | )
81 |
82 | config.dataset = d(
83 | name='imagenet256_features',
84 | path='',
85 | cfg=True,
86 | p_uncond=0.1
87 | )
88 |
89 | config.sample = d(
90 | sample_steps=50,
91 | n_samples=50000,
92 | mini_batch_size=20, # the decoder is large
93 | algorithm='dpm_solver',
94 | cfg=True,
95 | scale=0.4,
96 | path=''
97 | )
98 |
99 | return config
100 |
101 | config = get_config()
102 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 |
4 | import torch
5 |
6 | from utils import net_builder
7 | from datasets.ssl_dataset import SSL_Dataset
8 | from datasets.data_utils import get_data_loader
9 | import accelerate
10 |
11 | if __name__ == "__main__":
12 | import argparse
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--load_path', type=str, default='./saved_models/fixmatch/model_best.pth')
15 | parser.add_argument('--use_train_model', action='store_true')
16 |
17 | '''
18 | Backbone Net Configurations
19 | '''
20 | parser.add_argument('--net', type=str, default='WideResNet')
21 | parser.add_argument('--net_from_name', type=bool, default=False)
22 | parser.add_argument('--depth', type=int, default=28)
23 | parser.add_argument('--widen_factor', type=int, default=2)
24 | parser.add_argument('--leaky_slope', type=float, default=0.1)
25 | parser.add_argument('--dropout', type=float, default=0.0)
26 |
27 | '''
28 | Data Configurations
29 | '''
30 | parser.add_argument('--batch_size', type=int, default=256)
31 | parser.add_argument('--data_dir', type=str, default='./data')
32 | parser.add_argument('--dataset', type=str, default='cifar10')
33 | parser.add_argument('--num_classes', type=int, default=10)
34 | args = parser.parse_args()
35 |
36 | accelerator = accelerate.Accelerator()
37 |
38 | checkpoint_path = os.path.join(args.load_path)
39 | checkpoint = torch.load(checkpoint_path)
40 | load_model = checkpoint['train_model'] if args.use_train_model else checkpoint['ema_model']
41 |
42 | _net_builder = net_builder(args.net,
43 | args.net_from_name,
44 | {'depth': args.depth,
45 | 'widen_factor': args.widen_factor,
46 | 'leaky_slope': args.leaky_slope,
47 | 'dropRate': args.dropout,
48 | 'use_embed': False})
49 |
50 | net = _net_builder(num_classes=args.num_classes)
51 |
52 | _eval_dset = SSL_Dataset(args, name=args.dataset, train=False,
53 | num_classes=args.num_classes, data_dir=args.data_dir)
54 | eval_dset = _eval_dset.get_dset()
55 |
56 | eval_loader = get_data_loader(eval_dset,
57 | args.batch_size,
58 | num_workers=1, shuffle=False)
59 |
60 | #net, eval_loader = accelerator.prepare(net, eval_loader)
61 |
62 | #net.load_state_dict(load_model)
63 | # if torch.cuda.is_available():
64 | # net.cuda()
65 | weights_dict = {}
66 | for k, v in load_model.items():
67 | new_k = k.replace('module.', '') if 'module' in k else k
68 | weights_dict[new_k] = v
69 |
70 | net.load_state_dict(weights_dict)
71 |
72 | net.eval()
73 |
74 |
75 | acc = 0.0
76 | with torch.no_grad():
77 | for (_, image, target) in eval_loader:
78 | print(_)
79 | image = image.type(torch.FloatTensor)#.cuda()
80 | logit = net(image)
81 | print(logit.shape)
82 |
83 | target, logit = accelerator.gather(target), accelerator.gather(logit)
84 | acc += logit.cpu().max(1)[1].eq(target.cpu()).sum().numpy()
85 |
86 | accelerator.print(f"Test Accuracy: {acc/len(eval_dset)}")
87 |
88 |
--------------------------------------------------------------------------------
/extract_imagenet_features_semi.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import torch
4 | import os
5 | from datasets import ImageNet
6 | from torch.utils.data import DataLoader
7 | from libs.autoencoder import get_model
8 | import argparse
9 | from tqdm import tqdm
10 | torch.manual_seed(0)
11 | np.random.seed(0)
12 |
13 | def mprint(*args):
14 | print('\n-----------------------------')
15 | print(*args)
16 | print('-----------------------------\n')
17 |
18 | from absl import flags
19 | from absl import app
20 | from ml_collections import config_flags
21 | import sys
22 | from pathlib import Path
23 |
24 | FLAGS = flags.FLAGS
25 | config_flags.DEFINE_config_file(
26 | "config", None, "Training configuration.", lock_config=False)
27 | flags.mark_flags_as_required(["config"])
28 |
29 | def get_config_name():
30 | argv = sys.argv
31 | for i in range(1, len(argv)):
32 | if argv[i].startswith('--config='):
33 | return Path(argv[i].split('=')[-1]).stem
34 |
35 |
36 | def get_hparams():
37 | argv = sys.argv
38 | lst = []
39 | for i in range(1, len(argv)):
40 | assert '=' in argv[i]
41 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
42 | hparam, val = argv[i].split('=')
43 | hparam = hparam.split('.')[-1]
44 | if hparam.endswith('path'):
45 | val = Path(val).stem
46 | lst.append(f'{hparam}={val}')
47 | hparams = '-'.join(lst)
48 | if hparams == '':
49 | hparams = 'default'
50 | return hparams
51 |
52 | def main(argv):
53 | config = FLAGS.config
54 | config.config_name = get_config_name()
55 | config.hparams = get_hparams()
56 | cluster_name = config.model_name + '-' + '-'.join(config.subset_path.split('/')).split('.txt')[0]
57 | cluster_path = f'pretrained/cluster/{cluster_name}/imagenet_features_preds.npy'
58 | fnames_path = f'pretrained/cluster/{cluster_name}/imagenet_features_fnames.pth'
59 | autoencoder_path = config.autoencoder.pretrained_path
60 | path = config.image_path
61 |
62 | dataset = ImageNet(path=path, resolution=config.resolution, random_flip=False, cluster_path=cluster_path, fnames_path=fnames_path)
63 |
64 | train_dataset = dataset.get_split(split='train', labeled=True)
65 |
66 | train_batch_size = 128
67 | if config.resolution == 512:
68 | train_batch_size = 64
69 |
70 | train_dataset_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=False, drop_last=False,
71 | num_workers=8, pin_memory=True, persistent_workers=True)
72 |
73 | model = get_model(autoencoder_path)
74 | model = nn.DataParallel(model)
75 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
76 | model.to(device)
77 |
78 | save_path = f'pretrained/datasets/{cluster_name}'
79 | save_features_path = os.path.join(save_path, f'imagenet{config.resolution}_features')
80 | os.system(f'mkdir -p {save_features_path}')
81 |
82 | idx = 0
83 | for batch in tqdm(train_dataset_loader):
84 | img, label = batch
85 | img = torch.cat([img, img.flip(dims=[-1])], dim=0)
86 | img = img.to(device)
87 | moments = model(img, fn='encode_moments')
88 | moments = moments.detach().cpu().numpy()
89 |
90 | label = torch.cat([label, label], dim=0)
91 | label = label.detach().cpu().numpy()
92 |
93 | for moment, lb in zip(moments, label):
94 | np.save(os.path.join(save_features_path, f'{idx}.npy'), (moment, lb))
95 | idx += 1
96 |
97 | mprint(f'save {idx} files')
98 |
99 | if __name__ == "__main__":
100 | app.run(main)
101 |
--------------------------------------------------------------------------------
/cifar10_experiment/edm/training/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Loss functions used in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import torch
12 | from torch_utils import persistence
13 |
14 | #----------------------------------------------------------------------------
15 | # Loss function corresponding to the variance preserving (VP) formulation
16 | # from the paper "Score-Based Generative Modeling through Stochastic
17 | # Differential Equations".
18 |
19 | @persistence.persistent_class
20 | class VPLoss:
21 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
22 | self.beta_d = beta_d
23 | self.beta_min = beta_min
24 | self.epsilon_t = epsilon_t
25 |
26 | def __call__(self, net, images, labels, augment_pipe=None):
27 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
28 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
29 | weight = 1 / sigma ** 2
30 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
31 | n = torch.randn_like(y) * sigma
32 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
33 | loss = weight * ((D_yn - y) ** 2)
34 | return loss
35 |
36 | def sigma(self, t):
37 | t = torch.as_tensor(t)
38 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
39 |
40 | #----------------------------------------------------------------------------
41 | # Loss function corresponding to the variance exploding (VE) formulation
42 | # from the paper "Score-Based Generative Modeling through Stochastic
43 | # Differential Equations".
44 |
45 | @persistence.persistent_class
46 | class VELoss:
47 | def __init__(self, sigma_min=0.02, sigma_max=100):
48 | self.sigma_min = sigma_min
49 | self.sigma_max = sigma_max
50 |
51 | def __call__(self, net, images, labels, augment_pipe=None):
52 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
53 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
54 | weight = 1 / sigma ** 2
55 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
56 | n = torch.randn_like(y) * sigma
57 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
58 | loss = weight * ((D_yn - y) ** 2)
59 | return loss
60 |
61 | #----------------------------------------------------------------------------
62 | # Improved loss function proposed in the paper "Elucidating the Design Space
63 | # of Diffusion-Based Generative Models" (EDM).
64 |
65 | @persistence.persistent_class
66 | class EDMLoss:
67 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
68 | self.P_mean = P_mean
69 | self.P_std = P_std
70 | self.sigma_data = sigma_data
71 |
72 | def __call__(self, net, images, labels=None, augment_pipe=None):
73 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
74 | sigma = (rnd_normal * self.P_std + self.P_mean).exp()
75 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
76 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
77 | n = torch.randn_like(y) * sigma
78 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
79 | loss = weight * ((D_yn - y) ** 2)
80 | return loss
81 |
82 | #----------------------------------------------------------------------------
83 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from torch.utils.tensorboard import SummaryWriter
4 | import logging
5 | import yaml
6 |
7 |
8 | def over_write_args_from_file(args, yml):
9 | if yml == '':
10 | return
11 | with open(yml, 'r', encoding='utf-8') as f:
12 | dic = yaml.load(f.read(), Loader=yaml.Loader)
13 | for k in dic:
14 | setattr(args, k, dic[k])
15 |
16 |
17 | def setattr_cls_from_kwargs(cls, kwargs):
18 | # if default values are in the cls,
19 | # overlap the value by kwargs
20 | for key in kwargs.keys():
21 | if hasattr(cls, key):
22 | print(f"{key} in {cls} is overlapped by kwargs: {getattr(cls, key)} -> {kwargs[key]}")
23 | setattr(cls, key, kwargs[key])
24 |
25 |
26 | def test_setattr_cls_from_kwargs():
27 | class _test_cls:
28 | def __init__(self):
29 | self.a = 1
30 | self.b = 'hello'
31 |
32 | test_cls = _test_cls()
33 | config = {'a': 3, 'b': 'change_hello', 'c': 5}
34 | setattr_cls_from_kwargs(test_cls, config)
35 | for key in config.keys():
36 | print(f"{key}:\t {getattr(test_cls, key)}")
37 |
38 |
39 | def net_builder(net_name, from_name: bool, net_conf=None, is_remix=False):
40 | """
41 | return **class** of backbone network (not instance).
42 | Args
43 | net_name: 'WideResNet' or network names in torchvision.models
44 | from_name: If True, net_buidler takes models in torch.vision models. Then, net_conf is ignored.
45 | net_conf: When from_name is False, net_conf is the configuration of backbone network (now, only WRN is supported).
46 | """
47 | if from_name:
48 | import torchvision.models as models
49 | model_name_list = sorted(name for name in models.__dict__
50 | if name.islower() and not name.startswith("__")
51 | and callable(models.__dict__[name]))
52 |
53 | if net_name not in model_name_list:
54 | assert Exception(f"[!] Networks\' Name is wrong, check net config, \
55 | expected: {model_name_list} \
56 | received: {net_name}")
57 | else:
58 | return models.__dict__[net_name]
59 |
60 | else:
61 | if net_name == 'WideResNet':
62 | import models.nets.wrn as net
63 | builder = getattr(net, 'build_WideResNet')()
64 | elif net_name == 'WideResNetVar':
65 | import models.nets.wrn_var as net
66 | builder = getattr(net, 'build_WideResNetVar')()
67 | elif net_name == 'ResNet50':
68 | import models.nets.resnet50 as net
69 | builder = getattr(net, 'build_ResNet50')(is_remix)
70 | else:
71 | assert Exception("Not Implemented Error")
72 |
73 | if net_name != 'ResNet50':
74 | setattr_cls_from_kwargs(builder, net_conf)
75 | return builder.build
76 |
77 |
78 | def test_net_builder(net_name, from_name, net_conf=None):
79 | builder = net_builder(net_name, from_name, net_conf)
80 | print(f"net_name: {net_name}, from_name: {from_name}, net_conf: {net_conf}")
81 | print(builder)
82 |
83 |
84 | def get_logger(name, save_path=None, level='INFO'):
85 | logger = logging.getLogger(name)
86 | logger.setLevel(getattr(logging, level))
87 |
88 | log_format = logging.Formatter('[%(asctime)s %(levelname)s] %(message)s')
89 | streamHandler = logging.StreamHandler()
90 | streamHandler.setFormatter(log_format)
91 | logger.addHandler(streamHandler)
92 |
93 | if not save_path is None:
94 | os.makedirs(save_path, exist_ok=True)
95 | fileHandler = logging.FileHandler(os.path.join(save_path, 'log.txt'))
96 | fileHandler.setFormatter(log_format)
97 | logger.addHandler(fileHandler)
98 |
99 | return logger
100 |
101 |
102 | def count_parameters(model):
103 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
104 |
--------------------------------------------------------------------------------
/src/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from logging import getLogger
9 | import torch
10 | import math
11 | from src.utils import AllReduce
12 |
13 |
14 | logger = getLogger()
15 |
16 |
17 | def init_msn_loss(
18 | num_views=1,
19 | tau=0.1,
20 | me_max=True,
21 | return_preds=False
22 | ):
23 | """
24 | Make unsupervised MSN loss
25 |
26 | :num_views: number of anchor views
27 | :param tau: cosine similarity temperature
28 | :param me_max: whether to perform me-max regularization
29 | :param return_preds: whether to return anchor predictions
30 | """
31 | softmax = torch.nn.Softmax(dim=1)
32 |
33 | def sharpen(p, T):
34 | sharp_p = p**(1./T)
35 | sharp_p /= torch.sum(sharp_p, dim=1, keepdim=True)
36 | return sharp_p
37 |
38 | def snn(query, supports, support_labels, temp=tau):
39 | """ Soft Nearest Neighbours similarity classifier """
40 | query = torch.nn.functional.normalize(query)
41 | supports = torch.nn.functional.normalize(supports)
42 | return softmax(query @ supports.T / temp) @ support_labels
43 |
44 | def loss(
45 | anchor_views,
46 | target_views,
47 | prototypes,
48 | proto_labels,
49 | T=0.25,
50 | use_entropy=False,
51 | use_sinkhorn=False,
52 | sharpen=sharpen,
53 | snn=snn
54 | ):
55 | # Step 1: compute anchor predictions
56 | probs = snn(anchor_views, prototypes, proto_labels)
57 |
58 | # Step 2: compute targets for anchor predictions
59 | with torch.no_grad():
60 | targets = sharpen(snn(target_views, prototypes, proto_labels), T=T)
61 | if use_sinkhorn:
62 | targets = distributed_sinkhorn(targets)
63 | targets = torch.cat([targets for _ in range(num_views)], dim=0)
64 |
65 | # Step 3: compute cross-entropy loss H(targets, queries)
66 | loss = torch.mean(torch.sum(torch.log(probs**(-targets)), dim=1))
67 |
68 | # Step 4: compute me-max regularizer
69 | rloss = 0.
70 | if me_max:
71 | avg_probs = AllReduce.apply(torch.mean(probs, dim=0))
72 | rloss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))
73 |
74 | sloss = 0.
75 | if use_entropy:
76 | sloss = torch.mean(torch.sum(torch.log(probs**(-probs)), dim=1))
77 |
78 | # -- logging
79 | with torch.no_grad():
80 | num_ps = float(len(set(targets.argmax(dim=1).tolist())))
81 | max_t = targets.max(dim=1).values.mean()
82 | min_t = targets.min(dim=1).values.mean()
83 | log_dct = {'np': num_ps, 'max_t': max_t, 'min_t': min_t}
84 |
85 | if return_preds:
86 | return loss, rloss, sloss, log_dct, targets
87 |
88 | return loss, rloss, sloss, log_dct
89 |
90 | return loss
91 |
92 |
93 | @torch.no_grad()
94 | def distributed_sinkhorn(Q, num_itr=3, use_dist=True):
95 | _got_dist = use_dist and torch.distributed.is_available() \
96 | and torch.distributed.is_initialized() \
97 | and (torch.distributed.get_world_size() > 1)
98 |
99 | if _got_dist:
100 | world_size = torch.distributed.get_world_size()
101 | else:
102 | world_size = 1
103 |
104 | Q = Q.T
105 | B = Q.shape[1] * world_size # number of samples to assign
106 | K = Q.shape[0] # how many prototypes
107 |
108 | # make the matrix sums to 1
109 | sum_Q = torch.sum(Q)
110 | if _got_dist:
111 | torch.distributed.all_reduce(sum_Q)
112 | Q /= sum_Q
113 |
114 | for it in range(num_itr):
115 | # normalize each row: total weight per prototype must be 1/K
116 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
117 | if _got_dist:
118 | torch.distributed.all_reduce(sum_of_rows)
119 | Q /= sum_of_rows
120 | Q /= K
121 |
122 | # normalize each column: total weight per sample must be 1/B
123 | Q /= torch.sum(Q, dim=0, keepdim=True)
124 | Q /= B
125 |
126 | Q *= B # the colomns must sum to 1 so that Q is an assignment
127 | return Q.T
128 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/get_labels.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 |
4 | import torch
5 | import numpy as np
6 |
7 | from utils import net_builder
8 | from datasets.ssl_dataset import SSL_Dataset
9 | from datasets.data_utils import get_data_loader
10 | import accelerate
11 |
12 | if __name__ == "__main__":
13 | import argparse
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--load_path', type=str, default='./saved_models/fixmatch/model_best.pth')
16 | parser.add_argument('--use_train_model', action='store_true')
17 |
18 | '''
19 | Backbone Net Configurations
20 | '''
21 | parser.add_argument('--net', type=str, default='WideResNet')
22 | parser.add_argument('--net_from_name', type=bool, default=False)
23 | parser.add_argument('--depth', type=int, default=28)
24 | parser.add_argument('--widen_factor', type=int, default=2)
25 | parser.add_argument('--leaky_slope', type=float, default=0.1)
26 | parser.add_argument('--dropout', type=float, default=0.0)
27 |
28 | '''
29 | Data Configurations
30 | '''
31 | parser.add_argument('--batch_size', type=int, default=256)
32 | parser.add_argument('--data_dir', type=str, default='./data')
33 | parser.add_argument('--dataset', type=str, default='cifar10')
34 | parser.add_argument('--num_classes', type=int, default=10)
35 |
36 | parser.add_argument('--save_path', type=str, default='./saved_models/fixmatch')
37 |
38 | args = parser.parse_args()
39 |
40 | accelerator = accelerate.Accelerator()
41 |
42 | checkpoint_path = os.path.join(args.load_path)
43 | checkpoint = torch.load(checkpoint_path)
44 | load_model = checkpoint['train_model'] if args.use_train_model else checkpoint['ema_model']
45 |
46 | _net_builder = net_builder(args.net,
47 | args.net_from_name,
48 | {'depth': args.depth,
49 | 'widen_factor': args.widen_factor,
50 | 'leaky_slope': args.leaky_slope,
51 | 'dropRate': args.dropout,
52 | 'use_embed': False})
53 |
54 | net = _net_builder(num_classes=args.num_classes)
55 |
56 | _train_dset = SSL_Dataset(args, alg='fullysupervised', name=args.dataset, train=True,
57 | num_classes=args.num_classes, data_dir=args.data_dir)
58 | train_dset = _train_dset.get_label_dset()
59 |
60 |
61 | train_loader = get_data_loader(train_dset,
62 | args.batch_size,
63 | num_workers=1, shuffle=False)
64 |
65 | # net, train_loader = accelerator.prepare(net, train_loader)
66 |
67 | # net.load_state_dict(load_model)
68 | weights_dict = {}
69 | for k, v in load_model.items():
70 | new_k = k.replace('module.', '') if 'module' in k else k
71 | weights_dict[new_k] = v
72 | net.load_state_dict(weights_dict)
73 | # if torch.cuda.is_available():
74 | # net.cuda()
75 | net.eval()
76 |
77 | acc = 0.0
78 | target_list = []
79 | pseudo_label_list = []
80 | with torch.no_grad():
81 | for (_, image, target) in train_loader:
82 | print(_)
83 | image = image.type(torch.FloatTensor)#.cuda()
84 | logit = net(image)
85 | print(logit.shape)
86 |
87 | target, logit = accelerator.gather(target), accelerator.gather(logit)
88 | for x, y in zip(target.cpu(), logit.cpu().max(1)[1]):
89 | target_list.append(x)
90 | pseudo_label_list.append(y)
91 |
92 | acc += logit.cpu().max(1)[1].eq(target.cpu()).sum().numpy()
93 |
94 | accelerator.print(len(train_dset))
95 |
96 | target_list = torch.stack(target_list)
97 | pseudo_label_list = torch.stack(pseudo_label_list)
98 |
99 | np.save(os.path.join(args.save_path, 'pseudo_label.npy'), pseudo_label_list)
100 |
101 | #accelerator.print(pseudo_label_list[:100],'\n', target_list[:100])
102 |
103 | print(pseudo_label_list[:100], '\n', target_list[:100])
104 | print(pseudo_label_list[:100] == target_list[:100])
105 |
106 | result1 = torch.bincount(target_list)
107 | result2 = torch.bincount(pseudo_label_list)
108 |
109 | print(result1)
110 | print(result2)
111 |
112 | accelerator.print(f"Test Accuracy: {acc/len(train_dset)}")
113 |
114 |
115 |
116 |
117 |
118 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from torch.utils.data import Dataset
3 | from .data_utils import get_onehot
4 | from .augmentation.randaugment import RandAugment
5 |
6 | import torchvision
7 | from PIL import Image
8 | import numpy as np
9 | import copy
10 |
11 |
12 | class BasicDataset(Dataset):
13 | """
14 | BasicDataset returns a pair of image and labels (targets).
15 | If targets are not given, BasicDataset returns None as the label.
16 | This class supports strong augmentation for Fixmatch,
17 | and return both weakly and strongly augmented images.
18 | """
19 |
20 | def __init__(self,
21 | alg,
22 | data,
23 | targets=None,
24 | num_classes=None,
25 | transform=None,
26 | is_ulb=False,
27 | strong_transform=None,
28 | onehot=False,
29 | *args, **kwargs):
30 | """
31 | Args
32 | data: x_data
33 | targets: y_data (if not exist, None)
34 | num_classes: number of label classes
35 | transform: basic transformation of data
36 | use_strong_transform: If True, this dataset returns both weakly and strongly augmented images.
37 | strong_transform: list of transformation functions for strong augmentation
38 | onehot: If True, label is converted into onehot vector.
39 | """
40 | super(BasicDataset, self).__init__()
41 | self.alg = alg
42 | self.data = data
43 | self.targets = targets
44 |
45 | self.num_classes = num_classes
46 | self.is_ulb = is_ulb
47 | self.onehot = onehot
48 |
49 | self.transform = transform
50 | if self.is_ulb:
51 | if strong_transform is None:
52 | self.strong_transform = copy.deepcopy(transform)
53 | self.strong_transform.transforms.insert(0, RandAugment(3, 5))
54 | else:
55 | self.strong_transform = strong_transform
56 |
57 | def __getitem__(self, idx):
58 | """
59 | If strong augmentation is not used,
60 | return weak_augment_image, target
61 | else:
62 | return weak_augment_image, strong_augment_image, target
63 | """
64 |
65 | # set idx-th target
66 | if self.targets is None:
67 | target = None
68 | else:
69 | target_ = self.targets[idx]
70 | target = target_ if not self.onehot else get_onehot(self.num_classes, target_)
71 |
72 | # set augmented images
73 |
74 | img = self.data[idx]
75 | if self.transform is None:
76 | return transforms.ToTensor()(img), target
77 | else:
78 | if isinstance(img, np.ndarray):
79 | img = Image.fromarray(img)
80 | img_w = self.transform(img)
81 | if not self.is_ulb:
82 | return idx, img_w, target
83 | else:
84 | if self.alg == 'fixmatch':
85 | return idx, img_w, self.strong_transform(img)
86 | elif self.alg == 'flexmatch':
87 | return idx, img_w, self.strong_transform(img)
88 | elif self.alg == 'softmatch' or self.alg == 'freematch' or self.alg == 'freematch_entropy':
89 | return idx, img_w, self.strong_transform(img)
90 | elif self.alg == 'pimodel':
91 | return idx, img_w, self.transform(img)
92 | elif self.alg == 'pseudolabel':
93 | return idx, img_w
94 | elif self.alg == 'vat':
95 | return idx, img_w
96 | elif self.alg == 'meanteacher':
97 | return idx, img_w, self.transform(img)
98 | elif self.alg == 'uda':
99 | return idx, img_w, self.strong_transform(img)
100 | elif self.alg == 'mixmatch':
101 | return idx, img_w, self.transform(img)
102 | elif self.alg == 'remixmatch':
103 | rotate_v_list = [0, 90, 180, 270]
104 | rotate_v1 = np.random.choice(rotate_v_list, 1).item()
105 | img_s1 = self.strong_transform(img)
106 | img_s1_rot = torchvision.transforms.functional.rotate(img_s1, rotate_v1)
107 | img_s2 = self.strong_transform(img)
108 | return idx, img_w, img_s1, img_s2, img_s1_rot, rotate_v_list.index(rotate_v1)
109 | elif self.alg == 'fullysupervised':
110 | return idx
111 |
112 | def __len__(self):
113 | return len(self.data)
114 |
--------------------------------------------------------------------------------
/libs/timm.py:
--------------------------------------------------------------------------------
1 | # code from timm 0.3.2
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | import warnings
6 |
7 |
8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11 | def norm_cdf(x):
12 | # Computes standard normal cumulative distribution function
13 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
14 |
15 | if (mean < a - 2 * std) or (mean > b + 2 * std):
16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17 | "The distribution of values may be incorrect.",
18 | stacklevel=2)
19 |
20 | with torch.no_grad():
21 | # Values are generated by using a truncated uniform distribution and
22 | # then using the inverse CDF for the normal distribution.
23 | # Get upper and lower cdf values
24 | l = norm_cdf((a - mean) / std)
25 | u = norm_cdf((b - mean) / std)
26 |
27 | # Uniformly fill tensor with values from [l, u], then translate to
28 | # [2l-1, 2u-1].
29 | tensor.uniform_(2 * l - 1, 2 * u - 1)
30 |
31 | # Use inverse cdf transform for normal distribution to get truncated
32 | # standard normal
33 | tensor.erfinv_()
34 |
35 | # Transform to proper mean, std
36 | tensor.mul_(std * math.sqrt(2.))
37 | tensor.add_(mean)
38 |
39 | # Clamp to ensure it's in the proper range
40 | tensor.clamp_(min=a, max=b)
41 | return tensor
42 |
43 |
44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45 | # type: (Tensor, float, float, float, float) -> Tensor
46 | r"""Fills the input Tensor with values drawn from a truncated
47 | normal distribution. The values are effectively drawn from the
48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49 | with values outside :math:`[a, b]` redrawn until they are within
50 | the bounds. The method used for generating the random values works
51 | best when :math:`a \leq \text{mean} \leq b`.
52 | Args:
53 | tensor: an n-dimensional `torch.Tensor`
54 | mean: the mean of the normal distribution
55 | std: the standard deviation of the normal distribution
56 | a: the minimum cutoff value
57 | b: the maximum cutoff value
58 | Examples:
59 | >>> w = torch.empty(3, 5)
60 | >>> nn.init.trunc_normal_(w)
61 | """
62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63 |
64 |
65 | def drop_path(x, drop_prob: float = 0., training: bool = False):
66 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
67 |
68 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
69 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
70 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
71 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
72 | 'survival rate' as the argument.
73 |
74 | """
75 | if drop_prob == 0. or not training:
76 | return x
77 | keep_prob = 1 - drop_prob
78 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
79 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
80 | random_tensor.floor_() # binarize
81 | output = x.div(keep_prob) * random_tensor
82 | return output
83 |
84 |
85 | class DropPath(nn.Module):
86 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
87 | """
88 | def __init__(self, drop_prob=None):
89 | super(DropPath, self).__init__()
90 | self.drop_prob = drop_prob
91 |
92 | def forward(self, x):
93 | return drop_path(x, self.drop_prob, self.training)
94 |
95 |
96 | class Mlp(nn.Module):
97 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
98 | super().__init__()
99 | out_features = out_features or in_features
100 | hidden_features = hidden_features or in_features
101 | self.fc1 = nn.Linear(in_features, hidden_features)
102 | self.act = act_layer()
103 | self.fc2 = nn.Linear(hidden_features, out_features)
104 | self.drop = nn.Dropout(drop)
105 |
106 | def forward(self, x):
107 | x = self.fc1(x)
108 | x = self.act(x)
109 | x = self.drop(x)
110 | x = self.fc2(x)
111 | x = self.drop(x)
112 | return x
113 |
--------------------------------------------------------------------------------
/grid_sample.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 | import torch
3 | from torch import multiprocessing as mp
4 | import accelerate
5 | import utils
6 | import sde
7 | from datasets import get_dataset
8 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
9 | from absl import logging
10 | import einops
11 | from torchvision.utils import save_image, make_grid
12 |
13 |
14 | def evaluate(config):
15 | if config.get('benchmark', False):
16 | torch.backends.cudnn.benchmark = True
17 | torch.backends.cudnn.deterministic = False
18 |
19 | mp.set_start_method('spawn')
20 | accelerator = accelerate.Accelerator()
21 | device = accelerator.device
22 | accelerate.utils.set_seed(config.seed, device_specific=True)
23 |
24 | config.mixed_precision = accelerator.mixed_precision
25 | config = ml_collections.FrozenConfigDict(config)
26 | utils.set_logger(log_level='info')
27 |
28 | dataset = get_dataset(**config.dataset)
29 |
30 | nnet = utils.get_nnet(**config.nnet)
31 | nnet = accelerator.prepare(nnet)
32 | logging.info(f'load nnet from {config.nnet_path}')
33 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
34 | nnet.eval()
35 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
36 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
37 | def cfg_nnet(x, timesteps, y):
38 | _cond = nnet(x, timesteps, y=y)
39 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device))
40 | return _cond + config.sample.scale * (_cond - _uncond)
41 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE())
42 | else:
43 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE())
44 |
45 | logging.info(config.sample)
46 | logging.info(f'mode={config.train.mode}, mixed_precision={config.mixed_precision}')
47 |
48 | def sample_fn(_x_init, _kwargs):
49 | if config.sample.algorithm == 'euler_maruyama_sde':
50 | rsde = sde.ReverseSDE(score_model)
51 | _samples = sde.euler_maruyama(rsde, _x_init, config.sample.sample_steps,
52 | verbose=accelerator.is_main_process, **_kwargs)
53 | elif config.sample.algorithm == 'euler_maruyama_ode':
54 | rsde = sde.ODE(score_model)
55 | _samples = sde.euler_maruyama(rsde, _x_init, config.sample.sample_steps,
56 | verbose=accelerator.is_main_process, **_kwargs)
57 | elif config.sample.algorithm == 'dpm_solver':
58 | noise_schedule = NoiseScheduleVP(schedule='linear')
59 | model_fn = model_wrapper(
60 | score_model.noise_pred,
61 | noise_schedule,
62 | time_input_type='0',
63 | model_kwargs=_kwargs
64 | )
65 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
66 | _samples = dpm_solver.sample(
67 | _x_init,
68 | steps=config.sample.sample_steps,
69 | eps=1e-4,
70 | adaptive_step_size=False,
71 | fast_version=True,
72 | )
73 | else:
74 | raise NotImplementedError
75 |
76 | return _samples
77 |
78 | if config.train.mode == 'uncond':
79 | x_init = torch.randn(100, *dataset.data_shape, device=device)
80 | kwargs = dict()
81 | idx = 0
82 | samples = []
83 | for _batch_size in utils.amortize(len(x_init), config.sample.mini_batch_size):
84 | samples.append(sample_fn(x_init[idx: idx + _batch_size], kwargs))
85 | idx += _batch_size
86 |
87 | elif config.train.mode == 'cond':
88 | x_init = torch.randn(config.nnet.num_classes * 10, *dataset.data_shape, device=device)
89 | y = einops.repeat(torch.arange(config.nnet.num_classes, device=device), 'nrow -> (nrow ncol)', ncol=10)
90 | idx = 0
91 | samples = []
92 | for _batch_size in utils.amortize(len(x_init), config.sample.mini_batch_size):
93 | samples.append(sample_fn(x_init[idx: idx + _batch_size], dict(y=y[idx: idx + _batch_size])))
94 | idx += _batch_size
95 |
96 | else:
97 | raise NotImplementedError
98 |
99 | samples = torch.cat(samples, dim=0)
100 | samples = dataset.unpreprocess(samples)
101 | save_image(make_grid(samples, 10), config.output_path)
102 |
103 |
104 |
105 | from absl import flags
106 | from absl import app
107 | from ml_collections import config_flags
108 |
109 |
110 | FLAGS = flags.FLAGS
111 | config_flags.DEFINE_config_file("config", None, "Training configuration.", lock_config=False)
112 | flags.mark_flags_as_required(["config"])
113 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
114 | flags.DEFINE_string("output_path", None, "The path to output samples.")
115 |
116 |
117 | def main(argv):
118 | config = FLAGS.config
119 | config.nnet_path = FLAGS.nnet_path
120 | config.output_path = FLAGS.output_path
121 | evaluate(config)
122 |
123 |
124 | if __name__ == "__main__":
125 | app.run(main)
126 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/datasets/augmentation/randaugment.py:
--------------------------------------------------------------------------------
1 | # copyright: https://github.com/ildoonet/pytorch-randaugment
2 | # code in this file is adpated from rpmcruz/autoaugment
3 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
4 | # This code is modified version of one of ildoonet, for randaugmentation of fixmatch.
5 |
6 | import random
7 |
8 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | from PIL import Image
13 |
14 |
15 | def AutoContrast(img, _):
16 | return PIL.ImageOps.autocontrast(img)
17 |
18 |
19 | def Brightness(img, v):
20 | assert v >= 0.0
21 | return PIL.ImageEnhance.Brightness(img).enhance(v)
22 |
23 |
24 | def Color(img, v):
25 | assert v >= 0.0
26 | return PIL.ImageEnhance.Color(img).enhance(v)
27 |
28 |
29 | def Contrast(img, v):
30 | assert v >= 0.0
31 | return PIL.ImageEnhance.Contrast(img).enhance(v)
32 |
33 |
34 | def Equalize(img, _):
35 | return PIL.ImageOps.equalize(img)
36 |
37 |
38 | def Invert(img, _):
39 | return PIL.ImageOps.invert(img)
40 |
41 |
42 | def Identity(img, v):
43 | return img
44 |
45 |
46 | def Posterize(img, v): # [4, 8]
47 | v = int(v)
48 | v = max(1, v)
49 | return PIL.ImageOps.posterize(img, v)
50 |
51 |
52 | def Rotate(img, v): # [-30, 30]
53 | #assert -30 <= v <= 30
54 | #if random.random() > 0.5:
55 | # v = -v
56 | return img.rotate(v)
57 |
58 |
59 |
60 | def Sharpness(img, v): # [0.1,1.9]
61 | assert v >= 0.0
62 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
63 |
64 |
65 | def ShearX(img, v): # [-0.3, 0.3]
66 | #assert -0.3 <= v <= 0.3
67 | #if random.random() > 0.5:
68 | # v = -v
69 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
70 |
71 |
72 | def ShearY(img, v): # [-0.3, 0.3]
73 | #assert -0.3 <= v <= 0.3
74 | #if random.random() > 0.5:
75 | # v = -v
76 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
77 |
78 |
79 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
80 | #assert -0.3 <= v <= 0.3
81 | #if random.random() > 0.5:
82 | # v = -v
83 | v = v * img.size[0]
84 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
85 |
86 |
87 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
88 | #assert v >= 0.0
89 | #if random.random() > 0.5:
90 | # v = -v
91 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
92 |
93 |
94 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
95 | #assert -0.3 <= v <= 0.3
96 | #if random.random() > 0.5:
97 | # v = -v
98 | v = v * img.size[1]
99 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
100 |
101 |
102 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
103 | #assert 0 <= v
104 | #if random.random() > 0.5:
105 | # v = -v
106 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
107 |
108 |
109 | def Solarize(img, v): # [0, 256]
110 | assert 0 <= v <= 256
111 | return PIL.ImageOps.solarize(img, v)
112 |
113 |
114 | def Cutout(img, v): #[0, 60] => percentage: [0, 0.2] => change to [0, 0.5]
115 | assert 0.0 <= v <= 0.5
116 | if v <= 0.:
117 | return img
118 |
119 | v = v * img.size[0]
120 | return CutoutAbs(img, v)
121 |
122 |
123 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
124 | # assert 0 <= v <= 20
125 | if v < 0:
126 | return img
127 | w, h = img.size
128 | x0 = np.random.uniform(w)
129 | y0 = np.random.uniform(h)
130 |
131 | x0 = int(max(0, x0 - v / 2.))
132 | y0 = int(max(0, y0 - v / 2.))
133 | x1 = min(w, x0 + v)
134 | y1 = min(h, y0 + v)
135 |
136 | xy = (x0, y0, x1, y1)
137 | color = (125, 123, 114)
138 | # color = (0, 0, 0)
139 | img = img.copy()
140 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
141 | return img
142 |
143 |
144 | def augment_list():
145 | l = [
146 | (AutoContrast, 0, 1),
147 | (Brightness, 0.05, 0.95),
148 | (Color, 0.05, 0.95),
149 | (Contrast, 0.05, 0.95),
150 | (Equalize, 0, 1),
151 | (Identity, 0, 1),
152 | (Posterize, 4, 8),
153 | (Rotate, -30, 30),
154 | (Sharpness, 0.05, 0.95),
155 | (ShearX, -0.3, 0.3),
156 | (ShearY, -0.3, 0.3),
157 | (Solarize, 0, 256),
158 | (TranslateX, -0.3, 0.3),
159 | (TranslateY, -0.3, 0.3)
160 | ]
161 | return l
162 |
163 |
164 | class RandAugment:
165 | def __init__(self, n, m):
166 | self.n = n
167 | self.m = m # [0, 30] in fixmatch, deprecated.
168 | self.augment_list = augment_list()
169 |
170 |
171 | def __call__(self, img):
172 | ops = random.choices(self.augment_list, k=self.n)
173 | for op, min_val, max_val in ops:
174 | val = min_val + float(max_val - min_val)*random.random()
175 | img = op(img, val)
176 | cutout_val = random.random() * 0.5
177 | img = Cutout(img, cutout_val) #for fixmatch
178 | return img
179 |
180 |
181 | if __name__ == '__main__':
182 | # randaug = RandAugment(3,5)
183 | # print(randaug)
184 | # for item in randaug.augment_list:
185 | # print(item)
186 | import os
187 |
188 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
189 | img = PIL.Image.open('./u.jpg')
190 | randaug = RandAugment(3,6)
191 | img = randaug(img)
192 | import matplotlib
193 | from matplotlib import pyplot as plt
194 | plt.imshow(img)
195 | plt.show()
196 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/datasets/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torchvision import datasets
4 | from torch.utils.data import sampler, DataLoader
5 | from torch.utils.data.sampler import BatchSampler
6 | import torch.distributed as dist
7 | import numpy as np
8 | import json
9 | import os
10 |
11 | from datasets.DistributedProxySampler import DistributedProxySampler
12 |
13 |
14 | def split_ssl_data(args, data, target, num_labels, num_classes, index=None, include_lb_to_ulb=True):
15 | """
16 | data & target is splitted into labeled and unlabeld data.
17 |
18 | Args
19 | index: If np.array of index is given, select the data[index], target[index] as labeled samples.
20 | include_lb_to_ulb: If True, labeled data is also included in unlabeld data
21 | """
22 | data, target = np.array(data), np.array(target)
23 | lb_data, lbs, lb_idx, = sample_labeled_data(args, data, target, num_labels, num_classes, index)
24 | ulb_idx = np.array(sorted(list(set(range(len(data))) - set(lb_idx)))) # unlabeled_data index of data
25 | if include_lb_to_ulb:
26 | return lb_data, lbs, data, target
27 | else:
28 | return lb_data, lbs, data[ulb_idx], target[ulb_idx]
29 |
30 |
31 | def sample_labeled_data(args, data, target,
32 | num_labels, num_classes,
33 | index=None, name=None):
34 | '''
35 | samples for labeled data
36 | (sampling with balanced ratio over classes)
37 | '''
38 | assert num_labels % num_classes == 0
39 | if not index is None:
40 | index = np.array(index, dtype=np.int32)
41 | return data[index], target[index], index
42 |
43 | dump_path = os.path.join(args.save_dir, args.save_name, 'sampled_label_idx.npy')
44 |
45 | if os.path.exists(dump_path):
46 | lb_idx = np.load(dump_path)
47 | lb_data = data[lb_idx]
48 | lbs = target[lb_idx]
49 | return lb_data, lbs, lb_idx
50 |
51 | samples_per_class = int(num_labels / num_classes)
52 |
53 | lb_data = []
54 | lbs = []
55 | lb_idx = []
56 | for c in range(num_classes):
57 | idx = np.where(target == c)[0]
58 | idx = np.random.choice(idx, samples_per_class, False)
59 | lb_idx.extend(idx)
60 |
61 | lb_data.extend(data[idx])
62 | lbs.extend(target[idx])
63 |
64 | np.save(dump_path, np.array(lb_idx))
65 |
66 | return np.array(lb_data), np.array(lbs), np.array(lb_idx)
67 |
68 |
69 | def get_sampler_by_name(name):
70 | '''
71 | get sampler in torch.utils.data.sampler by name
72 | '''
73 | sampler_name_list = sorted(name for name in torch.utils.data.sampler.__dict__
74 | if not name.startswith('_') and callable(sampler.__dict__[name]))
75 | try:
76 | if name == 'DistributedSampler':
77 | return torch.utils.data.distributed.DistributedSampler
78 | else:
79 | return getattr(torch.utils.data.sampler, name)
80 | except Exception as e:
81 | print(repr(e))
82 | print('[!] select sampler in:\t', sampler_name_list)
83 |
84 |
85 | def get_data_loader(dset,
86 | batch_size=None,
87 | shuffle=False,
88 | num_workers=4,
89 | pin_memory=False,
90 | data_sampler=None,
91 | replacement=True,
92 | num_epochs=None,
93 | num_iters=None,
94 | generator=None,
95 | drop_last=True,
96 | distributed=False):
97 | """
98 | get_data_loader returns torch.utils.data.DataLoader for a Dataset.
99 | All arguments are comparable with those of pytorch DataLoader.
100 | However, if distributed, DistributedProxySampler, which is a wrapper of data_sampler, is used.
101 |
102 | Args
103 | num_epochs: total batch -> (# of batches in dset) * num_epochs
104 | num_iters: total batch -> num_iters
105 | """
106 |
107 | assert batch_size is not None
108 |
109 | if data_sampler is None:
110 | return DataLoader(dset, batch_size=batch_size, shuffle=shuffle,
111 | num_workers=num_workers, pin_memory=pin_memory)
112 |
113 | else:
114 | if isinstance(data_sampler, str):
115 | data_sampler = get_sampler_by_name(data_sampler)
116 |
117 | if distributed:
118 | assert dist.is_available()
119 | num_replicas = dist.get_world_size()
120 | else:
121 | num_replicas = 1
122 |
123 | if (num_epochs is not None) and (num_iters is None):
124 | num_samples = len(dset) * num_epochs
125 | elif (num_epochs is None) and (num_iters is not None):
126 | num_samples = batch_size * num_iters * num_replicas
127 | else:
128 | num_samples = len(dset)
129 |
130 | if data_sampler.__name__ == 'RandomSampler':
131 | data_sampler = data_sampler(dset, replacement, num_samples, generator)
132 | else:
133 | raise RuntimeError(f"{data_sampler.__name__} is not implemented.")
134 |
135 | if distributed:
136 | '''
137 | Different with DistributedSampler,
138 | the DistribuedProxySampler does not shuffle the data (just wrapper for dist).
139 | '''
140 | data_sampler = DistributedProxySampler(data_sampler)
141 |
142 | batch_sampler = BatchSampler(data_sampler, batch_size, drop_last)
143 | return DataLoader(dset, batch_sampler=batch_sampler,
144 | num_workers=num_workers, pin_memory=pin_memory)
145 |
146 |
147 | def get_onehot(num_classes, idx):
148 | onehot = np.zeros([num_classes], dtype=np.float32)
149 | onehot[idx] += 1.0
150 | return onehot
151 |
--------------------------------------------------------------------------------
/cifar10_experiment/README.md:
--------------------------------------------------------------------------------
1 | ## Diffusion Models and Semi-Supervised Learners Benefit Mutually with Few Labels
Official PyTorch implementation of the NeurIPS 2023 paper
2 |
3 | **Diffusion Models and Semi-Supervised Learners Benefit Mutually with Few Labels**
4 | You, Zebin and Zhong, Yong and Bao, Fan and Sun, Jiacheng and Li, Chongxuan and Zhu, Jun
5 |
https://arxiv.org/abs/2302.10586
6 |
7 | Abstract: *In an effort to further advance semi-supervised generative and classification tasks, we propose a simple yet effective training strategy called dual pseudo training (DPT), built upon strong semi-supervised learners and diffusion models. DPT operates in three stages: training a classifier on partially labeled data to predict pseudo-labels; training a conditional generative model using these pseudo-labels to generate pseudo images; and retraining the classifier with a mix of real and pseudo images. Empirically, DPT consistently achieves SOTA performance of semi-supervised generation and classification across various settings. In particular, with one or two labels per class, DPT achieves a Fréchet Inception Distance (FID) score of 3.08 or 2.52 on ImageNet 256x256. Besides, DPT outperforms competitive semi-supervised baselines substantially on ImageNet classification tasks, achieving top-1 accuracies of 59.0 (+2.8), 69.5 (+3.0), and 74.4 (+2.0) with one, two, or five labels per class, respectively. Notably, our results demonstrate that diffusion can generate realistic images with only a few labels (e.g., <0.1%) and generative augmentation remains viable for semi-supervised classification.*
8 |
9 | ## Requirements
10 | ```.bash
11 | conda create -n dpt_cifar python==3.9
12 | conda activate dpt_cifar
13 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
14 | pip install tensorboard
15 | pip install accelerate
16 | pip install --upgrade scikit-learn
17 | pip install pandas
18 | pip install click
19 | pip install Pillow==9.3.0
20 | ```
21 |
22 | ## Get Started
23 | ### Stage 1: Training the Classifier and Generating Pseudo-Labels
24 | To begin, navigate to the TorchSSL directory and run the following commands to train a classifier, generate pseudo-labels, and prepare the dataset required for EDM.
25 | ```.bash
26 | # Stage 1: Training the Classifier and Generating Pseudo-Labels
27 | # After training, a high-quality classifier will be saved in the directory: saved_models/freematch_cifar10_40_1/, We can then use this classifier to obtain pseudo-labels
28 | python freematch.py --c config/freematch_cifar10_40_1.yaml
29 |
30 | # Generating Pseudo-Labels
31 | python get_labels.py --load_path ./saved_models/freematch_cifar10_40_1/model_best.pth --save_path ./saved_models/freematch_cifar10_40_1/ >> saved_models/freematch_cifar10_40_1/get_labels.log
32 |
33 | # Using pseudo-labels, we create a new dataset, which will be used to train our generative model
34 | # Creating a Pseudo-Dataset
35 | python pseudo_dataset.py --data_path ./data/cifar-10-batches-py --save_path ./saved_models/freematch_cifar10_40_1
36 | ```
37 | The purpose of the above steps is to generate a new dataset, similar to CIFAR-10 but with pseudo-labels instead of true labels.
38 |
39 | ```.bash
40 | # Preparing Pseudo Dataset for Training EDM
41 | python dataset_tool.py --source=./saved_models/freematch_cifar10_40_1/cifar-10-python.tar.gz --dest=./saved_models/freematch_cifar10_40_1/freematch-40-seed1-cifar10-32x32.zip
42 | ```
43 | This step converts the dataset into the format required by EDM.
44 |
45 | ### Stage 2: Training the Generative Model and Generating Pseudo-Images
46 | Following the dataset preparation, navigate to the EDM directory and execute the following commands for the second stage:
47 | ```.bash
48 | # Stage 2: Training the Generative Model
49 | torchrun --standalone --nproc_per_node=4 train.py --outdir=training-runs \
50 | --data=../TorchSSL/saved_models/freematch_cifar10_40_1/freematch-40-seed1-cifar10-32x32.zip --cond=1 --arch=ddpmpp --batch-gpu=32
51 | ```
52 | In the above command:
53 | - The **`--data`** flag specifies the dataset created earlier
54 | - The **`--cond`** indicates conditional generation
55 | - The **`--arch=ddpmpp`** signifies the use of the ddpmpp model
56 | - The **`--batch-gpu=32`** sets the batch size per GPU to 32
57 | This command is used for training the generative model in the second stage of the process.
58 |
59 | Assuming the EDM training has been completed, and the latest generative model is saved as **'./training-runs/00000-freematch-40-seed1-cifar10-32x32-cond-ddpmpp-edm-gpus4-batch512-fp32/network-snapshot-latest.pkl'**, you can use the following command to generate pseudo-images:
60 | ```.bash
61 | bash generate.sh 1000 ../TorchSSL/data-aug/nums_1000 ./training-runs/00000-freematch-40-seed1-cifar10-32x32-cond-ddpmpp-edm-gpus4-batch512-fp32/network-snapshot-latest.pkl 4
62 | ```
63 | In the above command:
64 | - The first argument, **`1000`**, represents generating 1001 samples for each class, ranging from 0 to 1000.
65 | - The second argument, **`../TorchSSL/data-aug/nums_1000`**, represents the path where the generated samples will be saved.
66 | - The third argument, **`./training-runs/00000-freematch-40-seed1-cifar10-32x32-cond-ddpmpp-edm-gpus4-batch512-fp32/network-snapshot-latest.pkl`**, is the path to the generative model you want to use for generation.
67 | - The fourth argument, **`4`**, specifies the number of GPUs to use for the generation process.
68 |
69 | ### Stage 3: Training the Classifier with Pseudo-Images
70 | Finally, we can train the classifier with the pseudo-images generated in the previous step. To do so, navigate to the TorchSSL directory and execute the following commands:
71 | ```.bash
72 | # stage3
73 | python freematch.py --c config/freematch_cifar10_40_1_aug1000.yaml --aug_path=./data-aug/nums_1000
74 | ```
75 |
76 | ## Question
77 | If you have any questions, please feel free to contact us via email: zebin@ruc.edu.cn
78 |
79 | ## Citation
80 |
81 | ```
82 | @inproceedings{you2023diffusion,
83 | author = {You, Zebin and Zhong, Yong and Bao, Fan and Sun, Jiacheng and Li, Chongxuan and Zhu, Jun},
84 | title = {Diffusion models and semi-supervised learners benefit mutually with few labels},
85 | booktitle = {Proc. NeurIPS},
86 | year = {2023}
87 | }
88 | ```
89 |
90 | ## Acknowledgments
91 |
92 | We would like to express our gratitude to the remarkable projects EDM (https://github.com/NVlabs/edm) and TorchSSL (https://github.com/TorchSSL/TorchSSL). Our work is built upon their contributions.
93 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import os
5 | from tqdm import tqdm
6 | from torchvision.utils import save_image
7 | from absl import logging
8 |
9 |
10 | def set_logger(log_level='info', fname=None):
11 | import logging as _logging
12 | handler = logging.get_absl_handler()
13 | formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s')
14 | handler.setFormatter(formatter)
15 | logging.set_verbosity(log_level)
16 | if fname is not None:
17 | handler = _logging.FileHandler(fname)
18 | handler.setFormatter(formatter)
19 | logging.get_absl_logger().addHandler(handler)
20 |
21 |
22 | def dct2str(dct):
23 | return str({k: f'{v:.6g}' for k, v in dct.items()})
24 |
25 |
26 | def get_nnet(name, **kwargs):
27 | if name == 'uvit':
28 | from libs.uvit import UViT
29 | return UViT(**kwargs)
30 | elif name == 'uvit_t2i':
31 | from libs.uvit_t2i import UViT
32 | return UViT(**kwargs)
33 | else:
34 | raise NotImplementedError(name)
35 |
36 |
37 | def set_seed(seed: int):
38 | if seed is not None:
39 | torch.manual_seed(seed)
40 | np.random.seed(seed)
41 |
42 |
43 | def get_optimizer(params, name, **kwargs):
44 | if name == 'adam':
45 | from torch.optim import Adam
46 | return Adam(params, **kwargs)
47 | elif name == 'adamw':
48 | from torch.optim import AdamW
49 | return AdamW(params, **kwargs)
50 | else:
51 | raise NotImplementedError(name)
52 |
53 |
54 | def customized_lr_scheduler(optimizer, warmup_steps=-1):
55 | from torch.optim.lr_scheduler import LambdaLR
56 | def fn(step):
57 | if warmup_steps > 0:
58 | return min(step / warmup_steps, 1)
59 | else:
60 | return 1
61 | return LambdaLR(optimizer, fn)
62 |
63 |
64 | def get_lr_scheduler(optimizer, name, **kwargs):
65 | if name == 'customized':
66 | return customized_lr_scheduler(optimizer, **kwargs)
67 | elif name == 'cosine':
68 | from torch.optim.lr_scheduler import CosineAnnealingLR
69 | return CosineAnnealingLR(optimizer, **kwargs)
70 | else:
71 | raise NotImplementedError(name)
72 |
73 |
74 | def ema(model_dest: nn.Module, model_src: nn.Module, rate):
75 | param_dict_src = dict(model_src.named_parameters())
76 | for p_name, p_dest in model_dest.named_parameters():
77 | p_src = param_dict_src[p_name]
78 | assert p_src is not p_dest
79 | p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
80 |
81 |
82 | class TrainState(object):
83 | def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None):
84 | self.optimizer = optimizer
85 | self.lr_scheduler = lr_scheduler
86 | self.step = step
87 | self.nnet = nnet
88 | self.nnet_ema = nnet_ema
89 |
90 | def ema_update(self, rate=0.9999):
91 | if self.nnet_ema is not None:
92 | ema(self.nnet_ema, self.nnet, rate)
93 |
94 | def save(self, path):
95 | os.makedirs(path, exist_ok=True)
96 | torch.save(self.step, os.path.join(path, 'step.pth'))
97 | for key, val in self.__dict__.items():
98 | if key != 'step' and val is not None:
99 | torch.save(val.state_dict(), os.path.join(path, f'{key}.pth'))
100 |
101 | def load(self, path):
102 | logging.info(f'load from {path}')
103 | self.step = torch.load(os.path.join(path, 'step.pth'))
104 | for key, val in self.__dict__.items():
105 | if key != 'step' and val is not None:
106 | val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))
107 |
108 | def resume(self, ckpt_root, step=None):
109 | if not os.path.exists(ckpt_root):
110 | return
111 | if step is None:
112 | ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root)))
113 | if not ckpts:
114 | return
115 | steps = map(lambda x: int(x.split(".")[0]), ckpts)
116 | step = max(steps)
117 | ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt')
118 | logging.info(f'resume from {ckpt_path}')
119 | self.load(ckpt_path)
120 |
121 | def to(self, device):
122 | for key, val in self.__dict__.items():
123 | if isinstance(val, nn.Module):
124 | val.to(device)
125 |
126 |
127 | def cnt_params(model):
128 | return sum(param.numel() for param in model.parameters())
129 |
130 |
131 | def initialize_train_state(config, device):
132 | params = []
133 |
134 | nnet = get_nnet(**config.nnet)
135 | params += nnet.parameters()
136 | nnet_ema = get_nnet(**config.nnet)
137 | nnet_ema.eval()
138 | logging.info(f'nnet has {cnt_params(nnet)} parameters')
139 |
140 | optimizer = get_optimizer(params, **config.optimizer)
141 | lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler)
142 |
143 | train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0,
144 | nnet=nnet, nnet_ema=nnet_ema)
145 | train_state.ema_update(0)
146 | train_state.to(device)
147 | return train_state
148 |
149 |
150 | def amortize(n_samples, batch_size):
151 | k = n_samples // batch_size
152 | r = n_samples % batch_size
153 | return k * [batch_size] if r == 0 else k * [batch_size] + [r]
154 |
155 |
156 | def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None):
157 | os.makedirs(path, exist_ok=True)
158 | idx = 0
159 | batch_size = mini_batch_size * accelerator.num_processes
160 |
161 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
162 | samples = unpreprocess_fn(sample_fn(mini_batch_size))
163 | samples = accelerator.gather(samples.contiguous())[:_batch_size]
164 | if accelerator.is_main_process:
165 | for sample in samples:
166 | save_image(sample, os.path.join(path, f"{idx}.png"))
167 | idx += 1
168 |
169 |
170 | def grad_norm(model):
171 | total_norm = 0.
172 | for p in model.parameters():
173 | param_norm = p.grad.data.norm(2)
174 | total_norm += param_norm.item() ** 2
175 | total_norm = total_norm ** (1. / 2)
176 | return total_norm
177 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/custom_writer.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from pathlib import Path
3 | from typing import Sequence, Union, Tuple
4 | import numpy as np
5 | import json
6 | import torch
7 | import os
8 |
9 | class CustomWriter(object):
10 | '''
11 | Custom Writer for training record.
12 | Parameters:
13 | -----------
14 | log_dir : pathlib.Path or str, path to save logs.
15 | enabled : bool, whether to enable tensorboard writer.
16 | '''
17 | def __init__(self, log_dir, enabled=True):
18 | self.writer = None
19 | self.selected_module = ''
20 |
21 | if enabled:
22 | self.log_dir = str(log_dir)
23 | self.stats = {}
24 | if not os.path.exists(self.log_dir):
25 | os.makedirs(self.log_dir, exist_ok=True)
26 |
27 | # Attributes to record
28 | self.epoch = 0
29 | self.mode = None
30 | self.timer = datetime.datetime.now()
31 | self.tb_writer_funcs = {
32 | 'add_scalar', 'add_scalars',
33 | 'add_image', 'add_images',
34 | 'add_figure',
35 | 'add_audio',
36 | 'add_text',
37 | 'add_histogram',
38 | 'add_pr_curve',
39 | #'add_embedding', # TODO: problem with add_embedding
40 | }
41 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} # TODO : Test these two funcs.
42 |
43 | def dump_stats(self):
44 | with open(f"{self.log_dir}/log", "w") as f:
45 | json.dump(self.stats, f,
46 | indent=4,
47 | ensure_ascii=False,
48 | separators=(",", ": "),
49 | )
50 |
51 | def set_epoch(self, epoch, mode):
52 | '''
53 | Execute this function to update the step attribute and compute the cost time of one epoch in seconds.
54 | Recommend to run this function every step.
55 | This function MUST be executed before other custom writer functions.
56 | Parameters:
57 | ------------
58 | step : int, step number.
59 | mode : str, 'train' or 'valid'
60 | '''
61 | if epoch == 0:
62 | self.timer = datetime.datetime.now()
63 | elif epoch != self.epoch:
64 | duration = datetime.datetime.now() - self.timer
65 | second_per_epoch = duration.total_seconds() / (epoch - self.epoch)
66 | self.add_scalar(tag='second_per_epoch', data=second_per_epoch)
67 | self.epoch = epoch
68 | self.mode = mode
69 |
70 | def get_epoch(self) -> int:
71 | return self.epoch
72 |
73 | def get_keys(self, epoch: int = None) -> Tuple[str, ...]:
74 | """Returns keys1 e.g. train,eval."""
75 | if epoch is None:
76 | epoch = self.get_epoch()
77 | return tuple(self.stats[epoch])
78 |
79 | def get_keys2(self, key: str, epoch: int = None) -> Tuple[str, ...]:
80 | """Returns keys2 e.g. loss,acc."""
81 | if epoch is None:
82 | epoch = self.get_epoch()
83 | d = self.stats[epoch][key]
84 | keys2 = tuple(k for k in d if k not in ("time", "total_count"))
85 | return keys2
86 |
87 | def plot_stats(self):
88 | self.matplotlib_plot(self.log_dir)
89 |
90 | def matplotlib_plot(self, output_dir: Union[str, Path]):
91 | """Plot stats using Matplotlib and save images."""
92 | keys2 = set.union(*[set(self.get_keys2(k)) for k in self.get_keys()])
93 | for key2 in keys2:
94 | keys = [k for k in self.get_keys() if key2 in self.get_keys2(k)]
95 | plt = self._plot_stats(keys, key2)
96 | p = Path(output_dir) / f"{key2}.png"
97 | p.parent.mkdir(parents=True, exist_ok=True)
98 | plt.savefig(p)
99 |
100 | def _plot_stats(self, keys: Sequence[str], key2: str):
101 | # str is also Sequence[str]
102 | if isinstance(keys, str):
103 | raise TypeError(f"Input as [{keys}]")
104 |
105 | import matplotlib
106 |
107 | matplotlib.use("agg")
108 | import matplotlib.pyplot as plt
109 | import matplotlib.ticker as ticker
110 |
111 | plt.clf()
112 |
113 | epochs = sorted(list(self.stats.keys()))
114 | for key in keys:
115 | y = [
116 | self.stats[e][key][key2]
117 | if e in self.stats
118 | and key in self.stats[e]
119 | and key2 in self.stats[e][key]
120 | else np.nan
121 | for e in epochs
122 | ]
123 | assert len(epochs) == len(y), "Bug?"
124 |
125 | plt.plot(epochs, y, label=key2, marker="x")
126 | plt.legend()
127 | plt.title(f"iteration vs {key2}")
128 | # Force integer tick for x-axis
129 | plt.gca().get_xaxis().set_major_locator(ticker.MaxNLocator(integer=True))
130 | plt.xlabel("iteration")
131 | plt.ylabel(key2)
132 | plt.grid()
133 | return plt
134 |
135 | def to_numpy(self, a):
136 | if isinstance(a, list):
137 | return np.array(a)
138 | for kind in [torch.Tensor, torch.nn.Parameter]:
139 | if isinstance(a, kind):
140 | if hasattr(a, 'detach'):
141 | a = a.detach()
142 | return a.cpu().numpy()
143 | return a
144 |
145 | def add_scalar(self, tag, data):
146 | data = self.to_numpy(data)
147 | data = float(data)
148 | self.stats.setdefault(self.epoch, {}).setdefault(self.mode, {})[tag] = data
149 |
150 |
151 | def __getattr__(self, name):
152 | if name in self.tb_writer_funcs:
153 | func = getattr(self, name, None)
154 | # Return a wrapper for all functions.
155 | def wrapper(tag, data, *args, **kwargs):
156 | if func is not None:
157 | if name not in self.tag_mode_exceptions:
158 | tag = f"{tag}/{self.mode}"
159 | func(tag, data, *args, global_step=self.step, **kwargs)
160 |
161 | return wrapper
162 | else:
163 | # default __getattr__ function to get other attributes.
164 | try:
165 | attr = object.__getattr__(name)
166 | except AttributeError:
167 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
168 | return attr
169 |
170 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/models/nets/wrn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | momentum = 0.001
7 |
8 |
9 | def mish(x):
10 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)"""
11 | return x * torch.tanh(F.softplus(x))
12 |
13 |
14 | class PSBatchNorm2d(nn.BatchNorm2d):
15 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)"""
16 |
17 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True):
18 | super().__init__(num_features, eps, momentum, affine, track_running_stats)
19 | self.alpha = alpha
20 |
21 | def forward(self, x):
22 | return super().forward(x) + self.alpha
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False):
27 | super(BasicBlock, self).__init__()
28 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001, eps=0.001)
29 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=False)
30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
31 | padding=1, bias=True)
32 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001, eps=0.001)
33 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=False)
34 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
35 | padding=1, bias=True)
36 | self.drop_rate = drop_rate
37 | self.equalInOut = (in_planes == out_planes)
38 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
39 | padding=0, bias=True) or None
40 | self.activate_before_residual = activate_before_residual
41 |
42 | def forward(self, x):
43 | if not self.equalInOut and self.activate_before_residual == True:
44 | x = self.relu1(self.bn1(x))
45 | else:
46 | out = self.relu1(self.bn1(x))
47 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
48 | if self.drop_rate > 0:
49 | out = F.dropout(out, p=self.drop_rate, training=self.training)
50 | out = self.conv2(out)
51 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
52 |
53 |
54 | class NetworkBlock(nn.Module):
55 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False):
56 | super(NetworkBlock, self).__init__()
57 | self.layer = self._make_layer(
58 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual)
59 |
60 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual):
61 | layers = []
62 | for i in range(int(nb_layers)):
63 | layers.append(block(i == 0 and in_planes or out_planes, out_planes,
64 | i == 0 and stride or 1, drop_rate, activate_before_residual))
65 | return nn.Sequential(*layers)
66 |
67 | def forward(self, x):
68 | return self.layer(x)
69 |
70 |
71 | class WideResNet(nn.Module):
72 | def __init__(self, first_stride, num_classes, depth=28, widen_factor=2, drop_rate=0.0, is_remix=False):
73 | super(WideResNet, self).__init__()
74 | channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
75 | assert ((depth - 4) % 6 == 0)
76 | n = (depth - 4) / 6
77 | block = BasicBlock
78 | # 1st conv before any network block
79 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1,
80 | padding=1, bias=True)
81 | # 1st block
82 | self.block1 = NetworkBlock(
83 | n, channels[0], channels[1], block, first_stride, drop_rate, activate_before_residual=True)
84 | # 2nd block
85 | self.block2 = NetworkBlock(
86 | n, channels[1], channels[2], block, 2, drop_rate)
87 | # 3rd block
88 | self.block3 = NetworkBlock(
89 | n, channels[2], channels[3], block, 2, drop_rate)
90 | # global average pooling and classifier
91 | self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001, eps=0.001)
92 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=False)
93 | self.fc = nn.Linear(channels[3], num_classes)
94 | self.channels = channels[3]
95 |
96 | # rot_classifier for Remix Match
97 | self.is_remix = is_remix
98 | if is_remix:
99 | self.rot_classifier = nn.Linear(self.channels, 4)
100 |
101 | for m in self.modules():
102 | if isinstance(m, nn.Conv2d):
103 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
104 | elif isinstance(m, nn.BatchNorm2d):
105 | m.weight.data.fill_(1)
106 | m.bias.data.zero_()
107 | elif isinstance(m, nn.Linear):
108 | nn.init.xavier_normal_(m.weight.data)
109 | m.bias.data.zero_()
110 |
111 | def forward(self, x, ood_test=False):
112 | out = self.conv1(x)
113 | out = self.block1(out)
114 | out = self.block2(out)
115 | out = self.block3(out)
116 | out = self.relu(self.bn1(out))
117 | out = F.adaptive_avg_pool2d(out, 1)
118 | out = out.view(-1, self.channels)
119 | output = self.fc(out)
120 |
121 | if ood_test:
122 | return output, out
123 | else:
124 | if self.is_remix:
125 | rot_output = self.rot_classifier(out)
126 | return output, rot_output
127 | else:
128 | return output
129 |
130 |
131 | class build_WideResNet:
132 | def __init__(self, first_stride=1, depth=28, widen_factor=2, bn_momentum=0.01, leaky_slope=0.0, dropRate=0.0,
133 | use_embed=False, is_remix=False):
134 | self.first_stride = first_stride
135 | self.depth = depth
136 | self.widen_factor = widen_factor
137 | self.bn_momentum = bn_momentum
138 | self.dropRate = dropRate
139 | self.leaky_slope = leaky_slope
140 | self.use_embed = use_embed
141 | self.is_remix = is_remix
142 |
143 | def build(self, num_classes):
144 | return WideResNet(
145 | first_stride=self.first_stride,
146 | depth=self.depth,
147 | num_classes=num_classes,
148 | widen_factor=self.widen_factor,
149 | drop_rate=self.dropRate,
150 | is_remix=self.is_remix,
151 | )
152 |
153 |
154 | if __name__ == '__main__':
155 | wrn_builder = build_WideResNet(1, 10, 2, 0.01, 0.1, 0.5)
156 | wrn = wrn_builder.build(10)
157 | print(wrn)
158 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/models/nets/wrn_var.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | momentum = 0.001
7 |
8 |
9 | def mish(x):
10 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)"""
11 | return x * torch.tanh(F.softplus(x))
12 |
13 |
14 | class PSBatchNorm2d(nn.BatchNorm2d):
15 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)"""
16 |
17 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True):
18 | super().__init__(num_features, eps, momentum, affine, track_running_stats)
19 | self.alpha = alpha
20 |
21 | def forward(self, x):
22 | return super().forward(x) + self.alpha
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False):
27 | super(BasicBlock, self).__init__()
28 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001, eps=0.001)
29 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=False)
30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
31 | padding=1, bias=True)
32 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001, eps=0.001)
33 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=False)
34 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
35 | padding=1, bias=True)
36 | self.drop_rate = drop_rate
37 | self.equalInOut = (in_planes == out_planes)
38 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
39 | padding=0, bias=True) or None
40 | self.activate_before_residual = activate_before_residual
41 |
42 | def forward(self, x):
43 | if not self.equalInOut and self.activate_before_residual == True:
44 | x = self.relu1(self.bn1(x))
45 | else:
46 | out = self.relu1(self.bn1(x))
47 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
48 | if self.drop_rate > 0:
49 | out = F.dropout(out, p=self.drop_rate, training=self.training)
50 | out = self.conv2(out)
51 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
52 |
53 |
54 | class NetworkBlock(nn.Module):
55 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False):
56 | super(NetworkBlock, self).__init__()
57 | self.layer = self._make_layer(
58 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual)
59 |
60 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual):
61 | layers = []
62 | for i in range(int(nb_layers)):
63 | layers.append(block(i == 0 and in_planes or out_planes, out_planes,
64 | i == 0 and stride or 1, drop_rate, activate_before_residual))
65 | return nn.Sequential(*layers)
66 |
67 | def forward(self, x):
68 | return self.layer(x)
69 |
70 |
71 | class WideResNetVar(nn.Module):
72 | def __init__(self, first_stride, num_classes, depth=28, widen_factor=2, drop_rate=0.0, is_remix=False):
73 | super(WideResNetVar, self).__init__()
74 | channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor, 128 * widen_factor]
75 | assert ((depth - 4) % 6 == 0)
76 | n = (depth - 4) / 6
77 | block = BasicBlock
78 | # 1st conv before any network block
79 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1,
80 | padding=1, bias=True)
81 | # 1st block
82 | self.block1 = NetworkBlock(
83 | n, channels[0], channels[1], block, first_stride, drop_rate, activate_before_residual=True)
84 | # 2nd block
85 | self.block2 = NetworkBlock(
86 | n, channels[1], channels[2], block, 2, drop_rate)
87 | # 3rd block
88 | self.block3 = NetworkBlock(
89 | n, channels[2], channels[3], block, 2, drop_rate)
90 | # 4th block
91 | self.block4 = NetworkBlock(
92 | n, channels[3], channels[4], block, 2, drop_rate)
93 | # global average pooling and classifier
94 | self.bn1 = nn.BatchNorm2d(channels[4], momentum=0.001, eps=0.001)
95 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=False)
96 | self.fc = nn.Linear(channels[4], num_classes)
97 | self.channels = channels[4]
98 |
99 | # rot_classifier for Remix Match
100 | self.is_remix = is_remix
101 | if is_remix:
102 | self.rot_classifier = nn.Linear(self.channels, 4)
103 |
104 | for m in self.modules():
105 | if isinstance(m, nn.Conv2d):
106 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
107 | elif isinstance(m, nn.BatchNorm2d):
108 | m.weight.data.fill_(1)
109 | m.bias.data.zero_()
110 | elif isinstance(m, nn.Linear):
111 | nn.init.xavier_normal_(m.weight.data)
112 | m.bias.data.zero_()
113 |
114 | def forward(self, x, ood_test=False):
115 | out = self.conv1(x)
116 | out = self.block1(out)
117 | out = self.block2(out)
118 | out = self.block3(out)
119 | out = self.block4(out)
120 | out = self.relu(self.bn1(out))
121 | out = F.adaptive_avg_pool2d(out, 1)
122 | out = out.view(-1, self.channels)
123 | output = self.fc(out)
124 |
125 | if ood_test:
126 | return output, out
127 | else:
128 | if self.is_remix:
129 | rot_output = self.rot_classifier(out)
130 | return output, rot_output
131 | else:
132 | return output
133 |
134 |
135 | class build_WideResNetVar:
136 | def __init__(self, first_stride=1, depth=28, widen_factor=2, bn_momentum=0.01, leaky_slope=0.0, dropRate=0.0,
137 | use_embed=False, is_remix=False):
138 | self.first_stride = first_stride
139 | self.depth = depth
140 | self.widen_factor = widen_factor
141 | self.bn_momentum = bn_momentum
142 | self.dropRate = dropRate
143 | self.leaky_slope = leaky_slope
144 | self.use_embed = use_embed
145 | self.is_remix = is_remix
146 |
147 | def build(self, num_classes):
148 | return WideResNetVar(
149 | first_stride=self.first_stride,
150 | depth=self.depth,
151 | num_classes=num_classes,
152 | widen_factor=self.widen_factor,
153 | drop_rate=self.dropRate,
154 | is_remix=self.is_remix,
155 | )
156 |
157 |
158 | if __name__ == '__main__':
159 | wrn_builder = build_WideResNet(1, 10, 2, 0.01, 0.1, 0.5)
160 | wrn = wrn_builder.build(10)
161 | print(wrn)
162 |
--------------------------------------------------------------------------------
/cifar10_experiment/edm/fid.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Script for calculating Frechet Inception Distance (FID)."""
9 |
10 | import os
11 | import click
12 | import tqdm
13 | import pickle
14 | import numpy as np
15 | import scipy.linalg
16 | import torch
17 | import dnnlib
18 | from torch_utils import distributed as dist
19 | from training import dataset
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | def calculate_inception_stats(
24 | image_path, num_expected=None, seed=0, max_batch_size=64,
25 | num_workers=3, prefetch_factor=2, device=torch.device('cuda'),
26 | ):
27 | # Rank 0 goes first.
28 | if dist.get_rank() != 0:
29 | torch.distributed.barrier()
30 |
31 | # Load Inception-v3 model.
32 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
33 | dist.print0('Loading Inception-v3 model...')
34 | #detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
35 | detector_url = 'assets/inception-2015-12-05.pkl'
36 | detector_kwargs = dict(return_features=True)
37 | feature_dim = 2048
38 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f:
39 | detector_net = pickle.load(f).to(device)
40 |
41 | # List images.
42 | dist.print0(f'Loading images from "{image_path}"...')
43 | dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
44 | if num_expected is not None and len(dataset_obj) < num_expected:
45 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
46 | if len(dataset_obj) < 2:
47 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')
48 |
49 | # Other ranks follow.
50 | if dist.get_rank() == 0:
51 | torch.distributed.barrier()
52 |
53 | # Divide images into batches.
54 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
55 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
56 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
57 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
58 |
59 | # Accumulate statistics.
60 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...')
61 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
62 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
63 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
64 | torch.distributed.barrier()
65 | if images.shape[0] == 0:
66 | continue
67 | if images.shape[1] == 1:
68 | images = images.repeat([1, 3, 1, 1])
69 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
70 | mu += features.sum(0)
71 | sigma += features.T @ features
72 |
73 | # Calculate grand totals.
74 | torch.distributed.all_reduce(mu)
75 | torch.distributed.all_reduce(sigma)
76 | mu /= len(dataset_obj)
77 | sigma -= mu.ger(mu) * len(dataset_obj)
78 | sigma /= len(dataset_obj) - 1
79 | return mu.cpu().numpy(), sigma.cpu().numpy()
80 |
81 | #----------------------------------------------------------------------------
82 |
83 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
84 | m = np.square(mu - mu_ref).sum()
85 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
86 | fid = m + np.trace(sigma + sigma_ref - s * 2)
87 | return float(np.real(fid))
88 |
89 | #----------------------------------------------------------------------------
90 |
91 | @click.group()
92 | def main():
93 | """Calculate Frechet Inception Distance (FID).
94 |
95 | Examples:
96 |
97 | \b
98 | # Generate 50000 images and save them as fid-tmp/*/*.png
99 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\
100 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
101 |
102 | \b
103 | # Calculate FID
104 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\
105 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
106 |
107 | \b
108 | # Compute dataset reference statistics
109 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
110 | """
111 |
112 | #----------------------------------------------------------------------------
113 |
114 | @main.command()
115 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True)
116 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True)
117 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True)
118 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True)
119 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
120 |
121 | def calc(image_path, ref_path, num_expected, seed, batch):
122 | """Calculate FID for a given set of images."""
123 | torch.multiprocessing.set_start_method('spawn')
124 | dist.init()
125 |
126 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...')
127 | ref = None
128 | if dist.get_rank() == 0:
129 | with dnnlib.util.open_url(ref_path) as f:
130 | ref = dict(np.load(f))
131 |
132 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch)
133 | dist.print0('Calculating FID...')
134 | if dist.get_rank() == 0:
135 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma'])
136 | print(f'{fid:g}')
137 | torch.distributed.barrier()
138 |
139 | #----------------------------------------------------------------------------
140 |
141 | @main.command()
142 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True)
143 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True)
144 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
145 |
146 | def ref(dataset_path, dest_path, batch):
147 | """Calculate dataset reference statistics needed by 'calc'."""
148 | torch.multiprocessing.set_start_method('spawn')
149 | dist.init()
150 |
151 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch)
152 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...')
153 | if dist.get_rank() == 0:
154 | if os.path.dirname(dest_path):
155 | os.makedirs(os.path.dirname(dest_path), exist_ok=True)
156 | np.savez(dest_path, mu=mu, sigma=sigma)
157 |
158 | torch.distributed.barrier()
159 | dist.print0('Done.')
160 |
161 | #----------------------------------------------------------------------------
162 |
163 | if __name__ == "__main__":
164 | main()
165 |
166 | #----------------------------------------------------------------------------
167 |
--------------------------------------------------------------------------------
/sample_ldm_all.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | import accelerate
6 | import utils
7 | import sde
8 | import einops
9 | from datasets import get_dataset
10 | import tempfile
11 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
12 | from absl import logging
13 | import builtins
14 | import libs.autoencoder
15 | import torch.nn as nn
16 | import numpy as np
17 | import os
18 | from tqdm import tqdm
19 | from torchvision.utils import make_grid, save_image
20 | from absl import logging
21 | import pickle
22 |
23 | def evaluate(config):
24 | if config.get('benchmark', False):
25 | torch.backends.cudnn.benchmark = True
26 | torch.backends.cudnn.deterministic = False
27 |
28 | mp.set_start_method('spawn')
29 | accelerator = accelerate.Accelerator()
30 | device = accelerator.device
31 | accelerate.utils.set_seed(config.seed, device_specific=True)
32 | logging.info(f'Process {accelerator.process_index} using device: {device}')
33 |
34 | config.mixed_precision = accelerator.mixed_precision
35 | config = ml_collections.FrozenConfigDict(config)
36 | if accelerator.is_main_process:
37 | utils.set_logger(log_level='info', fname=config.output_path)
38 | else:
39 | utils.set_logger(log_level='error')
40 | builtins.print = lambda *args: None
41 |
42 | dataset = get_dataset(**config.dataset)
43 |
44 | nnet = utils.get_nnet(**config.nnet)
45 | nnet = accelerator.prepare(nnet)
46 | cluster_name = config.model_name + '-' + '-'.join(config.subset_path.split('/')).split('.txt')[0]
47 |
48 | nnet_path = f'{config.dpm_path}/{cluster_name}/{config.resolution}/ckpts/{config.train.n_steps}.ckpt/nnet_ema.pth'
49 | logging.info(f'load nnet from {nnet_path}')
50 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(nnet_path, map_location='cpu'))
51 | nnet.eval()
52 |
53 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
54 | autoencoder.to(device)
55 |
56 | @torch.cuda.amp.autocast()
57 | def encode(_batch):
58 | return autoencoder.encode(_batch)
59 |
60 | @torch.cuda.amp.autocast()
61 | def decode(_batch):
62 | return autoencoder.decode(_batch)
63 |
64 | def decode_large_batch(_batch):
65 | decode_mini_batch_size = 20 # use a small batch size since the decoder is large
66 | xs = []
67 | pt = 0
68 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size):
69 | x = decode(_batch[pt: pt + _decode_mini_batch_size])
70 | pt += _decode_mini_batch_size
71 | xs.append(x)
72 | xs = torch.concat(xs, dim=0)
73 | assert xs.size(0) == _batch.size(0)
74 | return xs
75 |
76 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
77 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
78 | def cfg_nnet(x, timesteps, y):
79 | _cond = nnet(x, timesteps, y=y)
80 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device))
81 | return _cond + config.sample.scale * (_cond - _uncond)
82 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE())
83 | else:
84 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE())
85 |
86 | logging.info(config.sample)
87 | assert os.path.exists(dataset.fid_stat)
88 | logging.info(f'sample: each class sample n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}')
89 |
90 | def amortize(n_samples, batch_size):
91 | k = n_samples // batch_size
92 | r = n_samples % batch_size
93 | return k * [batch_size] if r == 0 else k * [batch_size] + [r]
94 |
95 | def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, class_num=0):
96 | os.makedirs(path, exist_ok=True)
97 | idx = 0
98 | batch_size = mini_batch_size * accelerator.num_processes
99 |
100 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
101 | samples = unpreprocess_fn(sample_fn(mini_batch_size, class_num))
102 | samples = accelerator.gather(samples.contiguous())[:_batch_size]
103 | if accelerator.local_process_index == 0:
104 | for sample in samples:
105 | save_image(sample, os.path.join(path, f"{idx}.png"))
106 | idx += 1
107 |
108 | def sample_fn(_n_samples, class_num):
109 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
110 | if config.train.mode == 'uncond':
111 | kwargs = dict()
112 | elif config.train.mode == 'cond':
113 | torch_arr = torch.ones(_n_samples // 10, device=device, dtype=int) * torch.tensor(int(class_num))
114 | kwargs = dict(y=einops.repeat(torch_arr % dataset.K, 'nrow -> (nrow ncol)', ncol=10))
115 | else:
116 | raise NotImplementedError
117 |
118 | if config.sample.algorithm == 'euler_maruyama_sde':
119 | _z = sde.euler_maruyama(sde.ReverseSDE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
120 | elif config.sample.algorithm == 'euler_maruyama_ode':
121 | _z = sde.euler_maruyama(sde.ODE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
122 | elif config.sample.algorithm == 'dpm_solver':
123 | noise_schedule = NoiseScheduleVP(schedule='linear')
124 | model_fn = model_wrapper(
125 | score_model.noise_pred,
126 | noise_schedule,
127 | time_input_type='0',
128 | model_kwargs=kwargs
129 | )
130 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
131 | _z = dpm_solver.sample(
132 | _z_init,
133 | steps=config.sample.sample_steps,
134 | eps=1e-4,
135 | adaptive_step_size=False,
136 | fast_version=True,
137 | )
138 | else:
139 | raise NotImplementedError
140 | return decode_large_batch(_z)
141 | f_read = open('idx_to_class.pkl', 'rb')
142 | dict2 = pickle.load(f_read)
143 | print(dict2)
144 |
145 | for i in range(1000):
146 | aug_samples_path = f'{config.dpm_path}/{cluster_name}/{config.resolution}/samples_for_classifier/aug_{config.augmentation_K}_samples'
147 |
148 | path = os.path.join(aug_samples_path, f'train/{dict2[i]}')
149 | if accelerator.is_main_process:
150 | os.makedirs(path, exist_ok=True)
151 | sample2dir(accelerator, path, config.augmentation_K, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess, class_num = i)
152 |
153 |
154 | from absl import flags
155 | from absl import app
156 | from ml_collections import config_flags
157 | import os
158 | import sys
159 | from pathlib import Path
160 |
161 |
162 | FLAGS = flags.FLAGS
163 | config_flags.DEFINE_config_file(
164 | "config", None, "Training configuration.", lock_config=False)
165 | flags.mark_flags_as_required(["config"])
166 | flags.DEFINE_string("output_path", None, "The path to output log.")
167 |
168 |
169 | def get_config_name():
170 | argv = sys.argv
171 | for i in range(1, len(argv)):
172 | if argv[i].startswith('--config='):
173 | return Path(argv[i].split('=')[-1]).stem
174 |
175 |
176 | def get_hparams():
177 | argv = sys.argv
178 | lst = []
179 | for i in range(1, len(argv)):
180 | assert '=' in argv[i]
181 | if argv[i].startswith('--config.'):
182 | hparam, val = argv[i].split('=')
183 | hparam = hparam.split('.')[-1]
184 | if hparam.endswith('path'):
185 | val = Path(val).stem
186 | lst.append(f'{hparam}={val}')
187 | hparams = '-'.join(lst)
188 | if hparams == '':
189 | hparams = 'default'
190 | return hparams
191 |
192 |
193 | def main(argv):
194 | config = FLAGS.config
195 | config_name = get_config_name()
196 | hparams = get_hparams()
197 | config.project = config_name
198 | config.notes = hparams
199 | config.output_path = FLAGS.output_path
200 |
201 | evaluate(config)
202 |
203 |
204 | if __name__ == "__main__":
205 | app.run(main)
206 |
--------------------------------------------------------------------------------
/sample_ldm_discrete_all.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | import accelerate
6 | import utils
7 | import sde
8 | import einops
9 | from datasets import get_dataset
10 | import tempfile
11 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
12 | from absl import logging
13 | import builtins
14 | import libs.autoencoder
15 | import torch.nn as nn
16 | import numpy as np
17 | import os
18 | from tqdm import tqdm
19 | from torchvision.utils import make_grid, save_image
20 | from absl import logging
21 | import pickle
22 |
23 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
24 | _betas = (
25 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
26 | )
27 | return _betas.numpy()
28 |
29 | def evaluate(config):
30 | if config.get('benchmark', False):
31 | torch.backends.cudnn.benchmark = True
32 | torch.backends.cudnn.deterministic = False
33 |
34 | mp.set_start_method('spawn')
35 | accelerator = accelerate.Accelerator()
36 | device = accelerator.device
37 | accelerate.utils.set_seed(config.seed, device_specific=True)
38 | logging.info(f'Process {accelerator.process_index} using device: {device}')
39 |
40 | config.mixed_precision = accelerator.mixed_precision
41 | config = ml_collections.FrozenConfigDict(config)
42 | if accelerator.is_main_process:
43 | utils.set_logger(log_level='info', fname=config.output_path)
44 | else:
45 | utils.set_logger(log_level='error')
46 | builtins.print = lambda *args: None
47 |
48 | dataset = get_dataset(**config.dataset)
49 |
50 | nnet = utils.get_nnet(**config.nnet)
51 | nnet = accelerator.prepare(nnet)
52 | if config.nnet_path == '':
53 | cluster_name = config.model_name + '-' + '-'.join(config.subset_path.split('/')).split('.txt')[0]
54 |
55 | nnet_path = f'{config.dpm_path}/{cluster_name}/{config.resolution}/ckpts/{config.train.n_steps}.ckpt/nnet_ema.pth'
56 |
57 | else:
58 | nnet_path = config.nnet_path
59 |
60 | logging.info(f'load nnet from {nnet_path}')
61 |
62 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(nnet_path, map_location='cpu'))
63 | nnet.eval()
64 |
65 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
66 | autoencoder.to(device)
67 |
68 | @torch.cuda.amp.autocast()
69 | def encode(_batch):
70 | return autoencoder.encode(_batch)
71 |
72 | @torch.cuda.amp.autocast()
73 | def decode(_batch):
74 | return autoencoder.decode(_batch)
75 |
76 | def decode_large_batch(_batch):
77 | decode_mini_batch_size = 20 # use a small batch size since the decoder is large
78 | xs = []
79 | pt = 0
80 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size):
81 | x = decode(_batch[pt: pt + _decode_mini_batch_size])
82 | pt += _decode_mini_batch_size
83 | xs.append(x)
84 | xs = torch.concat(xs, dim=0)
85 | assert xs.size(0) == _batch.size(0)
86 | return xs
87 |
88 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
89 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
90 | def cfg_nnet(x, timesteps, y):
91 | _cond = nnet(x, timesteps, y=y)
92 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device))
93 | return _cond + config.sample.scale * (_cond - _uncond)
94 | else:
95 | def cfg_nnet(x, timesteps, y):
96 | _cond = nnet(x, timesteps, y=y)
97 | return _cond
98 |
99 | logging.info(config.sample)
100 | assert os.path.exists(dataset.fid_stat)
101 | logging.info(f'sample: each class sample n_samples={config.augmentation_K}, mode={config.train.mode}, mixed_precision={config.mixed_precision}')
102 |
103 | _betas = stable_diffusion_beta_schedule()
104 | N = len(_betas)
105 |
106 | def amortize(n_samples, batch_size):
107 | k = n_samples // batch_size
108 | r = n_samples % batch_size
109 | return k * [batch_size] if r == 0 else k * [batch_size] + [r]
110 |
111 | def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, class_num=0):
112 | os.makedirs(path, exist_ok=True)
113 | idx = 0
114 | batch_size = mini_batch_size * accelerator.num_processes
115 |
116 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
117 | samples = unpreprocess_fn(sample_fn(mini_batch_size, class_num))
118 | samples = accelerator.gather(samples.contiguous())[:_batch_size]
119 | if accelerator.local_process_index == 0:
120 | for sample in samples:
121 | save_image(sample, os.path.join(path, f"{idx}.png"))
122 | idx += 1
123 |
124 | def sample_z(_n_samples, _sample_steps, **kwargs):
125 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
126 |
127 | if config.sample.algorithm == 'dpm_solver':
128 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
129 |
130 | def model_fn(x, t_continuous):
131 | t = t_continuous * N
132 | eps_pre = cfg_nnet(x, t, **kwargs)
133 | return eps_pre
134 |
135 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
136 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1.)
137 |
138 | else:
139 | raise NotImplementedError
140 |
141 | return _z
142 |
143 | def sample_fn(_n_samples, class_num):
144 | if config.train.mode == 'uncond':
145 | kwargs = dict()
146 | elif config.train.mode == 'cond':
147 | torch_arr = torch.ones(_n_samples // 10, device=device, dtype=int) * torch.tensor(int(class_num))
148 | kwargs = dict(y=einops.repeat(torch_arr % dataset.K, 'nrow -> (nrow ncol)', ncol=10))
149 | else:
150 | raise NotImplementedError
151 | _z = sample_z(_n_samples, _sample_steps=config.sample.sample_steps, **kwargs)
152 | return decode_large_batch(_z)
153 |
154 | f_read = open('idx_to_class.pkl', 'rb')
155 | dict2 = pickle.load(f_read)
156 | print(dict2)
157 |
158 | for i in range(1000):
159 | if config.sample.path == '':
160 | aug_samples_path = f'{config.dpm_path}/{cluster_name}/{config.resolution}/samples_for_classifier/aug_{config.augmentation_K}_samples'
161 |
162 | path = os.path.join(aug_samples_path, f'train/{dict2[i]}')
163 | else:
164 | path = os.path.join(config.sample.path, f'{dict2[i]}')
165 |
166 | if accelerator.is_main_process:
167 | os.makedirs(path, exist_ok=True)
168 | sample2dir(accelerator, path, config.augmentation_K, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess, class_num = i)
169 |
170 | if config.sample.path != '' and accelerator.is_main_process:
171 | os.system(f"cd {config.sample.path} && cd .. && tar -zcf samples.tar.gz samples")
172 |
173 | from absl import flags
174 | from absl import app
175 | from ml_collections import config_flags
176 | import os
177 | import sys
178 | from pathlib import Path
179 |
180 |
181 | FLAGS = flags.FLAGS
182 | config_flags.DEFINE_config_file(
183 | "config", None, "Training configuration.", lock_config=False)
184 | flags.mark_flags_as_required(["config"])
185 | flags.DEFINE_string("nnet_path", '', "The nnet to evaluate.")
186 | flags.DEFINE_string("output_path", None, "The path to output log.")
187 |
188 |
189 | def get_config_name():
190 | argv = sys.argv
191 | for i in range(1, len(argv)):
192 | if argv[i].startswith('--config='):
193 | return Path(argv[i].split('=')[-1]).stem
194 |
195 |
196 | def get_hparams():
197 | argv = sys.argv
198 | lst = []
199 | for i in range(1, len(argv)):
200 | assert '=' in argv[i]
201 | if argv[i].startswith('--config.'):
202 | hparam, val = argv[i].split('=')
203 | hparam = hparam.split('.')[-1]
204 | if hparam.endswith('path'):
205 | val = Path(val).stem
206 | lst.append(f'{hparam}={val}')
207 | hparams = '-'.join(lst)
208 | if hparams == '':
209 | hparams = 'default'
210 | return hparams
211 |
212 |
213 | def main(argv):
214 | config = FLAGS.config
215 | config_name = get_config_name()
216 | hparams = get_hparams()
217 | config.project = config_name
218 | config.notes = hparams
219 | config.nnet_path = FLAGS.nnet_path
220 | config.output_path = FLAGS.output_path
221 |
222 | evaluate(config)
223 |
224 |
225 | if __name__ == "__main__":
226 | app.run(main)
227 |
--------------------------------------------------------------------------------
/libs/uvit_t2i.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from .timm import trunc_normal_, Mlp
5 | import einops
6 | import torch.utils.checkpoint
7 |
8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
9 | ATTENTION_MODE = 'flash'
10 | else:
11 | try:
12 | import xformers
13 | import xformers.ops
14 | ATTENTION_MODE = 'xformers'
15 | except:
16 | ATTENTION_MODE = 'math'
17 | print(f'attention mode is {ATTENTION_MODE}')
18 |
19 |
20 | def timestep_embedding(timesteps, dim, max_period=10000):
21 | """
22 | Create sinusoidal timestep embeddings.
23 |
24 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
25 | These may be fractional.
26 | :param dim: the dimension of the output.
27 | :param max_period: controls the minimum frequency of the embeddings.
28 | :return: an [N x dim] Tensor of positional embeddings.
29 | """
30 | half = dim // 2
31 | freqs = torch.exp(
32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
33 | ).to(device=timesteps.device)
34 | args = timesteps[:, None].float() * freqs[None]
35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
36 | if dim % 2:
37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
38 | return embedding
39 |
40 |
41 | def patchify(imgs, patch_size):
42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
43 | return x
44 |
45 |
46 | def unpatchify(x, channels=3):
47 | patch_size = int((x.shape[2] // channels) ** 0.5)
48 | h = w = int(x.shape[1] ** .5)
49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
51 | return x
52 |
53 |
54 | class Attention(nn.Module):
55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
56 | super().__init__()
57 | self.num_heads = num_heads
58 | head_dim = dim // num_heads
59 | self.scale = qk_scale or head_dim ** -0.5
60 |
61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62 | self.attn_drop = nn.Dropout(attn_drop)
63 | self.proj = nn.Linear(dim, dim)
64 | self.proj_drop = nn.Dropout(proj_drop)
65 |
66 | def forward(self, x):
67 | B, L, C = x.shape
68 |
69 | qkv = self.qkv(x)
70 | if ATTENTION_MODE == 'flash':
71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
74 | x = einops.rearrange(x, 'B H L D -> B L (H D)')
75 | elif ATTENTION_MODE == 'xformers':
76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
78 | x = xformers.ops.memory_efficient_attention(q, k, v)
79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
80 | elif ATTENTION_MODE == 'math':
81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
83 | attn = (q @ k.transpose(-2, -1)) * self.scale
84 | attn = attn.softmax(dim=-1)
85 | attn = self.attn_drop(attn)
86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C)
87 | else:
88 | raise NotImplemented
89 |
90 | x = self.proj(x)
91 | x = self.proj_drop(x)
92 | return x
93 |
94 |
95 | class Block(nn.Module):
96 |
97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
99 | super().__init__()
100 | self.norm1 = norm_layer(dim)
101 | self.attn = Attention(
102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
103 | self.norm2 = norm_layer(dim)
104 | mlp_hidden_dim = int(dim * mlp_ratio)
105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
107 | self.use_checkpoint = use_checkpoint
108 |
109 | def forward(self, x, skip=None):
110 | if self.use_checkpoint:
111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
112 | else:
113 | return self._forward(x, skip)
114 |
115 | def _forward(self, x, skip=None):
116 | if self.skip_linear is not None:
117 | x = self.skip_linear(torch.cat([x, skip], dim=-1))
118 | x = x + self.attn(self.norm1(x))
119 | x = x + self.mlp(self.norm2(x))
120 | return x
121 |
122 |
123 | class PatchEmbed(nn.Module):
124 | """ Image to Patch Embedding
125 | """
126 | def __init__(self, patch_size, in_chans=3, embed_dim=768):
127 | super().__init__()
128 | self.patch_size = patch_size
129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
130 |
131 | def forward(self, x):
132 | B, C, H, W = x.shape
133 | assert H % self.patch_size == 0 and W % self.patch_size == 0
134 | x = self.proj(x).flatten(2).transpose(1, 2)
135 | return x
136 |
137 |
138 | class UViT(nn.Module):
139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False,
141 | clip_dim=768, num_clip_token=77, conv=True, skip=True):
142 | super().__init__()
143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
144 | self.in_chans = in_chans
145 |
146 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
147 | num_patches = (img_size // patch_size) ** 2
148 |
149 | self.time_embed = nn.Sequential(
150 | nn.Linear(embed_dim, 4 * embed_dim),
151 | nn.SiLU(),
152 | nn.Linear(4 * embed_dim, embed_dim),
153 | ) if mlp_time_embed else nn.Identity()
154 |
155 | self.context_embed = nn.Linear(clip_dim, embed_dim)
156 |
157 | self.extras = 1 + num_clip_token
158 |
159 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
160 |
161 | self.in_blocks = nn.ModuleList([
162 | Block(
163 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
164 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
165 | for _ in range(depth // 2)])
166 |
167 | self.mid_block = Block(
168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
169 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
170 |
171 | self.out_blocks = nn.ModuleList([
172 | Block(
173 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
174 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
175 | for _ in range(depth // 2)])
176 |
177 | self.norm = norm_layer(embed_dim)
178 | self.patch_dim = patch_size ** 2 * in_chans
179 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
180 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity()
181 |
182 | trunc_normal_(self.pos_embed, std=.02)
183 | self.apply(self._init_weights)
184 |
185 | def _init_weights(self, m):
186 | if isinstance(m, nn.Linear):
187 | trunc_normal_(m.weight, std=.02)
188 | if isinstance(m, nn.Linear) and m.bias is not None:
189 | nn.init.constant_(m.bias, 0)
190 | elif isinstance(m, nn.LayerNorm):
191 | nn.init.constant_(m.bias, 0)
192 | nn.init.constant_(m.weight, 1.0)
193 |
194 | @torch.jit.ignore
195 | def no_weight_decay(self):
196 | return {'pos_embed'}
197 |
198 | def forward(self, x, timesteps, context):
199 | x = self.patch_embed(x)
200 | B, L, D = x.shape
201 |
202 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
203 | time_token = time_token.unsqueeze(dim=1)
204 | context_token = self.context_embed(context)
205 | x = torch.cat((time_token, context_token, x), dim=1)
206 | x = x + self.pos_embed
207 |
208 | skips = []
209 | for blk in self.in_blocks:
210 | x = blk(x)
211 | skips.append(x)
212 |
213 | x = self.mid_block(x)
214 |
215 | for blk in self.out_blocks:
216 | x = blk(x, skips.pop())
217 |
218 | x = self.norm(x)
219 | x = self.decoder_pred(x)
220 | assert x.size(1) == self.extras + L
221 | x = x[:, self.extras:, :]
222 | x = unpatchify(x, self.in_chans)
223 | x = self.final_layer(x)
224 | return x
225 |
--------------------------------------------------------------------------------
/libs/uvit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from .timm import trunc_normal_, Mlp
5 | import einops
6 | import torch.utils.checkpoint
7 |
8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
9 | ATTENTION_MODE = 'flash'
10 | else:
11 | try:
12 | import xformers
13 | import xformers.ops
14 | ATTENTION_MODE = 'xformers'
15 | except:
16 | ATTENTION_MODE = 'math'
17 | print(f'attention mode is {ATTENTION_MODE}')
18 |
19 |
20 | def timestep_embedding(timesteps, dim, max_period=10000):
21 | """
22 | Create sinusoidal timestep embeddings.
23 |
24 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
25 | These may be fractional.
26 | :param dim: the dimension of the output.
27 | :param max_period: controls the minimum frequency of the embeddings.
28 | :return: an [N x dim] Tensor of positional embeddings.
29 | """
30 | half = dim // 2
31 | freqs = torch.exp(
32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
33 | ).to(device=timesteps.device)
34 | args = timesteps[:, None].float() * freqs[None]
35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
36 | if dim % 2:
37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
38 | return embedding
39 |
40 |
41 | def patchify(imgs, patch_size):
42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
43 | return x
44 |
45 |
46 | def unpatchify(x, channels=3):
47 | patch_size = int((x.shape[2] // channels) ** 0.5)
48 | h = w = int(x.shape[1] ** .5)
49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
51 | return x
52 |
53 |
54 | class Attention(nn.Module):
55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
56 | super().__init__()
57 | self.num_heads = num_heads
58 | head_dim = dim // num_heads
59 | self.scale = qk_scale or head_dim ** -0.5
60 |
61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62 | self.attn_drop = nn.Dropout(attn_drop)
63 | self.proj = nn.Linear(dim, dim)
64 | self.proj_drop = nn.Dropout(proj_drop)
65 |
66 | def forward(self, x):
67 | B, L, C = x.shape
68 |
69 | qkv = self.qkv(x)
70 | if ATTENTION_MODE == 'flash':
71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
74 | x = einops.rearrange(x, 'B H L D -> B L (H D)')
75 | elif ATTENTION_MODE == 'xformers':
76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
78 | x = xformers.ops.memory_efficient_attention(q, k, v)
79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
80 | elif ATTENTION_MODE == 'math':
81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
83 | attn = (q @ k.transpose(-2, -1)) * self.scale
84 | attn = attn.softmax(dim=-1)
85 | attn = self.attn_drop(attn)
86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C)
87 | else:
88 | raise NotImplemented
89 |
90 | x = self.proj(x)
91 | x = self.proj_drop(x)
92 | return x
93 |
94 |
95 | class Block(nn.Module):
96 |
97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
99 | super().__init__()
100 | self.norm1 = norm_layer(dim)
101 | self.attn = Attention(
102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
103 | self.norm2 = norm_layer(dim)
104 | mlp_hidden_dim = int(dim * mlp_ratio)
105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
107 | self.use_checkpoint = use_checkpoint
108 |
109 | def forward(self, x, skip=None):
110 | if self.use_checkpoint:
111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
112 | else:
113 | return self._forward(x, skip)
114 |
115 | def _forward(self, x, skip=None):
116 | if self.skip_linear is not None:
117 | x = self.skip_linear(torch.cat([x, skip], dim=-1))
118 | x = x + self.attn(self.norm1(x))
119 | x = x + self.mlp(self.norm2(x))
120 | return x
121 |
122 |
123 | class PatchEmbed(nn.Module):
124 | """ Image to Patch Embedding
125 | """
126 | def __init__(self, patch_size, in_chans=3, embed_dim=768):
127 | super().__init__()
128 | self.patch_size = patch_size
129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
130 |
131 | def forward(self, x):
132 | B, C, H, W = x.shape
133 | assert H % self.patch_size == 0 and W % self.patch_size == 0
134 | x = self.proj(x).flatten(2).transpose(1, 2)
135 | return x
136 |
137 |
138 | class UViT(nn.Module):
139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1,
141 | use_checkpoint=False, conv=True, skip=True):
142 | super().__init__()
143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
144 | self.num_classes = num_classes
145 | self.in_chans = in_chans
146 |
147 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
148 | num_patches = (img_size // patch_size) ** 2
149 |
150 | self.time_embed = nn.Sequential(
151 | nn.Linear(embed_dim, 4 * embed_dim),
152 | nn.SiLU(),
153 | nn.Linear(4 * embed_dim, embed_dim),
154 | ) if mlp_time_embed else nn.Identity()
155 |
156 | if self.num_classes > 0:
157 | self.label_emb = nn.Embedding(self.num_classes, embed_dim)
158 | self.extras = 2
159 | else:
160 | self.extras = 1
161 |
162 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
163 |
164 | self.in_blocks = nn.ModuleList([
165 | Block(
166 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
167 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
168 | for _ in range(depth // 2)])
169 |
170 | self.mid_block = Block(
171 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
172 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
173 |
174 | self.out_blocks = nn.ModuleList([
175 | Block(
176 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
177 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
178 | for _ in range(depth // 2)])
179 |
180 | self.norm = norm_layer(embed_dim)
181 | self.patch_dim = patch_size ** 2 * in_chans
182 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
183 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity()
184 |
185 | trunc_normal_(self.pos_embed, std=.02)
186 | self.apply(self._init_weights)
187 |
188 | def _init_weights(self, m):
189 | if isinstance(m, nn.Linear):
190 | trunc_normal_(m.weight, std=.02)
191 | if isinstance(m, nn.Linear) and m.bias is not None:
192 | nn.init.constant_(m.bias, 0)
193 | elif isinstance(m, nn.LayerNorm):
194 | nn.init.constant_(m.bias, 0)
195 | nn.init.constant_(m.weight, 1.0)
196 |
197 | @torch.jit.ignore
198 | def no_weight_decay(self):
199 | return {'pos_embed'}
200 |
201 | def forward(self, x, timesteps, y=None):
202 | x = self.patch_embed(x)
203 | B, L, D = x.shape
204 |
205 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
206 | time_token = time_token.unsqueeze(dim=1)
207 | x = torch.cat((time_token, x), dim=1)
208 | if y is not None:
209 | label_emb = self.label_emb(y)
210 | label_emb = label_emb.unsqueeze(dim=1)
211 | x = torch.cat((label_emb, x), dim=1)
212 | x = x + self.pos_embed
213 |
214 | skips = []
215 | for blk in self.in_blocks:
216 | x = blk(x)
217 | skips.append(x)
218 |
219 | x = self.mid_block(x)
220 |
221 | for blk in self.out_blocks:
222 | x = blk(x, skips.pop())
223 |
224 | x = self.norm(x)
225 | x = self.decoder_pred(x)
226 | assert x.size(1) == self.extras + L
227 | x = x[:, self.extras:, :]
228 | x = unpatchify(x, self.in_chans)
229 | x = self.final_layer(x)
230 | return x
231 |
--------------------------------------------------------------------------------
/sde.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from absl import logging
4 | import numpy as np
5 | import math
6 | from tqdm import tqdm
7 |
8 |
9 | def get_sde(name, **kwargs):
10 | if name == 'vpsde':
11 | return VPSDE(**kwargs)
12 | elif name == 'vpsde_cosine':
13 | return VPSDECosine(**kwargs)
14 | else:
15 | raise NotImplementedError
16 |
17 |
18 | def stp(s, ts: torch.Tensor): # scalar tensor product
19 | if isinstance(s, np.ndarray):
20 | s = torch.from_numpy(s).type_as(ts)
21 | extra_dims = (1,) * (ts.dim() - 1)
22 | return s.view(-1, *extra_dims) * ts
23 |
24 |
25 | def mos(a, start_dim=1): # mean of square
26 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
27 |
28 |
29 | def duplicate(tensor, *size):
30 | return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape)
31 |
32 |
33 | class SDE(object):
34 | r"""
35 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
36 | f(x, t) is the drift
37 | g(t) is the diffusion
38 | """
39 | def drift(self, x, t):
40 | raise NotImplementedError
41 |
42 | def diffusion(self, t):
43 | raise NotImplementedError
44 |
45 | def cum_beta(self, t): # the variance of xt|x0
46 | raise NotImplementedError
47 |
48 | def cum_alpha(self, t):
49 | raise NotImplementedError
50 |
51 | def snr(self, t): # signal noise ratio
52 | raise NotImplementedError
53 |
54 | def nsr(self, t): # noise signal ratio
55 | raise NotImplementedError
56 |
57 | def marginal_prob(self, x0, t): # the mean and std of q(xt|x0)
58 | alpha = self.cum_alpha(t)
59 | beta = self.cum_beta(t)
60 | mean = stp(alpha ** 0.5, x0) # E[xt|x0]
61 | std = beta ** 0.5 # Cov[xt|x0] ** 0.5
62 | return mean, std
63 |
64 | def sample(self, x0, t_init=0): # sample from q(xn|x0), where n is uniform
65 | t = torch.rand(x0.shape[0], device=x0.device) * (1. - t_init) + t_init
66 | mean, std = self.marginal_prob(x0, t)
67 | eps = torch.randn_like(x0)
68 | xt = mean + stp(std, eps)
69 | return t, eps, xt
70 |
71 |
72 | class VPSDE(SDE):
73 | def __init__(self, beta_min=0.1, beta_max=20):
74 | # 0 <= t <= 1
75 | self.beta_0 = beta_min
76 | self.beta_1 = beta_max
77 |
78 | def drift(self, x, t):
79 | return -0.5 * stp(self.squared_diffusion(t), x)
80 |
81 | def diffusion(self, t):
82 | return self.squared_diffusion(t) ** 0.5
83 |
84 | def squared_diffusion(self, t): # beta(t)
85 | return self.beta_0 + t * (self.beta_1 - self.beta_0)
86 |
87 | def squared_diffusion_integral(self, s, t): # \int_s^t beta(tau) d tau
88 | return self.beta_0 * (t - s) + (self.beta_1 - self.beta_0) * (t ** 2 - s ** 2) * 0.5
89 |
90 | def skip_beta(self, s, t): # beta_{t|s}, Cov[xt|xs]=beta_{t|s} I
91 | return 1. - self.skip_alpha(s, t)
92 |
93 | def skip_alpha(self, s, t): # alpha_{t|s}, E[xt|xs]=alpha_{t|s}**0.5 xs
94 | x = -self.squared_diffusion_integral(s, t)
95 | return x.exp()
96 |
97 | def cum_beta(self, t):
98 | return self.skip_beta(0, t)
99 |
100 | def cum_alpha(self, t):
101 | return self.skip_alpha(0, t)
102 |
103 | def nsr(self, t):
104 | return self.squared_diffusion_integral(0, t).expm1()
105 |
106 | def snr(self, t):
107 | return 1. / self.nsr(t)
108 |
109 | def __str__(self):
110 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
111 |
112 | def __repr__(self):
113 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
114 |
115 |
116 | class VPSDECosine(SDE):
117 | r"""
118 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
119 | f(x, t) is the drift
120 | g(t) is the diffusion
121 | """
122 | def __init__(self, s=0.008):
123 | self.s = s
124 | self.F = lambda t: torch.cos((t + s) / (1 + s) * math.pi / 2) ** 2
125 | self.F0 = math.cos(s / (1 + s) * math.pi / 2) ** 2
126 |
127 | def drift(self, x, t):
128 | ft = - torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi / 2
129 | return stp(ft, x)
130 |
131 | def diffusion(self, t):
132 | return (torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi) ** 0.5
133 |
134 | def cum_beta(self, t): # the variance of xt|x0
135 | return 1 - self.cum_alpha(t)
136 |
137 | def cum_alpha(self, t):
138 | return self.F(t) / self.F0
139 |
140 | def snr(self, t): # signal noise ratio
141 | Ft = self.F(t)
142 | return Ft / (self.F0 - Ft)
143 |
144 | def nsr(self, t): # noise signal ratio
145 | Ft = self.F(t)
146 | return self.F0 / Ft - 1
147 |
148 | def __str__(self):
149 | return 'vpsde_cosine'
150 |
151 | def __repr__(self):
152 | return 'vpsde_cosine'
153 |
154 |
155 | class ScoreModel(object):
156 | r"""
157 | The forward process is q(x_[0,T])
158 | """
159 |
160 | def __init__(self, nnet: nn.Module, pred: str, sde: SDE, T=1):
161 | assert T == 1
162 | self.nnet = nnet
163 | self.pred = pred
164 | self.sde = sde
165 | self.T = T
166 | print(f'ScoreModel with pred={pred}, sde={sde}, T={T}')
167 |
168 | def predict(self, xt, t, **kwargs):
169 | if not isinstance(t, torch.Tensor):
170 | t = torch.tensor(t)
171 | t = t.to(xt.device)
172 | if t.dim() == 0:
173 | t = duplicate(t, xt.size(0))
174 | return self.nnet(xt, t * 999, **kwargs) # follow SDE
175 |
176 | def noise_pred(self, xt, t, **kwargs):
177 | pred = self.predict(xt, t, **kwargs)
178 | if self.pred == 'noise_pred':
179 | noise_pred = pred
180 | elif self.pred == 'x0_pred':
181 | noise_pred = - stp(self.sde.snr(t).sqrt(), pred) + stp(self.sde.cum_beta(t).rsqrt(), xt)
182 | else:
183 | raise NotImplementedError
184 | return noise_pred
185 |
186 | def x0_pred(self, xt, t, **kwargs):
187 | pred = self.predict(xt, t, **kwargs)
188 | if self.pred == 'noise_pred':
189 | x0_pred = stp(self.sde.cum_alpha(t).rsqrt(), xt) - stp(self.sde.nsr(t).sqrt(), pred)
190 | elif self.pred == 'x0_pred':
191 | x0_pred = pred
192 | else:
193 | raise NotImplementedError
194 | return x0_pred
195 |
196 | def score(self, xt, t, **kwargs):
197 | cum_beta = self.sde.cum_beta(t)
198 | noise_pred = self.noise_pred(xt, t, **kwargs)
199 | return stp(-cum_beta.rsqrt(), noise_pred)
200 |
201 |
202 | class ReverseSDE(object):
203 | r"""
204 | dx = [f(x, t) - g(t)^2 s(x, t)] dt + g(t) dw
205 | """
206 | def __init__(self, score_model):
207 | self.sde = score_model.sde # the forward sde
208 | self.score_model = score_model
209 |
210 | def drift(self, x, t, **kwargs):
211 | drift = self.sde.drift(x, t) # f(x, t)
212 | diffusion = self.sde.diffusion(t) # g(t)
213 | score = self.score_model.score(x, t, **kwargs)
214 | return drift - stp(diffusion ** 2, score)
215 |
216 | def diffusion(self, t):
217 | return self.sde.diffusion(t)
218 |
219 |
220 | class ODE(object):
221 | r"""
222 | dx = [f(x, t) - g(t)^2 s(x, t)] dt
223 | """
224 |
225 | def __init__(self, score_model):
226 | self.sde = score_model.sde # the forward sde
227 | self.score_model = score_model
228 |
229 | def drift(self, x, t, **kwargs):
230 | drift = self.sde.drift(x, t) # f(x, t)
231 | diffusion = self.sde.diffusion(t) # g(t)
232 | score = self.score_model.score(x, t, **kwargs)
233 | return drift - 0.5 * stp(diffusion ** 2, score)
234 |
235 | def diffusion(self, t):
236 | return 0
237 |
238 |
239 | def dct2str(dct):
240 | return str({k: f'{v:.6g}' for k, v in dct.items()})
241 |
242 |
243 | @ torch.no_grad()
244 | def euler_maruyama(rsde, x_init, sample_steps, eps=1e-3, T=1, trace=None, verbose=False, **kwargs):
245 | r"""
246 | The Euler Maruyama sampler for reverse SDE / ODE
247 | See `Score-Based Generative Modeling through Stochastic Differential Equations`
248 | """
249 | assert isinstance(rsde, ReverseSDE) or isinstance(rsde, ODE)
250 | print(f"euler_maruyama with sample_steps={sample_steps}")
251 | timesteps = np.append(0., np.linspace(eps, T, sample_steps))
252 | timesteps = torch.tensor(timesteps).to(x_init)
253 | x = x_init
254 | if trace is not None:
255 | trace.append(x)
256 | for s, t in tqdm(list(zip(timesteps, timesteps[1:]))[::-1], disable=not verbose, desc='euler_maruyama'):
257 | drift = rsde.drift(x, t, **kwargs)
258 | diffusion = rsde.diffusion(t)
259 | dt = s - t
260 | mean = x + drift * dt
261 | sigma = diffusion * (-dt).sqrt()
262 | x = mean + stp(sigma, torch.randn_like(x)) if s != 0 else mean
263 | if trace is not None:
264 | trace.append(x)
265 | statistics = dict(s=s, t=t, sigma=sigma.item())
266 | logging.debug(dct2str(statistics))
267 | return x
268 |
269 |
270 | def LSimple(score_model: ScoreModel, x0, pred='noise_pred', **kwargs):
271 | t, noise, xt = score_model.sde.sample(x0)
272 | if pred == 'noise_pred':
273 | noise_pred = score_model.noise_pred(xt, t, **kwargs)
274 | return mos(noise - noise_pred)
275 | elif pred == 'x0_pred':
276 | x0_pred = score_model.x0_pred(xt, t, **kwargs)
277 | return mos(x0 - x0_pred)
278 | else:
279 | raise NotImplementedError(pred)
280 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import os
9 | import math
10 | import torch
11 | import torch.distributed as dist
12 |
13 | from logging import getLogger
14 |
15 | logger = getLogger()
16 |
17 |
18 | def gpu_timer(closure, log_timings=True):
19 | """ Helper to time gpu-time to execute closure() """
20 | elapsed_time = -1.
21 | if log_timings:
22 | start = torch.cuda.Event(enable_timing=True)
23 | end = torch.cuda.Event(enable_timing=True)
24 | start.record()
25 |
26 | result = closure()
27 |
28 | if log_timings:
29 | end.record()
30 | torch.cuda.synchronize()
31 | elapsed_time = start.elapsed_time(end)
32 |
33 | return result, elapsed_time
34 |
35 |
36 | def init_distributed(port=40111, rank_and_world_size=(None, None)):
37 |
38 | if dist.is_available() and dist.is_initialized():
39 | return dist.get_world_size(), dist.get_rank()
40 |
41 | rank, world_size = rank_and_world_size
42 | os.environ['MASTER_ADDR'] = 'localhost'
43 |
44 | if (rank is None) or (world_size is None):
45 | try:
46 | world_size = int(os.environ['SLURM_NTASKS'])
47 | rank = int(os.environ['SLURM_PROCID'])
48 | os.environ['MASTER_ADDR'] = os.environ['HOSTNAME']
49 | except Exception:
50 | logger.info('SLURM vars not set (distributed training not available)')
51 | world_size, rank = 1, 0
52 | return world_size, rank
53 |
54 | try:
55 | os.environ['MASTER_PORT'] = str(port)
56 | torch.distributed.init_process_group(
57 | backend='nccl',
58 | world_size=world_size,
59 | rank=rank)
60 | except Exception:
61 | world_size, rank = 1, 0
62 | logger.info('distributed training not available')
63 |
64 | return world_size, rank
65 |
66 |
67 | class WarmupCosineSchedule(object):
68 |
69 | def __init__(
70 | self,
71 | optimizer,
72 | warmup_steps,
73 | start_lr,
74 | ref_lr,
75 | T_max,
76 | last_epoch=-1,
77 | final_lr=0.
78 | ):
79 | self.optimizer = optimizer
80 | self.start_lr = start_lr
81 | self.ref_lr = ref_lr
82 | self.final_lr = final_lr
83 | self.warmup_steps = warmup_steps
84 | self.T_max = T_max - warmup_steps
85 | self._step = 0.
86 |
87 | def step(self):
88 | self._step += 1
89 | if self._step < self.warmup_steps:
90 | progress = float(self._step) / float(max(1, self.warmup_steps))
91 | new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
92 | else:
93 | # -- progress after warmup
94 | progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max))
95 | new_lr = max(self.final_lr,
96 | self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress)))
97 |
98 | for group in self.optimizer.param_groups:
99 | group['lr'] = new_lr
100 |
101 | return new_lr
102 |
103 |
104 | class CosineWDSchedule(object):
105 |
106 | def __init__(
107 | self,
108 | optimizer,
109 | ref_wd,
110 | T_max,
111 | final_wd=0.
112 | ):
113 | self.optimizer = optimizer
114 | self.ref_wd = ref_wd
115 | self.final_wd = final_wd
116 | self.T_max = T_max
117 | self._step = 0.
118 |
119 | def step(self):
120 | self._step += 1
121 | progress = self._step / self.T_max
122 | new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress))
123 |
124 | if self.final_wd <= self.ref_wd:
125 | new_wd = max(self.final_wd, new_wd)
126 | else:
127 | new_wd = min(self.final_wd, new_wd)
128 |
129 | for group in self.optimizer.param_groups:
130 | if ('WD_exclude' not in group) or not group['WD_exclude']:
131 | group['weight_decay'] = new_wd
132 | return new_wd
133 |
134 |
135 | class CSVLogger(object):
136 |
137 | def __init__(self, fname, *argv):
138 | self.fname = fname
139 | self.types = []
140 | # -- print headers
141 | with open(self.fname, '+a') as f:
142 | for i, v in enumerate(argv, 1):
143 | self.types.append(v[0])
144 | if i < len(argv):
145 | print(v[1], end=',', file=f)
146 | else:
147 | print(v[1], end='\n', file=f)
148 |
149 | def log(self, *argv):
150 | with open(self.fname, '+a') as f:
151 | for i, tv in enumerate(zip(self.types, argv), 1):
152 | end = ',' if i < len(argv) else '\n'
153 | print(tv[0] % tv[1], end=end, file=f)
154 |
155 |
156 | class AverageMeter(object):
157 | """computes and stores the average and current value"""
158 |
159 | def __init__(self):
160 | self.reset()
161 |
162 | def reset(self):
163 | self.val = 0
164 | self.avg = 0
165 | self.max = float('-inf')
166 | self.min = float('inf')
167 | self.sum = 0
168 | self.count = 0
169 |
170 | def update(self, val, n=1):
171 | self.val = val
172 | self.max = max(val, self.max)
173 | self.min = min(val, self.min)
174 | self.sum += val * n
175 | self.count += n
176 | self.avg = self.sum / self.count
177 |
178 |
179 | class AllGather(torch.autograd.Function):
180 |
181 | @staticmethod
182 | def forward(ctx, x):
183 | if (
184 | dist.is_available()
185 | and dist.is_initialized()
186 | and (dist.get_world_size() > 1)
187 | ):
188 | outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
189 | dist.all_gather(outputs, x)
190 | return torch.cat(outputs, 0)
191 | return x
192 |
193 | @staticmethod
194 | def backward(ctx, grads):
195 | if (
196 | dist.is_available()
197 | and dist.is_initialized()
198 | and (dist.get_world_size() > 1)
199 | ):
200 | s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank()
201 | e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1)
202 | grads = grads.contiguous()
203 | dist.all_reduce(grads)
204 | return grads[s:e]
205 | return grads
206 |
207 |
208 | class AllReduceSum(torch.autograd.Function):
209 |
210 | @staticmethod
211 | def forward(ctx, x):
212 | if (
213 | dist.is_available()
214 | and dist.is_initialized()
215 | and (dist.get_world_size() > 1)
216 | ):
217 | x = x.contiguous()
218 | dist.all_reduce(x)
219 | return x
220 |
221 | @staticmethod
222 | def backward(ctx, grads):
223 | return grads
224 |
225 |
226 | class AllReduce(torch.autograd.Function):
227 |
228 | @staticmethod
229 | def forward(ctx, x):
230 | if (
231 | dist.is_available()
232 | and dist.is_initialized()
233 | and (dist.get_world_size() > 1)
234 | ):
235 | x = x.contiguous() / dist.get_world_size()
236 | dist.all_reduce(x)
237 | return x
238 |
239 | @staticmethod
240 | def backward(ctx, grads):
241 | return grads
242 |
243 |
244 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
245 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
246 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
247 | def norm_cdf(x):
248 | # Computes standard normal cumulative distribution function
249 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
250 |
251 | with torch.no_grad():
252 | # Values are generated by using a truncated uniform distribution and
253 | # then using the inverse CDF for the normal distribution.
254 | # Get upper and lower cdf values
255 | l = norm_cdf((a - mean) / std)
256 | u = norm_cdf((b - mean) / std)
257 |
258 | # Uniformly fill tensor with values from [l, u], then translate to
259 | # [2l-1, 2u-1].
260 | tensor.uniform_(2 * l - 1, 2 * u - 1)
261 |
262 | # Use inverse cdf transform for normal distribution to get truncated
263 | # standard normal
264 | tensor.erfinv_()
265 |
266 | # Transform to proper mean, std
267 | tensor.mul_(std * math.sqrt(2.))
268 | tensor.add_(mean)
269 |
270 | # Clamp to ensure it's in the proper range
271 | tensor.clamp_(min=a, max=b)
272 | return tensor
273 |
274 |
275 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
276 | # type: (Tensor, float, float, float, float) -> Tensor
277 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
278 |
279 |
280 | def grad_logger(named_params):
281 | stats = AverageMeter()
282 | stats.first_layer = None
283 | stats.last_layer = None
284 | for n, p in named_params:
285 | if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1):
286 | grad_norm = float(torch.norm(p.grad.data))
287 | stats.update(grad_norm)
288 | if 'qkv' in n:
289 | stats.last_layer = grad_norm
290 | if stats.first_layer is None:
291 | stats.first_layer = grad_norm
292 | if stats.first_layer is None or stats.last_layer is None:
293 | stats.first_layer = stats.last_layer = 0.
294 | return stats
295 |
--------------------------------------------------------------------------------
/cifar10_experiment/edm/training/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Streaming images and labels from datasets created with dataset_tool.py."""
9 |
10 | import os
11 | import numpy as np
12 | import zipfile
13 | import PIL.Image
14 | import json
15 | import torch
16 | import dnnlib
17 |
18 | try:
19 | import pyspng
20 | except ImportError:
21 | pyspng = None
22 |
23 | #----------------------------------------------------------------------------
24 | # Abstract base class for datasets.
25 |
26 | class Dataset(torch.utils.data.Dataset):
27 | def __init__(self,
28 | name, # Name of the dataset.
29 | raw_shape, # Shape of the raw image data (NCHW).
30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero.
32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
33 | random_seed = 0, # Random seed to use when applying max_size.
34 | cache = False, # Cache images in CPU memory?
35 | ):
36 | self._name = name
37 | self._raw_shape = list(raw_shape)
38 | self._use_labels = use_labels
39 | self._cache = cache
40 | self._cached_images = dict() # {raw_idx: np.ndarray, ...}
41 | self._raw_labels = None
42 | self._label_shape = None
43 |
44 | # Apply max_size.
45 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
46 | if (max_size is not None) and (self._raw_idx.size > max_size):
47 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
48 | self._raw_idx = np.sort(self._raw_idx[:max_size])
49 |
50 | # Apply xflip.
51 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
52 | if xflip:
53 | self._raw_idx = np.tile(self._raw_idx, 2)
54 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
55 |
56 | def _get_raw_labels(self):
57 | if self._raw_labels is None:
58 | self._raw_labels = self._load_raw_labels() if self._use_labels else None
59 | if self._raw_labels is None:
60 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
61 | assert isinstance(self._raw_labels, np.ndarray)
62 | assert self._raw_labels.shape[0] == self._raw_shape[0]
63 | assert self._raw_labels.dtype in [np.float32, np.int64]
64 | if self._raw_labels.dtype == np.int64:
65 | assert self._raw_labels.ndim == 1
66 | assert np.all(self._raw_labels >= 0)
67 | return self._raw_labels
68 |
69 | def close(self): # to be overridden by subclass
70 | pass
71 |
72 | def _load_raw_image(self, raw_idx): # to be overridden by subclass
73 | raise NotImplementedError
74 |
75 | def _load_raw_labels(self): # to be overridden by subclass
76 | raise NotImplementedError
77 |
78 | def __getstate__(self):
79 | return dict(self.__dict__, _raw_labels=None)
80 |
81 | def __del__(self):
82 | try:
83 | self.close()
84 | except:
85 | pass
86 |
87 | def __len__(self):
88 | return self._raw_idx.size
89 |
90 | def __getitem__(self, idx):
91 | raw_idx = self._raw_idx[idx]
92 | image = self._cached_images.get(raw_idx, None)
93 | if image is None:
94 | image = self._load_raw_image(raw_idx)
95 | if self._cache:
96 | self._cached_images[raw_idx] = image
97 | assert isinstance(image, np.ndarray)
98 | assert list(image.shape) == self.image_shape
99 | assert image.dtype == np.uint8
100 | if self._xflip[idx]:
101 | assert image.ndim == 3 # CHW
102 | image = image[:, :, ::-1]
103 | return image.copy(), self.get_label(idx)
104 |
105 | def get_label(self, idx):
106 | label = self._get_raw_labels()[self._raw_idx[idx]]
107 | if label.dtype == np.int64:
108 | onehot = np.zeros(self.label_shape, dtype=np.float32)
109 | onehot[label] = 1
110 | label = onehot
111 | return label.copy()
112 |
113 | def get_details(self, idx):
114 | d = dnnlib.EasyDict()
115 | d.raw_idx = int(self._raw_idx[idx])
116 | d.xflip = (int(self._xflip[idx]) != 0)
117 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
118 | return d
119 |
120 | @property
121 | def name(self):
122 | return self._name
123 |
124 | @property
125 | def image_shape(self):
126 | return list(self._raw_shape[1:])
127 |
128 | @property
129 | def num_channels(self):
130 | assert len(self.image_shape) == 3 # CHW
131 | return self.image_shape[0]
132 |
133 | @property
134 | def resolution(self):
135 | assert len(self.image_shape) == 3 # CHW
136 | assert self.image_shape[1] == self.image_shape[2]
137 | return self.image_shape[1]
138 |
139 | @property
140 | def label_shape(self):
141 | if self._label_shape is None:
142 | raw_labels = self._get_raw_labels()
143 | if raw_labels.dtype == np.int64:
144 | self._label_shape = [int(np.max(raw_labels)) + 1]
145 | else:
146 | self._label_shape = raw_labels.shape[1:]
147 | return list(self._label_shape)
148 |
149 | @property
150 | def label_dim(self):
151 | assert len(self.label_shape) == 1
152 | return self.label_shape[0]
153 |
154 | @property
155 | def has_labels(self):
156 | return any(x != 0 for x in self.label_shape)
157 |
158 | @property
159 | def has_onehot_labels(self):
160 | return self._get_raw_labels().dtype == np.int64
161 |
162 | #----------------------------------------------------------------------------
163 | # Dataset subclass that loads images recursively from the specified directory
164 | # or ZIP file.
165 |
166 | class ImageFolderDataset(Dataset):
167 | def __init__(self,
168 | path, # Path to directory or zip.
169 | resolution = None, # Ensure specific resolution, None = highest available.
170 | use_pyspng = True, # Use pyspng if available?
171 | **super_kwargs, # Additional arguments for the Dataset base class.
172 | ):
173 | self._path = path
174 | self._use_pyspng = use_pyspng
175 | self._zipfile = None
176 |
177 | if os.path.isdir(self._path):
178 | self._type = 'dir'
179 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
180 | elif self._file_ext(self._path) == '.zip':
181 | self._type = 'zip'
182 | self._all_fnames = set(self._get_zipfile().namelist())
183 | else:
184 | raise IOError('Path must point to a directory or zip')
185 |
186 | PIL.Image.init()
187 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
188 | if len(self._image_fnames) == 0:
189 | raise IOError('No image files found in the specified path')
190 |
191 | name = os.path.splitext(os.path.basename(self._path))[0]
192 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
193 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
194 | raise IOError('Image files do not match the specified resolution')
195 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
196 |
197 | @staticmethod
198 | def _file_ext(fname):
199 | return os.path.splitext(fname)[1].lower()
200 |
201 | def _get_zipfile(self):
202 | assert self._type == 'zip'
203 | if self._zipfile is None:
204 | self._zipfile = zipfile.ZipFile(self._path)
205 | return self._zipfile
206 |
207 | def _open_file(self, fname):
208 | if self._type == 'dir':
209 | return open(os.path.join(self._path, fname), 'rb')
210 | if self._type == 'zip':
211 | return self._get_zipfile().open(fname, 'r')
212 | return None
213 |
214 | def close(self):
215 | try:
216 | if self._zipfile is not None:
217 | self._zipfile.close()
218 | finally:
219 | self._zipfile = None
220 |
221 | def __getstate__(self):
222 | return dict(super().__getstate__(), _zipfile=None)
223 |
224 | def _load_raw_image(self, raw_idx):
225 | fname = self._image_fnames[raw_idx]
226 | with self._open_file(fname) as f:
227 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png':
228 | image = pyspng.load(f.read())
229 | else:
230 | image = np.array(PIL.Image.open(f))
231 | if image.ndim == 2:
232 | image = image[:, :, np.newaxis] # HW => HWC
233 | image = image.transpose(2, 0, 1) # HWC => CHW
234 | return image
235 |
236 | def _load_raw_labels(self):
237 | fname = 'dataset.json'
238 | if fname not in self._all_fnames:
239 | return None
240 | with self._open_file(fname) as f:
241 | labels = json.load(f)['labels']
242 | if labels is None:
243 | return None
244 | labels = dict(labels)
245 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
246 | labels = np.array(labels)
247 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
248 | return labels
249 |
250 | #----------------------------------------------------------------------------
251 |
--------------------------------------------------------------------------------
/cifar10_experiment/TorchSSL/models/nets/resnet50.py:
--------------------------------------------------------------------------------
1 | """
2 | from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
3 | """
4 | import torch
5 | from torch import Tensor
6 | import torch.nn as nn
7 | from typing import Type, Any, Callable, Union, List, Optional
8 |
9 |
10 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
11 | """3x3 convolution with padding"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=dilation, groups=groups, bias=False, dilation=dilation)
14 |
15 |
16 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
17 | """1x1 convolution"""
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
19 |
20 |
21 | class BasicBlock(nn.Module):
22 | expansion: int = 1
23 |
24 | def __init__(
25 | self,
26 | inplanes: int,
27 | planes: int,
28 | stride: int = 1,
29 | downsample: Optional[nn.Module] = None,
30 | groups: int = 1,
31 | base_width: int = 64,
32 | dilation: int = 1,
33 | norm_layer: Optional[Callable[..., nn.Module]] = None
34 | ) -> None:
35 | super(BasicBlock, self).__init__()
36 | if norm_layer is None:
37 | norm_layer = nn.BatchNorm2d
38 | if groups != 1 or base_width != 64:
39 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
40 | if dilation > 1:
41 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
42 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
43 | self.conv1 = conv3x3(inplanes, planes, stride)
44 | self.bn1 = norm_layer(planes)
45 | self.relu = nn.ReLU(inplace=True)
46 | self.conv2 = conv3x3(planes, planes)
47 | self.bn2 = norm_layer(planes)
48 | self.downsample = downsample
49 | self.stride = stride
50 |
51 | def forward(self, x: Tensor) -> Tensor:
52 | identity = x
53 |
54 | out = self.conv1(x)
55 | out = self.bn1(out)
56 | out = self.relu(out)
57 |
58 | out = self.conv2(out)
59 | out = self.bn2(out)
60 |
61 | if self.downsample is not None:
62 | identity = self.downsample(x)
63 |
64 | out += identity
65 | out = self.relu(out)
66 |
67 | return out
68 |
69 |
70 | class Bottleneck(nn.Module):
71 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
72 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
73 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
74 | # This variant is also known as ResNet V1.5 and improves accuracy according to
75 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
76 |
77 | expansion: int = 4
78 |
79 | def __init__(
80 | self,
81 | inplanes: int,
82 | planes: int,
83 | stride: int = 1,
84 | downsample: Optional[nn.Module] = None,
85 | groups: int = 1,
86 | base_width: int = 64,
87 | dilation: int = 1,
88 | norm_layer: Optional[Callable[..., nn.Module]] = None
89 | ) -> None:
90 | super(Bottleneck, self).__init__()
91 | if norm_layer is None:
92 | norm_layer = nn.BatchNorm2d
93 | width = int(planes * (base_width / 64.)) * groups
94 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
95 | self.conv1 = conv1x1(inplanes, width)
96 | self.bn1 = norm_layer(width)
97 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
98 | self.bn2 = norm_layer(width)
99 | self.conv3 = conv1x1(width, planes * self.expansion)
100 | self.bn3 = norm_layer(planes * self.expansion)
101 | self.relu = nn.ReLU(inplace=True)
102 | self.downsample = downsample
103 | self.stride = stride
104 |
105 | def forward(self, x: Tensor) -> Tensor:
106 | identity = x
107 |
108 | out = self.conv1(x)
109 | out = self.bn1(out)
110 | out = self.relu(out)
111 |
112 | out = self.conv2(out)
113 | out = self.bn2(out)
114 | out = self.relu(out)
115 |
116 | out = self.conv3(out)
117 | out = self.bn3(out)
118 |
119 | if self.downsample is not None:
120 | identity = self.downsample(x)
121 |
122 | out += identity
123 | out = self.relu(out)
124 |
125 | return out
126 |
127 |
128 | class ResNet50(nn.Module):
129 |
130 | def __init__(
131 | self,
132 | block: Type[Union[BasicBlock, Bottleneck]] = Bottleneck,
133 | layers: List[int] = [3, 4, 6, 3],
134 | n_class: int = 1000,
135 | zero_init_residual: bool = False,
136 | groups: int = 1,
137 | width_per_group: int = 64,
138 | replace_stride_with_dilation: Optional[List[bool]] = None,
139 | norm_layer: Optional[Callable[..., nn.Module]] = None,
140 | is_remix=False
141 | ) -> None:
142 | super(ResNet50, self).__init__()
143 | if norm_layer is None:
144 | norm_layer = nn.BatchNorm2d
145 | self._norm_layer = norm_layer
146 |
147 | self.inplanes = 64
148 | self.dilation = 1
149 | if replace_stride_with_dilation is None:
150 | # each element in the tuple indicates if we should replace
151 | # the 2x2 stride with a dilated convolution instead
152 | replace_stride_with_dilation = [False, False, False]
153 | if len(replace_stride_with_dilation) != 3:
154 | raise ValueError("replace_stride_with_dilation should be None "
155 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
156 | self.groups = groups
157 | self.base_width = width_per_group
158 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
159 | bias=False)
160 | self.bn1 = norm_layer(self.inplanes)
161 | self.relu = nn.ReLU(inplace=True)
162 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
163 | self.layer1 = self._make_layer(block, 64, layers[0])
164 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
165 | dilate=replace_stride_with_dilation[0])
166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
167 | dilate=replace_stride_with_dilation[1])
168 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
169 | dilate=replace_stride_with_dilation[2])
170 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
171 | self.fc = nn.Linear(512 * block.expansion, n_class)
172 |
173 | # rot_classifier for Remix Match
174 | self.is_remix = is_remix
175 | if is_remix:
176 | self.rot_classifier = nn.Linear(2048, 4)
177 |
178 | for m in self.modules():
179 | if isinstance(m, nn.Conv2d):
180 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
181 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
182 | nn.init.constant_(m.weight, 1)
183 | nn.init.constant_(m.bias, 0)
184 |
185 | # Zero-initialize the last BN in each residual branch,
186 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
187 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
188 | if zero_init_residual:
189 | for m in self.modules():
190 | if isinstance(m, Bottleneck):
191 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
192 | elif isinstance(m, BasicBlock):
193 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
194 |
195 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
196 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
197 | norm_layer = self._norm_layer
198 | downsample = None
199 | previous_dilation = self.dilation
200 | if dilate:
201 | self.dilation *= stride
202 | stride = 1
203 | if stride != 1 or self.inplanes != planes * block.expansion:
204 | downsample = nn.Sequential(
205 | conv1x1(self.inplanes, planes * block.expansion, stride),
206 | norm_layer(planes * block.expansion),
207 | )
208 |
209 | layers = []
210 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
211 | self.base_width, previous_dilation, norm_layer))
212 | self.inplanes = planes * block.expansion
213 | for _ in range(1, blocks):
214 | layers.append(block(self.inplanes, planes, groups=self.groups,
215 | base_width=self.base_width, dilation=self.dilation,
216 | norm_layer=norm_layer))
217 |
218 | return nn.Sequential(*layers)
219 |
220 | def _forward_impl(self, x):
221 | # See note [TorchScript super()]
222 | x = self.conv1(x)
223 | x = self.bn1(x)
224 | x = self.relu(x)
225 | x = self.maxpool(x)
226 |
227 | x = self.layer1(x)
228 | x = self.layer2(x)
229 | x = self.layer3(x)
230 | x = self.layer4(x)
231 |
232 | x = self.avgpool(x)
233 | x = torch.flatten(x, 1)
234 | out = self.fc(x)
235 | if self.is_remix:
236 | rot_output = self.rot_classifier(x)
237 | return out, rot_output
238 | else:
239 | return out
240 |
241 | def forward(self, x):
242 | return self._forward_impl(x)
243 |
244 |
245 | class build_ResNet50:
246 | def __init__(self, is_remix=False):
247 | self.is_remix = is_remix
248 |
249 | def build(self, num_classes):
250 | return ResNet50(n_class=num_classes, is_remix=self.is_remix)
251 |
252 |
253 | if __name__ == '__main__':
254 | a = torch.rand(16, 3, 224, 224)
255 | net = ResNet50(is_remix=True)
256 | x,y = net(a)
257 | print(x.shape)
258 | print(y.shape)
259 |
--------------------------------------------------------------------------------
/tools/fid_score.py:
--------------------------------------------------------------------------------
1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2 |
3 | The FID metric calculates the distance between two distributions of images.
4 | Typically, we have summary statistics (mean & covariance matrix) of one
5 | of these distributions, while the 2nd distribution is given by a GAN.
6 |
7 | When run as a stand-alone program, it compares the distribution of
8 | images that are stored as PNG/JPEG at a specified location with a
9 | distribution given by summary statistics (in pickle format).
10 |
11 | The FID is calculated by assuming that X_1 and X_2 are the activations of
12 | the pool_3 layer of the inception net for generated samples and real world
13 | samples respectively.
14 |
15 | See --help to see further details.
16 |
17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18 | of Tensorflow
19 |
20 | Copyright 2018 Institute of Bioinformatics, JKU Linz
21 |
22 | Licensed under the Apache License, Version 2.0 (the "License");
23 | you may not use this file except in compliance with the License.
24 | You may obtain a copy of the License at
25 |
26 | http://www.apache.org/licenses/LICENSE-2.0
27 |
28 | Unless required by applicable law or agreed to in writing, software
29 | distributed under the License is distributed on an "AS IS" BASIS,
30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31 | See the License for the specific language governing permissions and
32 | limitations under the License.
33 | """
34 | import os
35 | import pathlib
36 |
37 | import numpy as np
38 | import torch
39 | import torchvision.transforms as TF
40 | from PIL import Image
41 | from scipy import linalg
42 | from torch.nn.functional import adaptive_avg_pool2d
43 |
44 | try:
45 | from tqdm import tqdm
46 | except ImportError:
47 | # If tqdm is not available, provide a mock version of it
48 | def tqdm(x):
49 | return x
50 |
51 | from .inception import InceptionV3
52 |
53 |
54 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
55 | 'tif', 'tiff', 'webp'}
56 |
57 |
58 | class ImagePathDataset(torch.utils.data.Dataset):
59 | def __init__(self, files, transforms=None):
60 | self.files = files
61 | self.transforms = transforms
62 |
63 | def __len__(self):
64 | return len(self.files)
65 |
66 | def __getitem__(self, i):
67 | path = self.files[i]
68 | img = Image.open(path).convert('RGB')
69 | if self.transforms is not None:
70 | img = self.transforms(img)
71 | return img
72 |
73 |
74 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8):
75 | """Calculates the activations of the pool_3 layer for all images.
76 |
77 | Params:
78 | -- files : List of image files paths
79 | -- model : Instance of inception model
80 | -- batch_size : Batch size of images for the model to process at once.
81 | Make sure that the number of samples is a multiple of
82 | the batch size, otherwise some samples are ignored. This
83 | behavior is retained to match the original FID score
84 | implementation.
85 | -- dims : Dimensionality of features returned by Inception
86 | -- device : Device to run calculations
87 | -- num_workers : Number of parallel dataloader workers
88 |
89 | Returns:
90 | -- A numpy array of dimension (num images, dims) that contains the
91 | activations of the given tensor when feeding inception with the
92 | query tensor.
93 | """
94 | model.eval()
95 |
96 | if batch_size > len(files):
97 | print(('Warning: batch size is bigger than the data size. '
98 | 'Setting batch size to data size'))
99 | batch_size = len(files)
100 |
101 | dataset = ImagePathDataset(files, transforms=TF.ToTensor())
102 | dataloader = torch.utils.data.DataLoader(dataset,
103 | batch_size=batch_size,
104 | shuffle=False,
105 | drop_last=False,
106 | num_workers=num_workers)
107 |
108 | pred_arr = np.empty((len(files), dims))
109 |
110 | start_idx = 0
111 |
112 | for batch in tqdm(dataloader):
113 | batch = batch.to(device)
114 |
115 | with torch.no_grad():
116 | pred = model(batch)[0]
117 |
118 | # If model output is not scalar, apply global spatial average pooling.
119 | # This happens if you choose a dimensionality not equal 2048.
120 | if pred.size(2) != 1 or pred.size(3) != 1:
121 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
122 |
123 | pred = pred.squeeze(3).squeeze(2).cpu().numpy()
124 |
125 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred
126 |
127 | start_idx = start_idx + pred.shape[0]
128 |
129 | return pred_arr
130 |
131 |
132 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
133 | """Numpy implementation of the Frechet Distance.
134 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
135 | and X_2 ~ N(mu_2, C_2) is
136 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
137 |
138 | Stable version by Dougal J. Sutherland.
139 |
140 | Params:
141 | -- mu1 : Numpy array containing the activations of a layer of the
142 | inception net (like returned by the function 'get_predictions')
143 | for generated samples.
144 | -- mu2 : The sample mean over activations, precalculated on an
145 | representative data set.
146 | -- sigma1: The covariance matrix over activations for generated samples.
147 | -- sigma2: The covariance matrix over activations, precalculated on an
148 | representative data set.
149 |
150 | Returns:
151 | -- : The Frechet Distance.
152 | """
153 |
154 | mu1 = np.atleast_1d(mu1)
155 | mu2 = np.atleast_1d(mu2)
156 |
157 | sigma1 = np.atleast_2d(sigma1)
158 | sigma2 = np.atleast_2d(sigma2)
159 |
160 | assert mu1.shape == mu2.shape, \
161 | 'Training and test mean vectors have different lengths'
162 | assert sigma1.shape == sigma2.shape, \
163 | 'Training and test covariances have different dimensions'
164 |
165 | diff = mu1 - mu2
166 |
167 | # Product might be almost singular
168 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
169 | if not np.isfinite(covmean).all():
170 | msg = ('fid calculation produces singular product; '
171 | 'adding %s to diagonal of cov estimates') % eps
172 | print(msg)
173 | offset = np.eye(sigma1.shape[0]) * eps
174 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
175 |
176 | # Numerical error might give slight imaginary component
177 | if np.iscomplexobj(covmean):
178 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
179 | m = np.max(np.abs(covmean.imag))
180 | raise ValueError('Imaginary component {}'.format(m))
181 | covmean = covmean.real
182 |
183 | tr_covmean = np.trace(covmean)
184 |
185 | return (diff.dot(diff) + np.trace(sigma1)
186 | + np.trace(sigma2) - 2 * tr_covmean)
187 |
188 |
189 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
190 | device='cpu', num_workers=8):
191 | """Calculation of the statistics used by the FID.
192 | Params:
193 | -- files : List of image files paths
194 | -- model : Instance of inception model
195 | -- batch_size : The images numpy array is split into batches with
196 | batch size batch_size. A reasonable batch size
197 | depends on the hardware.
198 | -- dims : Dimensionality of features returned by Inception
199 | -- device : Device to run calculations
200 | -- num_workers : Number of parallel dataloader workers
201 |
202 | Returns:
203 | -- mu : The mean over samples of the activations of the pool_3 layer of
204 | the inception model.
205 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
206 | the inception model.
207 | """
208 | act = get_activations(files, model, batch_size, dims, device, num_workers)
209 | mu = np.mean(act, axis=0)
210 | sigma = np.cov(act, rowvar=False)
211 | return mu, sigma
212 |
213 |
214 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8):
215 | if path.endswith('.npz'):
216 | with np.load(path) as f:
217 | m, s = f['mu'][:], f['sigma'][:]
218 | else:
219 | path = pathlib.Path(path)
220 | files = sorted([file for ext in IMAGE_EXTENSIONS
221 | for file in path.glob('*.{}'.format(ext))])
222 | m, s = calculate_activation_statistics(files, model, batch_size,
223 | dims, device, num_workers)
224 |
225 | return m, s
226 |
227 |
228 | def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8):
229 | if device is None:
230 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
231 | else:
232 | device = torch.device(device)
233 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
234 | model = InceptionV3([block_idx]).to(device)
235 | m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers)
236 | np.savez(out_path, mu=m1, sigma=s1)
237 |
238 |
239 | def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8):
240 | """Calculates the FID of two paths"""
241 | if device is None:
242 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
243 | else:
244 | device = torch.device(device)
245 |
246 | for p in paths:
247 | if not os.path.exists(p):
248 | raise RuntimeError('Invalid path: %s' % p)
249 |
250 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
251 |
252 | model = InceptionV3([block_idx]).to(device)
253 |
254 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
255 | dims, device, num_workers)
256 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
257 | dims, device, num_workers)
258 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
259 |
260 | return fid_value
261 |
--------------------------------------------------------------------------------
/src/data_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import os
9 | import subprocess
10 | import time
11 |
12 | from logging import getLogger
13 |
14 | from PIL import ImageFilter
15 |
16 | import torch
17 | import torchvision.transforms as transforms
18 | import torchvision
19 |
20 | _GLOBAL_SEED = 0
21 | logger = getLogger()
22 |
23 |
24 | def init_data(
25 | transform,
26 | batch_size,
27 | pin_mem=True,
28 | num_workers=8,
29 | world_size=1,
30 | rank=0,
31 | root_path=None,
32 | image_folder=None,
33 | training=True,
34 | copy_data=False,
35 | drop_last=True,
36 | subset_file=None
37 | ):
38 |
39 | dataset = ImageNet(
40 | root=root_path,
41 | image_folder=image_folder,
42 | transform=transform,
43 | train=training,
44 | copy_data=copy_data)
45 | if subset_file is not None:
46 | dataset = ImageNetSubset(dataset, subset_file)
47 | logger.info('ImageNet dataset created')
48 | dist_sampler = torch.utils.data.distributed.DistributedSampler(
49 | dataset=dataset,
50 | num_replicas=world_size,
51 | rank=rank)
52 | data_loader = torch.utils.data.DataLoader(
53 | dataset,
54 | sampler=dist_sampler,
55 | batch_size=batch_size,
56 | drop_last=drop_last,
57 | pin_memory=pin_mem,
58 | num_workers=num_workers)
59 | logger.info('ImageNet unsupervised data loader created')
60 |
61 | return (data_loader, dist_sampler)
62 |
63 |
64 | def make_transforms(
65 | rand_size=224,
66 | focal_size=96,
67 | rand_crop_scale=(0.3, 1.0),
68 | focal_crop_scale=(0.05, 0.3),
69 | color_jitter=1.0,
70 | rand_views=2,
71 | focal_views=10,
72 | ):
73 | logger.info('making imagenet data transforms')
74 |
75 | def get_color_distortion(s=1.0):
76 | # s is the strength of color distortion.
77 | color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
78 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
79 | rnd_gray = transforms.RandomGrayscale(p=0.2)
80 | color_distort = transforms.Compose([
81 | rnd_color_jitter,
82 | rnd_gray])
83 | return color_distort
84 |
85 | rand_transform = transforms.Compose([
86 | transforms.RandomResizedCrop(rand_size, scale=rand_crop_scale),
87 | transforms.RandomHorizontalFlip(),
88 | get_color_distortion(s=color_jitter),
89 | GaussianBlur(p=0.5),
90 | transforms.ToTensor(),
91 | transforms.Normalize(
92 | (0.485, 0.456, 0.406),
93 | (0.229, 0.224, 0.225))
94 | ])
95 | focal_transform = transforms.Compose([
96 | transforms.RandomResizedCrop(focal_size, scale=focal_crop_scale),
97 | transforms.RandomHorizontalFlip(),
98 | get_color_distortion(s=color_jitter),
99 | GaussianBlur(p=0.5),
100 | transforms.ToTensor(),
101 | transforms.Normalize(
102 | (0.485, 0.456, 0.406),
103 | (0.229, 0.224, 0.225))
104 | ])
105 |
106 | transform = MultiViewTransform(
107 | rand_transform=rand_transform,
108 | focal_transform=focal_transform,
109 | rand_views=rand_views,
110 | focal_views=focal_views
111 | )
112 | return transform
113 |
114 |
115 | class MultiViewTransform(object):
116 |
117 | def __init__(
118 | self,
119 | rand_transform=None,
120 | focal_transform=None,
121 | rand_views=1,
122 | focal_views=1,
123 | ):
124 | self.rand_views = rand_views
125 | self.focal_views = focal_views
126 | self.rand_transform = rand_transform
127 | self.focal_transform = focal_transform
128 |
129 | def __call__(self, img):
130 | img_views = []
131 |
132 | # -- generate random views
133 | if self.rand_views > 0:
134 | img_views += [self.rand_transform(img) for i in range(self.rand_views)]
135 |
136 | # -- generate focal views
137 | if self.focal_views > 0:
138 | img_views += [self.focal_transform(img) for i in range(self.focal_views)]
139 |
140 | return img_views
141 |
142 |
143 | class ImageNet(torchvision.datasets.ImageFolder):
144 |
145 | def __init__(
146 | self,
147 | root,
148 | image_folder='imagenet_full_size/061417/',
149 | tar_folder='imagenet_full_size/',
150 | tar_file='imagenet_full_size-061417.tar',
151 | transform=None,
152 | train=True,
153 | job_id=None,
154 | local_rank=None,
155 | copy_data=True
156 | ):
157 | """
158 | ImageNet
159 |
160 | Dataset wrapper (can copy data locally to machine)
161 |
162 | :param root: root network directory for ImageNet data
163 | :param image_folder: path to images inside root network directory
164 | :param tar_file: zipped image_folder inside root network directory
165 | :param train: whether to load train data (or validation)
166 | :param job_id: scheduler job-id used to create dir on local machine
167 | :param copy_data: whether to copy data from network file locally
168 | """
169 |
170 | suffix = 'train/' if train else 'val/'
171 | data_path = None
172 | if copy_data:
173 | logger.info('copying data locally')
174 | data_path = copy_imgnt_locally(
175 | root=root,
176 | suffix=suffix,
177 | image_folder=image_folder,
178 | tar_folder=tar_folder,
179 | tar_file=tar_file,
180 | job_id=job_id,
181 | local_rank=local_rank)
182 | if (not copy_data) or (data_path is None):
183 | data_path = os.path.join(root, image_folder, suffix)
184 | logger.info(f'data-path {data_path}')
185 |
186 | super(ImageNet, self).__init__(root=data_path, transform=transform)
187 | logger.info('Initialized ImageNet')
188 |
189 | def __getitem__(self, index):
190 | path, target = self.samples[index]
191 | sample = self.loader(path)
192 | if self.transform is not None:
193 | sample = self.transform(sample)
194 | if self.target_transform is not None:
195 | target = self.target_transform(target)
196 | return sample, target, path.split('/')[-1]
197 |
198 |
199 | class ImageNetSubset(object):
200 |
201 | def __init__(self, dataset, subset_file):
202 | """
203 | ImageNetSubset
204 |
205 | :param dataset: ImageNet dataset object
206 | :param subset_file: '.txt' file containing IDs of IN1K images to keep
207 | """
208 | self.dataset = dataset
209 | self.subset_file = subset_file
210 | self.filter_dataset_(subset_file)
211 |
212 | def filter_dataset_(self, subset_file):
213 | """ Filter self.dataset to a subset """
214 | root = self.dataset.root
215 | class_to_idx = self.dataset.class_to_idx
216 | # -- update samples to subset of IN1k targets/samples
217 | new_samples = []
218 | logger.info(f'Using {subset_file}')
219 | with open(subset_file, 'r') as rfile:
220 | for line in rfile:
221 | class_name = line.split('_')[0]
222 | target = class_to_idx[class_name]
223 | img = line.split('\n')[0]
224 | new_samples.append(
225 | (os.path.join(root, class_name, img), target)
226 | )
227 | self.samples = new_samples
228 |
229 | @property
230 | def classes(self):
231 | return self.dataset.classes
232 |
233 | def __len__(self):
234 | return len(self.samples)
235 |
236 | def __getitem__(self, index):
237 | path, target = self.samples[index]
238 | img = self.dataset.loader(path)
239 | if self.dataset.transform is not None:
240 | img = self.dataset.transform(img)
241 | if self.dataset.target_transform is not None:
242 | target = self.dataset.target_transform(target)
243 | return img, target, path.split('/')[-1]
244 |
245 |
246 | class GaussianBlur(object):
247 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
248 | self.prob = p
249 | self.radius_min = radius_min
250 | self.radius_max = radius_max
251 |
252 | def __call__(self, img):
253 | if torch.bernoulli(torch.tensor(self.prob)) == 0:
254 | return img
255 |
256 | radius = self.radius_min + torch.rand(1) * (self.radius_max - self.radius_min)
257 | return img.filter(ImageFilter.GaussianBlur(radius=radius))
258 |
259 |
260 | def copy_imgnt_locally(
261 | root,
262 | suffix,
263 | image_folder='imagenet_full_size/061417/',
264 | tar_folder='imagenet_full_size/',
265 | tar_file='imagenet_full_size-061417.tar',
266 | job_id=None,
267 | local_rank=None
268 | ):
269 | if job_id is None:
270 | try:
271 | job_id = os.environ['SLURM_JOBID']
272 | except Exception:
273 | logger.info('No job-id, will load directly from network file')
274 | return None
275 |
276 | if local_rank is None:
277 | try:
278 | local_rank = int(os.environ['SLURM_LOCALID'])
279 | except Exception:
280 | logger.info('No job-id, will load directly from network file')
281 | return None
282 |
283 | source_file = os.path.join(root, tar_folder, tar_file)
284 | target = f'/scratch/slurm_tmpdir/{job_id}/'
285 | target_file = os.path.join(target, tar_file)
286 | data_path = os.path.join(target, image_folder, suffix)
287 | logger.info(f'{source_file}\n{target}\n{target_file}\n{data_path}')
288 |
289 | tmp_sgnl_file = os.path.join(target, 'copy_signal.txt')
290 |
291 | if not os.path.exists(data_path):
292 | if local_rank == 0:
293 | commands = [
294 | ['tar', '-xf', source_file, '-C', target]]
295 | for cmnd in commands:
296 | start_time = time.time()
297 | logger.info(f'Executing {cmnd}')
298 | subprocess.run(cmnd)
299 | logger.info(f'Cmnd took {(time.time()-start_time)/60.} min.')
300 | with open(tmp_sgnl_file, '+w') as f:
301 | print('Done copying locally.', file=f)
302 | else:
303 | while not os.path.exists(tmp_sgnl_file):
304 | time.sleep(60)
305 | logger.info(f'{local_rank}: Checking {tmp_sgnl_file}')
306 |
307 | return data_path
308 |
--------------------------------------------------------------------------------