├── .gitignore ├── cifar10_experiment ├── TorchSSL │ ├── datasets │ │ ├── __init__.py │ │ ├── DistributedProxySampler.py │ │ ├── dataset.py │ │ ├── augmentation │ │ │ └── randaugment.py │ │ └── data_utils.py │ ├── models │ │ ├── nets │ │ │ ├── __init__.py │ │ │ ├── wrn.py │ │ │ ├── wrn_var.py │ │ │ └── resnet50.py │ │ └── freematch │ │ │ ├── __init__.py │ │ │ └── freematch_utils.py │ ├── config │ │ ├── freematch_cifar10_40_1.yaml │ │ └── freematch_cifar10_40_1_aug1000.yaml │ ├── pseudo_dataset.py │ ├── eval.py │ ├── utils.py │ ├── get_labels.py │ └── custom_writer.py ├── edm │ ├── training │ │ ├── __init__.py │ │ ├── loss.py │ │ └── dataset.py │ ├── torch_utils │ │ ├── __init__.py │ │ └── distributed.py │ ├── dnnlib │ │ └── __init__.py │ ├── generate.sh │ └── fid.py └── README.md ├── libs ├── __init__.py ├── clip.py ├── timm.py ├── uvit_t2i.py └── uvit.py ├── dpt.png ├── idx_to_class.pkl ├── scripts ├── convert_fid_stats.py ├── extract_empty_feature.py ├── extract_test_prompt_feature.py ├── extract_imagenet_feature.py ├── extract_imagenet_features.py ├── extract_mscoco_feature.py └── sweep_sample.py ├── LICENSE ├── src ├── sgd.py ├── losses.py ├── utils.py └── data_manager.py ├── configs ├── accelerate_b4_subset1_2img_k_128_large.py ├── accelerate_l7_subset2_zimg_k_128_large.py └── accelerate_b4_subset1_2img_k_128_huge.py ├── extract_imagenet_features_semi.py ├── grid_sample.py ├── utils.py ├── sample_ldm_all.py ├── sample_ldm_discrete_all.py ├── sde.py └── tools └── fid_score.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *cluster*/ -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- 1 | # codes from third party 2 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/models/nets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/models/freematch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/DPT/HEAD/dpt.png -------------------------------------------------------------------------------- /idx_to_class.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/DPT/HEAD/idx_to_class.pkl -------------------------------------------------------------------------------- /cifar10_experiment/edm/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /cifar10_experiment/edm/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /cifar10_experiment/edm/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /scripts/convert_fid_stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def convert(resolution): 5 | # download ImageNet reference batch from https://github.com/openai/guided-diffusion/tree/main/evaluations 6 | if resolution == 512: 7 | obj = np.load(f'VIRTUAL_imagenet{resolution}.npz') 8 | else: 9 | obj = np.load(f'VIRTUAL_imagenet{resolution}_labeled.npz') 10 | np.savez(f'fid_stats_imagenet{resolution}_guided_diffusion.npz', mu=obj['mu'], sigma=obj['sigma']) 11 | 12 | 13 | convert(resolution=512) 14 | -------------------------------------------------------------------------------- /scripts/extract_empty_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(): 12 | prompts = [ 13 | '', 14 | ] 15 | 16 | device = 'cuda' 17 | clip = libs.clip.FrozenCLIPEmbedder() 18 | clip.eval() 19 | clip.to(device) 20 | 21 | save_dir = f'assets/datasets/coco256_features' 22 | latent = clip.encode(prompts) 23 | print(latent.shape) 24 | c = latent[0].detach().cpu().numpy() 25 | np.save(os.path.join(save_dir, f'empty_context.npy'), c) 26 | 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/config/freematch_cifar10_40_1.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./saved_models 2 | save_name: freematch_cifar10_40_1 3 | resume: False 4 | load_path: None 5 | overwrite: True 6 | use_tensorboard: True 7 | epoch: 1 8 | num_train_iter: 1048576 9 | num_eval_iter: 5000 10 | num_labels: 40 11 | batch_size: 64 12 | eval_batch_size: 1024 13 | hard_label: True 14 | T: 0.5 15 | ulb_loss_ratio: 1.0 16 | ent_loss_ratio: 0.0 17 | uratio: 7 18 | ema_m: 0.999 19 | optim: SGD 20 | lr: 0.03 21 | momentum: 0.9 22 | weight_decay: 0.0005 23 | amp: False 24 | net: WideResNet 25 | net_from_name: False 26 | depth: 28 27 | widen_factor: 2 28 | leaky_slope: 0.1 29 | dropout: 0.0 30 | data_dir: ./data 31 | dataset: cifar10 32 | train_sampler: RandomSampler 33 | num_classes: 10 34 | num_workers: 1 35 | alg: freematch 36 | seed: 1 37 | world_size: 1 38 | rank: 0 39 | multiprocessing_distributed: True 40 | dist_url: tcp://127.0.0.1:10043 41 | dist_backend: nccl 42 | gpu: None 43 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/config/freematch_cifar10_40_1_aug1000.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./saved_models 2 | save_name: freematch_cifar10_40_1_aug1000 3 | resume: False 4 | load_path: None 5 | overwrite: True 6 | use_tensorboard: True 7 | epoch: 1 8 | num_train_iter: 1048576 9 | num_eval_iter: 5000 10 | num_labels: 40 11 | batch_size: 64 12 | eval_batch_size: 1024 13 | hard_label: True 14 | T: 0.5 15 | ulb_loss_ratio: 1.0 16 | ent_loss_ratio: 0.0 17 | uratio: 7 18 | ema_m: 0.999 19 | optim: SGD 20 | lr: 0.03 21 | momentum: 0.9 22 | weight_decay: 0.0005 23 | amp: False 24 | net: WideResNet 25 | net_from_name: False 26 | depth: 28 27 | widen_factor: 2 28 | leaky_slope: 0.1 29 | dropout: 0.0 30 | data_dir: ./data 31 | dataset: cifar10 32 | train_sampler: RandomSampler 33 | num_classes: 10 34 | num_workers: 1 35 | alg: freematch 36 | seed: 1 37 | world_size: 1 38 | rank: 0 39 | multiprocessing_distributed: True 40 | dist_url: tcp://127.0.0.1:10043 41 | dist_backend: nccl 42 | gpu: None 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 yyyouy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cifar10_experiment/edm/generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | num=$1 3 | save_dir=$2 4 | model=$3 5 | nproc_per_node=$4 6 | 7 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/0 --seeds=0-$num --class=0 --network=$model 8 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/1 --seeds=0-$num --class=1 --network=$model 9 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/2 --seeds=0-$num --class=2 --network=$model 10 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/3 --seeds=0-$num --class=3 --network=$model 11 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/4 --seeds=0-$num --class=4 --network=$model 12 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/5 --seeds=0-$num --class=5 --network=$model 13 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/6 --seeds=0-$num --class=6 --network=$model 14 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/7 --seeds=0-$num --class=7 --network=$model 15 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/8 --seeds=0-$num --class=8 --network=$model 16 | torchrun --standalone --nproc_per_node=$nproc_per_node generate.py --outdir=$save_dir/9 --seeds=0-$num --class=9 --network=$model -------------------------------------------------------------------------------- /libs/clip.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import CLIPTokenizer, CLIPTextModel 3 | 4 | 5 | class AbstractEncoder(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def encode(self, *args, **kwargs): 10 | raise NotImplementedError 11 | 12 | 13 | class FrozenCLIPEmbedder(AbstractEncoder): 14 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 15 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 16 | super().__init__() 17 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 18 | self.transformer = CLIPTextModel.from_pretrained(version) 19 | self.device = device 20 | self.max_length = max_length 21 | self.freeze() 22 | 23 | def freeze(self): 24 | self.transformer = self.transformer.eval() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def forward(self, text): 29 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 30 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 31 | tokens = batch_encoding["input_ids"].to(self.device) 32 | outputs = self.transformer(input_ids=tokens) 33 | 34 | z = outputs.last_hidden_state 35 | return z 36 | 37 | def encode(self, text): 38 | return self(text) 39 | -------------------------------------------------------------------------------- /scripts/extract_test_prompt_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(): 12 | prompts = [ 13 | 'A green train is coming down the tracks.', 14 | 'A group of skiers are preparing to ski down a mountain.', 15 | 'A small kitchen with a low ceiling.', 16 | 'A group of elephants walking in muddy water.', 17 | 'A living area with a television and a table.', 18 | 'A road with traffic lights, street lights and cars.', 19 | 'A bus driving in a city area with traffic signs.', 20 | 'A bus pulls over to the curb close to an intersection.', 21 | 'A group of people are walking and one is holding an umbrella.', 22 | 'A baseball player taking a swing at an incoming ball.', 23 | 'A city street line with brick buildings and trees.', 24 | 'A close up of a plate of broccoli and sauce.', 25 | ] 26 | 27 | device = 'cuda' 28 | clip = libs.clip.FrozenCLIPEmbedder() 29 | clip.eval() 30 | clip.to(device) 31 | 32 | save_dir = f'assets/datasets/coco256_features/run_vis' 33 | latent = clip.encode(prompts) 34 | for i in range(len(latent)): 35 | c = latent[i].detach().cpu().numpy() 36 | np.save(os.path.join(save_dir, f'{i}.npy'), (prompts[i], c)) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/models/freematch/freematch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from train_utils import ce_loss 4 | 5 | 6 | class Get_Scalar: 7 | def __init__(self, value): 8 | self.value = value 9 | 10 | def get_value(self, iter): 11 | return self.value 12 | 13 | def __call__(self, iter): 14 | return self.value 15 | 16 | 17 | 18 | def consistency_loss(dataset,logits_s, logits_w,time_p,p_model, name='ce', use_hard_labels=True): 19 | assert name in ['ce', 'L2'] 20 | logits_w = logits_w.detach() 21 | if name == 'L2': 22 | assert logits_w.size() == logits_s.size() 23 | return F.mse_loss(logits_s, logits_w, reduction='mean') 24 | 25 | elif name == 'L2_mask': 26 | pass 27 | 28 | elif name == 'ce': 29 | pseudo_label = torch.softmax(logits_w, dim=-1) 30 | max_probs, max_idx = torch.max(pseudo_label, dim=-1) 31 | p_cutoff = time_p 32 | p_model_cutoff = p_model / torch.max(p_model,dim=-1)[0] 33 | threshold = p_cutoff * p_model_cutoff[max_idx] 34 | if dataset == 'svhn': 35 | threshold = torch.clamp(threshold, min=0.9, max=0.95) 36 | mask = max_probs.ge(threshold) 37 | if use_hard_labels: 38 | masked_loss = ce_loss(logits_s, max_idx, use_hard_labels, reduction='none') * mask.float() 39 | else: 40 | pseudo_label = torch.softmax(logits_w / T, dim=-1) 41 | masked_loss = ce_loss(logits_s, pseudo_label, use_hard_labels) * mask.float() 42 | return masked_loss.mean(), mask 43 | 44 | else: 45 | assert Exception('Not Implemented consistency_loss') 46 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/datasets/DistributedProxySampler.py: -------------------------------------------------------------------------------- 1 | # copyright: https://github.com/pytorch/pytorch/issues/23430#issuecomment-562350407 2 | 3 | import math 4 | import torch 5 | from torch.utils.data.distributed import DistributedSampler 6 | 7 | 8 | class DistributedProxySampler(DistributedSampler): 9 | """Sampler that restricts data loading to a subset of input sampler indices. 10 | 11 | It is especially useful in conjunction with 12 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 13 | process can pass a DistributedSampler instance as a DataLoader sampler, 14 | and load a subset of the original dataset that is exclusive to it. 15 | 16 | .. note:: 17 | Input sampler is assumed to be of constant size. 18 | 19 | Arguments: 20 | sampler: Input data sampler. 21 | num_replicas (optional): Number of processes participating in 22 | distributed training. 23 | rank (optional): Rank of the current process within num_replicas. 24 | """ 25 | 26 | def __init__(self, sampler, num_replicas=None, rank=None): 27 | super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False) 28 | self.sampler = sampler 29 | 30 | def __iter__(self): 31 | # deterministically shuffle based on epoch 32 | torch.manual_seed(self.epoch) 33 | indices = list(self.sampler) 34 | 35 | # add extra samples to make it evenly divisible 36 | indices += indices[:(self.total_size - len(indices))] 37 | if len(indices) != self.total_size: 38 | raise RuntimeError("{} vs {}".format(len(indices), self.total_size)) 39 | 40 | # subsample 41 | indices = indices[self.rank:self.total_size:self.num_replicas] 42 | if len(indices) != self.num_samples: 43 | raise RuntimeError("{} vs {}".format(len(indices), self.num_samples)) 44 | 45 | return iter(indices) -------------------------------------------------------------------------------- /src/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | from torch.optim import Optimizer 10 | 11 | 12 | class SGD(Optimizer): 13 | 14 | def __init__(self, params, lr, momentum=0, weight_decay=0, nesterov=False): 15 | if lr < 0.0: 16 | raise ValueError(f'Invalid learning rate: {lr}') 17 | if weight_decay < 0.0: 18 | raise ValueError(f'Invalid weight_decay value: {weight_decay}') 19 | 20 | defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, 21 | nesterov=nesterov) 22 | super(SGD, self).__init__(params, defaults) 23 | 24 | @torch.no_grad() 25 | def step(self): 26 | for group in self.param_groups: 27 | weight_decay = group['weight_decay'] 28 | momentum = group['momentum'] 29 | nesterov = group['nesterov'] 30 | 31 | for p in group['params']: 32 | if p.grad is None: 33 | continue 34 | 35 | d_p = p.grad 36 | if weight_decay != 0: 37 | d_p = d_p.add(p, alpha=weight_decay) 38 | d_p.mul_(-group['lr']) 39 | 40 | if momentum != 0: 41 | param_state = self.state[p] 42 | if 'momentum_buffer' not in param_state: 43 | buf = param_state['momentum_buffer'] = d_p.clone().detach() 44 | else: 45 | buf = param_state['momentum_buffer'] 46 | buf.mul_(momentum).add_(d_p) 47 | 48 | if nesterov: 49 | d_p.add_(buf, alpha=momentum) 50 | else: 51 | d_p = buf 52 | 53 | p.add_(d_p) 54 | 55 | return None 56 | -------------------------------------------------------------------------------- /scripts/extract_imagenet_feature.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | from datasets import ImageNet 5 | from torch.utils.data import DataLoader 6 | from libs.autoencoder import get_model 7 | import argparse 8 | from tqdm import tqdm 9 | torch.manual_seed(0) 10 | np.random.seed(0) 11 | 12 | 13 | def main(resolution=256): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('path') 16 | args = parser.parse_args() 17 | 18 | dataset = ImageNet(path=args.path, resolution=resolution, random_flip=False) 19 | train_dataset = dataset.get_split(split='train', labeled=True) 20 | train_dataset_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, drop_last=False, 21 | num_workers=8, pin_memory=True, persistent_workers=True) 22 | 23 | model = get_model('assets/stable-diffusion/autoencoder_kl.pth') 24 | model = nn.DataParallel(model) 25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 26 | model.to(device) 27 | 28 | # features = [] 29 | # labels = [] 30 | 31 | idx = 0 32 | for batch in tqdm(train_dataset_loader): 33 | img, label = batch 34 | img = torch.cat([img, img.flip(dims=[-1])], dim=0) 35 | img = img.to(device) 36 | moments = model(img, fn='encode_moments') 37 | moments = moments.detach().cpu().numpy() 38 | 39 | label = torch.cat([label, label], dim=0) 40 | label = label.detach().cpu().numpy() 41 | 42 | for moment, lb in zip(moments, label): 43 | np.save(f'assets/datasets/imagenet{resolution}_features/{idx}.npy', (moment, lb)) 44 | idx += 1 45 | 46 | print(f'save {idx} files') 47 | 48 | # features = np.concatenate(features, axis=0) 49 | # labels = np.concatenate(labels, axis=0) 50 | # print(f'features.shape={features.shape}') 51 | # print(f'labels.shape={labels.shape}') 52 | # np.save(f'imagenet{resolution}_features.npy', features) 53 | # np.save(f'imagenet{resolution}_labels.npy', labels) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /scripts/extract_imagenet_features.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | from datasets import ImageNet 5 | from torch.utils.data import DataLoader 6 | from libs.autoencoder import get_model 7 | import argparse 8 | from tqdm import tqdm 9 | torch.manual_seed(0) 10 | np.random.seed(0) 11 | 12 | 13 | def main(resolution=256): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('path') 16 | args = parser.parse_args() 17 | 18 | dataset = ImageNet(path=args.path, resolution=resolution, random_flip=False) 19 | train_dataset = dataset.get_split(split='train', labeled=True) 20 | train_dataset_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, drop_last=False, 21 | num_workers=8, pin_memory=True, persistent_workers=True) 22 | 23 | model = get_model('assets/stable-diffusion/autoencoder_kl.pth') 24 | model = nn.DataParallel(model) 25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 26 | model.to(device) 27 | 28 | # features = [] 29 | # labels = [] 30 | 31 | idx = 0 32 | for batch in tqdm(train_dataset_loader): 33 | img, label = batch 34 | img = torch.cat([img, img.flip(dims=[-1])], dim=0) 35 | img = img.to(device) 36 | moments = model(img, fn='encode_moments') 37 | moments = moments.detach().cpu().numpy() 38 | 39 | label = torch.cat([label, label], dim=0) 40 | label = label.detach().cpu().numpy() 41 | 42 | for moment, lb in zip(moments, label): 43 | np.save(f'assets/datasets/imagenet{resolution}_features/{idx}.npy', (moment, lb)) 44 | idx += 1 45 | 46 | print(f'save {idx} files') 47 | 48 | # features = np.concatenate(features, axis=0) 49 | # labels = np.concatenate(labels, axis=0) 50 | # print(f'features.shape={features.shape}') 51 | # print(f'labels.shape={labels.shape}') 52 | # np.save(f'imagenet{resolution}_features.npy', features) 53 | # np.save(f'imagenet{resolution}_labels.npy', labels) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /scripts/extract_mscoco_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(resolution=256): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--split', default='train') 14 | args = parser.parse_args() 15 | print(args) 16 | 17 | 18 | if args.split == "train": 19 | datas = MSCOCODatabase(root='assets/datasets/coco/train2014', 20 | annFile='assets/datasets/coco/annotations/captions_train2014.json', 21 | size=resolution) 22 | save_dir = f'assets/datasets/coco{resolution}_features/train' 23 | elif args.split == "val": 24 | datas = MSCOCODatabase(root='assets/datasets/coco/val2014', 25 | annFile='assets/datasets/coco/annotations/captions_val2014.json', 26 | size=resolution) 27 | save_dir = f'assets/datasets/coco{resolution}_features/val' 28 | else: 29 | raise NotImplementedError("ERROR!") 30 | 31 | device = "cuda" 32 | os.makedirs(save_dir) 33 | 34 | autoencoder = libs.autoencoder.get_model('assets/stable-diffusion/autoencoder_kl.pth') 35 | autoencoder.to(device) 36 | clip = libs.clip.FrozenCLIPEmbedder() 37 | clip.eval() 38 | clip.to(device) 39 | 40 | with torch.no_grad(): 41 | for idx, data in tqdm(enumerate(datas)): 42 | x, captions = data 43 | 44 | if len(x.shape) == 3: 45 | x = x[None, ...] 46 | x = torch.tensor(x, device=device) 47 | moments = autoencoder(x, fn='encode_moments').squeeze(0) 48 | moments = moments.detach().cpu().numpy() 49 | np.save(os.path.join(save_dir, f'{idx}.npy'), moments) 50 | 51 | latent = clip.encode(captions) 52 | for i in range(len(latent)): 53 | c = latent[i].detach().cpu().numpy() 54 | np.save(os.path.join(save_dir, f'{idx}_{i}.npy'), c) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/pseudo_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchvision 3 | import torch 4 | import os 5 | import pickle 6 | import sys 7 | import pickle 8 | import numpy as np 9 | from torch.utils.data import DataLoader 10 | # 随机打乱targets 11 | import random 12 | import tarfile 13 | 14 | def unpickle(file): 15 | with open(file, 'rb') as fo: 16 | dict = pickle.load(fo, encoding='latin1') 17 | return dict 18 | 19 | if __name__ == "__main__": 20 | import argparse 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--data_path', type=str, default='./data/cifar-10-batches-py') 23 | parser.add_argument('--save_path', type=str, default='./saved_models/freematch_cifar10_40_1') 24 | 25 | args = parser.parse_args() 26 | 27 | labels = [] # save labels 28 | data_dict = {} # data_1, data_2, data_3, data_4, data_5 29 | 30 | for i in range(5): 31 | batch = unpickle(os.path.join(args.data_path, 'data_batch_{}'.format(i+1))) 32 | data_dict['data_{}'.format(i+1)] = batch['data'] 33 | labels.extend(batch['labels']) 34 | 35 | pseudo_path = os.path.join(args.save_path, 'pseudo_label.npy') 36 | pseudo_labels = np.load(pseudo_path) 37 | 38 | # print true labels distribution and pseudo labels distribution 39 | true_distribution = torch.bincount(torch.tensor(labels)) 40 | pseudo_distribution = torch.bincount(torch.tensor(pseudo_labels)) 41 | print("true distribution:", true_distribution) 42 | print("pseudo distribution:", pseudo_distribution) 43 | # print accuracy 44 | print(np.sum(labels == pseudo_labels) / 50000) 45 | 46 | for i in range(5): 47 | data = data_dict['data_{}'.format(i+1)] 48 | targets = pseudo_labels[i*10000:(i+1)*10000] 49 | batch_path = os.path.join(args.save_path, 'cifar-10-batches-py/data_batch_{}'.format(i+1)) 50 | os.makedirs(os.path.dirname(batch_path), exist_ok=True) 51 | with open(batch_path, 'wb') as f: 52 | pickle.dump({'data': data, 'labels': targets}, f) 53 | 54 | with tarfile.open(os.path.join(args.save_path, 'cifar-10-python.tar.gz'), 'w:gz') as f: 55 | f.add(os.path.join(args.save_path, 'cifar-10-batches-py'), arcname='cifar-10-batches-py') 56 | 57 | 58 | -------------------------------------------------------------------------------- /cifar10_experiment/edm/torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /scripts/sweep_sample.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random 3 | import os 4 | import argparse 5 | import subprocess 6 | from pathlib import Path 7 | import datetime 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--workdir', required=True) 13 | parser.add_argument('--port', type=int) 14 | parser.add_argument('--ckpts', type=str) 15 | 16 | args, unknown = parser.parse_known_args() 17 | args.ckpt_root = os.path.join(args.workdir, 'ckpts') 18 | if args.port is None: 19 | args.port = random.randint(10000, 30000) 20 | 21 | for x in unknown: 22 | assert '=' in x 23 | 24 | return args, ' '.join(unknown) 25 | 26 | 27 | def valid_str(unknown): 28 | items = unknown.split(' ') 29 | res = [] 30 | for item in items: 31 | assert item.startswith('--') 32 | res.append(item[2:].replace('/', '_')) 33 | return '_'.join(res) 34 | 35 | 36 | def main(): 37 | args, unknown = parse_args() 38 | ckpts = [f'{int(ckpt)}.ckpt' for ckpt in args.ckpts.split(',')] 39 | n_devices = len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) 40 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 41 | output_path = os.path.join(args.workdir, f'{now}_sweep_sample.log') 42 | 43 | for ckpt in ckpts: 44 | print(f'try sample {os.path.join(args.ckpt_root, ckpt)}') 45 | 46 | while not os.path.exists(os.path.join(args.ckpt_root, ckpt)): 47 | time.sleep(5 + 5 * random.random()) 48 | time.sleep(5 + 5 * random.random()) 49 | 50 | if os.path.exists(os.path.join(args.ckpt_root, '.state', ckpt, valid_str(unknown)[:100])): 51 | print(f'{ckpt} already evaluated, skip') 52 | continue 53 | 54 | os.makedirs(os.path.join(args.ckpt_root, '.state', ckpt), exist_ok=True) # mark as running 55 | Path(os.path.join(args.ckpt_root, '.state', ckpt, valid_str(unknown)[:100])).touch() 56 | 57 | dct = dict() 58 | nnet_path = os.path.join(args.ckpt_root, ckpt, 'nnet_ema.pth') 59 | dct['nnet_path'] = nnet_path 60 | dct['output_path'] = output_path 61 | 62 | dct_str = ' '.join([f'--{key}={val}' for key, val in dct.items()]) 63 | 64 | accelerate_args = f'--multi_gpu --main_process_port {args.port} --num_processes {n_devices} --mixed_precision fp16' 65 | cmd = f'accelerate launch {accelerate_args} sample.py {dct_str} {unknown}' 66 | cmd = list(filter(lambda x: x != '', cmd.split(' '))) 67 | print(cmd) 68 | subprocess.Popen(cmd).wait() 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /configs/accelerate_b4_subset1_2img_k_128_large.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.lambd = 0.0025 13 | config.penalty = 'l2' 14 | config.mask = 0.0 15 | config.preload = True 16 | config.fname = 'vitb4_300ep.pth.tar' 17 | config.model_name = 'deit_base_p4' 18 | config.pretrained = 'pretrained/' 19 | 20 | config.normalize = True 21 | config.root_path = '/cache/datasets/ILSVRC/Data/' 22 | config.image_folder = 'CLS-LOC/' 23 | config.image_path = '/cache/datasets/ILSVRC/Data/CLS-LOC' 24 | 25 | config.subset_path = 'imagenet_subsets1/2imgs_class.txt' 26 | config.blocks = 1 27 | 28 | config.seed = 1234 29 | config.pred = 'noise_pred' 30 | config.ema_rate = 0.9999 31 | config.z_shape = (4, 32, 32) 32 | config.resolution = 256 33 | 34 | config.autoencoder = d( 35 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth' 36 | ) 37 | 38 | config.dpm_path = 'assets/DPM' 39 | 40 | # augmentation 41 | config.augmentation_K = 128 # add 128 generated images per class 42 | config.using_true_label = True 43 | config.output_path = '' 44 | 45 | config.train = d( 46 | n_steps=300000, 47 | batch_size=1024, 48 | mode='cond', 49 | log_interval=10, 50 | eval_interval=5000, 51 | save_interval=50000, 52 | ) 53 | 54 | config.optimizer = d( 55 | name='adamw', 56 | lr=0.0002, 57 | weight_decay=0.03, 58 | betas=(0.99, 0.99), 59 | ) 60 | 61 | config.lr_scheduler = d( 62 | name='customized', 63 | warmup_steps=5000 64 | ) 65 | 66 | config.nnet = d( 67 | name='uvit', 68 | img_size=32, 69 | patch_size=2, 70 | in_chans=4, 71 | embed_dim=1024, 72 | depth=20, 73 | num_heads=16, 74 | mlp_ratio=4, 75 | qkv_bias=False, 76 | mlp_time_embed=False, 77 | num_classes=1001, 78 | use_checkpoint=True 79 | ) 80 | 81 | config.dataset = d( 82 | name='imagenet256_features', 83 | path='', 84 | cfg=True, 85 | p_uncond=0.15 86 | ) 87 | 88 | config.sample = d( 89 | sample_steps=50, 90 | n_samples=50000, 91 | mini_batch_size=20, # the decoder is large 92 | algorithm='dpm_solver', 93 | cfg=True, 94 | scale=0.4, 95 | path='' 96 | ) 97 | 98 | return config 99 | 100 | config = get_config() 101 | -------------------------------------------------------------------------------- /configs/accelerate_l7_subset2_zimg_k_128_large.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.lambd = 0.0025 13 | config.penalty = 'l2' 14 | config.mask = 0.0 15 | config.preload = True 16 | config.fname = 'vitl7_200ep.pth.tar' 17 | config.model_name = 'deit_large_p7' 18 | config.pretrained = 'pretrained/' 19 | 20 | config.normalize = True 21 | config.root_path = '/cache/datasets/ILSVRC/Data/' 22 | config.image_folder = 'CLS-LOC/' 23 | config.image_path = '/cache/datasets/ILSVRC/Data/CLS-LOC' 24 | 25 | config.subset_path = 'imagenet_subsets2/1imgs_class.txt' 26 | config.blocks = 1 27 | 28 | config.seed = 1234 29 | config.pred = 'noise_pred' 30 | config.ema_rate = 0.9999 31 | config.z_shape = (4, 32, 32) 32 | config.resolution = 256 33 | 34 | config.autoencoder = d( 35 | pretrained_path='assets/stable-diffusion/autoencoder_kl.pth' 36 | ) 37 | 38 | config.dpm_path = 'assets/DPM' 39 | 40 | # augmentation 41 | config.augmentation_K = 128 # add 128 generated images per class 42 | config.using_true_label = True 43 | config.output_path = '' 44 | 45 | config.train = d( 46 | n_steps=300000, 47 | batch_size=1024, 48 | mode='cond', 49 | log_interval=10, 50 | eval_interval=5000, 51 | save_interval=50000, 52 | ) 53 | 54 | config.optimizer = d( 55 | name='adamw', 56 | lr=0.0002, 57 | weight_decay=0.03, 58 | betas=(0.99, 0.99), 59 | ) 60 | 61 | config.lr_scheduler = d( 62 | name='customized', 63 | warmup_steps=5000 64 | ) 65 | 66 | config.nnet = d( 67 | name='uvit', 68 | img_size=32, 69 | patch_size=2, 70 | in_chans=4, 71 | embed_dim=1024, 72 | depth=20, 73 | num_heads=16, 74 | mlp_ratio=4, 75 | qkv_bias=False, 76 | mlp_time_embed=False, 77 | num_classes=1001, 78 | use_checkpoint=True 79 | ) 80 | 81 | config.dataset = d( 82 | name='imagenet256_features', 83 | path='', 84 | cfg=True, 85 | p_uncond=0.15 86 | ) 87 | 88 | config.sample = d( 89 | sample_steps=50, 90 | n_samples=50000, 91 | mini_batch_size=20, # the decoder is large 92 | algorithm='dpm_solver', 93 | cfg=True, 94 | scale=0.4, 95 | path='' 96 | ) 97 | 98 | return config 99 | 100 | config = get_config() 101 | -------------------------------------------------------------------------------- /configs/accelerate_b4_subset1_2img_k_128_huge.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.lambd = 0.0025 13 | config.penalty = 'l2' 14 | config.mask = 0.0 15 | config.preload = True 16 | config.fname = 'vitb4_300ep.pth.tar' 17 | config.model_name = 'deit_base_p4' 18 | config.pretrained = 'pretrained/' 19 | 20 | config.normalize = True 21 | config.root_path = '/cache/datasets/ILSVRC/Data/' 22 | config.image_folder = 'CLS-LOC/' 23 | config.image_path = '/cache/datasets/ILSVRC/Data/CLS-LOC' 24 | 25 | config.subset_path = 'imagenet_subsets1/2imgs_class.txt' 26 | config.blocks = 1 27 | 28 | config.seed = 1234 29 | config.pred = 'noise_pred' 30 | config.ema_rate = 0.9999 31 | config.z_shape = (4, 32, 32) 32 | config.resolution = 256 33 | 34 | config.autoencoder = d( 35 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 36 | ) 37 | 38 | config.dpm_path = 'assets/DPM' 39 | 40 | # augmentation 41 | config.augmentation_K = 128 # add 128 generated images per class 42 | config.using_true_label = True 43 | config.output_path = '' 44 | 45 | config.train = d( 46 | n_steps=500000, 47 | batch_size=1024, 48 | mode='cond', 49 | log_interval=10, 50 | eval_interval=5000, 51 | save_interval=50000, 52 | ) 53 | 54 | config.optimizer = d( 55 | name='adamw', 56 | lr=0.0002, 57 | weight_decay=0.03, 58 | betas=(0.99, 0.99), 59 | ) 60 | 61 | config.lr_scheduler = d( 62 | name='customized', 63 | warmup_steps=5000 64 | ) 65 | 66 | config.nnet = d( 67 | name='uvit', 68 | img_size=32, 69 | patch_size=2, 70 | in_chans=4, 71 | embed_dim=1152, 72 | depth=28, 73 | num_heads=16, 74 | mlp_ratio=4, 75 | qkv_bias=False, 76 | mlp_time_embed=False, 77 | num_classes=1001, 78 | use_checkpoint=True, 79 | conv=False 80 | ) 81 | 82 | config.dataset = d( 83 | name='imagenet256_features', 84 | path='', 85 | cfg=True, 86 | p_uncond=0.1 87 | ) 88 | 89 | config.sample = d( 90 | sample_steps=50, 91 | n_samples=50000, 92 | mini_batch_size=20, # the decoder is large 93 | algorithm='dpm_solver', 94 | cfg=True, 95 | scale=0.4, 96 | path='' 97 | ) 98 | 99 | return config 100 | 101 | config = get_config() 102 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | 4 | import torch 5 | 6 | from utils import net_builder 7 | from datasets.ssl_dataset import SSL_Dataset 8 | from datasets.data_utils import get_data_loader 9 | import accelerate 10 | 11 | if __name__ == "__main__": 12 | import argparse 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--load_path', type=str, default='./saved_models/fixmatch/model_best.pth') 15 | parser.add_argument('--use_train_model', action='store_true') 16 | 17 | ''' 18 | Backbone Net Configurations 19 | ''' 20 | parser.add_argument('--net', type=str, default='WideResNet') 21 | parser.add_argument('--net_from_name', type=bool, default=False) 22 | parser.add_argument('--depth', type=int, default=28) 23 | parser.add_argument('--widen_factor', type=int, default=2) 24 | parser.add_argument('--leaky_slope', type=float, default=0.1) 25 | parser.add_argument('--dropout', type=float, default=0.0) 26 | 27 | ''' 28 | Data Configurations 29 | ''' 30 | parser.add_argument('--batch_size', type=int, default=256) 31 | parser.add_argument('--data_dir', type=str, default='./data') 32 | parser.add_argument('--dataset', type=str, default='cifar10') 33 | parser.add_argument('--num_classes', type=int, default=10) 34 | args = parser.parse_args() 35 | 36 | accelerator = accelerate.Accelerator() 37 | 38 | checkpoint_path = os.path.join(args.load_path) 39 | checkpoint = torch.load(checkpoint_path) 40 | load_model = checkpoint['train_model'] if args.use_train_model else checkpoint['ema_model'] 41 | 42 | _net_builder = net_builder(args.net, 43 | args.net_from_name, 44 | {'depth': args.depth, 45 | 'widen_factor': args.widen_factor, 46 | 'leaky_slope': args.leaky_slope, 47 | 'dropRate': args.dropout, 48 | 'use_embed': False}) 49 | 50 | net = _net_builder(num_classes=args.num_classes) 51 | 52 | _eval_dset = SSL_Dataset(args, name=args.dataset, train=False, 53 | num_classes=args.num_classes, data_dir=args.data_dir) 54 | eval_dset = _eval_dset.get_dset() 55 | 56 | eval_loader = get_data_loader(eval_dset, 57 | args.batch_size, 58 | num_workers=1, shuffle=False) 59 | 60 | #net, eval_loader = accelerator.prepare(net, eval_loader) 61 | 62 | #net.load_state_dict(load_model) 63 | # if torch.cuda.is_available(): 64 | # net.cuda() 65 | weights_dict = {} 66 | for k, v in load_model.items(): 67 | new_k = k.replace('module.', '') if 'module' in k else k 68 | weights_dict[new_k] = v 69 | 70 | net.load_state_dict(weights_dict) 71 | 72 | net.eval() 73 | 74 | 75 | acc = 0.0 76 | with torch.no_grad(): 77 | for (_, image, target) in eval_loader: 78 | print(_) 79 | image = image.type(torch.FloatTensor)#.cuda() 80 | logit = net(image) 81 | print(logit.shape) 82 | 83 | target, logit = accelerator.gather(target), accelerator.gather(logit) 84 | acc += logit.cpu().max(1)[1].eq(target.cpu()).sum().numpy() 85 | 86 | accelerator.print(f"Test Accuracy: {acc/len(eval_dset)}") 87 | 88 | -------------------------------------------------------------------------------- /extract_imagenet_features_semi.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | import os 5 | from datasets import ImageNet 6 | from torch.utils.data import DataLoader 7 | from libs.autoencoder import get_model 8 | import argparse 9 | from tqdm import tqdm 10 | torch.manual_seed(0) 11 | np.random.seed(0) 12 | 13 | def mprint(*args): 14 | print('\n-----------------------------') 15 | print(*args) 16 | print('-----------------------------\n') 17 | 18 | from absl import flags 19 | from absl import app 20 | from ml_collections import config_flags 21 | import sys 22 | from pathlib import Path 23 | 24 | FLAGS = flags.FLAGS 25 | config_flags.DEFINE_config_file( 26 | "config", None, "Training configuration.", lock_config=False) 27 | flags.mark_flags_as_required(["config"]) 28 | 29 | def get_config_name(): 30 | argv = sys.argv 31 | for i in range(1, len(argv)): 32 | if argv[i].startswith('--config='): 33 | return Path(argv[i].split('=')[-1]).stem 34 | 35 | 36 | def get_hparams(): 37 | argv = sys.argv 38 | lst = [] 39 | for i in range(1, len(argv)): 40 | assert '=' in argv[i] 41 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'): 42 | hparam, val = argv[i].split('=') 43 | hparam = hparam.split('.')[-1] 44 | if hparam.endswith('path'): 45 | val = Path(val).stem 46 | lst.append(f'{hparam}={val}') 47 | hparams = '-'.join(lst) 48 | if hparams == '': 49 | hparams = 'default' 50 | return hparams 51 | 52 | def main(argv): 53 | config = FLAGS.config 54 | config.config_name = get_config_name() 55 | config.hparams = get_hparams() 56 | cluster_name = config.model_name + '-' + '-'.join(config.subset_path.split('/')).split('.txt')[0] 57 | cluster_path = f'pretrained/cluster/{cluster_name}/imagenet_features_preds.npy' 58 | fnames_path = f'pretrained/cluster/{cluster_name}/imagenet_features_fnames.pth' 59 | autoencoder_path = config.autoencoder.pretrained_path 60 | path = config.image_path 61 | 62 | dataset = ImageNet(path=path, resolution=config.resolution, random_flip=False, cluster_path=cluster_path, fnames_path=fnames_path) 63 | 64 | train_dataset = dataset.get_split(split='train', labeled=True) 65 | 66 | train_batch_size = 128 67 | if config.resolution == 512: 68 | train_batch_size = 64 69 | 70 | train_dataset_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=False, drop_last=False, 71 | num_workers=8, pin_memory=True, persistent_workers=True) 72 | 73 | model = get_model(autoencoder_path) 74 | model = nn.DataParallel(model) 75 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 76 | model.to(device) 77 | 78 | save_path = f'pretrained/datasets/{cluster_name}' 79 | save_features_path = os.path.join(save_path, f'imagenet{config.resolution}_features') 80 | os.system(f'mkdir -p {save_features_path}') 81 | 82 | idx = 0 83 | for batch in tqdm(train_dataset_loader): 84 | img, label = batch 85 | img = torch.cat([img, img.flip(dims=[-1])], dim=0) 86 | img = img.to(device) 87 | moments = model(img, fn='encode_moments') 88 | moments = moments.detach().cpu().numpy() 89 | 90 | label = torch.cat([label, label], dim=0) 91 | label = label.detach().cpu().numpy() 92 | 93 | for moment, lb in zip(moments, label): 94 | np.save(os.path.join(save_features_path, f'{idx}.npy'), (moment, lb)) 95 | idx += 1 96 | 97 | mprint(f'save {idx} files') 98 | 99 | if __name__ == "__main__": 100 | app.run(main) 101 | -------------------------------------------------------------------------------- /cifar10_experiment/edm/training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Loss functions used in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import torch 12 | from torch_utils import persistence 13 | 14 | #---------------------------------------------------------------------------- 15 | # Loss function corresponding to the variance preserving (VP) formulation 16 | # from the paper "Score-Based Generative Modeling through Stochastic 17 | # Differential Equations". 18 | 19 | @persistence.persistent_class 20 | class VPLoss: 21 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5): 22 | self.beta_d = beta_d 23 | self.beta_min = beta_min 24 | self.epsilon_t = epsilon_t 25 | 26 | def __call__(self, net, images, labels, augment_pipe=None): 27 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 28 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) 29 | weight = 1 / sigma ** 2 30 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 31 | n = torch.randn_like(y) * sigma 32 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 33 | loss = weight * ((D_yn - y) ** 2) 34 | return loss 35 | 36 | def sigma(self, t): 37 | t = torch.as_tensor(t) 38 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() 39 | 40 | #---------------------------------------------------------------------------- 41 | # Loss function corresponding to the variance exploding (VE) formulation 42 | # from the paper "Score-Based Generative Modeling through Stochastic 43 | # Differential Equations". 44 | 45 | @persistence.persistent_class 46 | class VELoss: 47 | def __init__(self, sigma_min=0.02, sigma_max=100): 48 | self.sigma_min = sigma_min 49 | self.sigma_max = sigma_max 50 | 51 | def __call__(self, net, images, labels, augment_pipe=None): 52 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 53 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) 54 | weight = 1 / sigma ** 2 55 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 56 | n = torch.randn_like(y) * sigma 57 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 58 | loss = weight * ((D_yn - y) ** 2) 59 | return loss 60 | 61 | #---------------------------------------------------------------------------- 62 | # Improved loss function proposed in the paper "Elucidating the Design Space 63 | # of Diffusion-Based Generative Models" (EDM). 64 | 65 | @persistence.persistent_class 66 | class EDMLoss: 67 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): 68 | self.P_mean = P_mean 69 | self.P_std = P_std 70 | self.sigma_data = sigma_data 71 | 72 | def __call__(self, net, images, labels=None, augment_pipe=None): 73 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 74 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 75 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 76 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 77 | n = torch.randn_like(y) * sigma 78 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 79 | loss = weight * ((D_yn - y) ** 2) 80 | return loss 81 | 82 | #---------------------------------------------------------------------------- 83 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from torch.utils.tensorboard import SummaryWriter 4 | import logging 5 | import yaml 6 | 7 | 8 | def over_write_args_from_file(args, yml): 9 | if yml == '': 10 | return 11 | with open(yml, 'r', encoding='utf-8') as f: 12 | dic = yaml.load(f.read(), Loader=yaml.Loader) 13 | for k in dic: 14 | setattr(args, k, dic[k]) 15 | 16 | 17 | def setattr_cls_from_kwargs(cls, kwargs): 18 | # if default values are in the cls, 19 | # overlap the value by kwargs 20 | for key in kwargs.keys(): 21 | if hasattr(cls, key): 22 | print(f"{key} in {cls} is overlapped by kwargs: {getattr(cls, key)} -> {kwargs[key]}") 23 | setattr(cls, key, kwargs[key]) 24 | 25 | 26 | def test_setattr_cls_from_kwargs(): 27 | class _test_cls: 28 | def __init__(self): 29 | self.a = 1 30 | self.b = 'hello' 31 | 32 | test_cls = _test_cls() 33 | config = {'a': 3, 'b': 'change_hello', 'c': 5} 34 | setattr_cls_from_kwargs(test_cls, config) 35 | for key in config.keys(): 36 | print(f"{key}:\t {getattr(test_cls, key)}") 37 | 38 | 39 | def net_builder(net_name, from_name: bool, net_conf=None, is_remix=False): 40 | """ 41 | return **class** of backbone network (not instance). 42 | Args 43 | net_name: 'WideResNet' or network names in torchvision.models 44 | from_name: If True, net_buidler takes models in torch.vision models. Then, net_conf is ignored. 45 | net_conf: When from_name is False, net_conf is the configuration of backbone network (now, only WRN is supported). 46 | """ 47 | if from_name: 48 | import torchvision.models as models 49 | model_name_list = sorted(name for name in models.__dict__ 50 | if name.islower() and not name.startswith("__") 51 | and callable(models.__dict__[name])) 52 | 53 | if net_name not in model_name_list: 54 | assert Exception(f"[!] Networks\' Name is wrong, check net config, \ 55 | expected: {model_name_list} \ 56 | received: {net_name}") 57 | else: 58 | return models.__dict__[net_name] 59 | 60 | else: 61 | if net_name == 'WideResNet': 62 | import models.nets.wrn as net 63 | builder = getattr(net, 'build_WideResNet')() 64 | elif net_name == 'WideResNetVar': 65 | import models.nets.wrn_var as net 66 | builder = getattr(net, 'build_WideResNetVar')() 67 | elif net_name == 'ResNet50': 68 | import models.nets.resnet50 as net 69 | builder = getattr(net, 'build_ResNet50')(is_remix) 70 | else: 71 | assert Exception("Not Implemented Error") 72 | 73 | if net_name != 'ResNet50': 74 | setattr_cls_from_kwargs(builder, net_conf) 75 | return builder.build 76 | 77 | 78 | def test_net_builder(net_name, from_name, net_conf=None): 79 | builder = net_builder(net_name, from_name, net_conf) 80 | print(f"net_name: {net_name}, from_name: {from_name}, net_conf: {net_conf}") 81 | print(builder) 82 | 83 | 84 | def get_logger(name, save_path=None, level='INFO'): 85 | logger = logging.getLogger(name) 86 | logger.setLevel(getattr(logging, level)) 87 | 88 | log_format = logging.Formatter('[%(asctime)s %(levelname)s] %(message)s') 89 | streamHandler = logging.StreamHandler() 90 | streamHandler.setFormatter(log_format) 91 | logger.addHandler(streamHandler) 92 | 93 | if not save_path is None: 94 | os.makedirs(save_path, exist_ok=True) 95 | fileHandler = logging.FileHandler(os.path.join(save_path, 'log.txt')) 96 | fileHandler.setFormatter(log_format) 97 | logger.addHandler(fileHandler) 98 | 99 | return logger 100 | 101 | 102 | def count_parameters(model): 103 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 104 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import torch 10 | import math 11 | from src.utils import AllReduce 12 | 13 | 14 | logger = getLogger() 15 | 16 | 17 | def init_msn_loss( 18 | num_views=1, 19 | tau=0.1, 20 | me_max=True, 21 | return_preds=False 22 | ): 23 | """ 24 | Make unsupervised MSN loss 25 | 26 | :num_views: number of anchor views 27 | :param tau: cosine similarity temperature 28 | :param me_max: whether to perform me-max regularization 29 | :param return_preds: whether to return anchor predictions 30 | """ 31 | softmax = torch.nn.Softmax(dim=1) 32 | 33 | def sharpen(p, T): 34 | sharp_p = p**(1./T) 35 | sharp_p /= torch.sum(sharp_p, dim=1, keepdim=True) 36 | return sharp_p 37 | 38 | def snn(query, supports, support_labels, temp=tau): 39 | """ Soft Nearest Neighbours similarity classifier """ 40 | query = torch.nn.functional.normalize(query) 41 | supports = torch.nn.functional.normalize(supports) 42 | return softmax(query @ supports.T / temp) @ support_labels 43 | 44 | def loss( 45 | anchor_views, 46 | target_views, 47 | prototypes, 48 | proto_labels, 49 | T=0.25, 50 | use_entropy=False, 51 | use_sinkhorn=False, 52 | sharpen=sharpen, 53 | snn=snn 54 | ): 55 | # Step 1: compute anchor predictions 56 | probs = snn(anchor_views, prototypes, proto_labels) 57 | 58 | # Step 2: compute targets for anchor predictions 59 | with torch.no_grad(): 60 | targets = sharpen(snn(target_views, prototypes, proto_labels), T=T) 61 | if use_sinkhorn: 62 | targets = distributed_sinkhorn(targets) 63 | targets = torch.cat([targets for _ in range(num_views)], dim=0) 64 | 65 | # Step 3: compute cross-entropy loss H(targets, queries) 66 | loss = torch.mean(torch.sum(torch.log(probs**(-targets)), dim=1)) 67 | 68 | # Step 4: compute me-max regularizer 69 | rloss = 0. 70 | if me_max: 71 | avg_probs = AllReduce.apply(torch.mean(probs, dim=0)) 72 | rloss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs))) 73 | 74 | sloss = 0. 75 | if use_entropy: 76 | sloss = torch.mean(torch.sum(torch.log(probs**(-probs)), dim=1)) 77 | 78 | # -- logging 79 | with torch.no_grad(): 80 | num_ps = float(len(set(targets.argmax(dim=1).tolist()))) 81 | max_t = targets.max(dim=1).values.mean() 82 | min_t = targets.min(dim=1).values.mean() 83 | log_dct = {'np': num_ps, 'max_t': max_t, 'min_t': min_t} 84 | 85 | if return_preds: 86 | return loss, rloss, sloss, log_dct, targets 87 | 88 | return loss, rloss, sloss, log_dct 89 | 90 | return loss 91 | 92 | 93 | @torch.no_grad() 94 | def distributed_sinkhorn(Q, num_itr=3, use_dist=True): 95 | _got_dist = use_dist and torch.distributed.is_available() \ 96 | and torch.distributed.is_initialized() \ 97 | and (torch.distributed.get_world_size() > 1) 98 | 99 | if _got_dist: 100 | world_size = torch.distributed.get_world_size() 101 | else: 102 | world_size = 1 103 | 104 | Q = Q.T 105 | B = Q.shape[1] * world_size # number of samples to assign 106 | K = Q.shape[0] # how many prototypes 107 | 108 | # make the matrix sums to 1 109 | sum_Q = torch.sum(Q) 110 | if _got_dist: 111 | torch.distributed.all_reduce(sum_Q) 112 | Q /= sum_Q 113 | 114 | for it in range(num_itr): 115 | # normalize each row: total weight per prototype must be 1/K 116 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 117 | if _got_dist: 118 | torch.distributed.all_reduce(sum_of_rows) 119 | Q /= sum_of_rows 120 | Q /= K 121 | 122 | # normalize each column: total weight per sample must be 1/B 123 | Q /= torch.sum(Q, dim=0, keepdim=True) 124 | Q /= B 125 | 126 | Q *= B # the colomns must sum to 1 so that Q is an assignment 127 | return Q.T 128 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/get_labels.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from utils import net_builder 8 | from datasets.ssl_dataset import SSL_Dataset 9 | from datasets.data_utils import get_data_loader 10 | import accelerate 11 | 12 | if __name__ == "__main__": 13 | import argparse 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--load_path', type=str, default='./saved_models/fixmatch/model_best.pth') 16 | parser.add_argument('--use_train_model', action='store_true') 17 | 18 | ''' 19 | Backbone Net Configurations 20 | ''' 21 | parser.add_argument('--net', type=str, default='WideResNet') 22 | parser.add_argument('--net_from_name', type=bool, default=False) 23 | parser.add_argument('--depth', type=int, default=28) 24 | parser.add_argument('--widen_factor', type=int, default=2) 25 | parser.add_argument('--leaky_slope', type=float, default=0.1) 26 | parser.add_argument('--dropout', type=float, default=0.0) 27 | 28 | ''' 29 | Data Configurations 30 | ''' 31 | parser.add_argument('--batch_size', type=int, default=256) 32 | parser.add_argument('--data_dir', type=str, default='./data') 33 | parser.add_argument('--dataset', type=str, default='cifar10') 34 | parser.add_argument('--num_classes', type=int, default=10) 35 | 36 | parser.add_argument('--save_path', type=str, default='./saved_models/fixmatch') 37 | 38 | args = parser.parse_args() 39 | 40 | accelerator = accelerate.Accelerator() 41 | 42 | checkpoint_path = os.path.join(args.load_path) 43 | checkpoint = torch.load(checkpoint_path) 44 | load_model = checkpoint['train_model'] if args.use_train_model else checkpoint['ema_model'] 45 | 46 | _net_builder = net_builder(args.net, 47 | args.net_from_name, 48 | {'depth': args.depth, 49 | 'widen_factor': args.widen_factor, 50 | 'leaky_slope': args.leaky_slope, 51 | 'dropRate': args.dropout, 52 | 'use_embed': False}) 53 | 54 | net = _net_builder(num_classes=args.num_classes) 55 | 56 | _train_dset = SSL_Dataset(args, alg='fullysupervised', name=args.dataset, train=True, 57 | num_classes=args.num_classes, data_dir=args.data_dir) 58 | train_dset = _train_dset.get_label_dset() 59 | 60 | 61 | train_loader = get_data_loader(train_dset, 62 | args.batch_size, 63 | num_workers=1, shuffle=False) 64 | 65 | # net, train_loader = accelerator.prepare(net, train_loader) 66 | 67 | # net.load_state_dict(load_model) 68 | weights_dict = {} 69 | for k, v in load_model.items(): 70 | new_k = k.replace('module.', '') if 'module' in k else k 71 | weights_dict[new_k] = v 72 | net.load_state_dict(weights_dict) 73 | # if torch.cuda.is_available(): 74 | # net.cuda() 75 | net.eval() 76 | 77 | acc = 0.0 78 | target_list = [] 79 | pseudo_label_list = [] 80 | with torch.no_grad(): 81 | for (_, image, target) in train_loader: 82 | print(_) 83 | image = image.type(torch.FloatTensor)#.cuda() 84 | logit = net(image) 85 | print(logit.shape) 86 | 87 | target, logit = accelerator.gather(target), accelerator.gather(logit) 88 | for x, y in zip(target.cpu(), logit.cpu().max(1)[1]): 89 | target_list.append(x) 90 | pseudo_label_list.append(y) 91 | 92 | acc += logit.cpu().max(1)[1].eq(target.cpu()).sum().numpy() 93 | 94 | accelerator.print(len(train_dset)) 95 | 96 | target_list = torch.stack(target_list) 97 | pseudo_label_list = torch.stack(pseudo_label_list) 98 | 99 | np.save(os.path.join(args.save_path, 'pseudo_label.npy'), pseudo_label_list) 100 | 101 | #accelerator.print(pseudo_label_list[:100],'\n', target_list[:100]) 102 | 103 | print(pseudo_label_list[:100], '\n', target_list[:100]) 104 | print(pseudo_label_list[:100] == target_list[:100]) 105 | 106 | result1 = torch.bincount(target_list) 107 | result2 = torch.bincount(pseudo_label_list) 108 | 109 | print(result1) 110 | print(result2) 111 | 112 | accelerator.print(f"Test Accuracy: {acc/len(train_dset)}") 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torch.utils.data import Dataset 3 | from .data_utils import get_onehot 4 | from .augmentation.randaugment import RandAugment 5 | 6 | import torchvision 7 | from PIL import Image 8 | import numpy as np 9 | import copy 10 | 11 | 12 | class BasicDataset(Dataset): 13 | """ 14 | BasicDataset returns a pair of image and labels (targets). 15 | If targets are not given, BasicDataset returns None as the label. 16 | This class supports strong augmentation for Fixmatch, 17 | and return both weakly and strongly augmented images. 18 | """ 19 | 20 | def __init__(self, 21 | alg, 22 | data, 23 | targets=None, 24 | num_classes=None, 25 | transform=None, 26 | is_ulb=False, 27 | strong_transform=None, 28 | onehot=False, 29 | *args, **kwargs): 30 | """ 31 | Args 32 | data: x_data 33 | targets: y_data (if not exist, None) 34 | num_classes: number of label classes 35 | transform: basic transformation of data 36 | use_strong_transform: If True, this dataset returns both weakly and strongly augmented images. 37 | strong_transform: list of transformation functions for strong augmentation 38 | onehot: If True, label is converted into onehot vector. 39 | """ 40 | super(BasicDataset, self).__init__() 41 | self.alg = alg 42 | self.data = data 43 | self.targets = targets 44 | 45 | self.num_classes = num_classes 46 | self.is_ulb = is_ulb 47 | self.onehot = onehot 48 | 49 | self.transform = transform 50 | if self.is_ulb: 51 | if strong_transform is None: 52 | self.strong_transform = copy.deepcopy(transform) 53 | self.strong_transform.transforms.insert(0, RandAugment(3, 5)) 54 | else: 55 | self.strong_transform = strong_transform 56 | 57 | def __getitem__(self, idx): 58 | """ 59 | If strong augmentation is not used, 60 | return weak_augment_image, target 61 | else: 62 | return weak_augment_image, strong_augment_image, target 63 | """ 64 | 65 | # set idx-th target 66 | if self.targets is None: 67 | target = None 68 | else: 69 | target_ = self.targets[idx] 70 | target = target_ if not self.onehot else get_onehot(self.num_classes, target_) 71 | 72 | # set augmented images 73 | 74 | img = self.data[idx] 75 | if self.transform is None: 76 | return transforms.ToTensor()(img), target 77 | else: 78 | if isinstance(img, np.ndarray): 79 | img = Image.fromarray(img) 80 | img_w = self.transform(img) 81 | if not self.is_ulb: 82 | return idx, img_w, target 83 | else: 84 | if self.alg == 'fixmatch': 85 | return idx, img_w, self.strong_transform(img) 86 | elif self.alg == 'flexmatch': 87 | return idx, img_w, self.strong_transform(img) 88 | elif self.alg == 'softmatch' or self.alg == 'freematch' or self.alg == 'freematch_entropy': 89 | return idx, img_w, self.strong_transform(img) 90 | elif self.alg == 'pimodel': 91 | return idx, img_w, self.transform(img) 92 | elif self.alg == 'pseudolabel': 93 | return idx, img_w 94 | elif self.alg == 'vat': 95 | return idx, img_w 96 | elif self.alg == 'meanteacher': 97 | return idx, img_w, self.transform(img) 98 | elif self.alg == 'uda': 99 | return idx, img_w, self.strong_transform(img) 100 | elif self.alg == 'mixmatch': 101 | return idx, img_w, self.transform(img) 102 | elif self.alg == 'remixmatch': 103 | rotate_v_list = [0, 90, 180, 270] 104 | rotate_v1 = np.random.choice(rotate_v_list, 1).item() 105 | img_s1 = self.strong_transform(img) 106 | img_s1_rot = torchvision.transforms.functional.rotate(img_s1, rotate_v1) 107 | img_s2 = self.strong_transform(img) 108 | return idx, img_w, img_s1, img_s2, img_s1_rot, rotate_v_list.index(rotate_v1) 109 | elif self.alg == 'fullysupervised': 110 | return idx 111 | 112 | def __len__(self): 113 | return len(self.data) 114 | -------------------------------------------------------------------------------- /libs/timm.py: -------------------------------------------------------------------------------- 1 | # code from timm 0.3.2 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import warnings 6 | 7 | 8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | with torch.no_grad(): 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | Args: 53 | tensor: an n-dimensional `torch.Tensor` 54 | mean: the mean of the normal distribution 55 | std: the standard deviation of the normal distribution 56 | a: the minimum cutoff value 57 | b: the maximum cutoff value 58 | Examples: 59 | >>> w = torch.empty(3, 5) 60 | >>> nn.init.trunc_normal_(w) 61 | """ 62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 63 | 64 | 65 | def drop_path(x, drop_prob: float = 0., training: bool = False): 66 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 67 | 68 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 69 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 70 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 71 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 72 | 'survival rate' as the argument. 73 | 74 | """ 75 | if drop_prob == 0. or not training: 76 | return x 77 | keep_prob = 1 - drop_prob 78 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 79 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 80 | random_tensor.floor_() # binarize 81 | output = x.div(keep_prob) * random_tensor 82 | return output 83 | 84 | 85 | class DropPath(nn.Module): 86 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 87 | """ 88 | def __init__(self, drop_prob=None): 89 | super(DropPath, self).__init__() 90 | self.drop_prob = drop_prob 91 | 92 | def forward(self, x): 93 | return drop_path(x, self.drop_prob, self.training) 94 | 95 | 96 | class Mlp(nn.Module): 97 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 98 | super().__init__() 99 | out_features = out_features or in_features 100 | hidden_features = hidden_features or in_features 101 | self.fc1 = nn.Linear(in_features, hidden_features) 102 | self.act = act_layer() 103 | self.fc2 = nn.Linear(hidden_features, out_features) 104 | self.drop = nn.Dropout(drop) 105 | 106 | def forward(self, x): 107 | x = self.fc1(x) 108 | x = self.act(x) 109 | x = self.drop(x) 110 | x = self.fc2(x) 111 | x = self.drop(x) 112 | return x 113 | -------------------------------------------------------------------------------- /grid_sample.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | from torch import multiprocessing as mp 4 | import accelerate 5 | import utils 6 | import sde 7 | from datasets import get_dataset 8 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 9 | from absl import logging 10 | import einops 11 | from torchvision.utils import save_image, make_grid 12 | 13 | 14 | def evaluate(config): 15 | if config.get('benchmark', False): 16 | torch.backends.cudnn.benchmark = True 17 | torch.backends.cudnn.deterministic = False 18 | 19 | mp.set_start_method('spawn') 20 | accelerator = accelerate.Accelerator() 21 | device = accelerator.device 22 | accelerate.utils.set_seed(config.seed, device_specific=True) 23 | 24 | config.mixed_precision = accelerator.mixed_precision 25 | config = ml_collections.FrozenConfigDict(config) 26 | utils.set_logger(log_level='info') 27 | 28 | dataset = get_dataset(**config.dataset) 29 | 30 | nnet = utils.get_nnet(**config.nnet) 31 | nnet = accelerator.prepare(nnet) 32 | logging.info(f'load nnet from {config.nnet_path}') 33 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 34 | nnet.eval() 35 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 36 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 37 | def cfg_nnet(x, timesteps, y): 38 | _cond = nnet(x, timesteps, y=y) 39 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 40 | return _cond + config.sample.scale * (_cond - _uncond) 41 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE()) 42 | else: 43 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE()) 44 | 45 | logging.info(config.sample) 46 | logging.info(f'mode={config.train.mode}, mixed_precision={config.mixed_precision}') 47 | 48 | def sample_fn(_x_init, _kwargs): 49 | if config.sample.algorithm == 'euler_maruyama_sde': 50 | rsde = sde.ReverseSDE(score_model) 51 | _samples = sde.euler_maruyama(rsde, _x_init, config.sample.sample_steps, 52 | verbose=accelerator.is_main_process, **_kwargs) 53 | elif config.sample.algorithm == 'euler_maruyama_ode': 54 | rsde = sde.ODE(score_model) 55 | _samples = sde.euler_maruyama(rsde, _x_init, config.sample.sample_steps, 56 | verbose=accelerator.is_main_process, **_kwargs) 57 | elif config.sample.algorithm == 'dpm_solver': 58 | noise_schedule = NoiseScheduleVP(schedule='linear') 59 | model_fn = model_wrapper( 60 | score_model.noise_pred, 61 | noise_schedule, 62 | time_input_type='0', 63 | model_kwargs=_kwargs 64 | ) 65 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 66 | _samples = dpm_solver.sample( 67 | _x_init, 68 | steps=config.sample.sample_steps, 69 | eps=1e-4, 70 | adaptive_step_size=False, 71 | fast_version=True, 72 | ) 73 | else: 74 | raise NotImplementedError 75 | 76 | return _samples 77 | 78 | if config.train.mode == 'uncond': 79 | x_init = torch.randn(100, *dataset.data_shape, device=device) 80 | kwargs = dict() 81 | idx = 0 82 | samples = [] 83 | for _batch_size in utils.amortize(len(x_init), config.sample.mini_batch_size): 84 | samples.append(sample_fn(x_init[idx: idx + _batch_size], kwargs)) 85 | idx += _batch_size 86 | 87 | elif config.train.mode == 'cond': 88 | x_init = torch.randn(config.nnet.num_classes * 10, *dataset.data_shape, device=device) 89 | y = einops.repeat(torch.arange(config.nnet.num_classes, device=device), 'nrow -> (nrow ncol)', ncol=10) 90 | idx = 0 91 | samples = [] 92 | for _batch_size in utils.amortize(len(x_init), config.sample.mini_batch_size): 93 | samples.append(sample_fn(x_init[idx: idx + _batch_size], dict(y=y[idx: idx + _batch_size]))) 94 | idx += _batch_size 95 | 96 | else: 97 | raise NotImplementedError 98 | 99 | samples = torch.cat(samples, dim=0) 100 | samples = dataset.unpreprocess(samples) 101 | save_image(make_grid(samples, 10), config.output_path) 102 | 103 | 104 | 105 | from absl import flags 106 | from absl import app 107 | from ml_collections import config_flags 108 | 109 | 110 | FLAGS = flags.FLAGS 111 | config_flags.DEFINE_config_file("config", None, "Training configuration.", lock_config=False) 112 | flags.mark_flags_as_required(["config"]) 113 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 114 | flags.DEFINE_string("output_path", None, "The path to output samples.") 115 | 116 | 117 | def main(argv): 118 | config = FLAGS.config 119 | config.nnet_path = FLAGS.nnet_path 120 | config.output_path = FLAGS.output_path 121 | evaluate(config) 122 | 123 | 124 | if __name__ == "__main__": 125 | app.run(main) 126 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/datasets/augmentation/randaugment.py: -------------------------------------------------------------------------------- 1 | # copyright: https://github.com/ildoonet/pytorch-randaugment 2 | # code in this file is adpated from rpmcruz/autoaugment 3 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 4 | # This code is modified version of one of ildoonet, for randaugmentation of fixmatch. 5 | 6 | import random 7 | 8 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from PIL import Image 13 | 14 | 15 | def AutoContrast(img, _): 16 | return PIL.ImageOps.autocontrast(img) 17 | 18 | 19 | def Brightness(img, v): 20 | assert v >= 0.0 21 | return PIL.ImageEnhance.Brightness(img).enhance(v) 22 | 23 | 24 | def Color(img, v): 25 | assert v >= 0.0 26 | return PIL.ImageEnhance.Color(img).enhance(v) 27 | 28 | 29 | def Contrast(img, v): 30 | assert v >= 0.0 31 | return PIL.ImageEnhance.Contrast(img).enhance(v) 32 | 33 | 34 | def Equalize(img, _): 35 | return PIL.ImageOps.equalize(img) 36 | 37 | 38 | def Invert(img, _): 39 | return PIL.ImageOps.invert(img) 40 | 41 | 42 | def Identity(img, v): 43 | return img 44 | 45 | 46 | def Posterize(img, v): # [4, 8] 47 | v = int(v) 48 | v = max(1, v) 49 | return PIL.ImageOps.posterize(img, v) 50 | 51 | 52 | def Rotate(img, v): # [-30, 30] 53 | #assert -30 <= v <= 30 54 | #if random.random() > 0.5: 55 | # v = -v 56 | return img.rotate(v) 57 | 58 | 59 | 60 | def Sharpness(img, v): # [0.1,1.9] 61 | assert v >= 0.0 62 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 63 | 64 | 65 | def ShearX(img, v): # [-0.3, 0.3] 66 | #assert -0.3 <= v <= 0.3 67 | #if random.random() > 0.5: 68 | # v = -v 69 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 70 | 71 | 72 | def ShearY(img, v): # [-0.3, 0.3] 73 | #assert -0.3 <= v <= 0.3 74 | #if random.random() > 0.5: 75 | # v = -v 76 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 77 | 78 | 79 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 80 | #assert -0.3 <= v <= 0.3 81 | #if random.random() > 0.5: 82 | # v = -v 83 | v = v * img.size[0] 84 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 85 | 86 | 87 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 88 | #assert v >= 0.0 89 | #if random.random() > 0.5: 90 | # v = -v 91 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 92 | 93 | 94 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 95 | #assert -0.3 <= v <= 0.3 96 | #if random.random() > 0.5: 97 | # v = -v 98 | v = v * img.size[1] 99 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 100 | 101 | 102 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 103 | #assert 0 <= v 104 | #if random.random() > 0.5: 105 | # v = -v 106 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 107 | 108 | 109 | def Solarize(img, v): # [0, 256] 110 | assert 0 <= v <= 256 111 | return PIL.ImageOps.solarize(img, v) 112 | 113 | 114 | def Cutout(img, v): #[0, 60] => percentage: [0, 0.2] => change to [0, 0.5] 115 | assert 0.0 <= v <= 0.5 116 | if v <= 0.: 117 | return img 118 | 119 | v = v * img.size[0] 120 | return CutoutAbs(img, v) 121 | 122 | 123 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 124 | # assert 0 <= v <= 20 125 | if v < 0: 126 | return img 127 | w, h = img.size 128 | x0 = np.random.uniform(w) 129 | y0 = np.random.uniform(h) 130 | 131 | x0 = int(max(0, x0 - v / 2.)) 132 | y0 = int(max(0, y0 - v / 2.)) 133 | x1 = min(w, x0 + v) 134 | y1 = min(h, y0 + v) 135 | 136 | xy = (x0, y0, x1, y1) 137 | color = (125, 123, 114) 138 | # color = (0, 0, 0) 139 | img = img.copy() 140 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 141 | return img 142 | 143 | 144 | def augment_list(): 145 | l = [ 146 | (AutoContrast, 0, 1), 147 | (Brightness, 0.05, 0.95), 148 | (Color, 0.05, 0.95), 149 | (Contrast, 0.05, 0.95), 150 | (Equalize, 0, 1), 151 | (Identity, 0, 1), 152 | (Posterize, 4, 8), 153 | (Rotate, -30, 30), 154 | (Sharpness, 0.05, 0.95), 155 | (ShearX, -0.3, 0.3), 156 | (ShearY, -0.3, 0.3), 157 | (Solarize, 0, 256), 158 | (TranslateX, -0.3, 0.3), 159 | (TranslateY, -0.3, 0.3) 160 | ] 161 | return l 162 | 163 | 164 | class RandAugment: 165 | def __init__(self, n, m): 166 | self.n = n 167 | self.m = m # [0, 30] in fixmatch, deprecated. 168 | self.augment_list = augment_list() 169 | 170 | 171 | def __call__(self, img): 172 | ops = random.choices(self.augment_list, k=self.n) 173 | for op, min_val, max_val in ops: 174 | val = min_val + float(max_val - min_val)*random.random() 175 | img = op(img, val) 176 | cutout_val = random.random() * 0.5 177 | img = Cutout(img, cutout_val) #for fixmatch 178 | return img 179 | 180 | 181 | if __name__ == '__main__': 182 | # randaug = RandAugment(3,5) 183 | # print(randaug) 184 | # for item in randaug.augment_list: 185 | # print(item) 186 | import os 187 | 188 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 189 | img = PIL.Image.open('./u.jpg') 190 | randaug = RandAugment(3,6) 191 | img = randaug(img) 192 | import matplotlib 193 | from matplotlib import pyplot as plt 194 | plt.imshow(img) 195 | plt.show() 196 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision import datasets 4 | from torch.utils.data import sampler, DataLoader 5 | from torch.utils.data.sampler import BatchSampler 6 | import torch.distributed as dist 7 | import numpy as np 8 | import json 9 | import os 10 | 11 | from datasets.DistributedProxySampler import DistributedProxySampler 12 | 13 | 14 | def split_ssl_data(args, data, target, num_labels, num_classes, index=None, include_lb_to_ulb=True): 15 | """ 16 | data & target is splitted into labeled and unlabeld data. 17 | 18 | Args 19 | index: If np.array of index is given, select the data[index], target[index] as labeled samples. 20 | include_lb_to_ulb: If True, labeled data is also included in unlabeld data 21 | """ 22 | data, target = np.array(data), np.array(target) 23 | lb_data, lbs, lb_idx, = sample_labeled_data(args, data, target, num_labels, num_classes, index) 24 | ulb_idx = np.array(sorted(list(set(range(len(data))) - set(lb_idx)))) # unlabeled_data index of data 25 | if include_lb_to_ulb: 26 | return lb_data, lbs, data, target 27 | else: 28 | return lb_data, lbs, data[ulb_idx], target[ulb_idx] 29 | 30 | 31 | def sample_labeled_data(args, data, target, 32 | num_labels, num_classes, 33 | index=None, name=None): 34 | ''' 35 | samples for labeled data 36 | (sampling with balanced ratio over classes) 37 | ''' 38 | assert num_labels % num_classes == 0 39 | if not index is None: 40 | index = np.array(index, dtype=np.int32) 41 | return data[index], target[index], index 42 | 43 | dump_path = os.path.join(args.save_dir, args.save_name, 'sampled_label_idx.npy') 44 | 45 | if os.path.exists(dump_path): 46 | lb_idx = np.load(dump_path) 47 | lb_data = data[lb_idx] 48 | lbs = target[lb_idx] 49 | return lb_data, lbs, lb_idx 50 | 51 | samples_per_class = int(num_labels / num_classes) 52 | 53 | lb_data = [] 54 | lbs = [] 55 | lb_idx = [] 56 | for c in range(num_classes): 57 | idx = np.where(target == c)[0] 58 | idx = np.random.choice(idx, samples_per_class, False) 59 | lb_idx.extend(idx) 60 | 61 | lb_data.extend(data[idx]) 62 | lbs.extend(target[idx]) 63 | 64 | np.save(dump_path, np.array(lb_idx)) 65 | 66 | return np.array(lb_data), np.array(lbs), np.array(lb_idx) 67 | 68 | 69 | def get_sampler_by_name(name): 70 | ''' 71 | get sampler in torch.utils.data.sampler by name 72 | ''' 73 | sampler_name_list = sorted(name for name in torch.utils.data.sampler.__dict__ 74 | if not name.startswith('_') and callable(sampler.__dict__[name])) 75 | try: 76 | if name == 'DistributedSampler': 77 | return torch.utils.data.distributed.DistributedSampler 78 | else: 79 | return getattr(torch.utils.data.sampler, name) 80 | except Exception as e: 81 | print(repr(e)) 82 | print('[!] select sampler in:\t', sampler_name_list) 83 | 84 | 85 | def get_data_loader(dset, 86 | batch_size=None, 87 | shuffle=False, 88 | num_workers=4, 89 | pin_memory=False, 90 | data_sampler=None, 91 | replacement=True, 92 | num_epochs=None, 93 | num_iters=None, 94 | generator=None, 95 | drop_last=True, 96 | distributed=False): 97 | """ 98 | get_data_loader returns torch.utils.data.DataLoader for a Dataset. 99 | All arguments are comparable with those of pytorch DataLoader. 100 | However, if distributed, DistributedProxySampler, which is a wrapper of data_sampler, is used. 101 | 102 | Args 103 | num_epochs: total batch -> (# of batches in dset) * num_epochs 104 | num_iters: total batch -> num_iters 105 | """ 106 | 107 | assert batch_size is not None 108 | 109 | if data_sampler is None: 110 | return DataLoader(dset, batch_size=batch_size, shuffle=shuffle, 111 | num_workers=num_workers, pin_memory=pin_memory) 112 | 113 | else: 114 | if isinstance(data_sampler, str): 115 | data_sampler = get_sampler_by_name(data_sampler) 116 | 117 | if distributed: 118 | assert dist.is_available() 119 | num_replicas = dist.get_world_size() 120 | else: 121 | num_replicas = 1 122 | 123 | if (num_epochs is not None) and (num_iters is None): 124 | num_samples = len(dset) * num_epochs 125 | elif (num_epochs is None) and (num_iters is not None): 126 | num_samples = batch_size * num_iters * num_replicas 127 | else: 128 | num_samples = len(dset) 129 | 130 | if data_sampler.__name__ == 'RandomSampler': 131 | data_sampler = data_sampler(dset, replacement, num_samples, generator) 132 | else: 133 | raise RuntimeError(f"{data_sampler.__name__} is not implemented.") 134 | 135 | if distributed: 136 | ''' 137 | Different with DistributedSampler, 138 | the DistribuedProxySampler does not shuffle the data (just wrapper for dist). 139 | ''' 140 | data_sampler = DistributedProxySampler(data_sampler) 141 | 142 | batch_sampler = BatchSampler(data_sampler, batch_size, drop_last) 143 | return DataLoader(dset, batch_sampler=batch_sampler, 144 | num_workers=num_workers, pin_memory=pin_memory) 145 | 146 | 147 | def get_onehot(num_classes, idx): 148 | onehot = np.zeros([num_classes], dtype=np.float32) 149 | onehot[idx] += 1.0 150 | return onehot 151 | -------------------------------------------------------------------------------- /cifar10_experiment/README.md: -------------------------------------------------------------------------------- 1 | ## Diffusion Models and Semi-Supervised Learners Benefit Mutually with Few Labels
Official PyTorch implementation of the NeurIPS 2023 paper 2 | 3 | **Diffusion Models and Semi-Supervised Learners Benefit Mutually with Few Labels**
4 | You, Zebin and Zhong, Yong and Bao, Fan and Sun, Jiacheng and Li, Chongxuan and Zhu, Jun 5 |
https://arxiv.org/abs/2302.10586
6 | 7 | Abstract: *In an effort to further advance semi-supervised generative and classification tasks, we propose a simple yet effective training strategy called dual pseudo training (DPT), built upon strong semi-supervised learners and diffusion models. DPT operates in three stages: training a classifier on partially labeled data to predict pseudo-labels; training a conditional generative model using these pseudo-labels to generate pseudo images; and retraining the classifier with a mix of real and pseudo images. Empirically, DPT consistently achieves SOTA performance of semi-supervised generation and classification across various settings. In particular, with one or two labels per class, DPT achieves a Fréchet Inception Distance (FID) score of 3.08 or 2.52 on ImageNet 256x256. Besides, DPT outperforms competitive semi-supervised baselines substantially on ImageNet classification tasks, achieving top-1 accuracies of 59.0 (+2.8), 69.5 (+3.0), and 74.4 (+2.0) with one, two, or five labels per class, respectively. Notably, our results demonstrate that diffusion can generate realistic images with only a few labels (e.g., <0.1%) and generative augmentation remains viable for semi-supervised classification.* 8 | 9 | ## Requirements 10 | ```.bash 11 | conda create -n dpt_cifar python==3.9 12 | conda activate dpt_cifar 13 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 14 | pip install tensorboard 15 | pip install accelerate 16 | pip install --upgrade scikit-learn 17 | pip install pandas 18 | pip install click 19 | pip install Pillow==9.3.0 20 | ``` 21 | 22 | ## Get Started 23 | ### Stage 1: Training the Classifier and Generating Pseudo-Labels 24 | To begin, navigate to the TorchSSL directory and run the following commands to train a classifier, generate pseudo-labels, and prepare the dataset required for EDM. 25 | ```.bash 26 | # Stage 1: Training the Classifier and Generating Pseudo-Labels 27 | # After training, a high-quality classifier will be saved in the directory: saved_models/freematch_cifar10_40_1/, We can then use this classifier to obtain pseudo-labels 28 | python freematch.py --c config/freematch_cifar10_40_1.yaml 29 | 30 | # Generating Pseudo-Labels 31 | python get_labels.py --load_path ./saved_models/freematch_cifar10_40_1/model_best.pth --save_path ./saved_models/freematch_cifar10_40_1/ >> saved_models/freematch_cifar10_40_1/get_labels.log 32 | 33 | # Using pseudo-labels, we create a new dataset, which will be used to train our generative model 34 | # Creating a Pseudo-Dataset 35 | python pseudo_dataset.py --data_path ./data/cifar-10-batches-py --save_path ./saved_models/freematch_cifar10_40_1 36 | ``` 37 | The purpose of the above steps is to generate a new dataset, similar to CIFAR-10 but with pseudo-labels instead of true labels. 38 | 39 | ```.bash 40 | # Preparing Pseudo Dataset for Training EDM 41 | python dataset_tool.py --source=./saved_models/freematch_cifar10_40_1/cifar-10-python.tar.gz --dest=./saved_models/freematch_cifar10_40_1/freematch-40-seed1-cifar10-32x32.zip 42 | ``` 43 | This step converts the dataset into the format required by EDM. 44 | 45 | ### Stage 2: Training the Generative Model and Generating Pseudo-Images 46 | Following the dataset preparation, navigate to the EDM directory and execute the following commands for the second stage: 47 | ```.bash 48 | # Stage 2: Training the Generative Model 49 | torchrun --standalone --nproc_per_node=4 train.py --outdir=training-runs \ 50 | --data=../TorchSSL/saved_models/freematch_cifar10_40_1/freematch-40-seed1-cifar10-32x32.zip --cond=1 --arch=ddpmpp --batch-gpu=32 51 | ``` 52 | In the above command: 53 | - The **`--data`** flag specifies the dataset created earlier 54 | - The **`--cond`** indicates conditional generation 55 | - The **`--arch=ddpmpp`** signifies the use of the ddpmpp model 56 | - The **`--batch-gpu=32`** sets the batch size per GPU to 32 57 | This command is used for training the generative model in the second stage of the process. 58 | 59 | Assuming the EDM training has been completed, and the latest generative model is saved as **'./training-runs/00000-freematch-40-seed1-cifar10-32x32-cond-ddpmpp-edm-gpus4-batch512-fp32/network-snapshot-latest.pkl'**, you can use the following command to generate pseudo-images: 60 | ```.bash 61 | bash generate.sh 1000 ../TorchSSL/data-aug/nums_1000 ./training-runs/00000-freematch-40-seed1-cifar10-32x32-cond-ddpmpp-edm-gpus4-batch512-fp32/network-snapshot-latest.pkl 4 62 | ``` 63 | In the above command: 64 | - The first argument, **`1000`**, represents generating 1001 samples for each class, ranging from 0 to 1000. 65 | - The second argument, **`../TorchSSL/data-aug/nums_1000`**, represents the path where the generated samples will be saved. 66 | - The third argument, **`./training-runs/00000-freematch-40-seed1-cifar10-32x32-cond-ddpmpp-edm-gpus4-batch512-fp32/network-snapshot-latest.pkl`**, is the path to the generative model you want to use for generation. 67 | - The fourth argument, **`4`**, specifies the number of GPUs to use for the generation process. 68 | 69 | ### Stage 3: Training the Classifier with Pseudo-Images 70 | Finally, we can train the classifier with the pseudo-images generated in the previous step. To do so, navigate to the TorchSSL directory and execute the following commands: 71 | ```.bash 72 | # stage3 73 | python freematch.py --c config/freematch_cifar10_40_1_aug1000.yaml --aug_path=./data-aug/nums_1000 74 | ``` 75 | 76 | ## Question 77 | If you have any questions, please feel free to contact us via email: zebin@ruc.edu.cn 78 | 79 | ## Citation 80 | 81 | ``` 82 | @inproceedings{you2023diffusion, 83 | author = {You, Zebin and Zhong, Yong and Bao, Fan and Sun, Jiacheng and Li, Chongxuan and Zhu, Jun}, 84 | title = {Diffusion models and semi-supervised learners benefit mutually with few labels}, 85 | booktitle = {Proc. NeurIPS}, 86 | year = {2023} 87 | } 88 | ``` 89 | 90 | ## Acknowledgments 91 | 92 | We would like to express our gratitude to the remarkable projects EDM (https://github.com/NVlabs/edm) and TorchSSL (https://github.com/TorchSSL/TorchSSL). Our work is built upon their contributions. 93 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | from torchvision.utils import save_image 7 | from absl import logging 8 | 9 | 10 | def set_logger(log_level='info', fname=None): 11 | import logging as _logging 12 | handler = logging.get_absl_handler() 13 | formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s') 14 | handler.setFormatter(formatter) 15 | logging.set_verbosity(log_level) 16 | if fname is not None: 17 | handler = _logging.FileHandler(fname) 18 | handler.setFormatter(formatter) 19 | logging.get_absl_logger().addHandler(handler) 20 | 21 | 22 | def dct2str(dct): 23 | return str({k: f'{v:.6g}' for k, v in dct.items()}) 24 | 25 | 26 | def get_nnet(name, **kwargs): 27 | if name == 'uvit': 28 | from libs.uvit import UViT 29 | return UViT(**kwargs) 30 | elif name == 'uvit_t2i': 31 | from libs.uvit_t2i import UViT 32 | return UViT(**kwargs) 33 | else: 34 | raise NotImplementedError(name) 35 | 36 | 37 | def set_seed(seed: int): 38 | if seed is not None: 39 | torch.manual_seed(seed) 40 | np.random.seed(seed) 41 | 42 | 43 | def get_optimizer(params, name, **kwargs): 44 | if name == 'adam': 45 | from torch.optim import Adam 46 | return Adam(params, **kwargs) 47 | elif name == 'adamw': 48 | from torch.optim import AdamW 49 | return AdamW(params, **kwargs) 50 | else: 51 | raise NotImplementedError(name) 52 | 53 | 54 | def customized_lr_scheduler(optimizer, warmup_steps=-1): 55 | from torch.optim.lr_scheduler import LambdaLR 56 | def fn(step): 57 | if warmup_steps > 0: 58 | return min(step / warmup_steps, 1) 59 | else: 60 | return 1 61 | return LambdaLR(optimizer, fn) 62 | 63 | 64 | def get_lr_scheduler(optimizer, name, **kwargs): 65 | if name == 'customized': 66 | return customized_lr_scheduler(optimizer, **kwargs) 67 | elif name == 'cosine': 68 | from torch.optim.lr_scheduler import CosineAnnealingLR 69 | return CosineAnnealingLR(optimizer, **kwargs) 70 | else: 71 | raise NotImplementedError(name) 72 | 73 | 74 | def ema(model_dest: nn.Module, model_src: nn.Module, rate): 75 | param_dict_src = dict(model_src.named_parameters()) 76 | for p_name, p_dest in model_dest.named_parameters(): 77 | p_src = param_dict_src[p_name] 78 | assert p_src is not p_dest 79 | p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) 80 | 81 | 82 | class TrainState(object): 83 | def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None): 84 | self.optimizer = optimizer 85 | self.lr_scheduler = lr_scheduler 86 | self.step = step 87 | self.nnet = nnet 88 | self.nnet_ema = nnet_ema 89 | 90 | def ema_update(self, rate=0.9999): 91 | if self.nnet_ema is not None: 92 | ema(self.nnet_ema, self.nnet, rate) 93 | 94 | def save(self, path): 95 | os.makedirs(path, exist_ok=True) 96 | torch.save(self.step, os.path.join(path, 'step.pth')) 97 | for key, val in self.__dict__.items(): 98 | if key != 'step' and val is not None: 99 | torch.save(val.state_dict(), os.path.join(path, f'{key}.pth')) 100 | 101 | def load(self, path): 102 | logging.info(f'load from {path}') 103 | self.step = torch.load(os.path.join(path, 'step.pth')) 104 | for key, val in self.__dict__.items(): 105 | if key != 'step' and val is not None: 106 | val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu')) 107 | 108 | def resume(self, ckpt_root, step=None): 109 | if not os.path.exists(ckpt_root): 110 | return 111 | if step is None: 112 | ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root))) 113 | if not ckpts: 114 | return 115 | steps = map(lambda x: int(x.split(".")[0]), ckpts) 116 | step = max(steps) 117 | ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt') 118 | logging.info(f'resume from {ckpt_path}') 119 | self.load(ckpt_path) 120 | 121 | def to(self, device): 122 | for key, val in self.__dict__.items(): 123 | if isinstance(val, nn.Module): 124 | val.to(device) 125 | 126 | 127 | def cnt_params(model): 128 | return sum(param.numel() for param in model.parameters()) 129 | 130 | 131 | def initialize_train_state(config, device): 132 | params = [] 133 | 134 | nnet = get_nnet(**config.nnet) 135 | params += nnet.parameters() 136 | nnet_ema = get_nnet(**config.nnet) 137 | nnet_ema.eval() 138 | logging.info(f'nnet has {cnt_params(nnet)} parameters') 139 | 140 | optimizer = get_optimizer(params, **config.optimizer) 141 | lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler) 142 | 143 | train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0, 144 | nnet=nnet, nnet_ema=nnet_ema) 145 | train_state.ema_update(0) 146 | train_state.to(device) 147 | return train_state 148 | 149 | 150 | def amortize(n_samples, batch_size): 151 | k = n_samples // batch_size 152 | r = n_samples % batch_size 153 | return k * [batch_size] if r == 0 else k * [batch_size] + [r] 154 | 155 | 156 | def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None): 157 | os.makedirs(path, exist_ok=True) 158 | idx = 0 159 | batch_size = mini_batch_size * accelerator.num_processes 160 | 161 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): 162 | samples = unpreprocess_fn(sample_fn(mini_batch_size)) 163 | samples = accelerator.gather(samples.contiguous())[:_batch_size] 164 | if accelerator.is_main_process: 165 | for sample in samples: 166 | save_image(sample, os.path.join(path, f"{idx}.png")) 167 | idx += 1 168 | 169 | 170 | def grad_norm(model): 171 | total_norm = 0. 172 | for p in model.parameters(): 173 | param_norm = p.grad.data.norm(2) 174 | total_norm += param_norm.item() ** 2 175 | total_norm = total_norm ** (1. / 2) 176 | return total_norm 177 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/custom_writer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pathlib import Path 3 | from typing import Sequence, Union, Tuple 4 | import numpy as np 5 | import json 6 | import torch 7 | import os 8 | 9 | class CustomWriter(object): 10 | ''' 11 | Custom Writer for training record. 12 | Parameters: 13 | ----------- 14 | log_dir : pathlib.Path or str, path to save logs. 15 | enabled : bool, whether to enable tensorboard writer. 16 | ''' 17 | def __init__(self, log_dir, enabled=True): 18 | self.writer = None 19 | self.selected_module = '' 20 | 21 | if enabled: 22 | self.log_dir = str(log_dir) 23 | self.stats = {} 24 | if not os.path.exists(self.log_dir): 25 | os.makedirs(self.log_dir, exist_ok=True) 26 | 27 | # Attributes to record 28 | self.epoch = 0 29 | self.mode = None 30 | self.timer = datetime.datetime.now() 31 | self.tb_writer_funcs = { 32 | 'add_scalar', 'add_scalars', 33 | 'add_image', 'add_images', 34 | 'add_figure', 35 | 'add_audio', 36 | 'add_text', 37 | 'add_histogram', 38 | 'add_pr_curve', 39 | #'add_embedding', # TODO: problem with add_embedding 40 | } 41 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} # TODO : Test these two funcs. 42 | 43 | def dump_stats(self): 44 | with open(f"{self.log_dir}/log", "w") as f: 45 | json.dump(self.stats, f, 46 | indent=4, 47 | ensure_ascii=False, 48 | separators=(",", ": "), 49 | ) 50 | 51 | def set_epoch(self, epoch, mode): 52 | ''' 53 | Execute this function to update the step attribute and compute the cost time of one epoch in seconds. 54 | Recommend to run this function every step. 55 | This function MUST be executed before other custom writer functions. 56 | Parameters: 57 | ------------ 58 | step : int, step number. 59 | mode : str, 'train' or 'valid' 60 | ''' 61 | if epoch == 0: 62 | self.timer = datetime.datetime.now() 63 | elif epoch != self.epoch: 64 | duration = datetime.datetime.now() - self.timer 65 | second_per_epoch = duration.total_seconds() / (epoch - self.epoch) 66 | self.add_scalar(tag='second_per_epoch', data=second_per_epoch) 67 | self.epoch = epoch 68 | self.mode = mode 69 | 70 | def get_epoch(self) -> int: 71 | return self.epoch 72 | 73 | def get_keys(self, epoch: int = None) -> Tuple[str, ...]: 74 | """Returns keys1 e.g. train,eval.""" 75 | if epoch is None: 76 | epoch = self.get_epoch() 77 | return tuple(self.stats[epoch]) 78 | 79 | def get_keys2(self, key: str, epoch: int = None) -> Tuple[str, ...]: 80 | """Returns keys2 e.g. loss,acc.""" 81 | if epoch is None: 82 | epoch = self.get_epoch() 83 | d = self.stats[epoch][key] 84 | keys2 = tuple(k for k in d if k not in ("time", "total_count")) 85 | return keys2 86 | 87 | def plot_stats(self): 88 | self.matplotlib_plot(self.log_dir) 89 | 90 | def matplotlib_plot(self, output_dir: Union[str, Path]): 91 | """Plot stats using Matplotlib and save images.""" 92 | keys2 = set.union(*[set(self.get_keys2(k)) for k in self.get_keys()]) 93 | for key2 in keys2: 94 | keys = [k for k in self.get_keys() if key2 in self.get_keys2(k)] 95 | plt = self._plot_stats(keys, key2) 96 | p = Path(output_dir) / f"{key2}.png" 97 | p.parent.mkdir(parents=True, exist_ok=True) 98 | plt.savefig(p) 99 | 100 | def _plot_stats(self, keys: Sequence[str], key2: str): 101 | # str is also Sequence[str] 102 | if isinstance(keys, str): 103 | raise TypeError(f"Input as [{keys}]") 104 | 105 | import matplotlib 106 | 107 | matplotlib.use("agg") 108 | import matplotlib.pyplot as plt 109 | import matplotlib.ticker as ticker 110 | 111 | plt.clf() 112 | 113 | epochs = sorted(list(self.stats.keys())) 114 | for key in keys: 115 | y = [ 116 | self.stats[e][key][key2] 117 | if e in self.stats 118 | and key in self.stats[e] 119 | and key2 in self.stats[e][key] 120 | else np.nan 121 | for e in epochs 122 | ] 123 | assert len(epochs) == len(y), "Bug?" 124 | 125 | plt.plot(epochs, y, label=key2, marker="x") 126 | plt.legend() 127 | plt.title(f"iteration vs {key2}") 128 | # Force integer tick for x-axis 129 | plt.gca().get_xaxis().set_major_locator(ticker.MaxNLocator(integer=True)) 130 | plt.xlabel("iteration") 131 | plt.ylabel(key2) 132 | plt.grid() 133 | return plt 134 | 135 | def to_numpy(self, a): 136 | if isinstance(a, list): 137 | return np.array(a) 138 | for kind in [torch.Tensor, torch.nn.Parameter]: 139 | if isinstance(a, kind): 140 | if hasattr(a, 'detach'): 141 | a = a.detach() 142 | return a.cpu().numpy() 143 | return a 144 | 145 | def add_scalar(self, tag, data): 146 | data = self.to_numpy(data) 147 | data = float(data) 148 | self.stats.setdefault(self.epoch, {}).setdefault(self.mode, {})[tag] = data 149 | 150 | 151 | def __getattr__(self, name): 152 | if name in self.tb_writer_funcs: 153 | func = getattr(self, name, None) 154 | # Return a wrapper for all functions. 155 | def wrapper(tag, data, *args, **kwargs): 156 | if func is not None: 157 | if name not in self.tag_mode_exceptions: 158 | tag = f"{tag}/{self.mode}" 159 | func(tag, data, *args, global_step=self.step, **kwargs) 160 | 161 | return wrapper 162 | else: 163 | # default __getattr__ function to get other attributes. 164 | try: 165 | attr = object.__getattr__(name) 166 | except AttributeError: 167 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 168 | return attr 169 | 170 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/models/nets/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | momentum = 0.001 7 | 8 | 9 | def mish(x): 10 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)""" 11 | return x * torch.tanh(F.softplus(x)) 12 | 13 | 14 | class PSBatchNorm2d(nn.BatchNorm2d): 15 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)""" 16 | 17 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True): 18 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 19 | self.alpha = alpha 20 | 21 | def forward(self, x): 22 | return super().forward(x) + self.alpha 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False): 27 | super(BasicBlock, self).__init__() 28 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001, eps=0.001) 29 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=False) 30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 31 | padding=1, bias=True) 32 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001, eps=0.001) 33 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=False) 34 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 35 | padding=1, bias=True) 36 | self.drop_rate = drop_rate 37 | self.equalInOut = (in_planes == out_planes) 38 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 39 | padding=0, bias=True) or None 40 | self.activate_before_residual = activate_before_residual 41 | 42 | def forward(self, x): 43 | if not self.equalInOut and self.activate_before_residual == True: 44 | x = self.relu1(self.bn1(x)) 45 | else: 46 | out = self.relu1(self.bn1(x)) 47 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 48 | if self.drop_rate > 0: 49 | out = F.dropout(out, p=self.drop_rate, training=self.training) 50 | out = self.conv2(out) 51 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 52 | 53 | 54 | class NetworkBlock(nn.Module): 55 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False): 56 | super(NetworkBlock, self).__init__() 57 | self.layer = self._make_layer( 58 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual) 59 | 60 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual): 61 | layers = [] 62 | for i in range(int(nb_layers)): 63 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, 64 | i == 0 and stride or 1, drop_rate, activate_before_residual)) 65 | return nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | return self.layer(x) 69 | 70 | 71 | class WideResNet(nn.Module): 72 | def __init__(self, first_stride, num_classes, depth=28, widen_factor=2, drop_rate=0.0, is_remix=False): 73 | super(WideResNet, self).__init__() 74 | channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 75 | assert ((depth - 4) % 6 == 0) 76 | n = (depth - 4) / 6 77 | block = BasicBlock 78 | # 1st conv before any network block 79 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, 80 | padding=1, bias=True) 81 | # 1st block 82 | self.block1 = NetworkBlock( 83 | n, channels[0], channels[1], block, first_stride, drop_rate, activate_before_residual=True) 84 | # 2nd block 85 | self.block2 = NetworkBlock( 86 | n, channels[1], channels[2], block, 2, drop_rate) 87 | # 3rd block 88 | self.block3 = NetworkBlock( 89 | n, channels[2], channels[3], block, 2, drop_rate) 90 | # global average pooling and classifier 91 | self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001, eps=0.001) 92 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=False) 93 | self.fc = nn.Linear(channels[3], num_classes) 94 | self.channels = channels[3] 95 | 96 | # rot_classifier for Remix Match 97 | self.is_remix = is_remix 98 | if is_remix: 99 | self.rot_classifier = nn.Linear(self.channels, 4) 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 104 | elif isinstance(m, nn.BatchNorm2d): 105 | m.weight.data.fill_(1) 106 | m.bias.data.zero_() 107 | elif isinstance(m, nn.Linear): 108 | nn.init.xavier_normal_(m.weight.data) 109 | m.bias.data.zero_() 110 | 111 | def forward(self, x, ood_test=False): 112 | out = self.conv1(x) 113 | out = self.block1(out) 114 | out = self.block2(out) 115 | out = self.block3(out) 116 | out = self.relu(self.bn1(out)) 117 | out = F.adaptive_avg_pool2d(out, 1) 118 | out = out.view(-1, self.channels) 119 | output = self.fc(out) 120 | 121 | if ood_test: 122 | return output, out 123 | else: 124 | if self.is_remix: 125 | rot_output = self.rot_classifier(out) 126 | return output, rot_output 127 | else: 128 | return output 129 | 130 | 131 | class build_WideResNet: 132 | def __init__(self, first_stride=1, depth=28, widen_factor=2, bn_momentum=0.01, leaky_slope=0.0, dropRate=0.0, 133 | use_embed=False, is_remix=False): 134 | self.first_stride = first_stride 135 | self.depth = depth 136 | self.widen_factor = widen_factor 137 | self.bn_momentum = bn_momentum 138 | self.dropRate = dropRate 139 | self.leaky_slope = leaky_slope 140 | self.use_embed = use_embed 141 | self.is_remix = is_remix 142 | 143 | def build(self, num_classes): 144 | return WideResNet( 145 | first_stride=self.first_stride, 146 | depth=self.depth, 147 | num_classes=num_classes, 148 | widen_factor=self.widen_factor, 149 | drop_rate=self.dropRate, 150 | is_remix=self.is_remix, 151 | ) 152 | 153 | 154 | if __name__ == '__main__': 155 | wrn_builder = build_WideResNet(1, 10, 2, 0.01, 0.1, 0.5) 156 | wrn = wrn_builder.build(10) 157 | print(wrn) 158 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/models/nets/wrn_var.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | momentum = 0.001 7 | 8 | 9 | def mish(x): 10 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)""" 11 | return x * torch.tanh(F.softplus(x)) 12 | 13 | 14 | class PSBatchNorm2d(nn.BatchNorm2d): 15 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)""" 16 | 17 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True): 18 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 19 | self.alpha = alpha 20 | 21 | def forward(self, x): 22 | return super().forward(x) + self.alpha 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False): 27 | super(BasicBlock, self).__init__() 28 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001, eps=0.001) 29 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=False) 30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 31 | padding=1, bias=True) 32 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001, eps=0.001) 33 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=False) 34 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 35 | padding=1, bias=True) 36 | self.drop_rate = drop_rate 37 | self.equalInOut = (in_planes == out_planes) 38 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 39 | padding=0, bias=True) or None 40 | self.activate_before_residual = activate_before_residual 41 | 42 | def forward(self, x): 43 | if not self.equalInOut and self.activate_before_residual == True: 44 | x = self.relu1(self.bn1(x)) 45 | else: 46 | out = self.relu1(self.bn1(x)) 47 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 48 | if self.drop_rate > 0: 49 | out = F.dropout(out, p=self.drop_rate, training=self.training) 50 | out = self.conv2(out) 51 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 52 | 53 | 54 | class NetworkBlock(nn.Module): 55 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False): 56 | super(NetworkBlock, self).__init__() 57 | self.layer = self._make_layer( 58 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual) 59 | 60 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual): 61 | layers = [] 62 | for i in range(int(nb_layers)): 63 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, 64 | i == 0 and stride or 1, drop_rate, activate_before_residual)) 65 | return nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | return self.layer(x) 69 | 70 | 71 | class WideResNetVar(nn.Module): 72 | def __init__(self, first_stride, num_classes, depth=28, widen_factor=2, drop_rate=0.0, is_remix=False): 73 | super(WideResNetVar, self).__init__() 74 | channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor, 128 * widen_factor] 75 | assert ((depth - 4) % 6 == 0) 76 | n = (depth - 4) / 6 77 | block = BasicBlock 78 | # 1st conv before any network block 79 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, 80 | padding=1, bias=True) 81 | # 1st block 82 | self.block1 = NetworkBlock( 83 | n, channels[0], channels[1], block, first_stride, drop_rate, activate_before_residual=True) 84 | # 2nd block 85 | self.block2 = NetworkBlock( 86 | n, channels[1], channels[2], block, 2, drop_rate) 87 | # 3rd block 88 | self.block3 = NetworkBlock( 89 | n, channels[2], channels[3], block, 2, drop_rate) 90 | # 4th block 91 | self.block4 = NetworkBlock( 92 | n, channels[3], channels[4], block, 2, drop_rate) 93 | # global average pooling and classifier 94 | self.bn1 = nn.BatchNorm2d(channels[4], momentum=0.001, eps=0.001) 95 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=False) 96 | self.fc = nn.Linear(channels[4], num_classes) 97 | self.channels = channels[4] 98 | 99 | # rot_classifier for Remix Match 100 | self.is_remix = is_remix 101 | if is_remix: 102 | self.rot_classifier = nn.Linear(self.channels, 4) 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | elif isinstance(m, nn.Linear): 111 | nn.init.xavier_normal_(m.weight.data) 112 | m.bias.data.zero_() 113 | 114 | def forward(self, x, ood_test=False): 115 | out = self.conv1(x) 116 | out = self.block1(out) 117 | out = self.block2(out) 118 | out = self.block3(out) 119 | out = self.block4(out) 120 | out = self.relu(self.bn1(out)) 121 | out = F.adaptive_avg_pool2d(out, 1) 122 | out = out.view(-1, self.channels) 123 | output = self.fc(out) 124 | 125 | if ood_test: 126 | return output, out 127 | else: 128 | if self.is_remix: 129 | rot_output = self.rot_classifier(out) 130 | return output, rot_output 131 | else: 132 | return output 133 | 134 | 135 | class build_WideResNetVar: 136 | def __init__(self, first_stride=1, depth=28, widen_factor=2, bn_momentum=0.01, leaky_slope=0.0, dropRate=0.0, 137 | use_embed=False, is_remix=False): 138 | self.first_stride = first_stride 139 | self.depth = depth 140 | self.widen_factor = widen_factor 141 | self.bn_momentum = bn_momentum 142 | self.dropRate = dropRate 143 | self.leaky_slope = leaky_slope 144 | self.use_embed = use_embed 145 | self.is_remix = is_remix 146 | 147 | def build(self, num_classes): 148 | return WideResNetVar( 149 | first_stride=self.first_stride, 150 | depth=self.depth, 151 | num_classes=num_classes, 152 | widen_factor=self.widen_factor, 153 | drop_rate=self.dropRate, 154 | is_remix=self.is_remix, 155 | ) 156 | 157 | 158 | if __name__ == '__main__': 159 | wrn_builder = build_WideResNet(1, 10, 2, 0.01, 0.1, 0.5) 160 | wrn = wrn_builder.build(10) 161 | print(wrn) 162 | -------------------------------------------------------------------------------- /cifar10_experiment/edm/fid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Script for calculating Frechet Inception Distance (FID).""" 9 | 10 | import os 11 | import click 12 | import tqdm 13 | import pickle 14 | import numpy as np 15 | import scipy.linalg 16 | import torch 17 | import dnnlib 18 | from torch_utils import distributed as dist 19 | from training import dataset 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def calculate_inception_stats( 24 | image_path, num_expected=None, seed=0, max_batch_size=64, 25 | num_workers=3, prefetch_factor=2, device=torch.device('cuda'), 26 | ): 27 | # Rank 0 goes first. 28 | if dist.get_rank() != 0: 29 | torch.distributed.barrier() 30 | 31 | # Load Inception-v3 model. 32 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 33 | dist.print0('Loading Inception-v3 model...') 34 | #detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 35 | detector_url = 'assets/inception-2015-12-05.pkl' 36 | detector_kwargs = dict(return_features=True) 37 | feature_dim = 2048 38 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f: 39 | detector_net = pickle.load(f).to(device) 40 | 41 | # List images. 42 | dist.print0(f'Loading images from "{image_path}"...') 43 | dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed) 44 | if num_expected is not None and len(dataset_obj) < num_expected: 45 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}') 46 | if len(dataset_obj) < 2: 47 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics') 48 | 49 | # Other ranks follow. 50 | if dist.get_rank() == 0: 51 | torch.distributed.barrier() 52 | 53 | # Divide images into batches. 54 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 55 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches) 56 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 57 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor) 58 | 59 | # Accumulate statistics. 60 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...') 61 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device) 62 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device) 63 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)): 64 | torch.distributed.barrier() 65 | if images.shape[0] == 0: 66 | continue 67 | if images.shape[1] == 1: 68 | images = images.repeat([1, 3, 1, 1]) 69 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64) 70 | mu += features.sum(0) 71 | sigma += features.T @ features 72 | 73 | # Calculate grand totals. 74 | torch.distributed.all_reduce(mu) 75 | torch.distributed.all_reduce(sigma) 76 | mu /= len(dataset_obj) 77 | sigma -= mu.ger(mu) * len(dataset_obj) 78 | sigma /= len(dataset_obj) - 1 79 | return mu.cpu().numpy(), sigma.cpu().numpy() 80 | 81 | #---------------------------------------------------------------------------- 82 | 83 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref): 84 | m = np.square(mu - mu_ref).sum() 85 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False) 86 | fid = m + np.trace(sigma + sigma_ref - s * 2) 87 | return float(np.real(fid)) 88 | 89 | #---------------------------------------------------------------------------- 90 | 91 | @click.group() 92 | def main(): 93 | """Calculate Frechet Inception Distance (FID). 94 | 95 | Examples: 96 | 97 | \b 98 | # Generate 50000 images and save them as fid-tmp/*/*.png 99 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\ 100 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 101 | 102 | \b 103 | # Calculate FID 104 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\ 105 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 106 | 107 | \b 108 | # Compute dataset reference statistics 109 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz 110 | """ 111 | 112 | #---------------------------------------------------------------------------- 113 | 114 | @main.command() 115 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True) 116 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True) 117 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True) 118 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True) 119 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 120 | 121 | def calc(image_path, ref_path, num_expected, seed, batch): 122 | """Calculate FID for a given set of images.""" 123 | torch.multiprocessing.set_start_method('spawn') 124 | dist.init() 125 | 126 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...') 127 | ref = None 128 | if dist.get_rank() == 0: 129 | with dnnlib.util.open_url(ref_path) as f: 130 | ref = dict(np.load(f)) 131 | 132 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch) 133 | dist.print0('Calculating FID...') 134 | if dist.get_rank() == 0: 135 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma']) 136 | print(f'{fid:g}') 137 | torch.distributed.barrier() 138 | 139 | #---------------------------------------------------------------------------- 140 | 141 | @main.command() 142 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True) 143 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True) 144 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 145 | 146 | def ref(dataset_path, dest_path, batch): 147 | """Calculate dataset reference statistics needed by 'calc'.""" 148 | torch.multiprocessing.set_start_method('spawn') 149 | dist.init() 150 | 151 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch) 152 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...') 153 | if dist.get_rank() == 0: 154 | if os.path.dirname(dest_path): 155 | os.makedirs(os.path.dirname(dest_path), exist_ok=True) 156 | np.savez(dest_path, mu=mu, sigma=sigma) 157 | 158 | torch.distributed.barrier() 159 | dist.print0('Done.') 160 | 161 | #---------------------------------------------------------------------------- 162 | 163 | if __name__ == "__main__": 164 | main() 165 | 166 | #---------------------------------------------------------------------------- 167 | -------------------------------------------------------------------------------- /sample_ldm_all.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | import utils 7 | import sde 8 | import einops 9 | from datasets import get_dataset 10 | import tempfile 11 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 12 | from absl import logging 13 | import builtins 14 | import libs.autoencoder 15 | import torch.nn as nn 16 | import numpy as np 17 | import os 18 | from tqdm import tqdm 19 | from torchvision.utils import make_grid, save_image 20 | from absl import logging 21 | import pickle 22 | 23 | def evaluate(config): 24 | if config.get('benchmark', False): 25 | torch.backends.cudnn.benchmark = True 26 | torch.backends.cudnn.deterministic = False 27 | 28 | mp.set_start_method('spawn') 29 | accelerator = accelerate.Accelerator() 30 | device = accelerator.device 31 | accelerate.utils.set_seed(config.seed, device_specific=True) 32 | logging.info(f'Process {accelerator.process_index} using device: {device}') 33 | 34 | config.mixed_precision = accelerator.mixed_precision 35 | config = ml_collections.FrozenConfigDict(config) 36 | if accelerator.is_main_process: 37 | utils.set_logger(log_level='info', fname=config.output_path) 38 | else: 39 | utils.set_logger(log_level='error') 40 | builtins.print = lambda *args: None 41 | 42 | dataset = get_dataset(**config.dataset) 43 | 44 | nnet = utils.get_nnet(**config.nnet) 45 | nnet = accelerator.prepare(nnet) 46 | cluster_name = config.model_name + '-' + '-'.join(config.subset_path.split('/')).split('.txt')[0] 47 | 48 | nnet_path = f'{config.dpm_path}/{cluster_name}/{config.resolution}/ckpts/{config.train.n_steps}.ckpt/nnet_ema.pth' 49 | logging.info(f'load nnet from {nnet_path}') 50 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(nnet_path, map_location='cpu')) 51 | nnet.eval() 52 | 53 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 54 | autoencoder.to(device) 55 | 56 | @torch.cuda.amp.autocast() 57 | def encode(_batch): 58 | return autoencoder.encode(_batch) 59 | 60 | @torch.cuda.amp.autocast() 61 | def decode(_batch): 62 | return autoencoder.decode(_batch) 63 | 64 | def decode_large_batch(_batch): 65 | decode_mini_batch_size = 20 # use a small batch size since the decoder is large 66 | xs = [] 67 | pt = 0 68 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 69 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 70 | pt += _decode_mini_batch_size 71 | xs.append(x) 72 | xs = torch.concat(xs, dim=0) 73 | assert xs.size(0) == _batch.size(0) 74 | return xs 75 | 76 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 77 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 78 | def cfg_nnet(x, timesteps, y): 79 | _cond = nnet(x, timesteps, y=y) 80 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 81 | return _cond + config.sample.scale * (_cond - _uncond) 82 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE()) 83 | else: 84 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE()) 85 | 86 | logging.info(config.sample) 87 | assert os.path.exists(dataset.fid_stat) 88 | logging.info(f'sample: each class sample n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 89 | 90 | def amortize(n_samples, batch_size): 91 | k = n_samples // batch_size 92 | r = n_samples % batch_size 93 | return k * [batch_size] if r == 0 else k * [batch_size] + [r] 94 | 95 | def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, class_num=0): 96 | os.makedirs(path, exist_ok=True) 97 | idx = 0 98 | batch_size = mini_batch_size * accelerator.num_processes 99 | 100 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): 101 | samples = unpreprocess_fn(sample_fn(mini_batch_size, class_num)) 102 | samples = accelerator.gather(samples.contiguous())[:_batch_size] 103 | if accelerator.local_process_index == 0: 104 | for sample in samples: 105 | save_image(sample, os.path.join(path, f"{idx}.png")) 106 | idx += 1 107 | 108 | def sample_fn(_n_samples, class_num): 109 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 110 | if config.train.mode == 'uncond': 111 | kwargs = dict() 112 | elif config.train.mode == 'cond': 113 | torch_arr = torch.ones(_n_samples // 10, device=device, dtype=int) * torch.tensor(int(class_num)) 114 | kwargs = dict(y=einops.repeat(torch_arr % dataset.K, 'nrow -> (nrow ncol)', ncol=10)) 115 | else: 116 | raise NotImplementedError 117 | 118 | if config.sample.algorithm == 'euler_maruyama_sde': 119 | _z = sde.euler_maruyama(sde.ReverseSDE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 120 | elif config.sample.algorithm == 'euler_maruyama_ode': 121 | _z = sde.euler_maruyama(sde.ODE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 122 | elif config.sample.algorithm == 'dpm_solver': 123 | noise_schedule = NoiseScheduleVP(schedule='linear') 124 | model_fn = model_wrapper( 125 | score_model.noise_pred, 126 | noise_schedule, 127 | time_input_type='0', 128 | model_kwargs=kwargs 129 | ) 130 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 131 | _z = dpm_solver.sample( 132 | _z_init, 133 | steps=config.sample.sample_steps, 134 | eps=1e-4, 135 | adaptive_step_size=False, 136 | fast_version=True, 137 | ) 138 | else: 139 | raise NotImplementedError 140 | return decode_large_batch(_z) 141 | f_read = open('idx_to_class.pkl', 'rb') 142 | dict2 = pickle.load(f_read) 143 | print(dict2) 144 | 145 | for i in range(1000): 146 | aug_samples_path = f'{config.dpm_path}/{cluster_name}/{config.resolution}/samples_for_classifier/aug_{config.augmentation_K}_samples' 147 | 148 | path = os.path.join(aug_samples_path, f'train/{dict2[i]}') 149 | if accelerator.is_main_process: 150 | os.makedirs(path, exist_ok=True) 151 | sample2dir(accelerator, path, config.augmentation_K, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess, class_num = i) 152 | 153 | 154 | from absl import flags 155 | from absl import app 156 | from ml_collections import config_flags 157 | import os 158 | import sys 159 | from pathlib import Path 160 | 161 | 162 | FLAGS = flags.FLAGS 163 | config_flags.DEFINE_config_file( 164 | "config", None, "Training configuration.", lock_config=False) 165 | flags.mark_flags_as_required(["config"]) 166 | flags.DEFINE_string("output_path", None, "The path to output log.") 167 | 168 | 169 | def get_config_name(): 170 | argv = sys.argv 171 | for i in range(1, len(argv)): 172 | if argv[i].startswith('--config='): 173 | return Path(argv[i].split('=')[-1]).stem 174 | 175 | 176 | def get_hparams(): 177 | argv = sys.argv 178 | lst = [] 179 | for i in range(1, len(argv)): 180 | assert '=' in argv[i] 181 | if argv[i].startswith('--config.'): 182 | hparam, val = argv[i].split('=') 183 | hparam = hparam.split('.')[-1] 184 | if hparam.endswith('path'): 185 | val = Path(val).stem 186 | lst.append(f'{hparam}={val}') 187 | hparams = '-'.join(lst) 188 | if hparams == '': 189 | hparams = 'default' 190 | return hparams 191 | 192 | 193 | def main(argv): 194 | config = FLAGS.config 195 | config_name = get_config_name() 196 | hparams = get_hparams() 197 | config.project = config_name 198 | config.notes = hparams 199 | config.output_path = FLAGS.output_path 200 | 201 | evaluate(config) 202 | 203 | 204 | if __name__ == "__main__": 205 | app.run(main) 206 | -------------------------------------------------------------------------------- /sample_ldm_discrete_all.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | import utils 7 | import sde 8 | import einops 9 | from datasets import get_dataset 10 | import tempfile 11 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 12 | from absl import logging 13 | import builtins 14 | import libs.autoencoder 15 | import torch.nn as nn 16 | import numpy as np 17 | import os 18 | from tqdm import tqdm 19 | from torchvision.utils import make_grid, save_image 20 | from absl import logging 21 | import pickle 22 | 23 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 24 | _betas = ( 25 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 26 | ) 27 | return _betas.numpy() 28 | 29 | def evaluate(config): 30 | if config.get('benchmark', False): 31 | torch.backends.cudnn.benchmark = True 32 | torch.backends.cudnn.deterministic = False 33 | 34 | mp.set_start_method('spawn') 35 | accelerator = accelerate.Accelerator() 36 | device = accelerator.device 37 | accelerate.utils.set_seed(config.seed, device_specific=True) 38 | logging.info(f'Process {accelerator.process_index} using device: {device}') 39 | 40 | config.mixed_precision = accelerator.mixed_precision 41 | config = ml_collections.FrozenConfigDict(config) 42 | if accelerator.is_main_process: 43 | utils.set_logger(log_level='info', fname=config.output_path) 44 | else: 45 | utils.set_logger(log_level='error') 46 | builtins.print = lambda *args: None 47 | 48 | dataset = get_dataset(**config.dataset) 49 | 50 | nnet = utils.get_nnet(**config.nnet) 51 | nnet = accelerator.prepare(nnet) 52 | if config.nnet_path == '': 53 | cluster_name = config.model_name + '-' + '-'.join(config.subset_path.split('/')).split('.txt')[0] 54 | 55 | nnet_path = f'{config.dpm_path}/{cluster_name}/{config.resolution}/ckpts/{config.train.n_steps}.ckpt/nnet_ema.pth' 56 | 57 | else: 58 | nnet_path = config.nnet_path 59 | 60 | logging.info(f'load nnet from {nnet_path}') 61 | 62 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(nnet_path, map_location='cpu')) 63 | nnet.eval() 64 | 65 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 66 | autoencoder.to(device) 67 | 68 | @torch.cuda.amp.autocast() 69 | def encode(_batch): 70 | return autoencoder.encode(_batch) 71 | 72 | @torch.cuda.amp.autocast() 73 | def decode(_batch): 74 | return autoencoder.decode(_batch) 75 | 76 | def decode_large_batch(_batch): 77 | decode_mini_batch_size = 20 # use a small batch size since the decoder is large 78 | xs = [] 79 | pt = 0 80 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 81 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 82 | pt += _decode_mini_batch_size 83 | xs.append(x) 84 | xs = torch.concat(xs, dim=0) 85 | assert xs.size(0) == _batch.size(0) 86 | return xs 87 | 88 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 89 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 90 | def cfg_nnet(x, timesteps, y): 91 | _cond = nnet(x, timesteps, y=y) 92 | _uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device)) 93 | return _cond + config.sample.scale * (_cond - _uncond) 94 | else: 95 | def cfg_nnet(x, timesteps, y): 96 | _cond = nnet(x, timesteps, y=y) 97 | return _cond 98 | 99 | logging.info(config.sample) 100 | assert os.path.exists(dataset.fid_stat) 101 | logging.info(f'sample: each class sample n_samples={config.augmentation_K}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 102 | 103 | _betas = stable_diffusion_beta_schedule() 104 | N = len(_betas) 105 | 106 | def amortize(n_samples, batch_size): 107 | k = n_samples // batch_size 108 | r = n_samples % batch_size 109 | return k * [batch_size] if r == 0 else k * [batch_size] + [r] 110 | 111 | def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, class_num=0): 112 | os.makedirs(path, exist_ok=True) 113 | idx = 0 114 | batch_size = mini_batch_size * accelerator.num_processes 115 | 116 | for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): 117 | samples = unpreprocess_fn(sample_fn(mini_batch_size, class_num)) 118 | samples = accelerator.gather(samples.contiguous())[:_batch_size] 119 | if accelerator.local_process_index == 0: 120 | for sample in samples: 121 | save_image(sample, os.path.join(path, f"{idx}.png")) 122 | idx += 1 123 | 124 | def sample_z(_n_samples, _sample_steps, **kwargs): 125 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 126 | 127 | if config.sample.algorithm == 'dpm_solver': 128 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 129 | 130 | def model_fn(x, t_continuous): 131 | t = t_continuous * N 132 | eps_pre = cfg_nnet(x, t, **kwargs) 133 | return eps_pre 134 | 135 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 136 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1.) 137 | 138 | else: 139 | raise NotImplementedError 140 | 141 | return _z 142 | 143 | def sample_fn(_n_samples, class_num): 144 | if config.train.mode == 'uncond': 145 | kwargs = dict() 146 | elif config.train.mode == 'cond': 147 | torch_arr = torch.ones(_n_samples // 10, device=device, dtype=int) * torch.tensor(int(class_num)) 148 | kwargs = dict(y=einops.repeat(torch_arr % dataset.K, 'nrow -> (nrow ncol)', ncol=10)) 149 | else: 150 | raise NotImplementedError 151 | _z = sample_z(_n_samples, _sample_steps=config.sample.sample_steps, **kwargs) 152 | return decode_large_batch(_z) 153 | 154 | f_read = open('idx_to_class.pkl', 'rb') 155 | dict2 = pickle.load(f_read) 156 | print(dict2) 157 | 158 | for i in range(1000): 159 | if config.sample.path == '': 160 | aug_samples_path = f'{config.dpm_path}/{cluster_name}/{config.resolution}/samples_for_classifier/aug_{config.augmentation_K}_samples' 161 | 162 | path = os.path.join(aug_samples_path, f'train/{dict2[i]}') 163 | else: 164 | path = os.path.join(config.sample.path, f'{dict2[i]}') 165 | 166 | if accelerator.is_main_process: 167 | os.makedirs(path, exist_ok=True) 168 | sample2dir(accelerator, path, config.augmentation_K, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess, class_num = i) 169 | 170 | if config.sample.path != '' and accelerator.is_main_process: 171 | os.system(f"cd {config.sample.path} && cd .. && tar -zcf samples.tar.gz samples") 172 | 173 | from absl import flags 174 | from absl import app 175 | from ml_collections import config_flags 176 | import os 177 | import sys 178 | from pathlib import Path 179 | 180 | 181 | FLAGS = flags.FLAGS 182 | config_flags.DEFINE_config_file( 183 | "config", None, "Training configuration.", lock_config=False) 184 | flags.mark_flags_as_required(["config"]) 185 | flags.DEFINE_string("nnet_path", '', "The nnet to evaluate.") 186 | flags.DEFINE_string("output_path", None, "The path to output log.") 187 | 188 | 189 | def get_config_name(): 190 | argv = sys.argv 191 | for i in range(1, len(argv)): 192 | if argv[i].startswith('--config='): 193 | return Path(argv[i].split('=')[-1]).stem 194 | 195 | 196 | def get_hparams(): 197 | argv = sys.argv 198 | lst = [] 199 | for i in range(1, len(argv)): 200 | assert '=' in argv[i] 201 | if argv[i].startswith('--config.'): 202 | hparam, val = argv[i].split('=') 203 | hparam = hparam.split('.')[-1] 204 | if hparam.endswith('path'): 205 | val = Path(val).stem 206 | lst.append(f'{hparam}={val}') 207 | hparams = '-'.join(lst) 208 | if hparams == '': 209 | hparams = 'default' 210 | return hparams 211 | 212 | 213 | def main(argv): 214 | config = FLAGS.config 215 | config_name = get_config_name() 216 | hparams = get_hparams() 217 | config.project = config_name 218 | config.notes = hparams 219 | config.nnet_path = FLAGS.nnet_path 220 | config.output_path = FLAGS.output_path 221 | 222 | evaluate(config) 223 | 224 | 225 | if __name__ == "__main__": 226 | app.run(main) 227 | -------------------------------------------------------------------------------- /libs/uvit_t2i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .timm import trunc_normal_, Mlp 5 | import einops 6 | import torch.utils.checkpoint 7 | 8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 9 | ATTENTION_MODE = 'flash' 10 | else: 11 | try: 12 | import xformers 13 | import xformers.ops 14 | ATTENTION_MODE = 'xformers' 15 | except: 16 | ATTENTION_MODE = 'math' 17 | print(f'attention mode is {ATTENTION_MODE}') 18 | 19 | 20 | def timestep_embedding(timesteps, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | 24 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 25 | These may be fractional. 26 | :param dim: the dimension of the output. 27 | :param max_period: controls the minimum frequency of the embeddings. 28 | :return: an [N x dim] Tensor of positional embeddings. 29 | """ 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 33 | ).to(device=timesteps.device) 34 | args = timesteps[:, None].float() * freqs[None] 35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 36 | if dim % 2: 37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 38 | return embedding 39 | 40 | 41 | def patchify(imgs, patch_size): 42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 43 | return x 44 | 45 | 46 | def unpatchify(x, channels=3): 47 | patch_size = int((x.shape[2] // channels) ** 0.5) 48 | h = w = int(x.shape[1] ** .5) 49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 56 | super().__init__() 57 | self.num_heads = num_heads 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | def forward(self, x): 67 | B, L, C = x.shape 68 | 69 | qkv = self.qkv(x) 70 | if ATTENTION_MODE == 'flash': 71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 74 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 75 | elif ATTENTION_MODE == 'xformers': 76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 78 | x = xformers.ops.memory_efficient_attention(q, k, v) 79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 80 | elif ATTENTION_MODE == 'math': 81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 83 | attn = (q @ k.transpose(-2, -1)) * self.scale 84 | attn = attn.softmax(dim=-1) 85 | attn = self.attn_drop(attn) 86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 87 | else: 88 | raise NotImplemented 89 | 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x 93 | 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 99 | super().__init__() 100 | self.norm1 = norm_layer(dim) 101 | self.attn = Attention( 102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 107 | self.use_checkpoint = use_checkpoint 108 | 109 | def forward(self, x, skip=None): 110 | if self.use_checkpoint: 111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip) 112 | else: 113 | return self._forward(x, skip) 114 | 115 | def _forward(self, x, skip=None): 116 | if self.skip_linear is not None: 117 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 118 | x = x + self.attn(self.norm1(x)) 119 | x = x + self.mlp(self.norm2(x)) 120 | return x 121 | 122 | 123 | class PatchEmbed(nn.Module): 124 | """ Image to Patch Embedding 125 | """ 126 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 127 | super().__init__() 128 | self.patch_size = patch_size 129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 130 | 131 | def forward(self, x): 132 | B, C, H, W = x.shape 133 | assert H % self.patch_size == 0 and W % self.patch_size == 0 134 | x = self.proj(x).flatten(2).transpose(1, 2) 135 | return x 136 | 137 | 138 | class UViT(nn.Module): 139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False, 141 | clip_dim=768, num_clip_token=77, conv=True, skip=True): 142 | super().__init__() 143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 144 | self.in_chans = in_chans 145 | 146 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 147 | num_patches = (img_size // patch_size) ** 2 148 | 149 | self.time_embed = nn.Sequential( 150 | nn.Linear(embed_dim, 4 * embed_dim), 151 | nn.SiLU(), 152 | nn.Linear(4 * embed_dim, embed_dim), 153 | ) if mlp_time_embed else nn.Identity() 154 | 155 | self.context_embed = nn.Linear(clip_dim, embed_dim) 156 | 157 | self.extras = 1 + num_clip_token 158 | 159 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 160 | 161 | self.in_blocks = nn.ModuleList([ 162 | Block( 163 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 164 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 165 | for _ in range(depth // 2)]) 166 | 167 | self.mid_block = Block( 168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 169 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 170 | 171 | self.out_blocks = nn.ModuleList([ 172 | Block( 173 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 174 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 175 | for _ in range(depth // 2)]) 176 | 177 | self.norm = norm_layer(embed_dim) 178 | self.patch_dim = patch_size ** 2 * in_chans 179 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 180 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() 181 | 182 | trunc_normal_(self.pos_embed, std=.02) 183 | self.apply(self._init_weights) 184 | 185 | def _init_weights(self, m): 186 | if isinstance(m, nn.Linear): 187 | trunc_normal_(m.weight, std=.02) 188 | if isinstance(m, nn.Linear) and m.bias is not None: 189 | nn.init.constant_(m.bias, 0) 190 | elif isinstance(m, nn.LayerNorm): 191 | nn.init.constant_(m.bias, 0) 192 | nn.init.constant_(m.weight, 1.0) 193 | 194 | @torch.jit.ignore 195 | def no_weight_decay(self): 196 | return {'pos_embed'} 197 | 198 | def forward(self, x, timesteps, context): 199 | x = self.patch_embed(x) 200 | B, L, D = x.shape 201 | 202 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 203 | time_token = time_token.unsqueeze(dim=1) 204 | context_token = self.context_embed(context) 205 | x = torch.cat((time_token, context_token, x), dim=1) 206 | x = x + self.pos_embed 207 | 208 | skips = [] 209 | for blk in self.in_blocks: 210 | x = blk(x) 211 | skips.append(x) 212 | 213 | x = self.mid_block(x) 214 | 215 | for blk in self.out_blocks: 216 | x = blk(x, skips.pop()) 217 | 218 | x = self.norm(x) 219 | x = self.decoder_pred(x) 220 | assert x.size(1) == self.extras + L 221 | x = x[:, self.extras:, :] 222 | x = unpatchify(x, self.in_chans) 223 | x = self.final_layer(x) 224 | return x 225 | -------------------------------------------------------------------------------- /libs/uvit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .timm import trunc_normal_, Mlp 5 | import einops 6 | import torch.utils.checkpoint 7 | 8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 9 | ATTENTION_MODE = 'flash' 10 | else: 11 | try: 12 | import xformers 13 | import xformers.ops 14 | ATTENTION_MODE = 'xformers' 15 | except: 16 | ATTENTION_MODE = 'math' 17 | print(f'attention mode is {ATTENTION_MODE}') 18 | 19 | 20 | def timestep_embedding(timesteps, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | 24 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 25 | These may be fractional. 26 | :param dim: the dimension of the output. 27 | :param max_period: controls the minimum frequency of the embeddings. 28 | :return: an [N x dim] Tensor of positional embeddings. 29 | """ 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 33 | ).to(device=timesteps.device) 34 | args = timesteps[:, None].float() * freqs[None] 35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 36 | if dim % 2: 37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 38 | return embedding 39 | 40 | 41 | def patchify(imgs, patch_size): 42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 43 | return x 44 | 45 | 46 | def unpatchify(x, channels=3): 47 | patch_size = int((x.shape[2] // channels) ** 0.5) 48 | h = w = int(x.shape[1] ** .5) 49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 56 | super().__init__() 57 | self.num_heads = num_heads 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | def forward(self, x): 67 | B, L, C = x.shape 68 | 69 | qkv = self.qkv(x) 70 | if ATTENTION_MODE == 'flash': 71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 74 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 75 | elif ATTENTION_MODE == 'xformers': 76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 78 | x = xformers.ops.memory_efficient_attention(q, k, v) 79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 80 | elif ATTENTION_MODE == 'math': 81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 83 | attn = (q @ k.transpose(-2, -1)) * self.scale 84 | attn = attn.softmax(dim=-1) 85 | attn = self.attn_drop(attn) 86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 87 | else: 88 | raise NotImplemented 89 | 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x 93 | 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 99 | super().__init__() 100 | self.norm1 = norm_layer(dim) 101 | self.attn = Attention( 102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 107 | self.use_checkpoint = use_checkpoint 108 | 109 | def forward(self, x, skip=None): 110 | if self.use_checkpoint: 111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip) 112 | else: 113 | return self._forward(x, skip) 114 | 115 | def _forward(self, x, skip=None): 116 | if self.skip_linear is not None: 117 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 118 | x = x + self.attn(self.norm1(x)) 119 | x = x + self.mlp(self.norm2(x)) 120 | return x 121 | 122 | 123 | class PatchEmbed(nn.Module): 124 | """ Image to Patch Embedding 125 | """ 126 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 127 | super().__init__() 128 | self.patch_size = patch_size 129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 130 | 131 | def forward(self, x): 132 | B, C, H, W = x.shape 133 | assert H % self.patch_size == 0 and W % self.patch_size == 0 134 | x = self.proj(x).flatten(2).transpose(1, 2) 135 | return x 136 | 137 | 138 | class UViT(nn.Module): 139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1, 141 | use_checkpoint=False, conv=True, skip=True): 142 | super().__init__() 143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 144 | self.num_classes = num_classes 145 | self.in_chans = in_chans 146 | 147 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 148 | num_patches = (img_size // patch_size) ** 2 149 | 150 | self.time_embed = nn.Sequential( 151 | nn.Linear(embed_dim, 4 * embed_dim), 152 | nn.SiLU(), 153 | nn.Linear(4 * embed_dim, embed_dim), 154 | ) if mlp_time_embed else nn.Identity() 155 | 156 | if self.num_classes > 0: 157 | self.label_emb = nn.Embedding(self.num_classes, embed_dim) 158 | self.extras = 2 159 | else: 160 | self.extras = 1 161 | 162 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 163 | 164 | self.in_blocks = nn.ModuleList([ 165 | Block( 166 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 167 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 168 | for _ in range(depth // 2)]) 169 | 170 | self.mid_block = Block( 171 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 172 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 173 | 174 | self.out_blocks = nn.ModuleList([ 175 | Block( 176 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 177 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 178 | for _ in range(depth // 2)]) 179 | 180 | self.norm = norm_layer(embed_dim) 181 | self.patch_dim = patch_size ** 2 * in_chans 182 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 183 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() 184 | 185 | trunc_normal_(self.pos_embed, std=.02) 186 | self.apply(self._init_weights) 187 | 188 | def _init_weights(self, m): 189 | if isinstance(m, nn.Linear): 190 | trunc_normal_(m.weight, std=.02) 191 | if isinstance(m, nn.Linear) and m.bias is not None: 192 | nn.init.constant_(m.bias, 0) 193 | elif isinstance(m, nn.LayerNorm): 194 | nn.init.constant_(m.bias, 0) 195 | nn.init.constant_(m.weight, 1.0) 196 | 197 | @torch.jit.ignore 198 | def no_weight_decay(self): 199 | return {'pos_embed'} 200 | 201 | def forward(self, x, timesteps, y=None): 202 | x = self.patch_embed(x) 203 | B, L, D = x.shape 204 | 205 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 206 | time_token = time_token.unsqueeze(dim=1) 207 | x = torch.cat((time_token, x), dim=1) 208 | if y is not None: 209 | label_emb = self.label_emb(y) 210 | label_emb = label_emb.unsqueeze(dim=1) 211 | x = torch.cat((label_emb, x), dim=1) 212 | x = x + self.pos_embed 213 | 214 | skips = [] 215 | for blk in self.in_blocks: 216 | x = blk(x) 217 | skips.append(x) 218 | 219 | x = self.mid_block(x) 220 | 221 | for blk in self.out_blocks: 222 | x = blk(x, skips.pop()) 223 | 224 | x = self.norm(x) 225 | x = self.decoder_pred(x) 226 | assert x.size(1) == self.extras + L 227 | x = x[:, self.extras:, :] 228 | x = unpatchify(x, self.in_chans) 229 | x = self.final_layer(x) 230 | return x 231 | -------------------------------------------------------------------------------- /sde.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from absl import logging 4 | import numpy as np 5 | import math 6 | from tqdm import tqdm 7 | 8 | 9 | def get_sde(name, **kwargs): 10 | if name == 'vpsde': 11 | return VPSDE(**kwargs) 12 | elif name == 'vpsde_cosine': 13 | return VPSDECosine(**kwargs) 14 | else: 15 | raise NotImplementedError 16 | 17 | 18 | def stp(s, ts: torch.Tensor): # scalar tensor product 19 | if isinstance(s, np.ndarray): 20 | s = torch.from_numpy(s).type_as(ts) 21 | extra_dims = (1,) * (ts.dim() - 1) 22 | return s.view(-1, *extra_dims) * ts 23 | 24 | 25 | def mos(a, start_dim=1): # mean of square 26 | return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1) 27 | 28 | 29 | def duplicate(tensor, *size): 30 | return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape) 31 | 32 | 33 | class SDE(object): 34 | r""" 35 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1 36 | f(x, t) is the drift 37 | g(t) is the diffusion 38 | """ 39 | def drift(self, x, t): 40 | raise NotImplementedError 41 | 42 | def diffusion(self, t): 43 | raise NotImplementedError 44 | 45 | def cum_beta(self, t): # the variance of xt|x0 46 | raise NotImplementedError 47 | 48 | def cum_alpha(self, t): 49 | raise NotImplementedError 50 | 51 | def snr(self, t): # signal noise ratio 52 | raise NotImplementedError 53 | 54 | def nsr(self, t): # noise signal ratio 55 | raise NotImplementedError 56 | 57 | def marginal_prob(self, x0, t): # the mean and std of q(xt|x0) 58 | alpha = self.cum_alpha(t) 59 | beta = self.cum_beta(t) 60 | mean = stp(alpha ** 0.5, x0) # E[xt|x0] 61 | std = beta ** 0.5 # Cov[xt|x0] ** 0.5 62 | return mean, std 63 | 64 | def sample(self, x0, t_init=0): # sample from q(xn|x0), where n is uniform 65 | t = torch.rand(x0.shape[0], device=x0.device) * (1. - t_init) + t_init 66 | mean, std = self.marginal_prob(x0, t) 67 | eps = torch.randn_like(x0) 68 | xt = mean + stp(std, eps) 69 | return t, eps, xt 70 | 71 | 72 | class VPSDE(SDE): 73 | def __init__(self, beta_min=0.1, beta_max=20): 74 | # 0 <= t <= 1 75 | self.beta_0 = beta_min 76 | self.beta_1 = beta_max 77 | 78 | def drift(self, x, t): 79 | return -0.5 * stp(self.squared_diffusion(t), x) 80 | 81 | def diffusion(self, t): 82 | return self.squared_diffusion(t) ** 0.5 83 | 84 | def squared_diffusion(self, t): # beta(t) 85 | return self.beta_0 + t * (self.beta_1 - self.beta_0) 86 | 87 | def squared_diffusion_integral(self, s, t): # \int_s^t beta(tau) d tau 88 | return self.beta_0 * (t - s) + (self.beta_1 - self.beta_0) * (t ** 2 - s ** 2) * 0.5 89 | 90 | def skip_beta(self, s, t): # beta_{t|s}, Cov[xt|xs]=beta_{t|s} I 91 | return 1. - self.skip_alpha(s, t) 92 | 93 | def skip_alpha(self, s, t): # alpha_{t|s}, E[xt|xs]=alpha_{t|s}**0.5 xs 94 | x = -self.squared_diffusion_integral(s, t) 95 | return x.exp() 96 | 97 | def cum_beta(self, t): 98 | return self.skip_beta(0, t) 99 | 100 | def cum_alpha(self, t): 101 | return self.skip_alpha(0, t) 102 | 103 | def nsr(self, t): 104 | return self.squared_diffusion_integral(0, t).expm1() 105 | 106 | def snr(self, t): 107 | return 1. / self.nsr(t) 108 | 109 | def __str__(self): 110 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}' 111 | 112 | def __repr__(self): 113 | return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}' 114 | 115 | 116 | class VPSDECosine(SDE): 117 | r""" 118 | dx = f(x, t)dt + g(t) dw with 0 <= t <= 1 119 | f(x, t) is the drift 120 | g(t) is the diffusion 121 | """ 122 | def __init__(self, s=0.008): 123 | self.s = s 124 | self.F = lambda t: torch.cos((t + s) / (1 + s) * math.pi / 2) ** 2 125 | self.F0 = math.cos(s / (1 + s) * math.pi / 2) ** 2 126 | 127 | def drift(self, x, t): 128 | ft = - torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi / 2 129 | return stp(ft, x) 130 | 131 | def diffusion(self, t): 132 | return (torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi) ** 0.5 133 | 134 | def cum_beta(self, t): # the variance of xt|x0 135 | return 1 - self.cum_alpha(t) 136 | 137 | def cum_alpha(self, t): 138 | return self.F(t) / self.F0 139 | 140 | def snr(self, t): # signal noise ratio 141 | Ft = self.F(t) 142 | return Ft / (self.F0 - Ft) 143 | 144 | def nsr(self, t): # noise signal ratio 145 | Ft = self.F(t) 146 | return self.F0 / Ft - 1 147 | 148 | def __str__(self): 149 | return 'vpsde_cosine' 150 | 151 | def __repr__(self): 152 | return 'vpsde_cosine' 153 | 154 | 155 | class ScoreModel(object): 156 | r""" 157 | The forward process is q(x_[0,T]) 158 | """ 159 | 160 | def __init__(self, nnet: nn.Module, pred: str, sde: SDE, T=1): 161 | assert T == 1 162 | self.nnet = nnet 163 | self.pred = pred 164 | self.sde = sde 165 | self.T = T 166 | print(f'ScoreModel with pred={pred}, sde={sde}, T={T}') 167 | 168 | def predict(self, xt, t, **kwargs): 169 | if not isinstance(t, torch.Tensor): 170 | t = torch.tensor(t) 171 | t = t.to(xt.device) 172 | if t.dim() == 0: 173 | t = duplicate(t, xt.size(0)) 174 | return self.nnet(xt, t * 999, **kwargs) # follow SDE 175 | 176 | def noise_pred(self, xt, t, **kwargs): 177 | pred = self.predict(xt, t, **kwargs) 178 | if self.pred == 'noise_pred': 179 | noise_pred = pred 180 | elif self.pred == 'x0_pred': 181 | noise_pred = - stp(self.sde.snr(t).sqrt(), pred) + stp(self.sde.cum_beta(t).rsqrt(), xt) 182 | else: 183 | raise NotImplementedError 184 | return noise_pred 185 | 186 | def x0_pred(self, xt, t, **kwargs): 187 | pred = self.predict(xt, t, **kwargs) 188 | if self.pred == 'noise_pred': 189 | x0_pred = stp(self.sde.cum_alpha(t).rsqrt(), xt) - stp(self.sde.nsr(t).sqrt(), pred) 190 | elif self.pred == 'x0_pred': 191 | x0_pred = pred 192 | else: 193 | raise NotImplementedError 194 | return x0_pred 195 | 196 | def score(self, xt, t, **kwargs): 197 | cum_beta = self.sde.cum_beta(t) 198 | noise_pred = self.noise_pred(xt, t, **kwargs) 199 | return stp(-cum_beta.rsqrt(), noise_pred) 200 | 201 | 202 | class ReverseSDE(object): 203 | r""" 204 | dx = [f(x, t) - g(t)^2 s(x, t)] dt + g(t) dw 205 | """ 206 | def __init__(self, score_model): 207 | self.sde = score_model.sde # the forward sde 208 | self.score_model = score_model 209 | 210 | def drift(self, x, t, **kwargs): 211 | drift = self.sde.drift(x, t) # f(x, t) 212 | diffusion = self.sde.diffusion(t) # g(t) 213 | score = self.score_model.score(x, t, **kwargs) 214 | return drift - stp(diffusion ** 2, score) 215 | 216 | def diffusion(self, t): 217 | return self.sde.diffusion(t) 218 | 219 | 220 | class ODE(object): 221 | r""" 222 | dx = [f(x, t) - g(t)^2 s(x, t)] dt 223 | """ 224 | 225 | def __init__(self, score_model): 226 | self.sde = score_model.sde # the forward sde 227 | self.score_model = score_model 228 | 229 | def drift(self, x, t, **kwargs): 230 | drift = self.sde.drift(x, t) # f(x, t) 231 | diffusion = self.sde.diffusion(t) # g(t) 232 | score = self.score_model.score(x, t, **kwargs) 233 | return drift - 0.5 * stp(diffusion ** 2, score) 234 | 235 | def diffusion(self, t): 236 | return 0 237 | 238 | 239 | def dct2str(dct): 240 | return str({k: f'{v:.6g}' for k, v in dct.items()}) 241 | 242 | 243 | @ torch.no_grad() 244 | def euler_maruyama(rsde, x_init, sample_steps, eps=1e-3, T=1, trace=None, verbose=False, **kwargs): 245 | r""" 246 | The Euler Maruyama sampler for reverse SDE / ODE 247 | See `Score-Based Generative Modeling through Stochastic Differential Equations` 248 | """ 249 | assert isinstance(rsde, ReverseSDE) or isinstance(rsde, ODE) 250 | print(f"euler_maruyama with sample_steps={sample_steps}") 251 | timesteps = np.append(0., np.linspace(eps, T, sample_steps)) 252 | timesteps = torch.tensor(timesteps).to(x_init) 253 | x = x_init 254 | if trace is not None: 255 | trace.append(x) 256 | for s, t in tqdm(list(zip(timesteps, timesteps[1:]))[::-1], disable=not verbose, desc='euler_maruyama'): 257 | drift = rsde.drift(x, t, **kwargs) 258 | diffusion = rsde.diffusion(t) 259 | dt = s - t 260 | mean = x + drift * dt 261 | sigma = diffusion * (-dt).sqrt() 262 | x = mean + stp(sigma, torch.randn_like(x)) if s != 0 else mean 263 | if trace is not None: 264 | trace.append(x) 265 | statistics = dict(s=s, t=t, sigma=sigma.item()) 266 | logging.debug(dct2str(statistics)) 267 | return x 268 | 269 | 270 | def LSimple(score_model: ScoreModel, x0, pred='noise_pred', **kwargs): 271 | t, noise, xt = score_model.sde.sample(x0) 272 | if pred == 'noise_pred': 273 | noise_pred = score_model.noise_pred(xt, t, **kwargs) 274 | return mos(noise - noise_pred) 275 | elif pred == 'x0_pred': 276 | x0_pred = score_model.x0_pred(xt, t, **kwargs) 277 | return mos(x0 - x0_pred) 278 | else: 279 | raise NotImplementedError(pred) 280 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import math 10 | import torch 11 | import torch.distributed as dist 12 | 13 | from logging import getLogger 14 | 15 | logger = getLogger() 16 | 17 | 18 | def gpu_timer(closure, log_timings=True): 19 | """ Helper to time gpu-time to execute closure() """ 20 | elapsed_time = -1. 21 | if log_timings: 22 | start = torch.cuda.Event(enable_timing=True) 23 | end = torch.cuda.Event(enable_timing=True) 24 | start.record() 25 | 26 | result = closure() 27 | 28 | if log_timings: 29 | end.record() 30 | torch.cuda.synchronize() 31 | elapsed_time = start.elapsed_time(end) 32 | 33 | return result, elapsed_time 34 | 35 | 36 | def init_distributed(port=40111, rank_and_world_size=(None, None)): 37 | 38 | if dist.is_available() and dist.is_initialized(): 39 | return dist.get_world_size(), dist.get_rank() 40 | 41 | rank, world_size = rank_and_world_size 42 | os.environ['MASTER_ADDR'] = 'localhost' 43 | 44 | if (rank is None) or (world_size is None): 45 | try: 46 | world_size = int(os.environ['SLURM_NTASKS']) 47 | rank = int(os.environ['SLURM_PROCID']) 48 | os.environ['MASTER_ADDR'] = os.environ['HOSTNAME'] 49 | except Exception: 50 | logger.info('SLURM vars not set (distributed training not available)') 51 | world_size, rank = 1, 0 52 | return world_size, rank 53 | 54 | try: 55 | os.environ['MASTER_PORT'] = str(port) 56 | torch.distributed.init_process_group( 57 | backend='nccl', 58 | world_size=world_size, 59 | rank=rank) 60 | except Exception: 61 | world_size, rank = 1, 0 62 | logger.info('distributed training not available') 63 | 64 | return world_size, rank 65 | 66 | 67 | class WarmupCosineSchedule(object): 68 | 69 | def __init__( 70 | self, 71 | optimizer, 72 | warmup_steps, 73 | start_lr, 74 | ref_lr, 75 | T_max, 76 | last_epoch=-1, 77 | final_lr=0. 78 | ): 79 | self.optimizer = optimizer 80 | self.start_lr = start_lr 81 | self.ref_lr = ref_lr 82 | self.final_lr = final_lr 83 | self.warmup_steps = warmup_steps 84 | self.T_max = T_max - warmup_steps 85 | self._step = 0. 86 | 87 | def step(self): 88 | self._step += 1 89 | if self._step < self.warmup_steps: 90 | progress = float(self._step) / float(max(1, self.warmup_steps)) 91 | new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) 92 | else: 93 | # -- progress after warmup 94 | progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) 95 | new_lr = max(self.final_lr, 96 | self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress))) 97 | 98 | for group in self.optimizer.param_groups: 99 | group['lr'] = new_lr 100 | 101 | return new_lr 102 | 103 | 104 | class CosineWDSchedule(object): 105 | 106 | def __init__( 107 | self, 108 | optimizer, 109 | ref_wd, 110 | T_max, 111 | final_wd=0. 112 | ): 113 | self.optimizer = optimizer 114 | self.ref_wd = ref_wd 115 | self.final_wd = final_wd 116 | self.T_max = T_max 117 | self._step = 0. 118 | 119 | def step(self): 120 | self._step += 1 121 | progress = self._step / self.T_max 122 | new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress)) 123 | 124 | if self.final_wd <= self.ref_wd: 125 | new_wd = max(self.final_wd, new_wd) 126 | else: 127 | new_wd = min(self.final_wd, new_wd) 128 | 129 | for group in self.optimizer.param_groups: 130 | if ('WD_exclude' not in group) or not group['WD_exclude']: 131 | group['weight_decay'] = new_wd 132 | return new_wd 133 | 134 | 135 | class CSVLogger(object): 136 | 137 | def __init__(self, fname, *argv): 138 | self.fname = fname 139 | self.types = [] 140 | # -- print headers 141 | with open(self.fname, '+a') as f: 142 | for i, v in enumerate(argv, 1): 143 | self.types.append(v[0]) 144 | if i < len(argv): 145 | print(v[1], end=',', file=f) 146 | else: 147 | print(v[1], end='\n', file=f) 148 | 149 | def log(self, *argv): 150 | with open(self.fname, '+a') as f: 151 | for i, tv in enumerate(zip(self.types, argv), 1): 152 | end = ',' if i < len(argv) else '\n' 153 | print(tv[0] % tv[1], end=end, file=f) 154 | 155 | 156 | class AverageMeter(object): 157 | """computes and stores the average and current value""" 158 | 159 | def __init__(self): 160 | self.reset() 161 | 162 | def reset(self): 163 | self.val = 0 164 | self.avg = 0 165 | self.max = float('-inf') 166 | self.min = float('inf') 167 | self.sum = 0 168 | self.count = 0 169 | 170 | def update(self, val, n=1): 171 | self.val = val 172 | self.max = max(val, self.max) 173 | self.min = min(val, self.min) 174 | self.sum += val * n 175 | self.count += n 176 | self.avg = self.sum / self.count 177 | 178 | 179 | class AllGather(torch.autograd.Function): 180 | 181 | @staticmethod 182 | def forward(ctx, x): 183 | if ( 184 | dist.is_available() 185 | and dist.is_initialized() 186 | and (dist.get_world_size() > 1) 187 | ): 188 | outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 189 | dist.all_gather(outputs, x) 190 | return torch.cat(outputs, 0) 191 | return x 192 | 193 | @staticmethod 194 | def backward(ctx, grads): 195 | if ( 196 | dist.is_available() 197 | and dist.is_initialized() 198 | and (dist.get_world_size() > 1) 199 | ): 200 | s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() 201 | e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) 202 | grads = grads.contiguous() 203 | dist.all_reduce(grads) 204 | return grads[s:e] 205 | return grads 206 | 207 | 208 | class AllReduceSum(torch.autograd.Function): 209 | 210 | @staticmethod 211 | def forward(ctx, x): 212 | if ( 213 | dist.is_available() 214 | and dist.is_initialized() 215 | and (dist.get_world_size() > 1) 216 | ): 217 | x = x.contiguous() 218 | dist.all_reduce(x) 219 | return x 220 | 221 | @staticmethod 222 | def backward(ctx, grads): 223 | return grads 224 | 225 | 226 | class AllReduce(torch.autograd.Function): 227 | 228 | @staticmethod 229 | def forward(ctx, x): 230 | if ( 231 | dist.is_available() 232 | and dist.is_initialized() 233 | and (dist.get_world_size() > 1) 234 | ): 235 | x = x.contiguous() / dist.get_world_size() 236 | dist.all_reduce(x) 237 | return x 238 | 239 | @staticmethod 240 | def backward(ctx, grads): 241 | return grads 242 | 243 | 244 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 245 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 246 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 247 | def norm_cdf(x): 248 | # Computes standard normal cumulative distribution function 249 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 250 | 251 | with torch.no_grad(): 252 | # Values are generated by using a truncated uniform distribution and 253 | # then using the inverse CDF for the normal distribution. 254 | # Get upper and lower cdf values 255 | l = norm_cdf((a - mean) / std) 256 | u = norm_cdf((b - mean) / std) 257 | 258 | # Uniformly fill tensor with values from [l, u], then translate to 259 | # [2l-1, 2u-1]. 260 | tensor.uniform_(2 * l - 1, 2 * u - 1) 261 | 262 | # Use inverse cdf transform for normal distribution to get truncated 263 | # standard normal 264 | tensor.erfinv_() 265 | 266 | # Transform to proper mean, std 267 | tensor.mul_(std * math.sqrt(2.)) 268 | tensor.add_(mean) 269 | 270 | # Clamp to ensure it's in the proper range 271 | tensor.clamp_(min=a, max=b) 272 | return tensor 273 | 274 | 275 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 276 | # type: (Tensor, float, float, float, float) -> Tensor 277 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 278 | 279 | 280 | def grad_logger(named_params): 281 | stats = AverageMeter() 282 | stats.first_layer = None 283 | stats.last_layer = None 284 | for n, p in named_params: 285 | if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1): 286 | grad_norm = float(torch.norm(p.grad.data)) 287 | stats.update(grad_norm) 288 | if 'qkv' in n: 289 | stats.last_layer = grad_norm 290 | if stats.first_layer is None: 291 | stats.first_layer = grad_norm 292 | if stats.first_layer is None or stats.last_layer is None: 293 | stats.first_layer = stats.last_layer = 0. 294 | return stats 295 | -------------------------------------------------------------------------------- /cifar10_experiment/edm/training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | import PIL.Image 14 | import json 15 | import torch 16 | import dnnlib 17 | 18 | try: 19 | import pyspng 20 | except ImportError: 21 | pyspng = None 22 | 23 | #---------------------------------------------------------------------------- 24 | # Abstract base class for datasets. 25 | 26 | class Dataset(torch.utils.data.Dataset): 27 | def __init__(self, 28 | name, # Name of the dataset. 29 | raw_shape, # Shape of the raw image data (NCHW). 30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 33 | random_seed = 0, # Random seed to use when applying max_size. 34 | cache = False, # Cache images in CPU memory? 35 | ): 36 | self._name = name 37 | self._raw_shape = list(raw_shape) 38 | self._use_labels = use_labels 39 | self._cache = cache 40 | self._cached_images = dict() # {raw_idx: np.ndarray, ...} 41 | self._raw_labels = None 42 | self._label_shape = None 43 | 44 | # Apply max_size. 45 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 46 | if (max_size is not None) and (self._raw_idx.size > max_size): 47 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) 48 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 49 | 50 | # Apply xflip. 51 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 52 | if xflip: 53 | self._raw_idx = np.tile(self._raw_idx, 2) 54 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 55 | 56 | def _get_raw_labels(self): 57 | if self._raw_labels is None: 58 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 59 | if self._raw_labels is None: 60 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 61 | assert isinstance(self._raw_labels, np.ndarray) 62 | assert self._raw_labels.shape[0] == self._raw_shape[0] 63 | assert self._raw_labels.dtype in [np.float32, np.int64] 64 | if self._raw_labels.dtype == np.int64: 65 | assert self._raw_labels.ndim == 1 66 | assert np.all(self._raw_labels >= 0) 67 | return self._raw_labels 68 | 69 | def close(self): # to be overridden by subclass 70 | pass 71 | 72 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 73 | raise NotImplementedError 74 | 75 | def _load_raw_labels(self): # to be overridden by subclass 76 | raise NotImplementedError 77 | 78 | def __getstate__(self): 79 | return dict(self.__dict__, _raw_labels=None) 80 | 81 | def __del__(self): 82 | try: 83 | self.close() 84 | except: 85 | pass 86 | 87 | def __len__(self): 88 | return self._raw_idx.size 89 | 90 | def __getitem__(self, idx): 91 | raw_idx = self._raw_idx[idx] 92 | image = self._cached_images.get(raw_idx, None) 93 | if image is None: 94 | image = self._load_raw_image(raw_idx) 95 | if self._cache: 96 | self._cached_images[raw_idx] = image 97 | assert isinstance(image, np.ndarray) 98 | assert list(image.shape) == self.image_shape 99 | assert image.dtype == np.uint8 100 | if self._xflip[idx]: 101 | assert image.ndim == 3 # CHW 102 | image = image[:, :, ::-1] 103 | return image.copy(), self.get_label(idx) 104 | 105 | def get_label(self, idx): 106 | label = self._get_raw_labels()[self._raw_idx[idx]] 107 | if label.dtype == np.int64: 108 | onehot = np.zeros(self.label_shape, dtype=np.float32) 109 | onehot[label] = 1 110 | label = onehot 111 | return label.copy() 112 | 113 | def get_details(self, idx): 114 | d = dnnlib.EasyDict() 115 | d.raw_idx = int(self._raw_idx[idx]) 116 | d.xflip = (int(self._xflip[idx]) != 0) 117 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 118 | return d 119 | 120 | @property 121 | def name(self): 122 | return self._name 123 | 124 | @property 125 | def image_shape(self): 126 | return list(self._raw_shape[1:]) 127 | 128 | @property 129 | def num_channels(self): 130 | assert len(self.image_shape) == 3 # CHW 131 | return self.image_shape[0] 132 | 133 | @property 134 | def resolution(self): 135 | assert len(self.image_shape) == 3 # CHW 136 | assert self.image_shape[1] == self.image_shape[2] 137 | return self.image_shape[1] 138 | 139 | @property 140 | def label_shape(self): 141 | if self._label_shape is None: 142 | raw_labels = self._get_raw_labels() 143 | if raw_labels.dtype == np.int64: 144 | self._label_shape = [int(np.max(raw_labels)) + 1] 145 | else: 146 | self._label_shape = raw_labels.shape[1:] 147 | return list(self._label_shape) 148 | 149 | @property 150 | def label_dim(self): 151 | assert len(self.label_shape) == 1 152 | return self.label_shape[0] 153 | 154 | @property 155 | def has_labels(self): 156 | return any(x != 0 for x in self.label_shape) 157 | 158 | @property 159 | def has_onehot_labels(self): 160 | return self._get_raw_labels().dtype == np.int64 161 | 162 | #---------------------------------------------------------------------------- 163 | # Dataset subclass that loads images recursively from the specified directory 164 | # or ZIP file. 165 | 166 | class ImageFolderDataset(Dataset): 167 | def __init__(self, 168 | path, # Path to directory or zip. 169 | resolution = None, # Ensure specific resolution, None = highest available. 170 | use_pyspng = True, # Use pyspng if available? 171 | **super_kwargs, # Additional arguments for the Dataset base class. 172 | ): 173 | self._path = path 174 | self._use_pyspng = use_pyspng 175 | self._zipfile = None 176 | 177 | if os.path.isdir(self._path): 178 | self._type = 'dir' 179 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 180 | elif self._file_ext(self._path) == '.zip': 181 | self._type = 'zip' 182 | self._all_fnames = set(self._get_zipfile().namelist()) 183 | else: 184 | raise IOError('Path must point to a directory or zip') 185 | 186 | PIL.Image.init() 187 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 188 | if len(self._image_fnames) == 0: 189 | raise IOError('No image files found in the specified path') 190 | 191 | name = os.path.splitext(os.path.basename(self._path))[0] 192 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 193 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 194 | raise IOError('Image files do not match the specified resolution') 195 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 196 | 197 | @staticmethod 198 | def _file_ext(fname): 199 | return os.path.splitext(fname)[1].lower() 200 | 201 | def _get_zipfile(self): 202 | assert self._type == 'zip' 203 | if self._zipfile is None: 204 | self._zipfile = zipfile.ZipFile(self._path) 205 | return self._zipfile 206 | 207 | def _open_file(self, fname): 208 | if self._type == 'dir': 209 | return open(os.path.join(self._path, fname), 'rb') 210 | if self._type == 'zip': 211 | return self._get_zipfile().open(fname, 'r') 212 | return None 213 | 214 | def close(self): 215 | try: 216 | if self._zipfile is not None: 217 | self._zipfile.close() 218 | finally: 219 | self._zipfile = None 220 | 221 | def __getstate__(self): 222 | return dict(super().__getstate__(), _zipfile=None) 223 | 224 | def _load_raw_image(self, raw_idx): 225 | fname = self._image_fnames[raw_idx] 226 | with self._open_file(fname) as f: 227 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': 228 | image = pyspng.load(f.read()) 229 | else: 230 | image = np.array(PIL.Image.open(f)) 231 | if image.ndim == 2: 232 | image = image[:, :, np.newaxis] # HW => HWC 233 | image = image.transpose(2, 0, 1) # HWC => CHW 234 | return image 235 | 236 | def _load_raw_labels(self): 237 | fname = 'dataset.json' 238 | if fname not in self._all_fnames: 239 | return None 240 | with self._open_file(fname) as f: 241 | labels = json.load(f)['labels'] 242 | if labels is None: 243 | return None 244 | labels = dict(labels) 245 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 246 | labels = np.array(labels) 247 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 248 | return labels 249 | 250 | #---------------------------------------------------------------------------- 251 | -------------------------------------------------------------------------------- /cifar10_experiment/TorchSSL/models/nets/resnet50.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | """ 4 | import torch 5 | from torch import Tensor 6 | import torch.nn as nn 7 | from typing import Type, Any, Callable, Union, List, Optional 8 | 9 | 10 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=dilation, groups=groups, bias=False, dilation=dilation) 14 | 15 | 16 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 17 | """1x1 convolution""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion: int = 1 23 | 24 | def __init__( 25 | self, 26 | inplanes: int, 27 | planes: int, 28 | stride: int = 1, 29 | downsample: Optional[nn.Module] = None, 30 | groups: int = 1, 31 | base_width: int = 64, 32 | dilation: int = 1, 33 | norm_layer: Optional[Callable[..., nn.Module]] = None 34 | ) -> None: 35 | super(BasicBlock, self).__init__() 36 | if norm_layer is None: 37 | norm_layer = nn.BatchNorm2d 38 | if groups != 1 or base_width != 64: 39 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 40 | if dilation > 1: 41 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 42 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 43 | self.conv1 = conv3x3(inplanes, planes, stride) 44 | self.bn1 = norm_layer(planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.conv2 = conv3x3(planes, planes) 47 | self.bn2 = norm_layer(planes) 48 | self.downsample = downsample 49 | self.stride = stride 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | identity = x 53 | 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | 61 | if self.downsample is not None: 62 | identity = self.downsample(x) 63 | 64 | out += identity 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class Bottleneck(nn.Module): 71 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 72 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 73 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 74 | # This variant is also known as ResNet V1.5 and improves accuracy according to 75 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 76 | 77 | expansion: int = 4 78 | 79 | def __init__( 80 | self, 81 | inplanes: int, 82 | planes: int, 83 | stride: int = 1, 84 | downsample: Optional[nn.Module] = None, 85 | groups: int = 1, 86 | base_width: int = 64, 87 | dilation: int = 1, 88 | norm_layer: Optional[Callable[..., nn.Module]] = None 89 | ) -> None: 90 | super(Bottleneck, self).__init__() 91 | if norm_layer is None: 92 | norm_layer = nn.BatchNorm2d 93 | width = int(planes * (base_width / 64.)) * groups 94 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 95 | self.conv1 = conv1x1(inplanes, width) 96 | self.bn1 = norm_layer(width) 97 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 98 | self.bn2 = norm_layer(width) 99 | self.conv3 = conv1x1(width, planes * self.expansion) 100 | self.bn3 = norm_layer(planes * self.expansion) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.downsample = downsample 103 | self.stride = stride 104 | 105 | def forward(self, x: Tensor) -> Tensor: 106 | identity = x 107 | 108 | out = self.conv1(x) 109 | out = self.bn1(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv2(out) 113 | out = self.bn2(out) 114 | out = self.relu(out) 115 | 116 | out = self.conv3(out) 117 | out = self.bn3(out) 118 | 119 | if self.downsample is not None: 120 | identity = self.downsample(x) 121 | 122 | out += identity 123 | out = self.relu(out) 124 | 125 | return out 126 | 127 | 128 | class ResNet50(nn.Module): 129 | 130 | def __init__( 131 | self, 132 | block: Type[Union[BasicBlock, Bottleneck]] = Bottleneck, 133 | layers: List[int] = [3, 4, 6, 3], 134 | n_class: int = 1000, 135 | zero_init_residual: bool = False, 136 | groups: int = 1, 137 | width_per_group: int = 64, 138 | replace_stride_with_dilation: Optional[List[bool]] = None, 139 | norm_layer: Optional[Callable[..., nn.Module]] = None, 140 | is_remix=False 141 | ) -> None: 142 | super(ResNet50, self).__init__() 143 | if norm_layer is None: 144 | norm_layer = nn.BatchNorm2d 145 | self._norm_layer = norm_layer 146 | 147 | self.inplanes = 64 148 | self.dilation = 1 149 | if replace_stride_with_dilation is None: 150 | # each element in the tuple indicates if we should replace 151 | # the 2x2 stride with a dilated convolution instead 152 | replace_stride_with_dilation = [False, False, False] 153 | if len(replace_stride_with_dilation) != 3: 154 | raise ValueError("replace_stride_with_dilation should be None " 155 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 156 | self.groups = groups 157 | self.base_width = width_per_group 158 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 159 | bias=False) 160 | self.bn1 = norm_layer(self.inplanes) 161 | self.relu = nn.ReLU(inplace=True) 162 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 163 | self.layer1 = self._make_layer(block, 64, layers[0]) 164 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 165 | dilate=replace_stride_with_dilation[0]) 166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 167 | dilate=replace_stride_with_dilation[1]) 168 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 169 | dilate=replace_stride_with_dilation[2]) 170 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 171 | self.fc = nn.Linear(512 * block.expansion, n_class) 172 | 173 | # rot_classifier for Remix Match 174 | self.is_remix = is_remix 175 | if is_remix: 176 | self.rot_classifier = nn.Linear(2048, 4) 177 | 178 | for m in self.modules(): 179 | if isinstance(m, nn.Conv2d): 180 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 181 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 182 | nn.init.constant_(m.weight, 1) 183 | nn.init.constant_(m.bias, 0) 184 | 185 | # Zero-initialize the last BN in each residual branch, 186 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 187 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 188 | if zero_init_residual: 189 | for m in self.modules(): 190 | if isinstance(m, Bottleneck): 191 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 192 | elif isinstance(m, BasicBlock): 193 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 194 | 195 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 196 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 197 | norm_layer = self._norm_layer 198 | downsample = None 199 | previous_dilation = self.dilation 200 | if dilate: 201 | self.dilation *= stride 202 | stride = 1 203 | if stride != 1 or self.inplanes != planes * block.expansion: 204 | downsample = nn.Sequential( 205 | conv1x1(self.inplanes, planes * block.expansion, stride), 206 | norm_layer(planes * block.expansion), 207 | ) 208 | 209 | layers = [] 210 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 211 | self.base_width, previous_dilation, norm_layer)) 212 | self.inplanes = planes * block.expansion 213 | for _ in range(1, blocks): 214 | layers.append(block(self.inplanes, planes, groups=self.groups, 215 | base_width=self.base_width, dilation=self.dilation, 216 | norm_layer=norm_layer)) 217 | 218 | return nn.Sequential(*layers) 219 | 220 | def _forward_impl(self, x): 221 | # See note [TorchScript super()] 222 | x = self.conv1(x) 223 | x = self.bn1(x) 224 | x = self.relu(x) 225 | x = self.maxpool(x) 226 | 227 | x = self.layer1(x) 228 | x = self.layer2(x) 229 | x = self.layer3(x) 230 | x = self.layer4(x) 231 | 232 | x = self.avgpool(x) 233 | x = torch.flatten(x, 1) 234 | out = self.fc(x) 235 | if self.is_remix: 236 | rot_output = self.rot_classifier(x) 237 | return out, rot_output 238 | else: 239 | return out 240 | 241 | def forward(self, x): 242 | return self._forward_impl(x) 243 | 244 | 245 | class build_ResNet50: 246 | def __init__(self, is_remix=False): 247 | self.is_remix = is_remix 248 | 249 | def build(self, num_classes): 250 | return ResNet50(n_class=num_classes, is_remix=self.is_remix) 251 | 252 | 253 | if __name__ == '__main__': 254 | a = torch.rand(16, 3, 224, 224) 255 | net = ResNet50(is_remix=True) 256 | x,y = net(a) 257 | print(x.shape) 258 | print(y.shape) 259 | -------------------------------------------------------------------------------- /tools/fid_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | 7 | When run as a stand-alone program, it compares the distribution of 8 | images that are stored as PNG/JPEG at a specified location with a 9 | distribution given by summary statistics (in pickle format). 10 | 11 | The FID is calculated by assuming that X_1 and X_2 are the activations of 12 | the pool_3 layer of the inception net for generated samples and real world 13 | samples respectively. 14 | 15 | See --help to see further details. 16 | 17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 18 | of Tensorflow 19 | 20 | Copyright 2018 Institute of Bioinformatics, JKU Linz 21 | 22 | Licensed under the Apache License, Version 2.0 (the "License"); 23 | you may not use this file except in compliance with the License. 24 | You may obtain a copy of the License at 25 | 26 | http://www.apache.org/licenses/LICENSE-2.0 27 | 28 | Unless required by applicable law or agreed to in writing, software 29 | distributed under the License is distributed on an "AS IS" BASIS, 30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 31 | See the License for the specific language governing permissions and 32 | limitations under the License. 33 | """ 34 | import os 35 | import pathlib 36 | 37 | import numpy as np 38 | import torch 39 | import torchvision.transforms as TF 40 | from PIL import Image 41 | from scipy import linalg 42 | from torch.nn.functional import adaptive_avg_pool2d 43 | 44 | try: 45 | from tqdm import tqdm 46 | except ImportError: 47 | # If tqdm is not available, provide a mock version of it 48 | def tqdm(x): 49 | return x 50 | 51 | from .inception import InceptionV3 52 | 53 | 54 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 55 | 'tif', 'tiff', 'webp'} 56 | 57 | 58 | class ImagePathDataset(torch.utils.data.Dataset): 59 | def __init__(self, files, transforms=None): 60 | self.files = files 61 | self.transforms = transforms 62 | 63 | def __len__(self): 64 | return len(self.files) 65 | 66 | def __getitem__(self, i): 67 | path = self.files[i] 68 | img = Image.open(path).convert('RGB') 69 | if self.transforms is not None: 70 | img = self.transforms(img) 71 | return img 72 | 73 | 74 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8): 75 | """Calculates the activations of the pool_3 layer for all images. 76 | 77 | Params: 78 | -- files : List of image files paths 79 | -- model : Instance of inception model 80 | -- batch_size : Batch size of images for the model to process at once. 81 | Make sure that the number of samples is a multiple of 82 | the batch size, otherwise some samples are ignored. This 83 | behavior is retained to match the original FID score 84 | implementation. 85 | -- dims : Dimensionality of features returned by Inception 86 | -- device : Device to run calculations 87 | -- num_workers : Number of parallel dataloader workers 88 | 89 | Returns: 90 | -- A numpy array of dimension (num images, dims) that contains the 91 | activations of the given tensor when feeding inception with the 92 | query tensor. 93 | """ 94 | model.eval() 95 | 96 | if batch_size > len(files): 97 | print(('Warning: batch size is bigger than the data size. ' 98 | 'Setting batch size to data size')) 99 | batch_size = len(files) 100 | 101 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 102 | dataloader = torch.utils.data.DataLoader(dataset, 103 | batch_size=batch_size, 104 | shuffle=False, 105 | drop_last=False, 106 | num_workers=num_workers) 107 | 108 | pred_arr = np.empty((len(files), dims)) 109 | 110 | start_idx = 0 111 | 112 | for batch in tqdm(dataloader): 113 | batch = batch.to(device) 114 | 115 | with torch.no_grad(): 116 | pred = model(batch)[0] 117 | 118 | # If model output is not scalar, apply global spatial average pooling. 119 | # This happens if you choose a dimensionality not equal 2048. 120 | if pred.size(2) != 1 or pred.size(3) != 1: 121 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 122 | 123 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 124 | 125 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 126 | 127 | start_idx = start_idx + pred.shape[0] 128 | 129 | return pred_arr 130 | 131 | 132 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 133 | """Numpy implementation of the Frechet Distance. 134 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 135 | and X_2 ~ N(mu_2, C_2) is 136 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 137 | 138 | Stable version by Dougal J. Sutherland. 139 | 140 | Params: 141 | -- mu1 : Numpy array containing the activations of a layer of the 142 | inception net (like returned by the function 'get_predictions') 143 | for generated samples. 144 | -- mu2 : The sample mean over activations, precalculated on an 145 | representative data set. 146 | -- sigma1: The covariance matrix over activations for generated samples. 147 | -- sigma2: The covariance matrix over activations, precalculated on an 148 | representative data set. 149 | 150 | Returns: 151 | -- : The Frechet Distance. 152 | """ 153 | 154 | mu1 = np.atleast_1d(mu1) 155 | mu2 = np.atleast_1d(mu2) 156 | 157 | sigma1 = np.atleast_2d(sigma1) 158 | sigma2 = np.atleast_2d(sigma2) 159 | 160 | assert mu1.shape == mu2.shape, \ 161 | 'Training and test mean vectors have different lengths' 162 | assert sigma1.shape == sigma2.shape, \ 163 | 'Training and test covariances have different dimensions' 164 | 165 | diff = mu1 - mu2 166 | 167 | # Product might be almost singular 168 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 169 | if not np.isfinite(covmean).all(): 170 | msg = ('fid calculation produces singular product; ' 171 | 'adding %s to diagonal of cov estimates') % eps 172 | print(msg) 173 | offset = np.eye(sigma1.shape[0]) * eps 174 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 175 | 176 | # Numerical error might give slight imaginary component 177 | if np.iscomplexobj(covmean): 178 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 179 | m = np.max(np.abs(covmean.imag)) 180 | raise ValueError('Imaginary component {}'.format(m)) 181 | covmean = covmean.real 182 | 183 | tr_covmean = np.trace(covmean) 184 | 185 | return (diff.dot(diff) + np.trace(sigma1) 186 | + np.trace(sigma2) - 2 * tr_covmean) 187 | 188 | 189 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 190 | device='cpu', num_workers=8): 191 | """Calculation of the statistics used by the FID. 192 | Params: 193 | -- files : List of image files paths 194 | -- model : Instance of inception model 195 | -- batch_size : The images numpy array is split into batches with 196 | batch size batch_size. A reasonable batch size 197 | depends on the hardware. 198 | -- dims : Dimensionality of features returned by Inception 199 | -- device : Device to run calculations 200 | -- num_workers : Number of parallel dataloader workers 201 | 202 | Returns: 203 | -- mu : The mean over samples of the activations of the pool_3 layer of 204 | the inception model. 205 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 206 | the inception model. 207 | """ 208 | act = get_activations(files, model, batch_size, dims, device, num_workers) 209 | mu = np.mean(act, axis=0) 210 | sigma = np.cov(act, rowvar=False) 211 | return mu, sigma 212 | 213 | 214 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8): 215 | if path.endswith('.npz'): 216 | with np.load(path) as f: 217 | m, s = f['mu'][:], f['sigma'][:] 218 | else: 219 | path = pathlib.Path(path) 220 | files = sorted([file for ext in IMAGE_EXTENSIONS 221 | for file in path.glob('*.{}'.format(ext))]) 222 | m, s = calculate_activation_statistics(files, model, batch_size, 223 | dims, device, num_workers) 224 | 225 | return m, s 226 | 227 | 228 | def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8): 229 | if device is None: 230 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 231 | else: 232 | device = torch.device(device) 233 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 234 | model = InceptionV3([block_idx]).to(device) 235 | m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers) 236 | np.savez(out_path, mu=m1, sigma=s1) 237 | 238 | 239 | def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8): 240 | """Calculates the FID of two paths""" 241 | if device is None: 242 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 243 | else: 244 | device = torch.device(device) 245 | 246 | for p in paths: 247 | if not os.path.exists(p): 248 | raise RuntimeError('Invalid path: %s' % p) 249 | 250 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 251 | 252 | model = InceptionV3([block_idx]).to(device) 253 | 254 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 255 | dims, device, num_workers) 256 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 257 | dims, device, num_workers) 258 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 259 | 260 | return fid_value 261 | -------------------------------------------------------------------------------- /src/data_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import subprocess 10 | import time 11 | 12 | from logging import getLogger 13 | 14 | from PIL import ImageFilter 15 | 16 | import torch 17 | import torchvision.transforms as transforms 18 | import torchvision 19 | 20 | _GLOBAL_SEED = 0 21 | logger = getLogger() 22 | 23 | 24 | def init_data( 25 | transform, 26 | batch_size, 27 | pin_mem=True, 28 | num_workers=8, 29 | world_size=1, 30 | rank=0, 31 | root_path=None, 32 | image_folder=None, 33 | training=True, 34 | copy_data=False, 35 | drop_last=True, 36 | subset_file=None 37 | ): 38 | 39 | dataset = ImageNet( 40 | root=root_path, 41 | image_folder=image_folder, 42 | transform=transform, 43 | train=training, 44 | copy_data=copy_data) 45 | if subset_file is not None: 46 | dataset = ImageNetSubset(dataset, subset_file) 47 | logger.info('ImageNet dataset created') 48 | dist_sampler = torch.utils.data.distributed.DistributedSampler( 49 | dataset=dataset, 50 | num_replicas=world_size, 51 | rank=rank) 52 | data_loader = torch.utils.data.DataLoader( 53 | dataset, 54 | sampler=dist_sampler, 55 | batch_size=batch_size, 56 | drop_last=drop_last, 57 | pin_memory=pin_mem, 58 | num_workers=num_workers) 59 | logger.info('ImageNet unsupervised data loader created') 60 | 61 | return (data_loader, dist_sampler) 62 | 63 | 64 | def make_transforms( 65 | rand_size=224, 66 | focal_size=96, 67 | rand_crop_scale=(0.3, 1.0), 68 | focal_crop_scale=(0.05, 0.3), 69 | color_jitter=1.0, 70 | rand_views=2, 71 | focal_views=10, 72 | ): 73 | logger.info('making imagenet data transforms') 74 | 75 | def get_color_distortion(s=1.0): 76 | # s is the strength of color distortion. 77 | color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s) 78 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 79 | rnd_gray = transforms.RandomGrayscale(p=0.2) 80 | color_distort = transforms.Compose([ 81 | rnd_color_jitter, 82 | rnd_gray]) 83 | return color_distort 84 | 85 | rand_transform = transforms.Compose([ 86 | transforms.RandomResizedCrop(rand_size, scale=rand_crop_scale), 87 | transforms.RandomHorizontalFlip(), 88 | get_color_distortion(s=color_jitter), 89 | GaussianBlur(p=0.5), 90 | transforms.ToTensor(), 91 | transforms.Normalize( 92 | (0.485, 0.456, 0.406), 93 | (0.229, 0.224, 0.225)) 94 | ]) 95 | focal_transform = transforms.Compose([ 96 | transforms.RandomResizedCrop(focal_size, scale=focal_crop_scale), 97 | transforms.RandomHorizontalFlip(), 98 | get_color_distortion(s=color_jitter), 99 | GaussianBlur(p=0.5), 100 | transforms.ToTensor(), 101 | transforms.Normalize( 102 | (0.485, 0.456, 0.406), 103 | (0.229, 0.224, 0.225)) 104 | ]) 105 | 106 | transform = MultiViewTransform( 107 | rand_transform=rand_transform, 108 | focal_transform=focal_transform, 109 | rand_views=rand_views, 110 | focal_views=focal_views 111 | ) 112 | return transform 113 | 114 | 115 | class MultiViewTransform(object): 116 | 117 | def __init__( 118 | self, 119 | rand_transform=None, 120 | focal_transform=None, 121 | rand_views=1, 122 | focal_views=1, 123 | ): 124 | self.rand_views = rand_views 125 | self.focal_views = focal_views 126 | self.rand_transform = rand_transform 127 | self.focal_transform = focal_transform 128 | 129 | def __call__(self, img): 130 | img_views = [] 131 | 132 | # -- generate random views 133 | if self.rand_views > 0: 134 | img_views += [self.rand_transform(img) for i in range(self.rand_views)] 135 | 136 | # -- generate focal views 137 | if self.focal_views > 0: 138 | img_views += [self.focal_transform(img) for i in range(self.focal_views)] 139 | 140 | return img_views 141 | 142 | 143 | class ImageNet(torchvision.datasets.ImageFolder): 144 | 145 | def __init__( 146 | self, 147 | root, 148 | image_folder='imagenet_full_size/061417/', 149 | tar_folder='imagenet_full_size/', 150 | tar_file='imagenet_full_size-061417.tar', 151 | transform=None, 152 | train=True, 153 | job_id=None, 154 | local_rank=None, 155 | copy_data=True 156 | ): 157 | """ 158 | ImageNet 159 | 160 | Dataset wrapper (can copy data locally to machine) 161 | 162 | :param root: root network directory for ImageNet data 163 | :param image_folder: path to images inside root network directory 164 | :param tar_file: zipped image_folder inside root network directory 165 | :param train: whether to load train data (or validation) 166 | :param job_id: scheduler job-id used to create dir on local machine 167 | :param copy_data: whether to copy data from network file locally 168 | """ 169 | 170 | suffix = 'train/' if train else 'val/' 171 | data_path = None 172 | if copy_data: 173 | logger.info('copying data locally') 174 | data_path = copy_imgnt_locally( 175 | root=root, 176 | suffix=suffix, 177 | image_folder=image_folder, 178 | tar_folder=tar_folder, 179 | tar_file=tar_file, 180 | job_id=job_id, 181 | local_rank=local_rank) 182 | if (not copy_data) or (data_path is None): 183 | data_path = os.path.join(root, image_folder, suffix) 184 | logger.info(f'data-path {data_path}') 185 | 186 | super(ImageNet, self).__init__(root=data_path, transform=transform) 187 | logger.info('Initialized ImageNet') 188 | 189 | def __getitem__(self, index): 190 | path, target = self.samples[index] 191 | sample = self.loader(path) 192 | if self.transform is not None: 193 | sample = self.transform(sample) 194 | if self.target_transform is not None: 195 | target = self.target_transform(target) 196 | return sample, target, path.split('/')[-1] 197 | 198 | 199 | class ImageNetSubset(object): 200 | 201 | def __init__(self, dataset, subset_file): 202 | """ 203 | ImageNetSubset 204 | 205 | :param dataset: ImageNet dataset object 206 | :param subset_file: '.txt' file containing IDs of IN1K images to keep 207 | """ 208 | self.dataset = dataset 209 | self.subset_file = subset_file 210 | self.filter_dataset_(subset_file) 211 | 212 | def filter_dataset_(self, subset_file): 213 | """ Filter self.dataset to a subset """ 214 | root = self.dataset.root 215 | class_to_idx = self.dataset.class_to_idx 216 | # -- update samples to subset of IN1k targets/samples 217 | new_samples = [] 218 | logger.info(f'Using {subset_file}') 219 | with open(subset_file, 'r') as rfile: 220 | for line in rfile: 221 | class_name = line.split('_')[0] 222 | target = class_to_idx[class_name] 223 | img = line.split('\n')[0] 224 | new_samples.append( 225 | (os.path.join(root, class_name, img), target) 226 | ) 227 | self.samples = new_samples 228 | 229 | @property 230 | def classes(self): 231 | return self.dataset.classes 232 | 233 | def __len__(self): 234 | return len(self.samples) 235 | 236 | def __getitem__(self, index): 237 | path, target = self.samples[index] 238 | img = self.dataset.loader(path) 239 | if self.dataset.transform is not None: 240 | img = self.dataset.transform(img) 241 | if self.dataset.target_transform is not None: 242 | target = self.dataset.target_transform(target) 243 | return img, target, path.split('/')[-1] 244 | 245 | 246 | class GaussianBlur(object): 247 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): 248 | self.prob = p 249 | self.radius_min = radius_min 250 | self.radius_max = radius_max 251 | 252 | def __call__(self, img): 253 | if torch.bernoulli(torch.tensor(self.prob)) == 0: 254 | return img 255 | 256 | radius = self.radius_min + torch.rand(1) * (self.radius_max - self.radius_min) 257 | return img.filter(ImageFilter.GaussianBlur(radius=radius)) 258 | 259 | 260 | def copy_imgnt_locally( 261 | root, 262 | suffix, 263 | image_folder='imagenet_full_size/061417/', 264 | tar_folder='imagenet_full_size/', 265 | tar_file='imagenet_full_size-061417.tar', 266 | job_id=None, 267 | local_rank=None 268 | ): 269 | if job_id is None: 270 | try: 271 | job_id = os.environ['SLURM_JOBID'] 272 | except Exception: 273 | logger.info('No job-id, will load directly from network file') 274 | return None 275 | 276 | if local_rank is None: 277 | try: 278 | local_rank = int(os.environ['SLURM_LOCALID']) 279 | except Exception: 280 | logger.info('No job-id, will load directly from network file') 281 | return None 282 | 283 | source_file = os.path.join(root, tar_folder, tar_file) 284 | target = f'/scratch/slurm_tmpdir/{job_id}/' 285 | target_file = os.path.join(target, tar_file) 286 | data_path = os.path.join(target, image_folder, suffix) 287 | logger.info(f'{source_file}\n{target}\n{target_file}\n{data_path}') 288 | 289 | tmp_sgnl_file = os.path.join(target, 'copy_signal.txt') 290 | 291 | if not os.path.exists(data_path): 292 | if local_rank == 0: 293 | commands = [ 294 | ['tar', '-xf', source_file, '-C', target]] 295 | for cmnd in commands: 296 | start_time = time.time() 297 | logger.info(f'Executing {cmnd}') 298 | subprocess.run(cmnd) 299 | logger.info(f'Cmnd took {(time.time()-start_time)/60.} min.') 300 | with open(tmp_sgnl_file, '+w') as f: 301 | print('Done copying locally.', file=f) 302 | else: 303 | while not os.path.exists(tmp_sgnl_file): 304 | time.sleep(60) 305 | logger.info(f'{local_rank}: Checking {tmp_sgnl_file}') 306 | 307 | return data_path 308 | --------------------------------------------------------------------------------