├── .gitignore ├── core ├── __init__.py ├── LICENSE ├── classifier │ └── densenet.py ├── losses.py ├── dist_util.py ├── respace.py ├── nn.py ├── resample.py ├── resnet_vggface2.py ├── fp16_util.py ├── train_util.py ├── image_datasets.py ├── script_util.py ├── logger.py └── sample_utils.py ├── env.yaml ├── utils └── create-mini-val.py ├── celeba-train.sh ├── compute-fid.sh ├── LICENSE ├── test.sh ├── celeba-train-diffusion.py ├── compute_FVA.py ├── compute_MNAC.py ├── compute_LPIPS.py ├── eval_utils ├── resnet50_facevgg2_FVA.py ├── fid_metrics.py └── fid_inception.py ├── README.md ├── compute_CD.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | models -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: dime 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | dependencies: 7 | - python=3.8.11 8 | - pip 9 | - numpy 10 | - cudatoolkit=10.2 11 | - pytorch=1.9.1 12 | - torchvision 13 | - matplotlib 14 | - mpi4py 15 | - pyyaml 16 | - pandas 17 | - h5py 18 | - scipy 19 | - pip: 20 | - tqdm 21 | - blobfile -------------------------------------------------------------------------------- /utils/create-mini-val.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import pandas as pd 4 | 5 | import os 6 | import os.path as osp 7 | 8 | data_dir = '' 9 | 10 | random.seed(5) 11 | 12 | data = pd.read_csv(osp.join(data_dir, 'list_attr_celeba.csv')) 13 | partition_df = pd.read_csv(osp.join(data_dir, 'list_eval_partition.csv')) 14 | data = data[partition_df['partition'] == 1] 15 | data.reset_index(inplace=True) 16 | data.replace(-1, 0, inplace=True) 17 | 18 | indexes = random.choices(list(range(len(data))), k=1000) 19 | indexes.sort() 20 | 21 | minival = data.iloc[indexes] 22 | 23 | minival.to_csv('minival.csv') 24 | 25 | -------------------------------------------------------------------------------- /celeba-train.sh: -------------------------------------------------------------------------------- 1 | TRAIN_FLAGS="--batch_size 15 --lr 1e-4 --save_interval 10000 --weight_decay 0.05 --dropout 0.0" 2 | MODEL_FLAGS="--image_size 128 --attention_resolutions 32,16,8 --class_cond False --diffusion_steps 500 --learn_sigma True --noise_schedule linear --num_channels 128 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 3 | 4 | echo "Training diffusor" 5 | export NCCL_P2P_DISABLE=1 6 | mpiexec -n 4 python celeba-train-diffusion.py $TRAIN_FLAGS $MODEL_FLAGS \ 7 | --output_path 'ddpm-celeba' \ 8 | --gpus '0,1,2,3' 9 | -------------------------------------------------------------------------------- /compute-fid.sh: -------------------------------------------------------------------------------- 1 | TEMPPATH=./.temp # temp folder to store the data 2 | OUTPATH=/tmp/to/results 3 | EXPPATH=expname 4 | 5 | mkdir -p ${TEMPPATH}/real 6 | mkdir -p ${TEMPPATH}/cf 7 | mkdir -p ${TEMPPATH}/cfmin 8 | 9 | echo 'Copying CF images ' 10 | 11 | cp -r ${OUTPATH}/Results/${EXPPATH}/CC/CCF/CF/* ${TEMPPATH}/cf 12 | cp -r ${OUTPATH}/Results/${EXPPATH}/IC/CCF/CF/* ${TEMPPATH}/cf 13 | 14 | echo 'Copying real images' 15 | 16 | cp -r ${OUTPATH}/Original/Correct/* ${TEMPPATH}/real 17 | cp -r ${OUTPATH}/Original/Incorrect/* ${TEMPPATH}/real 18 | 19 | echo 'Computing FID' 20 | 21 | python -m pytorch_fid ${TEMPPATH}/real ${TEMPPATH}/cf --device cuda:0 22 | 23 | rm -rf ${TEMPPATH} 24 | -------------------------------------------------------------------------------- /core/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Guillaume Jeanneret 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 500 --learn_sigma True --noise_schedule linear --num_channels 128 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 2 | SAMPLE_FLAGS="--batch_size 50 --timestep_respacing 200" 3 | DATAPATH=/path/to/celeba 4 | OUTPUT_PATH=/path/to/results 5 | EXPNAME=expname 6 | 7 | # parameters of the sampling 8 | GPU=2 9 | S=60 10 | SEED=4 11 | USE_LOGITS=True 12 | CLASS_SCALES='8,10,15' 13 | LAYER=18 14 | PERC=30 15 | L1=0.05 16 | QUERYLABEL=31 17 | TARGETLABEL=-1 18 | IMAGESIZE=128 # dataset shape 19 | 20 | python -W ignore main.py $MODEL_FLAGS $SAMPLE_FLAGS \ 21 | --query_label $QUERYLABEL --target_label $TARGETLABEL \ 22 | --output_path $OUTPUT_PATH \ 23 | --start_step $S --dataset 'CelebA' \ 24 | --exp_name $EXPNAME --gpu $GPU \ 25 | --classifier_scales $CLASS_SCALES \ 26 | --seed $SEED --data_dir $DATAPATH\ 27 | --l1_loss $L1 --use_logits $USE_LOGITS \ 28 | --l_perc $PERC --l_perc_layer $LAYER \ 29 | --save_x_t True --save_z_t True \ 30 | --use_sampling_on_x_t True --num_batches 1 \ 31 | --save_images True --image_size $IMAGESIZE 32 | -------------------------------------------------------------------------------- /core/classifier/densenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Script taken from https://github.com/ServiceNow/beyond-trivial-explanations 3 | ''' 4 | 5 | 6 | import torch 7 | import torchvision 8 | 9 | 10 | class Identity(torch.nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | def forward(self, x): 14 | return x 15 | 16 | 17 | class DenseNet121(torch.nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | self.feat_extract = torchvision.models.densenet121(pretrained=False) 21 | self.feat_extract.classifier = Identity() 22 | self.output_size = 1024 23 | 24 | def forward(self, x): 25 | return self.feat_extract(x) 26 | 27 | 28 | class ClassificationModel(torch.nn.Module): 29 | def __init__(self, path_to_weights, query_label): 30 | 31 | super().__init__() 32 | self.feat_extract = DenseNet121() 33 | self.classifier = torch.nn.Linear(self.feat_extract.output_size, 40) 34 | self.query_label = query_label 35 | 36 | # load the model from the checkpoint 37 | state_dict = torch.load(path_to_weights, map_location='cpu') 38 | self.feat_extract.load_state_dict(state_dict['feat_extract']) 39 | self.classifier.load_state_dict(state_dict['classifier']) 40 | 41 | def forward(self, x, get_other_attrs=False): 42 | x = self.feat_extract(x) 43 | x = self.classifier(x) 44 | 45 | if get_other_attrs: 46 | return x[:, self.query_label], x 47 | else: 48 | return x[:, self.query_label] 49 | 50 | -------------------------------------------------------------------------------- /celeba-train-diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from core import dist_util, logger 8 | from core.image_datasets import load_data_celeba 9 | from core.resample import create_named_schedule_sampler 10 | from core.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | args_to_dict, 14 | add_dict_to_argparser, 15 | ) 16 | from core.train_util import TrainLoop 17 | 18 | 19 | def main(): 20 | args = create_argparser().parse_args() 21 | 22 | dist_util.setup_dist(args.gpus) 23 | logger.configure(dir=args.output_path) 24 | 25 | logger.log("creating model and diffusion...") 26 | model, diffusion = create_model_and_diffusion( 27 | num_classes=40, 28 | multiclass=True, 29 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 30 | ) 31 | model.to(dist_util.dev()) 32 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 33 | 34 | logger.log("creating data loader...") 35 | data = load_data_celeba( 36 | data_dir=args.data_dir, 37 | batch_size=args.batch_size, 38 | image_size=args.image_size, 39 | class_cond=args.class_cond 40 | ) 41 | 42 | logger.log("training...") 43 | TrainLoop( 44 | model=model, 45 | diffusion=diffusion, 46 | data=data, 47 | batch_size=args.batch_size, 48 | microbatch=args.microbatch, 49 | lr=args.lr, 50 | ema_rate=args.ema_rate, 51 | log_interval=args.log_interval, 52 | save_interval=args.save_interval, 53 | resume_checkpoint=args.resume_checkpoint, 54 | use_fp16=args.use_fp16, 55 | fp16_scale_growth=args.fp16_scale_growth, 56 | schedule_sampler=schedule_sampler, 57 | weight_decay=args.weight_decay, 58 | lr_anneal_steps=args.lr_anneal_steps, 59 | ).run_loop() 60 | 61 | 62 | def create_argparser(): 63 | defaults = dict( 64 | data_dir="", 65 | schedule_sampler="uniform", 66 | lr=1e-4, 67 | weight_decay=0.0, 68 | lr_anneal_steps=0, 69 | batch_size=1, 70 | microbatch=-1, # -1 disables microbatches 71 | ema_rate="0.9999", # comma-separated list of EMA values 72 | log_interval=10, 73 | save_interval=10000, 74 | resume_checkpoint="", 75 | use_fp16=False, 76 | fp16_scale_growth=1e-3, 77 | output_path='', 78 | gpus='' 79 | ) 80 | defaults.update(model_and_diffusion_defaults()) 81 | parser = argparse.ArgumentParser() 82 | add_dict_to_argparser(parser, defaults) 83 | return parser 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /core/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /core/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | 15 | def setup_dist(devices=''): 16 | """ 17 | Setup a distributed process group. 18 | """ 19 | if dist.is_initialized(): 20 | return 21 | if devices != '': 22 | # set only one of the devices 23 | device = devices.split(',')[MPI.COMM_WORLD.Get_rank()] 24 | print(f'RANK/GPU ({MPI.COMM_WORLD.Get_rank()}/{device}) with a world size:', MPI.COMM_WORLD.size) 25 | os.environ["CUDA_VISIBLE_DEVICES"] = device 26 | 27 | comm = MPI.COMM_WORLD 28 | backend = "gloo" if not th.cuda.is_available() else "nccl" 29 | 30 | if backend == "gloo": 31 | hostname = "localhost" 32 | else: 33 | hostname = socket.gethostbyname(socket.getfqdn()) 34 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 35 | os.environ["RANK"] = str(comm.rank) 36 | os.environ["WORLD_SIZE"] = str(comm.size) 37 | 38 | port = comm.bcast(_find_free_port(), root=0) 39 | os.environ["MASTER_PORT"] = str(port) 40 | dist.init_process_group(backend=backend, init_method="env://") 41 | 42 | 43 | def dev(): 44 | """ 45 | Get the device to use for torch.distributed. 46 | """ 47 | if th.cuda.is_available(): 48 | return th.device(f"cuda") 49 | return th.device("cpu") 50 | 51 | 52 | def load_state_dict(path, **kwargs): 53 | """ 54 | Load a PyTorch file without redundant fetches across MPI ranks. 55 | """ 56 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 57 | if MPI.COMM_WORLD.Get_rank() == 0: 58 | with bf.BlobFile(path, "rb") as f: 59 | data = f.read() 60 | num_chunks = len(data) // chunk_size 61 | if len(data) % chunk_size: 62 | num_chunks += 1 63 | MPI.COMM_WORLD.bcast(num_chunks) 64 | for i in range(0, len(data), chunk_size): 65 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 66 | else: 67 | num_chunks = MPI.COMM_WORLD.bcast(None) 68 | data = bytes() 69 | for _ in range(num_chunks): 70 | data += MPI.COMM_WORLD.bcast(None) 71 | 72 | return th.load(io.BytesIO(data), **kwargs) 73 | 74 | 75 | def sync_params(params): 76 | """ 77 | Synchronize a sequence of Tensors across ranks from rank 0. 78 | """ 79 | with th.no_grad(): 80 | for p in params: 81 | dist.broadcast(p, 0) 82 | # dist.broadcast(p, 0) 83 | 84 | 85 | def _find_free_port(): 86 | try: 87 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 88 | s.bind(("", 0)) 89 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 90 | return s.getsockname()[1] 91 | finally: 92 | s.close() 93 | -------------------------------------------------------------------------------- /compute_FVA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import itertools 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | import os.path as osp 9 | 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from torch.utils import data 13 | from torchvision import transforms 14 | 15 | from eval_utils.resnet50_facevgg2_FVA import resnet50, load_state_dict 16 | 17 | 18 | # create dataset to read the counterfactual results images 19 | class CFDataset(): 20 | mean_bgr = np.array([91.4953, 103.8827, 131.0912]) 21 | def __init__(self, path, exp_name): 22 | 23 | self.images = [] 24 | self.path = path 25 | self.exp_name = exp_name 26 | for CL, CF in itertools.product(['CC', 'IC'], ['CCF']): 27 | self.images += [(CL, CF, I) for I in os.listdir(osp.join(path, 'Results', self.exp_name, CL, CF, 'CF'))] 28 | 29 | def __len__(self): 30 | return len(self.images) 31 | 32 | def switch(self, partition): 33 | if partition == 'C': 34 | LCF = ['CCF'] 35 | elif partition == 'I': 36 | LCF = ['ICF'] 37 | else: 38 | LCF = ['CCF', 'ICF'] 39 | 40 | self.images = [] 41 | 42 | for CL, CF in itertools.product(['CC', 'IC'], LCF): 43 | self.images += [(CL, CF, I) for I in os.listdir(osp.join(self.path, 'Results', self.exp_name, CL, CF, 'CF'))] 44 | 45 | def __getitem__(self, idx): 46 | CL, CF, I = self.images[idx] 47 | # get paths 48 | cl_path = osp.join(self.path, 'Original', 'Correct' if CL == 'CC' else 'Incorrect', I) 49 | cf_path = osp.join(self.path, 'Results', self.exp_name, CL, CF, 'CF', I) 50 | 51 | cl = self.load_img(cl_path) 52 | cf = self.load_img(cf_path) 53 | 54 | return cl, cf 55 | 56 | def load_img(self, path): 57 | img = Image.open(os.path.join(path)) 58 | img = transforms.Resize(224)(img) 59 | img = np.array(img, dtype=np.uint8) 60 | return self.transform(img) 61 | 62 | def transform(self, img): 63 | img = img[:, :, ::-1] # RGB -> BGR 64 | img = img.astype(np.float32) 65 | img -= self.mean_bgr 66 | img = img.transpose(2, 0, 1) # C x H x W 67 | img = torch.from_numpy(img).float() 68 | return img 69 | 70 | 71 | 72 | @torch.no_grad() 73 | def compute_FVA(oracle, 74 | path, 75 | exp_name): 76 | 77 | dataset = CFDataset(path, exp_name) 78 | 79 | cosine_similarity = torch.nn.CosineSimilarity() 80 | 81 | FVAS = [] 82 | dists = [] 83 | loader = data.DataLoader(dataset, batch_size=15, 84 | shuffle=False, 85 | num_workers=4, pin_memory=True) 86 | 87 | for cl, cf in tqdm(loader): 88 | cl = cl.to(device, dtype=torch.float) 89 | cf = cf.to(device, dtype=torch.float) 90 | cl_feat = oracle(cl) 91 | cf_feat = oracle(cf) 92 | dist = cosine_similarity(cl_feat, cf_feat) 93 | FVAS.append((dist > 0.5).cpu().numpy()) 94 | dists.append(dist.cpu().numpy()) 95 | 96 | return np.concatenate(FVAS), np.concatenate(dists) 97 | 98 | 99 | def arguments(): 100 | parser = argparse.ArgumentParser(description='FVA arguments.') 101 | parser.add_argument('--gpu', default='0', type=str, 102 | help='GPU id') 103 | parser.add_argument('--exp-name', required=True, type=str, 104 | help='Experiment Name') 105 | parser.add_argument('--output-path', required=True, type=str, 106 | help='Results Path') 107 | parser.add_argument('--weights-path', default='models/resnet50_vggface2_model.pkl', type=str, 108 | help='ResNet50 VGGFace2 model weights') 109 | 110 | return parser.parse_args() 111 | 112 | 113 | if __name__ == '__main__': 114 | args = arguments() 115 | device = torch.device('cuda:' + args.gpu) 116 | oracle = resnet50(num_classes=8631, include_top=False).to(device) 117 | load_state_dict(oracle, args.weights_path) 118 | oracle.eval() 119 | 120 | results = compute_FVA(oracle, 121 | args.output_path, 122 | args.exp_path) 123 | 124 | print('FVA', np.mean(results[0])) 125 | print('mean dist', np.mean(results[1])) 126 | -------------------------------------------------------------------------------- /compute_MNAC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import itertools 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | import os.path as osp 9 | 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from torch.utils import data 13 | from torchvision import transforms 14 | 15 | from eval_utils.oracle_metrics import OracleMetrics 16 | 17 | 18 | def arguments(): 19 | parser = argparse.ArgumentParser(description='FVA arguments.') 20 | parser.add_argument('--gpu', default='0', type=str, 21 | help='GPU id') 22 | parser.add_argument('--oracle-path', default='models/oracle.pth', type=str, 23 | help='Oracle path') 24 | parser.add_argument('--exp-name', required=True, type=str, 25 | help='Experiment Name') 26 | parser.add_argument('--output-path', required=True, type=str, 27 | help='Results Path') 28 | 29 | return parser.parse_args() 30 | 31 | 32 | # create dataset to read the counterfactual results images 33 | class CFDataset(): 34 | mean_bgr = np.array([91.4953, 103.8827, 131.0912]) 35 | def __init__(self, path, exp_name): 36 | 37 | self.images = [] 38 | self.path = path 39 | self.exp_name = exp_name 40 | self.transform = transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize([0.5, 0.5, 0.5], 43 | [0.5, 0.5, 0.5]) 44 | ]) 45 | for CL, CF in itertools.product(['CC', 'IC'], ['CCF']): 46 | self.images += [(CL, CF, I) for I in os.listdir(osp.join(path, 'Results', self.exp_name, CL, CF, 'CF'))] 47 | 48 | def __len__(self): 49 | return len(self.images) 50 | 51 | def switch(self, partition): 52 | if partition == 'C': 53 | LCF = ['CCF'] 54 | elif partition == 'I': 55 | LCF = ['ICF'] 56 | else: 57 | LCF = ['CCF', 'ICF'] 58 | 59 | self.images = [] 60 | 61 | for CL, CF in itertools.product(['CC', 'IC'], LCF): 62 | self.images += [(CL, CF, I) for I in os.listdir(osp.join(self.path, 'Results', self.exp_name, CL, CF, 'CF'))] 63 | 64 | def __getitem__(self, idx): 65 | CL, CF, I = self.images[idx] 66 | # get paths 67 | cl_path = osp.join(self.path, 'Original', 'Correct' if CL == 'CC' else 'Incorrect', I) 68 | cf_path = osp.join(self.path, 'Results', self.exp_name, CL, CF, 'CF', I) 69 | 70 | cl = self.load_img(cl_path) 71 | cf = self.load_img(cf_path) 72 | 73 | return cl, cf 74 | 75 | def load_img(self, path): 76 | with open(path, "rb") as f: 77 | img = Image.open(f) 78 | img = img.convert('RGB') 79 | return self.transform(img) 80 | 81 | 82 | @torch.no_grad() 83 | def compute_MNAC(oracle, 84 | path, 85 | exp_name): 86 | 87 | dataset = CFDataset(path, exp_name) 88 | 89 | cosine_similarity = torch.nn.CosineSimilarity() 90 | 91 | MNACS = [] 92 | dists = [] 93 | loader = data.DataLoader(dataset, batch_size=15, 94 | shuffle=False, 95 | num_workers=4, pin_memory=True) 96 | 97 | for cl, cf in tqdm(loader): 98 | cl = cl.to(device, dtype=torch.float) 99 | cf = cf.to(device, dtype=torch.float) 100 | _, cl_feat = oracle.oracle(cl) 101 | _, cf_feat = oracle.oracle(cf) 102 | d_cl = torch.sigmoid(cl_feat) 103 | d_cf = torch.sigmoid(cf_feat) 104 | MNACS.append(((d_cl > 0.5) != (d_cf > 0.5)).sum(dim=1).cpu().numpy()) 105 | dists.append([d_cl.cpu().numpy(), d_cf.cpu().numpy()]) 106 | 107 | return np.concatenate(MNACS), np.concatenate([d[0] for d in dists]), np.concatenate([d[1] for d in dists]) 108 | 109 | 110 | if __name__ == '__main__': 111 | 112 | args = arguments() 113 | 114 | # load oracle trained on vggface2 and fine-tuned on CelebA 115 | ORACLEPATH = args.oracle_path 116 | device = torch.device('cuda:' + args.gpu) 117 | oracle = OracleMetrics(weights_path=ORACLEPATH, 118 | device=device) 119 | oracle.eval() 120 | A = 0 121 | 122 | results = compute_MNAC(oracle, 123 | args.output_path, 124 | args.exp_name) 125 | 126 | print('MNAC:', np.mean(results[0])) 127 | -------------------------------------------------------------------------------- /compute_LPIPS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lpips # from pip install lpips 3 | import torch 4 | import argparse 5 | import itertools 6 | import numpy as np 7 | import pandas as pd 8 | import matplotlib.pyplot as plt 9 | import os.path as osp 10 | 11 | from PIL import Image 12 | from tqdm import tqdm 13 | from torch.utils import data 14 | from torchvision import transforms 15 | 16 | from eval_utils.oracle_metrics import OracleMetrics 17 | 18 | 19 | # create dataset to read the counterfactual results images 20 | class CFDataset(): 21 | def __init__(self, path, exp_name_format, values): 22 | 23 | self.images = [] 24 | self.path = path 25 | self.exp_name_format = exp_name_format 26 | self.transform = transforms.Compose([ 27 | transforms.ToTensor(), 28 | transforms.Normalize([0.5, 0.5, 0.5], 29 | [0.5, 0.5, 0.5]) 30 | ]) 31 | 32 | self.exp_names = [exp_name_format.replace('*', v) for v in values] 33 | 34 | for CL, CF in itertools.product(['CC', 'IC'], ['CCF', 'ICF']): 35 | 36 | c_images = [] 37 | 38 | files = [os.listdir(osp.join(path, 'Results', en, CL, CF, 'CF')) for en in self.exp_names] 39 | 40 | for I in files[0]: 41 | 42 | # search for all images with the same name 43 | in_files = [I in f for f in files[1:]] 44 | 45 | if all(in_files): 46 | c_images.append((CL, CF, I)) 47 | 48 | self.images += c_images 49 | 50 | def __len__(self): 51 | return len(self.images) 52 | 53 | def switch(self, partition): 54 | if partition == 'C': 55 | LCF = ['CCF'] 56 | elif partition == 'I': 57 | LCF = ['ICF'] 58 | else: 59 | LCF = ['CCF', 'ICF'] 60 | 61 | self.images = [] 62 | 63 | for CL, CF in itertools.product(['CC', 'IC'], LCF): 64 | self.images += [(CL, CF, I) for I in os.listdir(osp.join(self.path, 'Results', self.exp_name, CL, CF, 'CF'))] 65 | 66 | def __getitem__(self, idx): 67 | CL, CF, I = self.images[idx] 68 | 69 | # get paths 70 | images = [] 71 | for exp_name in self.exp_names: 72 | cf_path = osp.join(self.path, 'Results', exp_name, CL, CF, 'CF', I) 73 | images.append(self.load_img(cf_path)) 74 | 75 | return images 76 | 77 | def load_img(self, path): 78 | with open(path, "rb") as f: 79 | img = Image.open(f) 80 | img = img.convert('RGB') 81 | return self.transform(img) 82 | 83 | 84 | @torch.no_grad() 85 | def compute_LPIPS(LPIPS, 86 | path, 87 | exp_name_format, 88 | values, 89 | device): 90 | 91 | dataset = CFDataset(path, exp_name_format, values) 92 | loader = data.DataLoader(dataset, batch_size=15, 93 | shuffle=False, 94 | num_workers=4, pin_memory=True) 95 | 96 | dists = [] 97 | 98 | for cfs in tqdm(loader): 99 | 100 | dist = [] 101 | 102 | for i in range(len(values)): 103 | cf1 = cfs[i].to(device, dtype=torch.float) 104 | 105 | for j in range(i + 1, len(values)): 106 | cf2 = cfs[j].to(device, dtype=torch.float) 107 | # import pdb; pdb.set_trace() 108 | dist.append(LPIPS.forward(cf1, cf2, normalize=False).squeeze()) # data is already in the [-1,1] range 109 | 110 | dists.append(sum(dist) / len(dist)) 111 | 112 | return torch.cat(dists).cpu().detach().numpy() 113 | 114 | 115 | def arguments(): 116 | parser = argparse.ArgumentParser(description='FVA arguments.') 117 | parser.add_argument('--gpu', default='0', type=str, 118 | help='GPU id') 119 | parser.add_argument('--exp-pattern', required=True, type=str, 120 | help='Experiment pattern. Must contain a *.') 121 | parser.add_argument('--exp-values', nargs='+', type=str, 122 | help='Values to be replaced by the * on the --exp-pattern flag.') 123 | parser.add_argument('--output-path', required=True, type=str, 124 | help='Results Path') 125 | 126 | return parser.parse_args() 127 | 128 | 129 | if __name__ == '__main__': 130 | 131 | args = arguments() 132 | device = torch.device('cuda:' + args.gpu) 133 | LPIPS = lpips.LPIPS(net='vgg', spatial=False).to(device) 134 | 135 | res = compute_LPIPS(LPIPS, 136 | args.output_path, 137 | args.exp_pattern, 138 | args.exp_values, 139 | device) 140 | 141 | print('sigma_L result:', np.mean(res)) 142 | -------------------------------------------------------------------------------- /core/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /core/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /core/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /eval_utils/resnet50_facevgg2_FVA.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | import pickle 5 | import math 6 | 7 | 8 | __all__ = ['ResNet', 'resnet50'] 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = conv3x3(inplanes, planes, stride) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.conv2 = conv3x3(planes, planes) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 58 | self.bn3 = nn.BatchNorm2d(planes * 4) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | 77 | if self.downsample is not None: 78 | residual = self.downsample(x) 79 | 80 | out += residual 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class ResNet(nn.Module): 87 | 88 | def __init__(self, block, layers, num_classes=1000, include_top=True): 89 | self.inplanes = 64 90 | super(ResNet, self).__init__() 91 | self.include_top = include_top 92 | 93 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 94 | self.bn1 = nn.BatchNorm2d(64) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) 97 | 98 | self.layer1 = self._make_layer(block, 64, layers[0]) 99 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 100 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 101 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 102 | self.avgpool = nn.AvgPool2d(7, stride=1) 103 | self.fc = nn.Linear(512 * block.expansion, num_classes) 104 | 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 108 | m.weight.data.normal_(0, math.sqrt(2. / n)) 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | 113 | def _make_layer(self, block, planes, blocks, stride=1): 114 | downsample = None 115 | if stride != 1 or self.inplanes != planes * block.expansion: 116 | downsample = nn.Sequential( 117 | nn.Conv2d(self.inplanes, planes * block.expansion, 118 | kernel_size=1, stride=stride, bias=False), 119 | nn.BatchNorm2d(planes * block.expansion), 120 | ) 121 | 122 | layers = [] 123 | layers.append(block(self.inplanes, planes, stride, downsample)) 124 | self.inplanes = planes * block.expansion 125 | for i in range(1, blocks): 126 | layers.append(block(self.inplanes, planes)) 127 | 128 | return nn.Sequential(*layers) 129 | 130 | def forward(self, x): 131 | x = self.conv1(x) 132 | x = self.bn1(x) 133 | x = self.relu(x) 134 | x = self.maxpool(x) 135 | 136 | x = self.layer1(x) 137 | x = self.layer2(x) 138 | x = self.layer3(x) 139 | x = self.layer4(x) 140 | 141 | x = self.avgpool(x) 142 | 143 | if not self.include_top: 144 | return x 145 | 146 | x = x.view(x.size(0), -1) 147 | x = self.fc(x) 148 | return x 149 | 150 | def resnet50(**kwargs): 151 | """Constructs a ResNet-50 model. 152 | """ 153 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 154 | return model 155 | 156 | 157 | def load_state_dict(model, fname): 158 | """ 159 | Set parameters converted from Caffe models authors of VGGFace2 provide. 160 | See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/. 161 | Arguments: 162 | model: model 163 | fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle. 164 | """ 165 | with open(fname, 'rb') as f: 166 | weights = pickle.load(f, encoding='latin1') 167 | 168 | own_state = model.state_dict() 169 | for name, param in weights.items(): 170 | if name in own_state: 171 | try: 172 | own_state[name].copy_(torch.from_numpy(param)) 173 | except Exception: 174 | raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\ 175 | 'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size())) 176 | else: 177 | print(name, 'not in the state dict') 178 | 179 | model.load_state_dict(own_state) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiME's official code 2 | 3 | This is the codebase for the ACCV 2022 paper [Diffusion Models for Counterfactual Explanations](https://arxiv.org/abs/2203.15636). 4 | 5 | ### UPDATE!!!! 6 | Please check our CVIU extension paper [here](https://www.sciencedirect.com/science/article/abs/pii/S1077314224002881) :)! 7 | 8 | ## Environment 9 | 10 | Through anaconda, install our environment: 11 | 12 | ```bash 13 | conda env create -f env.yaml 14 | conda activate dime 15 | ``` 16 | 17 | ## Data preparation 18 | 19 | Please download and uncompress the CelebA dataset [here](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). There is no need for any post-processing. The final folder structure should be: 20 | 21 | ``` 22 | PATH ---- img_align_celeba ---- xxxxxx.jpg 23 | | 24 | --- list_attr_celeba.csv 25 | | 26 | --- list_eval_partition.csv 27 | ``` 28 | 29 | ## Downloading pre-trained models 30 | 31 | To use our trained models, you must download them first from this [link](https://huggingface.co/guillaumejs2403/DiME). Please extract them to the folder `models`. We provides the CelebA diffusion model, the classifier under observation, and the trained oracle. Finally, download the VGGFace2 model throught this [github repo](https://github.com/cydonia999/VGGFace2-pytorch). Download the `resnet50_ft` model. 32 | 33 | ## Extracting Counterfactual Explanations 34 | 35 | To create the counterfactual explanations, please use the main.py script as follows: 36 | 37 | ```bash 38 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 500 --learn_sigma True --noise_schedule linear --num_channels 128 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 39 | SAMPLE_FLAGS="--batch_size 50 --timestep_respacing 200" 40 | DATAPATH=/path/to/dataset 41 | MODELPATH=/path/to/model.pt 42 | CLASSIFIERPATH=/path/to/classifier.pt 43 | ORACLEPATH=/path/to/oracle.pt 44 | OUTPUT_PATH=/path/to/output 45 | EXPNAME=exp/name 46 | 47 | # parameters of the sampling 48 | GPU=0 49 | S=60 50 | SEED=4 51 | USE_LOGITS=True 52 | CLASS_SCALES='8,10,15' 53 | LAYER=18 54 | PERC=30 55 | L1=0.05 56 | QUERYLABEL=31 57 | TARGETLABEL=-1 58 | IMAGESIZE=128 # dataset shape 59 | 60 | python -W ignore main.py $MODEL_FLAGS $SAMPLE_FLAGS \ 61 | --query_label $QUERYLABEL --target_label $TARGETLABEL \ 62 | --output_path $OUTPUT_PATH --num_batches $NUMBATCHES \ 63 | --start_step $S --dataset 'CelebAMV' \ 64 | --exp_name $EXPNAME --gpu $GPU \ 65 | --model_path $MODELPATH --classifier_scales $CLASS_SCALES \ 66 | --classifier_path $CLASSIFIERPATH --seed $SEED \ 67 | --oracle_path $ORACLEPATH \ 68 | --l1_loss $L1 --use_logits $USE_LOGITS \ 69 | --l_perc $PERC --l_perc_layer $LAYER \ 70 | --save_x_t True --save_z_t True \ 71 | --use_sampling_on_x_t True \ 72 | --save_images True --image_size $IMAGESIZE 73 | ``` 74 | 75 | Given that the sampling process may take much time, we've included a way to split the sampling into multiple processes. To use this feature, include the flag `--num_chunks C`, where `C` is the number of chunks to split the dataset. Then, run `C` times the code using the flag `--chunk c`, where `c` is the chunk to generate the evaluation (hence, `c \in {0, 1, ..., C - 1}`). 76 | 77 | The results will be stored `OUTPUT_PATH`. This folder has the following structure: 78 | 79 | ``` 80 | OUTPUT_PATH ----- Original ---- Correct 81 | | | 82 | | --- Incorrect 83 | | 84 | | 85 | | 86 | --- Results ---- EXPNAME ---- (I/C)C ---- (I/C)CF ---- CF 87 | | 88 | --- Info 89 | | 90 | --- Noise 91 | | 92 | --- SM 93 | ``` 94 | 95 | We found this structure useful to experiment since we can change only the `EXPNAME` to refer to another experiment without changing the original images. The folder `Original` contains the correctly classified (misclassified) images in `Correct` (`Incorrect`). We resume the structure of the counterfactuals explanations (`Results/EXPNAME`) as: `(I/C)C`: (In/correct) classification. `(I/C)CF`: (In/correct) counterfactual. `CF`: counterfactual images. `Info`: Useful information per instance. `Noise`: Noisy instance at timestep $\tau$ of the input data. `SM`: Difference between the input and its counterfactual. All files in all folders will have the same identifier. 96 | 97 | 98 | ## Evaluation 99 | 100 | We provide our evaluation protocol scripts to assess the performance of our method. All our evaluation codes use the folder structure presented before. Please look at the --help function flag for more information about their inputs. 101 | - FVA: `compute_FVA.py`. 102 | - MNAC: `compute_MNAC.py`. 103 | - $\sigma_L$: `compute_LPIPS.py`. Computes the variability metric. 104 | - CD: `compute_CD.py`. Computes our proposed metric, Correlation Difference. 105 | - FID: `compute-fid.sh`. The first input is the OUTPUT_PATH and the second one the EXPNAME. 106 | 107 | 108 | ## Training the DDPM model from scratch 109 | 110 | We provided a bash script to train the DDPM to generate the counterfactual explanations: `celeba-train.sh`. Nevertheless, the syntax to run the code base is: 111 | 112 | ```bash 113 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 500 --image_size 128 --learn_sigma True --noise_schedule linear --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 114 | TRAIN_FLAGS="--batch_size 15 --lr 1e-4 --save_interval 30000 --weight_decay 0.05 --dropout 0.0" 115 | mpiexec -n N python celeba-train-diffusion.py $TRAIN_FLAGS \ 116 | $MODEL_FLAGS \ 117 | --output_path OUTPUT_FOLDER \ 118 | --gpus GPUS 119 | ``` 120 | 121 | ## Citation 122 | 123 | If you found useful our code, please cite our work. 124 | 125 | ``` 126 | @inproceedings{Jeanneret_2022_ACCV, 127 | author = {Jeanneret, Guillaume and Simon, Lo\"ic and Fr\'ed\'eric Jurie}, 128 | title = {Diffusion Models for Counterfactual Explanations}, 129 | booktitle = {Proceedings of the Asian Conference on Computer Vision (ACCV)}, 130 | month = {December}, 131 | year = {2022} 132 | } 133 | ``` 134 | 135 | ## Code Base 136 | 137 | We based our repository on [openai/guided-diffusion](https://github.com/openai/guided-diffusion). 138 | -------------------------------------------------------------------------------- /eval_utils/fid_metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Script created direclty from the pytorch fid repository on github 3 | https://github.com/mseitzer/pytorch-fid 4 | ''' 5 | 6 | import os 7 | import itertools 8 | import numpy as np 9 | 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from scipy import linalg 13 | from os import path as osp 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | 20 | from .fid_inception import InceptionV3 21 | 22 | 23 | class Normalizer(nn.Module): 24 | def __init__(self, classifier): 25 | super().__init__() 26 | self.classifier = classifier 27 | # self.register_buffer('mu', torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1)) 28 | # self.register_buffer('sigma', torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1)) 29 | 30 | def forward(self, x): 31 | x = (torch.clamp(x, -1, 1) + 1) / 2 32 | # x = (x - self.mu) / self.sigma 33 | return self.classifier(x) 34 | 35 | 36 | class FIDMachine(): 37 | def __init__(self, dims=2048, device='cpu', 38 | num_samples=500): 39 | self.dims = dims 40 | self.device = device 41 | # self.cl_feat = np.empty((num_samples, dims)) 42 | # self.cf_feat = np.empty((num_samples, dims)) 43 | self.cl_feat = [] 44 | self.cf_feat = [] 45 | self.idx = 0 46 | 47 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 48 | # self.model = Normalizer(InceptionV3([block_idx])).to(device) 49 | self.model = InceptionV3([block_idx], 50 | resize_input=True, 51 | normalize_input=False, # our images are already in the [-1, 1] range 52 | ).to(device) 53 | self.model.eval() 54 | 55 | # def compute_and_store_activations(self, cl, cf): 56 | # ''' 57 | # :param cl: clean images of shape Bx3x128x128 58 | # :param cf: counterfactual images of shape Bx3x128x128 59 | # ''' 60 | # B = cl.size(0) 61 | # self.cl_feat[self.idx:self.idx + B] = self.get_activations(cl) 62 | # self.cf_feat[self.idx:self.idx + B] = self.get_activations(cf) 63 | # self.idx += B 64 | 65 | def compute_and_store_activations(self, cl, cf): 66 | ''' 67 | :param cl: clean images of shape B1x3x128x128 68 | :param cf: counterfactual images of shape B2x3x128x128 69 | ''' 70 | if cl.size(0) != 0: 71 | self.cl_feat.append(self.get_activations(cl)) 72 | 73 | if cf.size(0) != 0: 74 | self.cf_feat.append(self.get_activations(cf)) 75 | 76 | @torch.no_grad() 77 | def get_activations(self, imgs): 78 | pred = self.model(imgs)[0] 79 | if pred.size(2) != 1 or pred.size(3) != 1: 80 | pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1)) 81 | return pred.squeeze(3).squeeze(2).cpu().numpy() 82 | 83 | @staticmethod 84 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 85 | 86 | mu1 = np.atleast_1d(mu1) 87 | mu2 = np.atleast_1d(mu2) 88 | 89 | sigma1 = np.atleast_2d(sigma1) 90 | sigma2 = np.atleast_2d(sigma2) 91 | 92 | assert mu1.shape == mu2.shape, \ 93 | 'Training and test mean vectors have different lengths' 94 | assert sigma1.shape == sigma2.shape, \ 95 | 'Training and test covariances have different dimensions' 96 | 97 | diff = mu1 - mu2 98 | 99 | # Product might be almost singular 100 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 101 | if not np.isfinite(covmean).all(): 102 | msg = ('fid calculation produces singular product; ' 103 | 'adding %s to diagonal of cov estimates') % eps 104 | print(msg) 105 | offset = np.eye(sigma1.shape[0]) * eps 106 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 107 | 108 | # Numerical error might give slight imaginary component 109 | if np.iscomplexobj(covmean): 110 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 111 | m = np.max(np.abs(covmean.imag)) 112 | raise ValueError('Imaginary component {}'.format(m)) 113 | covmean = covmean.real 114 | 115 | tr_covmean = np.trace(covmean) 116 | 117 | return (diff.dot(diff) + np.trace(sigma1) 118 | + np.trace(sigma2) - 2 * tr_covmean) 119 | 120 | def compute_fid(self): 121 | if isinstance(self.cl_feat, list): 122 | self.cl_feat = np.concatenate(self.cl_feat, axis=0) 123 | self.cf_feat = np.concatenate(self.cf_feat, axis=0) 124 | cl_mu = np.mean(self.cl_feat, axis=0) 125 | cf_mu = np.mean(self.cf_feat, axis=0) 126 | cl_sigma = np.cov(self.cl_feat, rowvar=False) 127 | cf_sigma = np.cov(self.cf_feat, rowvar=False) 128 | return self.calculate_frechet_distance(cl_mu, 129 | cl_sigma, 130 | cf_mu, 131 | cf_sigma).item() 132 | 133 | def save_chunk_feature(self, output_path, exp_name, chunk, num_chunks): 134 | os.makedirs(osp.join(output_path, 'Results', exp_name, 'chunk-data'), exist_ok=True) 135 | output_path_clean = osp.join(output_path, 'Results', exp_name, 'chunk-data', 136 | f'clean_chunk-{chunk}_num-chunks-{num_chunks}.npy') 137 | output_path_cf = osp.join(output_path, 'Results', exp_name, 'chunk-data', 138 | f'cf_chunk-{chunk}_num-chunks-{num_chunks}.npy') 139 | self.cl_feat = np.concatenate(self.cl_feat, axis=0) 140 | self.cf_feat = np.concatenate(self.cf_feat, axis=0) 141 | np.save(output_path_clean, self.cl_feat) 142 | np.save(output_path_cf, self.cf_feat) 143 | 144 | def load_and_compute_fid(self, output_path, exp_name, num_chunks): 145 | # load clean and cf features 146 | cl_feat = np.empty((0, self.dims)) 147 | cf_feat = np.empty((0, self.dims)) 148 | 149 | for chunk in range(num_chunks): 150 | path_clean = osp.join(output_path, 'Results', exp_name, 'chunk-data', 151 | f'clean_chunk-{chunk}_num-chunks-{num_chunks}.npy') 152 | path_cf = osp.join(output_path, 'Results', exp_name, 'chunk-data', 153 | f'cf_chunk-{chunk}_num-chunks-{num_chunks}.npy') 154 | cl_feat = np.concatenate((cl_feat, np.load(path_clean)), axis=0) 155 | cf_feat = np.concatenate((cf_feat, np.load(path_cf)), axis=0) 156 | 157 | self.cl_feat = cl_feat 158 | self.cf_feat = cf_feat 159 | 160 | return self.compute_fid() 161 | -------------------------------------------------------------------------------- /core/resnet_vggface2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import pickle 6 | import math 7 | 8 | 9 | __all__ = ['ResNet', 'resnet50'] 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = conv3x3(inplanes, planes, stride) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | 40 | if self.downsample is not None: 41 | residual = self.downsample(x) 42 | 43 | out += residual 44 | out = self.relu(out) 45 | 46 | return out 47 | 48 | 49 | class Bottleneck(nn.Module): 50 | expansion = 4 51 | 52 | def __init__(self, inplanes, planes, stride=1, downsample=None): 53 | super(Bottleneck, self).__init__() 54 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(planes * 4) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv3(out) 76 | out = self.bn3(out) 77 | 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class ResNet(nn.Module): 88 | 89 | def __init__(self, block, layers, num_classes=1000, 90 | layer=0): 91 | 92 | self.inplanes = 64 93 | super(ResNet, self).__init__() 94 | self.layer = layer 95 | 96 | self.register_buffer('mu', torch.tensor((91.4953, 103.8827, 131.0912)).view(1, -1, 1, 1)) # to normalize the images just like VGGFace2 97 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 98 | if layer == 0: # a little brute force stopping 99 | return 100 | self.bn1 = nn.BatchNorm2d(64) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) 103 | 104 | self.layer1 = self._make_layer(block, 64, layers[0]) 105 | if layer == 1: 106 | return 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | if layer == 2: 109 | return 110 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 111 | if layer == 3: 112 | return 113 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 114 | if layer == 4: 115 | return 116 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 117 | if layer == 5: 118 | return 119 | self.fc = nn.Linear(512 * block.expansion, num_classes) 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = x[:, [2, 1, 0], :, :] # RGB -> BGR 140 | x = (x + 1) * 127.5 # make it to the 255 range 141 | x = x - self.mu # remove mean 142 | x = F.interpolate(x, size=(224, 224), mode='bilinear', 143 | align_corners=True) 144 | x = self.conv1(x) 145 | if self.layer == 0: 146 | return x 147 | x = self.bn1(x) 148 | x = self.relu(x) 149 | x = self.maxpool(x) 150 | 151 | x = self.layer1(x) 152 | if self.layer == 1: 153 | return x 154 | x = self.layer2(x) 155 | if self.layer == 2: 156 | return x 157 | x = self.layer3(x) 158 | if self.layer == 3: 159 | return x 160 | x = self.layer4(x) 161 | if self.layer == 4: 162 | return x 163 | 164 | x = self.avgpool(x) 165 | x = x.view(x.size(0), -1) 166 | 167 | if self.layer == 5: 168 | return x 169 | 170 | x = self.fc(x) 171 | return x 172 | 173 | def resnet50(**kwargs): 174 | """Constructs a ResNet-50 model. 175 | """ 176 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 177 | return model 178 | 179 | def load_state_dict(model, fname): 180 | """ 181 | Set parameters converted from Caffe models authors of VGGFace2 provide. 182 | See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/. 183 | Arguments: 184 | model: model 185 | fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle. 186 | """ 187 | with open(fname, 'rb') as f: 188 | weights = pickle.load(f, encoding='latin1') 189 | 190 | own_state = model.state_dict() 191 | for name, param in weights.items(): 192 | if name in own_state: 193 | try: 194 | own_state[name].copy_(torch.from_numpy(param)) 195 | except Exception: 196 | raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\ 197 | 'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size())) 198 | 199 | model.load_state_dict(own_state) 200 | -------------------------------------------------------------------------------- /core/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /compute_CD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import itertools 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.cm as cm 8 | import matplotlib.pyplot as plt 9 | import os.path as osp 10 | 11 | from PIL import Image 12 | from tqdm import tqdm 13 | from torch.utils import data 14 | from torchvision import transforms 15 | 16 | from eval_utils.oracle_metrics import OracleMetrics 17 | 18 | 19 | def arguments(): 20 | parser = argparse.ArgumentParser(description='FVA arguments.') 21 | parser.add_argument('--gpu', default='0', type=str, 22 | help='GPU id') 23 | parser.add_argument('--oracle-path', default='models/oracle.pth', type=str, 24 | help='Oracle path') 25 | parser.add_argument('--exp-name', required=True, type=str, 26 | help='Experiment Name') 27 | parser.add_argument('--output-path', required=True, type=str, 28 | help='Results Path') 29 | parser.add_argument('--celeba-path', required=True, type=str, 30 | help='CelebA path') 31 | 32 | return parser.parse_args() 33 | 34 | 35 | args = arguments() 36 | 37 | device = torch.device('cuda:' + args.gpu) 38 | 39 | # # load oracle 40 | oracle_metrics = OracleMetrics(weights_path=args.oracle_path, 41 | device=device) 42 | oracle_metrics.eval() 43 | 44 | # create dataset to read the counterfactual results images 45 | class CFDataset(): 46 | def __init__(self, path, exp_name): 47 | 48 | self.images = [] 49 | self.path = path 50 | self.exp_name = exp_name 51 | self.transform = transforms.Compose([ 52 | transforms.ToTensor(), 53 | transforms.Normalize([0.5, 0.5, 0.5], 54 | [0.5, 0.5, 0.5]) 55 | ]) 56 | 57 | for CL, CF in itertools.product(['CC', 'IC'], ['CCF']): 58 | self.images += [(CL, CF, I) for I in os.listdir(osp.join(path, 'Results', self.exp_name, CL, CF, 'CF'))] 59 | 60 | def __len__(self): 61 | return len(self.images) 62 | 63 | def switch(self, partition): 64 | if partition == 'C': 65 | LCF = ['CCF'] 66 | elif partition == 'I': 67 | LCF = ['ICF'] 68 | else: 69 | LCF = ['CCF', 'ICF'] 70 | 71 | self.images = [] 72 | 73 | for CL, CF in itertools.product(['CC', 'IC'], LCF): 74 | self.images += [(CL, CF, I) for I in os.listdir(osp.join(self.path, 'Results', self.exp_name, CL, CF, 'CF'))] 75 | 76 | def __getitem__(self, idx): 77 | CL, CF, I = self.images[idx] 78 | # get paths 79 | cl_path = osp.join(self.path, 'Original', 'Correct' if CL == 'CC' else 'Incorrect', I) 80 | cf_path = osp.join(self.path, 'Results', self.exp_name, CL, CF, 'CF', I) 81 | 82 | cl = self.load_img(cl_path) 83 | cf = self.load_img(cf_path) 84 | 85 | return cl, cf 86 | 87 | def load_img(self, path): 88 | with open(path, "rb") as f: 89 | img = Image.open(f) 90 | img = img.convert('RGB') 91 | return self.transform(img) 92 | 93 | 94 | 95 | CELEBAPATH = os.path.join(args.celeba_path, 'list_attr_celeba.csv') 96 | CELEBAPATHP = os.path.join(args.celeba_path, 'list_eval_partition.csv') 97 | # extract the names of the labels 98 | 99 | df = pd.read_csv(CELEBAPATH) 100 | p = pd.read_csv(CELEBAPATHP) 101 | labels = list(df.columns[1:]) 102 | 103 | df = df[p['partition'] == 0] # 1 is val, 0 train 104 | df.replace(-1, 0, inplace=True) 105 | 106 | corrs = np.zeros(40) 107 | 108 | for i in range(40): 109 | corrs[i] = np.corrcoef(df['Smiling'].to_numpy(), df.iloc[:, i + 1].to_numpy())[0, 1] 110 | 111 | df = pd.read_csv(CELEBAPATH) 112 | p = pd.read_csv(CELEBAPATHP) 113 | labels = list(df.columns[1:]) 114 | 115 | df = df[p['partition'] == 1] # 1 is val, 0 train 116 | df.replace(-1, 0, inplace=True) 117 | 118 | corrs2 = np.zeros(40) 119 | 120 | for i in range(40): 121 | corrs2[i] = np.corrcoef(df['Smiling'].to_numpy(), df.iloc[:, i + 1].to_numpy())[0, 1] 122 | 123 | diffs = np.zeros((2, 40)) 124 | 125 | for i in range(40): 126 | diffs[0, i] = df[df['Smiling'] == 1].iloc[:, i + 1].mean() 127 | diffs[1, i] = df[df['Smiling'] == 0].iloc[:, i + 1].mean() 128 | 129 | maindiff = np.abs(diffs[0] - diffs[1]) 130 | 131 | 132 | @torch.no_grad() 133 | def get_attrs_and_target_from_ds(path, exp_name, 134 | oracle, 135 | device): 136 | 137 | print('=' * 70) 138 | print('Evaluating data from:', path) 139 | print(' Experiment:', exp_name) 140 | dataset = CFDataset(path, exp_name) 141 | loader = data.DataLoader(dataset, batch_size=15, 142 | shuffle=False, 143 | num_workers=4, pin_memory=True) 144 | 145 | oracle_preds = {'cf': {'dist': [], 146 | 'pred': []}, 147 | 'cl': {'dist': [], 148 | 'pred': []}} 149 | 150 | for cl, cf in tqdm(loader): 151 | cl = cl.to(device, dtype=torch.float) 152 | cf = cf.to(device, dtype=torch.float) 153 | 154 | cl_o_dist = torch.sigmoid(oracle.oracle(cl)[1]) 155 | cf_o_dist = torch.sigmoid(oracle.oracle(cf)[1]) 156 | 157 | oracle_preds['cl']['dist'].append(cl_o_dist.cpu().numpy()) 158 | oracle_preds['cl']['pred'].append((cl_o_dist > 0.5).cpu().numpy()) 159 | oracle_preds['cf']['dist'].append(cf_o_dist.cpu().numpy()) 160 | oracle_preds['cf']['pred'].append((cf_o_dist > 0.5).cpu().numpy()) 161 | 162 | oracle_preds['cl']['dist'] = np.concatenate(oracle_preds['cl']['dist']) 163 | oracle_preds['cf']['dist'] = np.concatenate(oracle_preds['cf']['dist']) 164 | oracle_preds['cl']['pred'] = np.concatenate(oracle_preds['cl']['pred']) 165 | oracle_preds['cf']['pred'] = np.concatenate(oracle_preds['cf']['pred']) 166 | 167 | return oracle_preds 168 | 169 | 170 | def compute_CorrMetric(path, 171 | exp_name, 172 | oracle, 173 | device, 174 | query_label, 175 | corr, 176 | top=40, 177 | sorted=None, 178 | show=False, 179 | diff=True, 180 | remove_unchanged_oracle=False): 181 | 182 | oracle_preds = get_attrs_and_target_from_ds(path, exp_name, oracle, device) 183 | 184 | cf_pred = oracle_preds['cf']['pred'].astype('float') 185 | cl_pred = oracle_preds['cl']['pred'].astype('float') 186 | 187 | if diff: 188 | delta_query = cf_pred[:, query_label] - cl_pred[:, query_label] 189 | deltas = cf_pred - cl_pred 190 | else: 191 | delta_query = cf_pred[:, query_label] 192 | deltas = cf_pred 193 | 194 | if remove_unchanged_oracle: 195 | to_remove = cf_pred[:, query_label] != cl_pred[:, query_label] 196 | deltas = deltas[to_remove, :] 197 | delta_query = delta_query[to_remove] 198 | del to_remove 199 | 200 | print('Lenght:', len(deltas)) 201 | 202 | our_corrs = np.zeros(40) 203 | 204 | for i in range(40): 205 | our_corrs[i] = np.corrcoef(deltas[:, i], delta_query)[0, 1] 206 | 207 | if show: 208 | if sorted is None: 209 | plt.bar(np.arange(len(our_corrs))[:top] - 0.15, corr[:top], width=0.3, label='Correlations') 210 | plt.bar(np.arange(len(our_corrs))[:top] + 0.15, metric[:top], width=0.3, label='Metric') 211 | plt.xticks(np.arange(len(our_corrs))[:top], our_corrs[:top], rotation=90) 212 | else: 213 | plt.bar(np.arange(len(our_corrs))[:top] - 0.15, corr[sorted][:top], width=0.3, label='Correlations') 214 | plt.bar(np.arange(len(our_corrs))[:top] + 0.15, our_corrs[sorted][:top], width=0.3, label='Metric') 215 | plt.xticks(np.arange(len(our_corrs))[:top], [our_corrs[i] for i in sorted][:top], rotation=90) 216 | 217 | plt.legend() 218 | plt.show() 219 | 220 | return our_corrs 221 | 222 | 223 | def plot_bar(data, labs, top, sorted): 224 | 225 | r = 90 226 | f = 15 227 | n_items = len(data) 228 | eps = 1e-1 229 | x_base = np.arange(40) 230 | step = (1 - 2 * eps) / (2 * n_items + 1) 231 | width = 2 * step 232 | cmap = cm.get_cmap('viridis', 512)(np.linspace(0, 1, n_items)) 233 | 234 | def plot(x, d, l, c): 235 | plt.bar(x, d, width=width, label=l, color=c) 236 | 237 | for i, (d, l) in enumerate(zip(data, labs)): 238 | c_x = x_base - 0.5 + eps + step * (2 * i + 1) 239 | c = [p.item() for p in cmap[i]] 240 | 241 | if sorted is not None: 242 | d = d[sorted] 243 | 244 | plot(c_x[:top], d[:top], l, c[:top]) 245 | 246 | plt.legend() 247 | plt.tight_layout() 248 | 249 | if sorted is None: 250 | plt.xticks(x_base[:top], labels[:top], rotation=r, fontsize=f) 251 | else: 252 | plt.xticks(x_base[:top], [labels[i] for i in sorted][:top], rotation=r, fontsize=f) 253 | 254 | plt.show() 255 | 256 | 257 | # get results from dataset 258 | 259 | sorted = np.argsort(np.abs(corrs))[::-1] 260 | 261 | results = compute_CorrMetric(args.output_path, 262 | args.exp_name, 263 | oracle_metrics, 264 | device, 265 | 31, # smile attribute 266 | corrs, 267 | top=40, 268 | sorted=sorted, 269 | show=False, 270 | diff=True, 271 | remove_unchanged_oracle=False) 272 | 273 | print('CD Result:', np.sum(np.abs(results[sorted] - corrs[sorted]))) 274 | 275 | plot_bar([corrs, results], 276 | ['Correlation', 'Method'], 277 | 40, sorted) -------------------------------------------------------------------------------- /core/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | import glob 5 | 6 | import blobfile as bf 7 | import torch as th 8 | import torch.distributed as dist 9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 10 | from torch.optim import AdamW 11 | 12 | from . import dist_util, logger 13 | from .fp16_util import MixedPrecisionTrainer 14 | from .nn import update_ema 15 | from .resample import LossAwareSampler, UniformSampler 16 | from .sample_utils import load_from_DDP_model 17 | 18 | # For ImageNet experiments, this was a good default value. 19 | # We found that the lg_loss_scale quickly climbed to 20 | # 20-21 within the first ~1K steps of training. 21 | INITIAL_LOG_LOSS_SCALE = 20.0 22 | 23 | 24 | class TrainLoop: 25 | def __init__( 26 | self, 27 | *, 28 | model, 29 | diffusion, 30 | data, 31 | batch_size, 32 | microbatch, 33 | lr, 34 | ema_rate, 35 | log_interval, 36 | save_interval, 37 | resume_checkpoint, 38 | use_fp16=False, 39 | fp16_scale_growth=1e-3, 40 | schedule_sampler=None, 41 | weight_decay=0.0, 42 | lr_anneal_steps=0, 43 | ): 44 | self.model = model 45 | self.diffusion = diffusion 46 | self.data = data 47 | self.batch_size = batch_size 48 | self.microbatch = microbatch if microbatch > 0 else batch_size 49 | self.lr = lr 50 | self.ema_rate = ( 51 | [ema_rate] 52 | if isinstance(ema_rate, float) 53 | else [float(x) for x in ema_rate.split(",")] 54 | ) 55 | self.log_interval = log_interval 56 | self.save_interval = save_interval 57 | self.resume_checkpoint = resume_checkpoint 58 | self.use_fp16 = use_fp16 59 | self.fp16_scale_growth = fp16_scale_growth 60 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 61 | self.weight_decay = weight_decay 62 | self.lr_anneal_steps = lr_anneal_steps 63 | 64 | self.step = 0 65 | self.resume_step = 0 66 | self.global_batch = self.batch_size * dist.get_world_size() 67 | 68 | self.sync_cuda = th.cuda.is_available() 69 | 70 | self._load_and_sync_parameters() 71 | self.mp_trainer = MixedPrecisionTrainer( 72 | model=self.model, 73 | use_fp16=self.use_fp16, 74 | fp16_scale_growth=fp16_scale_growth, 75 | ) 76 | 77 | self.opt = AdamW( 78 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 79 | ) 80 | if self.resume_step: 81 | self._load_optimizer_state() 82 | # Model was resumed, either due to a restart or a checkpoint 83 | # being specified at the command line. 84 | self.ema_params = [ 85 | self._load_ema_parameters(rate) for rate in self.ema_rate 86 | ] 87 | else: 88 | self.ema_params = [ 89 | copy.deepcopy(self.mp_trainer.master_params) 90 | for _ in range(len(self.ema_rate)) 91 | ] 92 | 93 | if th.cuda.is_available(): 94 | self.use_ddp = True 95 | self.ddp_model = DDP( 96 | self.model, 97 | device_ids=[dist_util.dev()], 98 | output_device=dist_util.dev(), 99 | broadcast_buffers=False, 100 | bucket_cap_mb=128, 101 | find_unused_parameters=False, 102 | ) 103 | else: 104 | if dist.get_world_size() > 1: 105 | logger.warn( 106 | "Distributed training requires CUDA. " 107 | "Gradients will not be synchronized properly!" 108 | ) 109 | self.use_ddp = False 110 | self.ddp_model = self.model 111 | 112 | def _load_and_sync_parameters(self): 113 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 114 | 115 | if resume_checkpoint: 116 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 117 | if dist.get_rank() == 0: 118 | print(f"loading model from checkpoint: {resume_checkpoint}...") 119 | # self.model.load_state_dict( 120 | # dist_util.load_state_dict( 121 | # resume_checkpoint, map_location=dist_util.dev() 122 | # ) 123 | # ) 124 | self.model.load_state_dict( 125 | load_from_DDP_model( 126 | th.load(resume_checkpoint, map_location=dist_util.dev()) 127 | ) 128 | ) 129 | print('done') 130 | 131 | dist_util.sync_params(self.model.parameters()) 132 | 133 | def _load_ema_parameters(self, rate): 134 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 135 | 136 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 137 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 138 | if ema_checkpoint: 139 | if dist.get_rank() == 0: 140 | print(f"loading EMA from checkpoint: {ema_checkpoint}...") 141 | # state_dict = dist_util.load_state_dict( 142 | # ema_checkpoint, map_location=dist_util.dev() 143 | # ) 144 | state_dict = load_from_DDP_model( 145 | th.load(ema_checkpoint, map_location=dist_util.dev()) 146 | ) 147 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 148 | print('done') 149 | 150 | dist_util.sync_params(ema_params) 151 | return ema_params 152 | 153 | def _load_optimizer_state(self): 154 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 155 | opt_checkpoint = bf.join( 156 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 157 | ) 158 | if bf.exists(opt_checkpoint): 159 | print(f"loading optimizer state from checkpoint: {opt_checkpoint}") 160 | state_dict = dist_util.load_state_dict( 161 | opt_checkpoint, map_location=dist_util.dev() 162 | ) 163 | self.opt.load_state_dict(state_dict) 164 | 165 | def run_loop(self): 166 | print('Running training loop') 167 | while ( 168 | not self.lr_anneal_steps 169 | or self.step + self.resume_step < self.lr_anneal_steps 170 | ): 171 | batch, cond = next(self.data) 172 | self.run_step(batch, cond) 173 | if self.step % self.log_interval == 0: 174 | # logger.dumpkvs() 175 | print('Step', self.step) 176 | if self.step % self.save_interval == 0: 177 | self.save() 178 | # Run for a finite amount of time in integration tests. 179 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 180 | return 181 | self.step += 1 182 | # Save the last checkpoint if it wasn't already saved. 183 | if (self.step - 1) % self.save_interval != 0: 184 | self.save() 185 | 186 | def run_step(self, batch, cond): 187 | self.forward_backward(batch, cond) 188 | took_step = self.mp_trainer.optimize(self.opt) 189 | if took_step: 190 | self._update_ema() 191 | self._anneal_lr() 192 | # self.log_step() 193 | 194 | def forward_backward(self, batch, cond): 195 | self.mp_trainer.zero_grad() 196 | for i in range(0, batch.shape[0], self.microbatch): 197 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 198 | micro_cond = { 199 | k: v[i : i + self.microbatch].to(dist_util.dev()) 200 | for k, v in cond.items() 201 | } 202 | last_batch = (i + self.microbatch) >= batch.shape[0] 203 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 204 | 205 | compute_losses = functools.partial( 206 | self.diffusion.training_losses, 207 | self.ddp_model, 208 | micro, 209 | t, 210 | model_kwargs=micro_cond, 211 | ) 212 | 213 | if last_batch or not self.use_ddp: 214 | losses = compute_losses() 215 | else: 216 | with self.ddp_model.no_sync(): 217 | losses = compute_losses() 218 | 219 | if isinstance(self.schedule_sampler, LossAwareSampler): 220 | self.schedule_sampler.update_with_local_losses( 221 | t, losses["loss"].detach() 222 | ) 223 | 224 | loss = (losses["loss"] * weights).mean() 225 | # log_loss_dict( 226 | # self.diffusion, t, {k: v * weights for k, v in losses.items()} 227 | # ) 228 | self.mp_trainer.backward(loss) 229 | 230 | def _update_ema(self): 231 | for rate, params in zip(self.ema_rate, self.ema_params): 232 | update_ema(params, self.mp_trainer.master_params, rate=rate) 233 | 234 | def _anneal_lr(self): 235 | if not self.lr_anneal_steps: 236 | return 237 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 238 | lr = self.lr * (1 - frac_done) 239 | for param_group in self.opt.param_groups: 240 | param_group["lr"] = lr 241 | 242 | def log_step(self): 243 | logger.logkv("step", self.step + self.resume_step) 244 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 245 | 246 | def save(self): 247 | def save_checkpoint(rate, params): 248 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 249 | if dist.get_rank() == 0: 250 | print(f"saving model {rate}...") 251 | if not rate: 252 | filename = f"model{(self.step+self.resume_step):06d}.pt" 253 | else: 254 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 255 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 256 | th.save(state_dict, f) 257 | 258 | # delete old checkpoints 259 | if dist.get_rank() == 0: 260 | for f in glob.glob(os.path.join(get_blob_logdir(), '*.pt')): 261 | os.remove(os.path.join(get_blob_logdir(), f)) 262 | 263 | save_checkpoint(0, self.mp_trainer.master_params) 264 | for rate, params in zip(self.ema_rate, self.ema_params): 265 | save_checkpoint(rate, params) 266 | 267 | if dist.get_rank() == 0: 268 | with bf.BlobFile( 269 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 270 | "wb", 271 | ) as f: 272 | th.save(self.opt.state_dict(), f) 273 | 274 | dist.barrier() 275 | 276 | 277 | def parse_resume_step_from_filename(filename): 278 | """ 279 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 280 | checkpoint's number of steps. 281 | """ 282 | split = filename.split("model") 283 | if len(split) < 2: 284 | return 0 285 | split1 = split[-1].split(".")[0] 286 | try: 287 | return int(split1) 288 | except ValueError: 289 | return 0 290 | 291 | 292 | def get_blob_logdir(): 293 | # You can change this to be a separate path to save checkpoints to 294 | # a blobstore or some external drive. 295 | return logger.get_dir() 296 | 297 | 298 | def find_resume_checkpoint(): 299 | # On your infrastructure, you may want to override this to automatically 300 | # discover the latest checkpoint on your blob storage, etc. 301 | return None 302 | 303 | 304 | def find_ema_checkpoint(main_checkpoint, step, rate): 305 | if main_checkpoint is None: 306 | return None 307 | filename = f"ema_{rate}_{(step):06d}.pt" 308 | path = bf.join(bf.dirname(main_checkpoint), filename) 309 | if bf.exists(path): 310 | return path 311 | return None 312 | 313 | 314 | def log_loss_dict(diffusion, ts, losses): 315 | for key, values in losses.items(): 316 | logger.logkv_mean(key, values.mean().item()) 317 | # Log the quantiles (four quartiles, in particular). 318 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 319 | quartile = int(4 * sub_t / diffusion.num_timesteps) 320 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 321 | -------------------------------------------------------------------------------- /core/image_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import h5py 4 | import torch 5 | import random 6 | import numpy as np 7 | import pandas as pd 8 | import blobfile as bf 9 | 10 | from os import path as osp 11 | from PIL import Image 12 | from mpi4py import MPI 13 | from torchvision import transforms, datasets 14 | from torch.utils.data import DataLoader, Dataset 15 | 16 | 17 | # ============================================================================ 18 | # ImageFolder dataloader 19 | # ============================================================================ 20 | 21 | 22 | def load_data( 23 | *, 24 | data_dir, 25 | batch_size, 26 | image_size, 27 | class_cond=False, 28 | deterministic=False, 29 | random_crop=False, 30 | random_flip=True, 31 | ): 32 | """ 33 | For a dataset, create a generator over (images, kwargs) pairs. 34 | 35 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 36 | more keys, each of which map to a batched Tensor of their own. 37 | The kwargs dict can be used for class labels, in which case the key is "y" 38 | and the values are integer tensors of class labels. 39 | 40 | :param data_dir: a dataset directory. 41 | :param batch_size: the batch size of each returned pair. 42 | :param image_size: the size to which images are resized. 43 | :param class_cond: if True, include a "y" key in returned dicts for class 44 | label. If classes are not available and this is true, an 45 | exception will be raised. 46 | :param deterministic: if True, yield results in a deterministic order. 47 | :param random_crop: if True, randomly crop the images for augmentation. 48 | :param random_flip: if True, randomly flip the images for augmentation. 49 | """ 50 | if not data_dir: 51 | raise ValueError("unspecified data directory") 52 | all_files = _list_image_files_recursively(data_dir) 53 | classes = None 54 | if class_cond: 55 | # Assume classes are the first part of the filename, 56 | # before an underscore. 57 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 58 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 59 | classes = [sorted_classes[x] for x in class_names] 60 | dataset = ImageDataset( 61 | image_size, 62 | all_files, 63 | classes=classes, 64 | shard=MPI.COMM_WORLD.Get_rank(), 65 | num_shards=MPI.COMM_WORLD.Get_size(), 66 | random_crop=random_crop, 67 | random_flip=random_flip, 68 | ) 69 | if deterministic: 70 | loader = DataLoader( 71 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 72 | ) 73 | else: 74 | loader = DataLoader( 75 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 76 | ) 77 | while True: 78 | yield from loader 79 | 80 | 81 | def _list_image_files_recursively(data_dir): 82 | results = [] 83 | for entry in sorted(bf.listdir(data_dir)): 84 | full_path = bf.join(data_dir, entry) 85 | ext = entry.split(".")[-1] 86 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 87 | results.append(full_path) 88 | elif bf.isdir(full_path): 89 | results.extend(_list_image_files_recursively(full_path)) 90 | return results 91 | 92 | 93 | class ImageDataset(Dataset): 94 | def __init__( 95 | self, 96 | resolution, 97 | image_paths, 98 | classes=None, 99 | shard=0, 100 | num_shards=1, 101 | random_crop=False, 102 | random_flip=True, 103 | ): 104 | super().__init__() 105 | self.resolution = resolution 106 | self.local_images = image_paths[shard:][::num_shards] 107 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 108 | self.random_crop = random_crop 109 | self.random_flip = random_flip 110 | 111 | def __len__(self): 112 | return len(self.local_images) 113 | 114 | def __getitem__(self, idx): 115 | path = self.local_images[idx] 116 | with bf.BlobFile(path, "rb") as f: 117 | pil_image = Image.open(f) 118 | pil_image.load() 119 | pil_image = pil_image.convert("RGB") 120 | 121 | if self.random_crop: 122 | arr = random_crop_arr(pil_image, self.resolution) 123 | else: 124 | arr = center_crop_arr(pil_image, self.resolution) 125 | 126 | if self.random_flip and random.random() < 0.5: 127 | arr = arr[:, ::-1] 128 | 129 | arr = arr.astype(np.float32) / 127.5 - 1 130 | 131 | out_dict = {} 132 | if self.local_classes is not None: 133 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 134 | return np.transpose(arr, [2, 0, 1]), out_dict 135 | 136 | 137 | def center_crop_arr(pil_image, image_size): 138 | # We are not on a new enough PIL to support the `reducing_gap` 139 | # argument, which uses BOX downsampling at powers of two first. 140 | # Thus, we do it by hand to improve downsample quality. 141 | while min(*pil_image.size) >= 2 * image_size: 142 | pil_image = pil_image.resize( 143 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 144 | ) 145 | 146 | scale = image_size / min(*pil_image.size) 147 | pil_image = pil_image.resize( 148 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 149 | ) 150 | 151 | arr = np.array(pil_image) 152 | crop_y = (arr.shape[0] - image_size) // 2 153 | crop_x = (arr.shape[1] - image_size) // 2 154 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 155 | 156 | 157 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 158 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 159 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 160 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 161 | 162 | # We are not on a new enough PIL to support the `reducing_gap` 163 | # argument, which uses BOX downsampling at powers of two first. 164 | # Thus, we do it by hand to improve downsample quality. 165 | while min(*pil_image.size) >= 2 * smaller_dim_size: 166 | pil_image = pil_image.resize( 167 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 168 | ) 169 | 170 | scale = smaller_dim_size / min(*pil_image.size) 171 | pil_image = pil_image.resize( 172 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 173 | ) 174 | 175 | arr = np.array(pil_image) 176 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 177 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 178 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 179 | 180 | 181 | # ============================================================================ 182 | # CelebA dataloader 183 | # ============================================================================ 184 | 185 | 186 | def load_data_celeba( 187 | *, 188 | data_dir, 189 | batch_size, 190 | image_size, 191 | partition='train', 192 | class_cond=False, 193 | deterministic=False, 194 | random_crop=False, 195 | random_flip=True 196 | ): 197 | """ 198 | For a dataset, create a generator over (images, kwargs) pairs. 199 | 200 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 201 | more keys, each of which map to a batched Tensor of their own. 202 | The kwargs dict can be used for class labels, in which case the key is "y" 203 | and the values are integer tensors of class labels. 204 | 205 | :param data_dir: a dataset directory. 206 | :param batch_size: the batch size of each returned pair. 207 | :param image_size: the size to which images are resized. 208 | :param class_cond: if True, include a "y" key in returned dicts for class 209 | label. If classes are not available and this is true, an 210 | exception will be raised. 211 | :param deterministic: if True, yield results in a deterministic order. 212 | :param random_crop: if True, randomly crop the images for augmentation. 213 | :param random_flip: if True, randomly flip the images for augmentation. 214 | """ 215 | if not data_dir: 216 | raise ValueError("unspecified data directory") 217 | 218 | dataset = CelebADataset( 219 | image_size, 220 | data_dir, 221 | partition, 222 | shard=MPI.COMM_WORLD.Get_rank(), 223 | num_shards=MPI.COMM_WORLD.Get_size(), 224 | class_cond=class_cond, 225 | random_crop=random_crop, 226 | random_flip=random_flip, 227 | ) 228 | if deterministic: 229 | loader = DataLoader( 230 | dataset, batch_size=batch_size, shuffle=False, num_workers=5, drop_last=True 231 | ) 232 | else: 233 | loader = DataLoader( 234 | dataset, batch_size=batch_size, shuffle=True, num_workers=5, drop_last=True 235 | ) 236 | while True: 237 | yield from loader 238 | 239 | 240 | class CelebADataset(Dataset): 241 | def __init__( 242 | self, 243 | image_size, 244 | data_dir, 245 | partition, 246 | shard=0, 247 | num_shards=1, 248 | class_cond=False, 249 | random_crop=True, 250 | random_flip=True, 251 | query_label=-1, 252 | normalize=True, 253 | ): 254 | partition_df = pd.read_csv(osp.join(data_dir, 'list_eval_partition.csv')) 255 | self.data_dir = data_dir 256 | data = pd.read_csv(osp.join(data_dir, 'list_attr_celeba.csv')) 257 | 258 | if partition == 'train': 259 | partition = 0 260 | elif partition == 'val': 261 | partition = 1 262 | elif partition == 'test': 263 | partition = 2 264 | else: 265 | raise ValueError(f'Unkown partition {partition}') 266 | 267 | self.data = data[partition_df['partition'] == partition] 268 | self.data = self.data[shard::num_shards] 269 | self.data.reset_index(inplace=True) 270 | self.data.replace(-1, 0, inplace=True) 271 | 272 | self.transform = transforms.Compose([ 273 | transforms.Resize(image_size), 274 | transforms.RandomHorizontalFlip() if random_flip else lambda x: x, 275 | transforms.CenterCrop(image_size), 276 | transforms.RandomResizedCrop(image_size, (0.95, 1.0)) if random_crop else lambda x: x, 277 | transforms.ToTensor(), 278 | transforms.Normalize([0.5, 0.5, 0.5], 279 | [0.5, 0.5, 0.5]) if normalize else lambda x: x 280 | ]) 281 | 282 | self.query = query_label 283 | self.class_cond = class_cond 284 | 285 | def __len__(self): 286 | return len(self.data) 287 | 288 | def __getitem__(self, idx): 289 | sample = self.data.iloc[idx, :] 290 | labels = sample[2:].to_numpy() 291 | if self.query != -1: 292 | labels = int(labels[self.query]) 293 | else: 294 | labels = torch.from_numpy(labels.astype('float32')) 295 | img_file = sample['image_id'] 296 | 297 | with open(osp.join(self.data_dir, 'img_align_celeba', img_file), "rb") as f: 298 | img = Image.open(f) 299 | img = img.convert('RGB') 300 | 301 | img = self.transform(img) 302 | 303 | if self.query != -1: 304 | return img, labels 305 | 306 | if self.class_cond: 307 | return img, {'y': labels} 308 | else: 309 | return img, {} 310 | 311 | 312 | class CelebAMiniVal(CelebADataset): 313 | def __init__( 314 | self, 315 | image_size, 316 | data_dir, 317 | partition=None, 318 | shard=0, 319 | num_shards=1, 320 | class_cond=False, 321 | random_crop=True, 322 | random_flip=True, 323 | query_label=-1, 324 | normalize=True, 325 | ): 326 | self.data = pd.read_csv('utils/minival.csv').iloc[:, 1:] 327 | self.data = self.data[shard::num_shards] 328 | self.image_size = image_size 329 | self.transform = transforms.Compose([ 330 | transforms.Resize(image_size), 331 | transforms.RandomHorizontalFlip() if random_flip else lambda x: x, 332 | transforms.CenterCrop(image_size), 333 | transforms.RandomResizedCrop(image_size, (0.95, 1.0)) if random_crop else lambda x: x, 334 | transforms.ToTensor(), 335 | transforms.Normalize([0.5, 0.5, 0.5], 336 | [0.5, 0.5, 0.5]) if normalize else lambda x: x, 337 | ]) 338 | self.data_dir = data_dir 339 | self.class_cond = class_cond 340 | self.query = query_label 341 | -------------------------------------------------------------------------------- /eval_utils/fid_inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=(DEFAULT_BLOCK_INDEX,), 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = _inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def _inception_v3(*args, **kwargs): 167 | """Wraps `torchvision.models.inception_v3` 168 | 169 | Skips default weight inititialization if supported by torchvision version. 170 | See https://github.com/mseitzer/pytorch-fid/issues/28. 171 | """ 172 | try: 173 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 174 | except ValueError: 175 | # Just a caution against weird version strings 176 | version = (0,) 177 | 178 | if version >= (0, 6): 179 | kwargs['init_weights'] = False 180 | 181 | return torchvision.models.inception_v3(*args, **kwargs) 182 | 183 | 184 | def fid_inception_v3(): 185 | """Build pretrained Inception model for FID computation 186 | 187 | The Inception model for FID computation uses a different set of weights 188 | and has a slightly different structure than torchvision's Inception. 189 | 190 | This method first constructs torchvision's Inception and then patches the 191 | necessary parts that are different in the FID Inception model. 192 | """ 193 | inception = _inception_v3(num_classes=1008, 194 | aux_logits=False, 195 | pretrained=False) 196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 203 | inception.Mixed_7b = FIDInceptionE_1(1280) 204 | inception.Mixed_7c = FIDInceptionE_2(2048) 205 | 206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 207 | inception.load_state_dict(state_dict) 208 | return inception 209 | 210 | 211 | class FIDInceptionA(torchvision.models.inception.InceptionA): 212 | """InceptionA block patched for FID computation""" 213 | def __init__(self, in_channels, pool_features): 214 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 215 | 216 | def forward(self, x): 217 | branch1x1 = self.branch1x1(x) 218 | 219 | branch5x5 = self.branch5x5_1(x) 220 | branch5x5 = self.branch5x5_2(branch5x5) 221 | 222 | branch3x3dbl = self.branch3x3dbl_1(x) 223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 225 | 226 | # Patch: Tensorflow's average pool does not use the padded zero's in 227 | # its average calculation 228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 229 | count_include_pad=False) 230 | branch_pool = self.branch_pool(branch_pool) 231 | 232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 233 | return torch.cat(outputs, 1) 234 | 235 | 236 | class FIDInceptionC(torchvision.models.inception.InceptionC): 237 | """InceptionC block patched for FID computation""" 238 | def __init__(self, in_channels, channels_7x7): 239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 240 | 241 | def forward(self, x): 242 | branch1x1 = self.branch1x1(x) 243 | 244 | branch7x7 = self.branch7x7_1(x) 245 | branch7x7 = self.branch7x7_2(branch7x7) 246 | branch7x7 = self.branch7x7_3(branch7x7) 247 | 248 | branch7x7dbl = self.branch7x7dbl_1(x) 249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 253 | 254 | # Patch: Tensorflow's average pool does not use the padded zero's in 255 | # its average calculation 256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 257 | count_include_pad=False) 258 | branch_pool = self.branch_pool(branch_pool) 259 | 260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 261 | return torch.cat(outputs, 1) 262 | 263 | 264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 265 | """First InceptionE block patched for FID computation""" 266 | def __init__(self, in_channels): 267 | super(FIDInceptionE_1, self).__init__(in_channels) 268 | 269 | def forward(self, x): 270 | branch1x1 = self.branch1x1(x) 271 | 272 | branch3x3 = self.branch3x3_1(x) 273 | branch3x3 = [ 274 | self.branch3x3_2a(branch3x3), 275 | self.branch3x3_2b(branch3x3), 276 | ] 277 | branch3x3 = torch.cat(branch3x3, 1) 278 | 279 | branch3x3dbl = self.branch3x3dbl_1(x) 280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 281 | branch3x3dbl = [ 282 | self.branch3x3dbl_3a(branch3x3dbl), 283 | self.branch3x3dbl_3b(branch3x3dbl), 284 | ] 285 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 286 | 287 | # Patch: Tensorflow's average pool does not use the padded zero's in 288 | # its average calculation 289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 290 | count_include_pad=False) 291 | branch_pool = self.branch_pool(branch_pool) 292 | 293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 294 | return torch.cat(outputs, 1) 295 | 296 | 297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 298 | """Second InceptionE block patched for FID computation""" 299 | def __init__(self, in_channels): 300 | super(FIDInceptionE_2, self).__init__(in_channels) 301 | 302 | def forward(self, x): 303 | branch1x1 = self.branch1x1(x) 304 | 305 | branch3x3 = self.branch3x3_1(x) 306 | branch3x3 = [ 307 | self.branch3x3_2a(branch3x3), 308 | self.branch3x3_2b(branch3x3), 309 | ] 310 | branch3x3 = torch.cat(branch3x3, 1) 311 | 312 | branch3x3dbl = self.branch3x3dbl_1(x) 313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 314 | branch3x3dbl = [ 315 | self.branch3x3dbl_3a(branch3x3dbl), 316 | self.branch3x3dbl_3b(branch3x3dbl), 317 | ] 318 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 319 | 320 | # Patch: The FID Inception model uses max pooling instead of average 321 | # pooling. This is likely an error in this specific Inception 322 | # implementation, as other Inception models use average pooling here 323 | # (which matches the description in the paper). 324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 325 | branch_pool = self.branch_pool(branch_pool) 326 | 327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 328 | return torch.cat(outputs, 1) 329 | -------------------------------------------------------------------------------- /core/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def diffusion_defaults(): 12 | """ 13 | Defaults for image and classifier training. 14 | """ 15 | return dict( 16 | learn_sigma=False, 17 | diffusion_steps=1000, 18 | noise_schedule="linear", 19 | timestep_respacing="", 20 | use_kl=False, 21 | predict_xstart=False, 22 | rescale_timesteps=False, 23 | rescale_learned_sigmas=False, 24 | ) 25 | 26 | 27 | def classifier_defaults(): 28 | """ 29 | Defaults for classifier models. 30 | """ 31 | return dict( 32 | image_size=64, 33 | out_channels=1000, 34 | classifier_use_fp16=False, 35 | classifier_width=128, 36 | classifier_depth=2, 37 | classifier_attention_resolutions="32,16,8", # 16 38 | classifier_use_scale_shift_norm=True, # False 39 | classifier_resblock_updown=True, # False 40 | classifier_pool="attention", 41 | ) 42 | 43 | 44 | def model_and_diffusion_defaults(): 45 | """ 46 | Defaults for image training. 47 | """ 48 | res = dict( 49 | image_size=64, 50 | num_channels=128, 51 | num_res_blocks=2, 52 | num_heads=4, 53 | num_heads_upsample=-1, 54 | num_head_channels=-1, 55 | attention_resolutions="16,8", 56 | channel_mult="", 57 | dropout=0.0, 58 | class_cond=False, 59 | use_checkpoint=False, 60 | use_scale_shift_norm=True, 61 | resblock_updown=False, 62 | use_fp16=False, 63 | use_new_attention_order=False, 64 | ) 65 | res.update(diffusion_defaults()) 66 | return res 67 | 68 | 69 | def classifier_and_diffusion_defaults(): 70 | res = classifier_defaults() 71 | res.update(diffusion_defaults()) 72 | return res 73 | 74 | 75 | def create_model_and_diffusion( 76 | image_size, 77 | class_cond, 78 | learn_sigma, 79 | num_channels, 80 | num_res_blocks, 81 | channel_mult, 82 | num_heads, 83 | num_head_channels, 84 | num_heads_upsample, 85 | attention_resolutions, 86 | dropout, 87 | diffusion_steps, 88 | noise_schedule, 89 | timestep_respacing, 90 | use_kl, 91 | predict_xstart, 92 | rescale_timesteps, 93 | rescale_learned_sigmas, 94 | use_checkpoint, 95 | use_scale_shift_norm, 96 | resblock_updown, 97 | use_fp16, 98 | use_new_attention_order, 99 | num_classes=1000, 100 | multiclass=False 101 | ): 102 | model = create_model( 103 | image_size, 104 | num_channels, 105 | num_res_blocks, 106 | channel_mult=channel_mult, 107 | learn_sigma=learn_sigma, 108 | class_cond=class_cond, 109 | use_checkpoint=use_checkpoint, 110 | attention_resolutions=attention_resolutions, 111 | num_heads=num_heads, 112 | num_head_channels=num_head_channels, 113 | num_heads_upsample=num_heads_upsample, 114 | use_scale_shift_norm=use_scale_shift_norm, 115 | dropout=dropout, 116 | resblock_updown=resblock_updown, 117 | use_fp16=use_fp16, 118 | use_new_attention_order=use_new_attention_order, 119 | num_classes=num_classes, 120 | multiclass=multiclass 121 | ) 122 | diffusion = create_gaussian_diffusion( 123 | steps=diffusion_steps, 124 | learn_sigma=learn_sigma, 125 | noise_schedule=noise_schedule, 126 | use_kl=use_kl, 127 | predict_xstart=predict_xstart, 128 | rescale_timesteps=rescale_timesteps, 129 | rescale_learned_sigmas=rescale_learned_sigmas, 130 | timestep_respacing=timestep_respacing, 131 | ) 132 | return model, diffusion 133 | 134 | 135 | def create_model( 136 | image_size, 137 | num_channels, 138 | num_res_blocks, 139 | channel_mult="", 140 | learn_sigma=False, 141 | class_cond=False, 142 | use_checkpoint=False, 143 | attention_resolutions="16", 144 | num_heads=1, 145 | num_head_channels=-1, 146 | num_heads_upsample=-1, 147 | use_scale_shift_norm=False, 148 | dropout=0, 149 | resblock_updown=False, 150 | use_fp16=False, 151 | use_new_attention_order=False, 152 | num_classes=1000, 153 | multiclass=False 154 | ): 155 | if channel_mult == "": 156 | if image_size == 512: 157 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 158 | elif image_size == 256: 159 | channel_mult = (1, 1, 2, 2, 4, 4) 160 | elif image_size == 224: 161 | channel_mult = (1, 1, 2, 3, 4) 162 | elif image_size == 128: 163 | channel_mult = (1, 1, 2, 3, 4) 164 | elif image_size == 64: 165 | channel_mult = (1, 2, 3, 4) 166 | elif image_size == 28: # for mnist 167 | channel_mult = (1, 1, 2) 168 | elif isinstance(image_size, list): 169 | channel_mult = (1, 1, 2, 3, 4) 170 | else: 171 | raise ValueError(f"unsupported image size: {image_size}") 172 | else: 173 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 174 | 175 | attention_ds = [] 176 | for res in attention_resolutions.split(","): 177 | attention_ds.append(image_size // int(res)) 178 | 179 | return UNetModel( 180 | image_size=image_size, 181 | in_channels=3, 182 | model_channels=num_channels, 183 | out_channels=(3 if not learn_sigma else 6), 184 | num_res_blocks=num_res_blocks, 185 | attention_resolutions=tuple(attention_ds), 186 | dropout=dropout, 187 | channel_mult=channel_mult, 188 | num_classes=(num_classes if class_cond else None), 189 | use_checkpoint=use_checkpoint, 190 | use_fp16=use_fp16, 191 | num_heads=num_heads, 192 | num_head_channels=num_head_channels, 193 | num_heads_upsample=num_heads_upsample, 194 | use_scale_shift_norm=use_scale_shift_norm, 195 | resblock_updown=resblock_updown, 196 | use_new_attention_order=use_new_attention_order, 197 | multiclass=multiclass 198 | ) 199 | 200 | 201 | def create_classifier_and_diffusion( 202 | image_size, 203 | classifier_use_fp16, 204 | classifier_width, 205 | classifier_depth, 206 | classifier_attention_resolutions, 207 | classifier_use_scale_shift_norm, 208 | classifier_resblock_updown, 209 | classifier_pool, 210 | out_channels, 211 | learn_sigma, 212 | diffusion_steps, 213 | noise_schedule, 214 | timestep_respacing, 215 | use_kl, 216 | predict_xstart, 217 | rescale_timesteps, 218 | rescale_learned_sigmas, 219 | ): 220 | classifier = create_classifier( 221 | image_size, 222 | classifier_use_fp16, 223 | classifier_width, 224 | classifier_depth, 225 | classifier_attention_resolutions, 226 | classifier_use_scale_shift_norm, 227 | classifier_resblock_updown, 228 | classifier_pool, 229 | out_channels, 230 | ) 231 | diffusion = create_gaussian_diffusion( 232 | steps=diffusion_steps, 233 | learn_sigma=learn_sigma, 234 | noise_schedule=noise_schedule, 235 | use_kl=use_kl, 236 | predict_xstart=predict_xstart, 237 | rescale_timesteps=rescale_timesteps, 238 | rescale_learned_sigmas=rescale_learned_sigmas, 239 | timestep_respacing=timestep_respacing, 240 | ) 241 | return classifier, diffusion 242 | 243 | 244 | def create_classifier( 245 | image_size, 246 | classifier_use_fp16, 247 | classifier_width, 248 | classifier_depth, 249 | classifier_attention_resolutions, 250 | classifier_use_scale_shift_norm, 251 | classifier_resblock_updown, 252 | classifier_pool, 253 | out_channels, 254 | ): 255 | if image_size == 512: 256 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 257 | elif image_size == 256: 258 | channel_mult = (1, 1, 2, 2, 4, 4) 259 | elif image_size == 128: 260 | channel_mult = (1, 1, 2, 3, 4) 261 | elif image_size == 64: 262 | channel_mult = (1, 2, 3, 4) 263 | elif image_size == 28: # for mnist 264 | channel_mult = (1, 1, 2) 265 | else: 266 | raise ValueError(f"unsupported image size: {image_size}") 267 | 268 | attention_ds = [] 269 | for res in classifier_attention_resolutions.split(","): 270 | attention_ds.append(image_size // int(res)) 271 | 272 | return EncoderUNetModel( 273 | image_size=image_size, 274 | in_channels=3, 275 | model_channels=classifier_width, 276 | out_channels=out_channels, 277 | num_res_blocks=classifier_depth, 278 | attention_resolutions=tuple(attention_ds), 279 | channel_mult=channel_mult, 280 | use_fp16=classifier_use_fp16, 281 | num_head_channels=64, 282 | use_scale_shift_norm=classifier_use_scale_shift_norm, 283 | resblock_updown=classifier_resblock_updown, 284 | pool=classifier_pool, 285 | ) 286 | 287 | 288 | def sr_model_and_diffusion_defaults(): 289 | res = model_and_diffusion_defaults() 290 | res["large_size"] = 256 291 | res["small_size"] = 64 292 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 293 | for k in res.copy().keys(): 294 | if k not in arg_names: 295 | del res[k] 296 | return res 297 | 298 | 299 | def sr_create_model_and_diffusion( 300 | large_size, 301 | small_size, 302 | class_cond, 303 | learn_sigma, 304 | num_channels, 305 | num_res_blocks, 306 | num_heads, 307 | num_head_channels, 308 | num_heads_upsample, 309 | attention_resolutions, 310 | dropout, 311 | diffusion_steps, 312 | noise_schedule, 313 | timestep_respacing, 314 | use_kl, 315 | predict_xstart, 316 | rescale_timesteps, 317 | rescale_learned_sigmas, 318 | use_checkpoint, 319 | use_scale_shift_norm, 320 | resblock_updown, 321 | use_fp16, 322 | ): 323 | model = sr_create_model( 324 | large_size, 325 | small_size, 326 | num_channels, 327 | num_res_blocks, 328 | learn_sigma=learn_sigma, 329 | class_cond=class_cond, 330 | use_checkpoint=use_checkpoint, 331 | attention_resolutions=attention_resolutions, 332 | num_heads=num_heads, 333 | num_head_channels=num_head_channels, 334 | num_heads_upsample=num_heads_upsample, 335 | use_scale_shift_norm=use_scale_shift_norm, 336 | dropout=dropout, 337 | resblock_updown=resblock_updown, 338 | use_fp16=use_fp16, 339 | ) 340 | diffusion = create_gaussian_diffusion( 341 | steps=diffusion_steps, 342 | learn_sigma=learn_sigma, 343 | noise_schedule=noise_schedule, 344 | use_kl=use_kl, 345 | predict_xstart=predict_xstart, 346 | rescale_timesteps=rescale_timesteps, 347 | rescale_learned_sigmas=rescale_learned_sigmas, 348 | timestep_respacing=timestep_respacing, 349 | ) 350 | return model, diffusion 351 | 352 | 353 | def sr_create_model( 354 | large_size, 355 | small_size, 356 | num_channels, 357 | num_res_blocks, 358 | learn_sigma, 359 | class_cond, 360 | use_checkpoint, 361 | attention_resolutions, 362 | num_heads, 363 | num_head_channels, 364 | num_heads_upsample, 365 | use_scale_shift_norm, 366 | dropout, 367 | resblock_updown, 368 | use_fp16, 369 | ): 370 | _ = small_size # hack to prevent unused variable 371 | 372 | if large_size == 512: 373 | channel_mult = (1, 1, 2, 2, 4, 4) 374 | elif large_size == 256: 375 | channel_mult = (1, 1, 2, 2, 4, 4) 376 | elif large_size == 64: 377 | channel_mult = (1, 2, 3, 4) 378 | else: 379 | raise ValueError(f"unsupported large size: {large_size}") 380 | 381 | attention_ds = [] 382 | for res in attention_resolutions.split(","): 383 | attention_ds.append(large_size // int(res)) 384 | 385 | return SuperResModel( 386 | image_size=large_size, 387 | in_channels=3, 388 | model_channels=num_channels, 389 | out_channels=(3 if not learn_sigma else 6), 390 | num_res_blocks=num_res_blocks, 391 | attention_resolutions=tuple(attention_ds), 392 | dropout=dropout, 393 | channel_mult=channel_mult, 394 | num_classes=(NUM_CLASSES if class_cond else None), 395 | use_checkpoint=use_checkpoint, 396 | num_heads=num_heads, 397 | num_head_channels=num_head_channels, 398 | num_heads_upsample=num_heads_upsample, 399 | use_scale_shift_norm=use_scale_shift_norm, 400 | resblock_updown=resblock_updown, 401 | use_fp16=use_fp16, 402 | ) 403 | 404 | 405 | def create_gaussian_diffusion( 406 | *, 407 | steps=1000, 408 | learn_sigma=False, 409 | sigma_small=False, 410 | noise_schedule="linear", 411 | use_kl=False, 412 | predict_xstart=False, 413 | rescale_timesteps=False, 414 | rescale_learned_sigmas=False, 415 | timestep_respacing="", 416 | ): 417 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 418 | if use_kl: 419 | loss_type = gd.LossType.RESCALED_KL 420 | elif rescale_learned_sigmas: 421 | loss_type = gd.LossType.RESCALED_MSE 422 | else: 423 | loss_type = gd.LossType.MSE 424 | if not timestep_respacing: 425 | timestep_respacing = [steps] 426 | return SpacedDiffusion( 427 | use_timesteps=space_timesteps(steps, timestep_respacing), 428 | betas=betas, 429 | model_mean_type=( 430 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 431 | ), 432 | model_var_type=( 433 | ( 434 | gd.ModelVarType.FIXED_LARGE 435 | if not sigma_small 436 | else gd.ModelVarType.FIXED_SMALL 437 | ) 438 | if not learn_sigma 439 | else gd.ModelVarType.LEARNED_RANGE 440 | ), 441 | loss_type=loss_type, 442 | rescale_timesteps=rescale_timesteps, 443 | ) 444 | 445 | 446 | def add_dict_to_argparser(parser, default_dict): 447 | for k, v in default_dict.items(): 448 | v_type = type(v) 449 | if v is None: 450 | v_type = str 451 | elif isinstance(v, bool): 452 | v_type = str2bool 453 | parser.add_argument(f"--{k}", default=v, type=v_type) 454 | 455 | 456 | def args_to_dict(args, keys): 457 | return {k: getattr(args, k) for k in keys} 458 | 459 | 460 | def str2bool(v): 461 | """ 462 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 463 | """ 464 | if isinstance(v, bool): 465 | return v 466 | if v.lower() in ("yes", "true", "t", "y", "1"): 467 | return True 468 | elif v.lower() in ("no", "false", "f", "n", "0"): 469 | return False 470 | else: 471 | raise argparse.ArgumentTypeError("boolean value expected") 472 | -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | print("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /core/sample_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from scipy import linalg 8 | from os import path as osp 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from torch.nn import functional as F 14 | from torchvision.datasets import ImageFolder 15 | from torchvision.models import vgg19 16 | 17 | from .gaussian_diffusion import _extract_into_tensor 18 | 19 | 20 | # ======================================================= 21 | # Functions 22 | # ======================================================= 23 | 24 | 25 | def load_from_DDP_model(state_dict): 26 | 27 | new_state_dict = {} 28 | for k, v in state_dict.items(): 29 | if k[:7] == 'module.': 30 | k = k[7:] 31 | new_state_dict[k] = v 32 | return new_state_dict 33 | 34 | 35 | # ======================================================= 36 | # Gradient Extraction Functions 37 | # ======================================================= 38 | 39 | 40 | @torch.enable_grad() 41 | def clean_class_cond_fn(x_t, y, classifier, 42 | s, use_logits): 43 | ''' 44 | Computes the classifier gradients for the guidance 45 | 46 | :param x_t: clean instance 47 | :param y: target 48 | :param classifier: classification model 49 | :param s: scaling classifier gradients parameter 50 | :param use_logits: compute the loss over the logits 51 | ''' 52 | 53 | x_in = x_t.detach().requires_grad_(True) 54 | logits = classifier(x_in) 55 | 56 | y = y.to(logits.device).float() 57 | # Select the target logits, 58 | # for those of target 1, we take the logits as they are (sigmoid(logits) = p(y=1 | x)) 59 | # for those of target 0, we take the negative of the logits (sigmoid(-logits) = p(y=0 | x)) 60 | selected = y * logits - (1 - y) * logits 61 | if use_logits: 62 | selected = -selected 63 | else: 64 | selected = -F.logsigmoid(selected) 65 | 66 | selected = selected * s 67 | grads = torch.autograd.grad(selected.sum(), x_in)[0] 68 | 69 | return grads 70 | 71 | 72 | 73 | @torch.enable_grad() 74 | def clean_multiclass_cond_fn(x_t, y, classifier, 75 | s, use_logits): 76 | 77 | x_in = x_t.detach().requires_grad_(True) 78 | selected = classifier(x_in) 79 | 80 | # Select the target logits 81 | if not use_logits: 82 | selected = F.log_softmax(selected, dim=1) 83 | selected = -selected[range(len(y)), y] 84 | selected = selected * s 85 | grads = torch.autograd.grad(selected.sum(), x_in)[0] 86 | 87 | return grads 88 | 89 | 90 | @torch.enable_grad() 91 | def dist_cond_fn(x_tau, z_t, x_t, alpha_t, 92 | l1_loss, l2_loss, 93 | l_perc): 94 | 95 | ''' 96 | Computes the distance loss between x_t, z_t and x_tau 97 | :x_tau: initial image 98 | :z_t: current noisy instance 99 | :x_t: current clean instance 100 | :alpha_t: time dependant constant 101 | ''' 102 | 103 | z_in = z_t.detach().requires_grad_(True) 104 | x_in = x_t.detach().requires_grad_(True) 105 | 106 | m1 = l1_loss * torch.norm(z_in - x_tau, p=1, dim=1).sum() if l1_loss != 0 else 0 107 | m2 = l2_loss * torch.norm(z_in - x_tau, p=2, dim=1).sum() if l2_loss != 0 else 0 108 | mv = l_perc(x_in, x_tau) if l_perc is not None else 0 109 | 110 | if isinstance(m1 + m2 + mv, int): 111 | return 0 112 | 113 | if isinstance(m1 + m2, int): 114 | grads = 0 115 | else: 116 | grads = torch.autograd.grad(m1 + m2, z_in)[0] 117 | 118 | if isinstance(mv, int): 119 | return grads 120 | else: 121 | return grads + torch.autograd.grad(mv, x_in)[0] / alpha_t 122 | 123 | 124 | # ======================================================= 125 | # Sampling Function 126 | # ======================================================= 127 | 128 | 129 | def get_DiME_iterative_sampling(use_sampling=False): 130 | ''' 131 | Returns DiME's main algorithm to construct counterfactuals. 132 | The returned function computes x_t in a recursive way. 133 | Easy way to set the optional parameters into the sampling 134 | function such as the use_sampling flag. 135 | 136 | :param use_sampling: use mu + sigma * N(0,1) when computing 137 | the next iteration when estimating x_t 138 | ''' 139 | @torch.no_grad() 140 | def p_sample_loop(diffusion, 141 | model, 142 | shape, 143 | num_timesteps, 144 | img, 145 | t, 146 | z_t=None, 147 | clip_denoised=True, 148 | model_kwargs=None, 149 | device=None, 150 | class_grad_fn=None, 151 | class_grad_kwargs=None, 152 | dist_grad_fn=None, 153 | dist_grad_kargs=None, 154 | x_t_sampling=True, 155 | is_x_t_sampling=False, 156 | guided_iterations=9999999): 157 | 158 | ''' 159 | :param : 160 | :param diffusion: diffusion algorithm 161 | :param model: DDPM model 162 | :param num_timesteps: tau, or the depth of the noise chain 163 | :param img: instance to be explained 164 | :param t: time variable 165 | :param z_t: noisy instance. If z_t is instantiated then the model 166 | will denoise z_t 167 | :param clip_denoised: clip the noised data to [-1, 1] 168 | :param model_kwargs: useful when the model is conditioned 169 | :param device: torch device 170 | :param class_grad_fn: class function to compute the gradients of the classifier 171 | has at least an input, x_t. 172 | :param class_grad_kwargs: Additional arguments for class_grad_fn 173 | :param dist_grad_fn: Similar as class_grad_fn, uses z_t, x_t, x_tau, and alpha_t as inputs 174 | :param dist_grad_kwargs: Additional args fot dist_grad_fn 175 | :param x_t_sampling: use sampling when computing x_t 176 | :param is_x_t_sampling: useful flag to distinguish when x_t is been generated 177 | :param guided_iterations: Early stop the guided iterations 178 | ''' 179 | 180 | x_t = img.clone() 181 | z_t = diffusion.q_sample(img, t) if z_t is None else z_t 182 | 183 | x_t_steps = [] 184 | z_t_steps = [] 185 | indices = list(range(num_timesteps))[::-1] 186 | 187 | for jdx, i in enumerate(indices): 188 | 189 | t = torch.tensor([i] * shape[0], device=device) 190 | x_t_steps.append(x_t.detach()) 191 | z_t_steps.append(z_t.detach()) 192 | 193 | # out is a dictionary with the following (self-explanatory) keys: 194 | # 'mean', 'variance', 'log_variance' 195 | out = diffusion.p_mean_variance( 196 | model, 197 | z_t, 198 | t, 199 | clip_denoised=clip_denoised, 200 | denoised_fn=None, 201 | model_kwargs=model_kwargs, 202 | ) 203 | 204 | # extract sqrtalphacum 205 | alpha_t = _extract_into_tensor(diffusion.sqrt_alphas_cumprod, 206 | t, shape) 207 | 208 | nonzero_mask = ( 209 | (t != 0).float().view(-1, *([1] * (len(shape) - 1))) 210 | ) # no noise when t == 0 211 | 212 | grads = 0 213 | 214 | if (class_grad_fn is not None) and (guided_iterations > jdx): 215 | grads = grads + class_grad_fn(x_t=x_t, 216 | **class_grad_kwargs) / alpha_t 217 | 218 | if (dist_grad_fn is not None) and (guided_iterations > jdx): 219 | grads = grads + dist_grad_fn(z_t=z_t, 220 | x_tau=img, 221 | x_t=x_t, 222 | alpha_t=alpha_t, 223 | **dist_grad_kargs) 224 | 225 | out["mean"] = ( 226 | out["mean"].float() - 227 | out["variance"] * grads 228 | ) 229 | 230 | if not x_t_sampling: 231 | z_t = out["mean"] 232 | 233 | else: 234 | z_t = ( 235 | out["mean"] + 236 | nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img) 237 | ) 238 | 239 | # produce x_t in a brute force manner 240 | if (num_timesteps - (jdx + 1) > 0) and (class_grad_fn is not None) and (dist_grad_fn is not None) and (guided_iterations > jdx): 241 | x_t = p_sample_loop( 242 | diffusion=diffusion, 243 | model=model, 244 | model_kwargs=model_kwargs, 245 | shape=shape, 246 | num_timesteps=num_timesteps - (jdx + 1), 247 | img=img, 248 | t=None, 249 | z_t=z_t, 250 | clip_denoised=True, 251 | device=device, 252 | x_t_sampling=use_sampling, 253 | is_x_t_sampling=True, 254 | )[0] 255 | 256 | return z_t, x_t_steps, z_t_steps 257 | 258 | return p_sample_loop 259 | 260 | 261 | # ======================================================= 262 | # Classes 263 | # ======================================================= 264 | 265 | 266 | class ChunkedDataset: 267 | def __init__(self, dataset, chunk=0, num_chunks=1): 268 | self.dataset = dataset 269 | self.indexes = [i for i in range(len(dataset)) if (i % num_chunks) == chunk] 270 | 271 | def __len__(self): 272 | return len(self.indexes) 273 | 274 | def __getitem__(self, idx): 275 | i = [self.indexes[idx]] 276 | i += list(self.dataset[i[0]]) 277 | return i 278 | 279 | 280 | class ImageSaver(): 281 | def __init__(self, output_path, exp_name, extention='.jpg'): 282 | self.output_path = output_path 283 | self.exp_name = exp_name 284 | self.idx = 0 285 | self.extention = extention 286 | self.construct_directory() 287 | 288 | def construct_directory(self): 289 | 290 | os.makedirs(osp.join(self.output_path, 'Original', 'Correct'), exist_ok=True) 291 | os.makedirs(osp.join(self.output_path, 'Original', 'Incorrect'), exist_ok=True) 292 | 293 | for clst, cf, subf in itertools.product(['CC', 'IC'], 294 | ['CCF', 'ICF'], 295 | ['CF', 'Noise', 'Info', 'SM']): 296 | os.makedirs(osp.join(self.output_path, 'Results', 297 | self.exp_name, clst, 298 | cf, subf), 299 | exist_ok=True) 300 | 301 | def __call__(self, imgs, cfs, noises, target, label, 302 | pred, pred_cf, bkl, l_1, indexes=None, masks=None): 303 | 304 | for idx in range(len(imgs)): 305 | current_idx = indexes[idx].item() if indexes is not None else idx + self.idx 306 | mask = None if masks is None else masks[idx] 307 | self.save_img(img=imgs[idx], 308 | cf=cfs[idx], 309 | noise=noises[idx], 310 | idx=current_idx, 311 | target=target[idx].item(), 312 | label=label[idx].item(), 313 | pred=pred[idx].item(), 314 | pred_cf=pred_cf[idx].item(), 315 | bkl=bkl[idx].item(), 316 | l_1=l_1[idx].item(), 317 | mask=mask) 318 | 319 | self.idx += len(imgs) 320 | 321 | @staticmethod 322 | def select_folder(label, target, pred, pred_cf): 323 | folder = osp.join('CC' if label == pred else 'IC', 324 | 'CCF' if target == pred_cf else 'ICF') 325 | return folder 326 | 327 | @staticmethod 328 | def preprocess(img): 329 | ''' 330 | remove last dimension if it is 1 331 | ''' 332 | if img.shape[2] > 1: 333 | return img 334 | else: 335 | return np.squeeze(img, 2) 336 | 337 | def save_img(self, img, cf, noise, idx, target, label, 338 | pred, pred_cf, bkl, l_1, mask): 339 | folder = self.select_folder(label, target, pred, pred_cf) 340 | output_path = osp.join(self.output_path, 'Results', 341 | self.exp_name, folder) 342 | img_name = f'{idx}'.zfill(7) 343 | orig_path = osp.join(self.output_path, 'Original', 344 | 'Correct' if label == pred else 'Incorrect', 345 | img_name + self.extention) 346 | 347 | if mask is None: 348 | l0 = np.abs(img.astype('float') - cf.astype('float')) 349 | l0 = l0.sum(2, keepdims=True) 350 | l0 = 255 * l0 / l0.max() 351 | l0 = np.concatenate([l0] * img.shape[2], axis=2).astype('uint8') 352 | l0 = Image.fromarray(self.preprocess(l0)) 353 | l0.save(osp.join(output_path, 'SM', img_name + self.extention)) 354 | else: 355 | mask = mask.astype('uint8') * 255 356 | mask = Image.fromarray(mask) 357 | mask.save(osp.join(output_path, 'SM', img_name + self.extention)) 358 | 359 | img = Image.fromarray(self.preprocess(img)) 360 | img.save(orig_path) 361 | 362 | cf = Image.fromarray(self.preprocess(cf)) 363 | cf.save(osp.join(output_path, 'CF', img_name + self.extention)) 364 | 365 | noise = Image.fromarray(self.preprocess(noise)) 366 | noise.save(osp.join(output_path, 'Noise', img_name + self.extention)) 367 | 368 | 369 | to_write = (f'label: {label}' + 370 | f'\npred: {pred}' + 371 | f'\ntarget: {target}' + 372 | f'\ncf pred: {pred_cf}' + 373 | f'\nBKL: {bkl}' + 374 | f'\nl_1: {l_1}') 375 | with open(osp.join(output_path, 'Info', img_name + '.txt'), 'w') as f: 376 | f.write(to_write) 377 | 378 | 379 | class Normalizer(nn.Module): 380 | def __init__(self, classifier): 381 | super().__init__() 382 | self.classifier = classifier 383 | self.register_buffer('mu', torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1)) 384 | self.register_buffer('sigma', torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1)) 385 | 386 | def forward(self, x): 387 | x = (torch.clamp(x, -1, 1) + 1) / 2 388 | x = (x - self.mu) / self.sigma 389 | return self.classifier(x) 390 | 391 | 392 | class SingleLabel(ImageFolder): 393 | def __init__(self, query_label, **kwargs): 394 | super().__init__(**kwargs) 395 | self.query_label = query_label 396 | 397 | # remove those instances that do no have the 398 | # query label 399 | 400 | old_len = len(self) 401 | instances = [self.targets[i] == query_label 402 | for i in range(old_len)] 403 | self.samples = [self.samples[i] 404 | for i in range(old_len) if instances[i]] 405 | self.targets = [self.targets[i] 406 | for i in range(old_len) if instances[i]] 407 | self.imgs = [self.imgs[i] 408 | for i in range(old_len) if instances[i]] 409 | 410 | 411 | class SlowSingleLabel(): 412 | def __init__(self, query_label, dataset, maxlen=float('inf')): 413 | self.dataset = dataset 414 | self.indexes = [] 415 | if isinstance(dataset, ImageFolder): 416 | self.indexes = np.where(np.array(dataset.targets) == query_label)[0] 417 | self.indexes = self.indexes[:maxlen] 418 | else: 419 | print('Slow route. This may take some time!') 420 | if query_label != -1: 421 | for idx, (_, l) in enumerate(tqdm(dataset)): 422 | 423 | l = l['y'] if isinstance(l, dict) else l 424 | if l == query_label: 425 | self.indexes.append(idx) 426 | 427 | if len(self.indexes) == maxlen: 428 | break 429 | else: 430 | self.indexes = list(range(min(maxlen, len(dataset)))) 431 | 432 | def __len__(self): 433 | return len(self.indexes) 434 | 435 | def __getitem__(self, idx): 436 | return self.dataset[self.indexes[idx]] 437 | 438 | 439 | class PerceptualLoss(nn.Module): 440 | def __init__(self, layer, c): 441 | super().__init__() 442 | self.c = c 443 | vgg19_model = vgg19(pretrained=True) 444 | vgg19_model = nn.Sequential(*list(vgg19_model.features.children())[:layer]) 445 | self.model = Normalizer(vgg19_model) 446 | self.model.eval() 447 | 448 | def forward(self, x0, x1): 449 | B = x0.size(0) 450 | 451 | l = F.mse_loss(self.model(x0).view(B, -1), self.model(x1).view(B, -1), 452 | reduction='none').mean(dim=1) 453 | return self.c * l.sum() 454 | 455 | 456 | class extra_data_saver(): 457 | def __init__(self, output_path, exp_name): 458 | self.idx = 0 459 | self.exp_name = exp_name 460 | 461 | def __call__(self, x_ts, indexes=None): 462 | n_images = x_ts[0].size(0) 463 | n_steps = len(x_ts) 464 | 465 | for i in range(n_images): 466 | current_idx = indexes[i].item() if indexes is not None else i + self.idx 467 | os.makedirs(osp.join(self.output_path, self.exp_name, str(current_idx).zfill(6)), exist_ok=True) 468 | 469 | for j in range(n_steps): 470 | cf = x_ts[j][i, ...] 471 | 472 | # renormalize the image 473 | cf = ((cf + 1) * 127.5).clamp(0, 255).to(torch.uint8) 474 | cf = cf.permute(1, 2, 0) 475 | cf = cf.contiguous().cpu().numpy() 476 | cf = Image.fromarray(cf) 477 | cf.save(osp.join(self.output_path, self.exp_name, str(current_idx).zfill(6), str(j).zfill(4) + '.jpg')) 478 | 479 | self.idx += n_images 480 | 481 | 482 | class X_T_Saver(extra_data_saver): 483 | def __init__(self, output_path, exp_path, extention='.jpg'): 484 | super().__init__(output_path, exp_path) 485 | self.output_path = osp.join(output_path, 'x_t') 486 | 487 | 488 | class Z_T_Saver(extra_data_saver): 489 | def __init__(self, output_path, exp_path, extention='.jpg'): 490 | super().__init__(output_path, exp_path) 491 | self.output_path = osp.join(output_path, 'z_t') 492 | 493 | 494 | class Mask_Saver(extra_data_saver): 495 | def __init__(self, output_path, exp_path, extention='.jpg'): 496 | super().__init__(output_path, exp_path) 497 | self.output_path = osp.join(output_path, 'masks') 498 | 499 | def __call__(self, masks, indexes=None): 500 | ''' 501 | Masks are non-binarized 502 | ''' 503 | n_images = masks[0].size(0) 504 | n_steps = len(masks) 505 | 506 | for i in range(n_images): 507 | current_idx = indexes[i].item() if indexes is not None else i + self.idx 508 | os.makedirs(osp.join(self.output_path, self.exp_name, str(current_idx).zfill(6)), exist_ok=True) 509 | 510 | for j in range(n_steps): 511 | cf = masks[j][i, ...] 512 | cf = torch.cat((cf, (cf > 0.5).to(cf.dtype)), dim=-1) 513 | 514 | # renormalize the image 515 | cf = (cf * 255).clamp(0, 255).to(torch.uint8) 516 | cf = cf.permute(1, 2, 0) 517 | cf = cf.squeeze(dim=-1) 518 | cf = cf.contiguous().cpu().numpy() 519 | cf = Image.fromarray(cf) 520 | cf.save(osp.join(self.output_path, self.exp_name, str(current_idx).zfill(6), str(j).zfill(4) + self.extention)) 521 | 522 | self.idx += n_images 523 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import math 4 | import random 5 | import argparse 6 | import itertools 7 | import numpy as np 8 | import os.path as osp 9 | import matplotlib.pyplot as plt 10 | 11 | from PIL import Image 12 | from time import time 13 | from os import path as osp 14 | from multiprocessing import Pool 15 | 16 | import torch 17 | from torch.utils import data 18 | 19 | from torchvision import transforms 20 | from torchvision import datasets 21 | 22 | from core import dist_util 23 | from core.script_util import ( 24 | model_and_diffusion_defaults, 25 | create_model_and_diffusion, 26 | create_classifier, 27 | args_to_dict, 28 | add_dict_to_argparser, 29 | ) 30 | from core.sample_utils import ( 31 | get_DiME_iterative_sampling, 32 | clean_class_cond_fn, 33 | dist_cond_fn, 34 | ImageSaver, 35 | SlowSingleLabel, 36 | Normalizer, 37 | load_from_DDP_model, 38 | PerceptualLoss, 39 | X_T_Saver, 40 | Z_T_Saver, 41 | ChunkedDataset, 42 | ) 43 | from core.image_datasets import CelebADataset, CelebAMiniVal 44 | from core.gaussian_diffusion import _extract_into_tensor 45 | from core.classifier.densenet import ClassificationModel 46 | 47 | import matplotlib 48 | matplotlib.use('Agg') # to disable display 49 | 50 | # ======================================================= 51 | # ======================================================= 52 | # Functions 53 | # ======================================================= 54 | # ======================================================= 55 | 56 | 57 | def create_args(): 58 | defaults = dict( 59 | clip_denoised=True, 60 | batch_size=16, 61 | gpu='0', 62 | num_batches=50, 63 | use_train=False, 64 | dataset='CelebA', 65 | 66 | # path args 67 | output_path='', 68 | classifier_path='models/classifier.pth', 69 | oracle_path='models/oracle.pth', 70 | model_path="models/ddpm-celeba.pt", 71 | data_dir="", 72 | exp_name='', 73 | 74 | # sampling args 75 | classifier_scales='8,10,15', 76 | seed=4, 77 | query_label=-1, 78 | target_label=-1, 79 | use_ddim=False, 80 | start_step=60, 81 | use_logits=False, 82 | l1_loss=0.0, 83 | l2_loss=0.0, 84 | l_perc=0.0, 85 | l_perc_layer=1, 86 | use_sampling_on_x_t=True, 87 | sampling_scale=1., # use this flag to rescale the variance of the noise 88 | guided_iterations=9999999, # set a high number to do all iteration in a guided way 89 | 90 | # evaluation args 91 | merge_and_eval=False, # when all chunks have finished, run it with this flag 92 | 93 | # misc args 94 | num_chunks=1, 95 | chunk=0, 96 | save_x_t=False, 97 | save_z_t=False, 98 | save_images=True, 99 | ) 100 | defaults.update(model_and_diffusion_defaults()) 101 | parser = argparse.ArgumentParser() 102 | add_dict_to_argparser(parser, defaults) 103 | return parser.parse_args() 104 | 105 | 106 | # ======================================================= 107 | # ======================================================= 108 | # Merge all chunks' information and compute the 109 | # overall metrics 110 | # ======================================================= 111 | # ======================================================= 112 | 113 | 114 | def mean(array): 115 | m = np.mean(array).item() 116 | return 0 if math.isnan(m) else m 117 | 118 | 119 | def merge_and_compute_overall_metrics(args, device): 120 | 121 | def div(q, p): 122 | if p == 0: 123 | return 0 124 | return q / p 125 | 126 | print('Merging all results ...') 127 | 128 | # read all yaml files containing the info to add them together 129 | summary = { 130 | 'class-cor': {'cf-cor': {'bkl': 0, 'l_1': 0, 'n': 0, 'FVA': 0, 'MNAC': 0}, 131 | 'cf-inc': {'bkl': 0, 'l_1': 0, 'n': 0, 'FVA': 0, 'MNAC': 0}, 132 | 'bkl': 0, 'l_1': 0, 'n': 0, 'FVA': 0, 'MNAC': 0}, 133 | 'class-inc': {'cf-cor': {'bkl': 0, 'l_1': 0, 'n': 0, 'FVA': 0, 'MNAC': 0}, 134 | 'cf-inc': {'bkl': 0, 'l_1': 0, 'n': 0, 'FVA': 0, 'MNAC': 0}, 135 | 'bkl': 0, 'l_1': 0, 'n': 0, 'FVA': 0, 'MNAC': 0}, 136 | 'cf-cor': {'bkl': 0, 'l_1': 0, 'n': 0, 'FVA': 0, 'MNAC': 0}, 137 | 'cf-inc': {'bkl': 0, 'l_1': 0, 'n': 0, 'FVA': 0, 'MNAC': 0}, 138 | 'clean acc': 0, 139 | 'cf acc': 0, 140 | 'bkl': 0, 'l_1': 0, 'n': 0, 'FVA': 0, 'MNAC': 0, 141 | } 142 | 143 | for chunk in range(args.num_chunks): 144 | yaml_path = osp.join(args.output_path, 'Results', args.exp_name, 145 | f'chunk-{chunk}_num-chunks-{args.num_chunks}_summary.yaml') 146 | 147 | with open(yaml_path, 'r') as f: 148 | chunk_summary = yaml.load(f, Loader=yaml.FullLoader) 149 | 150 | summary['clean acc'] += chunk_summary['clean acc'] * chunk_summary['n'] 151 | summary['cf acc'] += chunk_summary['cf acc'] * chunk_summary['n'] 152 | 153 | summary['n'] += chunk_summary['n'] 154 | 155 | summary['class-cor']['n'] += chunk_summary['class-cor']['n'] 156 | summary['class-inc']['n'] += chunk_summary['class-inc']['n'] 157 | 158 | summary['cf-cor']['n'] += chunk_summary['cf-cor']['n'] 159 | summary['cf-inc']['n'] += chunk_summary['cf-inc']['n'] 160 | 161 | summary['class-cor']['cf-cor']['n'] += chunk_summary['class-cor']['cf-cor']['n'] 162 | summary['class-cor']['cf-inc']['n'] += chunk_summary['class-cor']['cf-inc']['n'] 163 | summary['class-inc']['cf-cor']['n'] += chunk_summary['class-inc']['cf-cor']['n'] 164 | summary['class-inc']['cf-inc']['n'] += chunk_summary['class-inc']['cf-inc']['n'] 165 | 166 | 167 | for k in ['bkl', 'l_1', 'FVA', 'MNAC']: 168 | summary[k] += chunk_summary[k] * chunk_summary['n'] 169 | 170 | summary['class-cor'][k] += chunk_summary['class-cor'][k] * chunk_summary['class-cor']['n'] 171 | summary['class-inc'][k] += chunk_summary['class-inc'][k] * chunk_summary['class-inc']['n'] 172 | 173 | summary['cf-cor'][k] += chunk_summary['cf-cor'][k] * chunk_summary['cf-cor']['n'] 174 | summary['cf-inc'][k] += chunk_summary['cf-inc'][k] * chunk_summary['cf-inc']['n'] 175 | 176 | summary['class-cor']['cf-cor'][k] += chunk_summary['class-cor']['cf-cor'][k] * chunk_summary['class-cor']['cf-cor']['n'] 177 | summary['class-cor']['cf-inc'][k] += chunk_summary['class-cor']['cf-inc'][k] * chunk_summary['class-cor']['cf-inc']['n'] 178 | summary['class-inc']['cf-cor'][k] += chunk_summary['class-inc']['cf-cor'][k] * chunk_summary['class-inc']['cf-cor']['n'] 179 | summary['class-inc']['cf-inc'][k] += chunk_summary['class-inc']['cf-inc'][k] * chunk_summary['class-inc']['cf-inc']['n'] 180 | 181 | for k in ['cf acc', 'clean acc']: 182 | summary[k] = div(summary[k], summary['n']) 183 | 184 | for k in ['bkl', 'l_1', 'FVA', 'MNAC']: 185 | summary[k] = div(summary[k], summary['n']) 186 | 187 | summary['class-cor'][k] = div(summary['class-cor'][k], summary['class-cor']['n']) 188 | summary['class-inc'][k] = div(summary['class-inc'][k], summary['class-inc']['n']) 189 | 190 | summary['cf-cor'][k] = div(summary['cf-cor'][k], summary['cf-cor']['n']) 191 | summary['cf-inc'][k] = div(summary['cf-inc'][k], summary['cf-inc']['n']) 192 | 193 | summary['class-cor']['cf-cor'][k] = div(summary['class-cor']['cf-cor'][k], summary['class-cor']['cf-cor']['n']) 194 | summary['class-cor']['cf-inc'][k] = div(summary['class-cor']['cf-inc'][k], summary['class-cor']['cf-inc']['n']) 195 | summary['class-inc']['cf-cor'][k] = div(summary['class-inc']['cf-cor'][k], summary['class-inc']['cf-cor']['n']) 196 | summary['class-inc']['cf-inc'][k] = div(summary['class-inc']['cf-inc'][k], summary['class-inc']['cf-inc']['n']) 197 | 198 | # summary is ready to save 199 | print('done') 200 | print('Acc on the set:', summary['clean acc']) 201 | print('CF Acc on the set:', summary['cf acc']) 202 | 203 | with open(osp.join(args.output_path, 'Results', args.exp_name, 'summary.yaml'), 'w') as f: 204 | yaml.dump(summary, f) 205 | 206 | 207 | # ======================================================= 208 | # ======================================================= 209 | # Main 210 | # ======================================================= 211 | # ======================================================= 212 | 213 | 214 | def main(): 215 | 216 | args = create_args() 217 | print(args) 218 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 219 | os.makedirs(osp.join(args.output_path, 'Results', args.exp_name), 220 | exist_ok=True) 221 | 222 | # ======================================== 223 | # Evaluate all feature in case of 224 | if args.merge_and_eval: 225 | merge_and_compute_overall_metrics(args, dist_util.dev()) 226 | return # finish the script 227 | 228 | # ======================================== 229 | # Set seeds 230 | 231 | torch.manual_seed(args.seed) 232 | random.seed(args.seed) 233 | np.random.seed(args.seed) 234 | 235 | # ======================================== 236 | # Load Dataset 237 | 238 | if args.dataset == 'CelebA': 239 | dataset = CelebADataset(image_size=args.image_size, 240 | data_dir=args.data_dir, 241 | partition='train' if args.use_train else 'val', 242 | random_crop=False, 243 | random_flip=False, 244 | query_label=args.query_label) 245 | 246 | elif args.dataset == 'CelebAMV': 247 | dataset = CelebAMiniVal(image_size=args.image_size, 248 | data_dir=args.data_dir, 249 | random_crop=False, 250 | random_flip=False, 251 | query_label=args.query_label) 252 | 253 | 254 | if len(dataset) - args.batch_size * args.num_batches > 0: 255 | dataset = SlowSingleLabel(query_label=1 - args.target_label if args.target_label != -1 else -1, 256 | dataset=dataset, 257 | maxlen=args.batch_size * args.num_batches) 258 | 259 | # breaks the dataset into chunks 260 | dataset = ChunkedDataset(dataset=dataset, 261 | chunk=args.chunk, 262 | num_chunks=args.num_chunks) 263 | 264 | print('Images on the dataset:', len(dataset)) 265 | 266 | loader = data.DataLoader(dataset, batch_size=args.batch_size, 267 | shuffle=False, 268 | num_workers=4, pin_memory=True) 269 | 270 | # ======================================== 271 | # load models 272 | 273 | print('Loading Model and diffusion model') 274 | model, diffusion = create_model_and_diffusion( 275 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 276 | ) 277 | model.load_state_dict( 278 | dist_util.load_state_dict(args.model_path, map_location="cpu") 279 | ) 280 | model.to(dist_util.dev()) 281 | if args.use_fp16: 282 | model.convert_to_fp16() 283 | model.eval() 284 | 285 | def model_fn(x, t, y=None): 286 | assert y is not None 287 | return model(x, t, y if args.class_cond else None) 288 | 289 | print('Loading Classifier') 290 | 291 | classifier = ClassificationModel(args.classifier_path, args.query_label).to(dist_util.dev()) 292 | classifier.eval() 293 | 294 | # ======================================== 295 | # Distance losses 296 | 297 | if args.l_perc != 0: 298 | print('Loading Perceptual Loss') 299 | vggloss = PerceptualLoss(layer=args.l_perc_layer, 300 | c=args.l_perc).to(dist_util.dev()) 301 | vggloss.eval() 302 | else: 303 | vggloss = None 304 | 305 | # ======================================== 306 | # get custom function for the forward phase 307 | # and other variables of interest 308 | 309 | sample_fn = get_DiME_iterative_sampling(use_sampling=args.use_sampling_on_x_t) 310 | 311 | x_t_saver = X_T_Saver(args.output_path, args.exp_name) if args.save_x_t else None 312 | z_t_saver = Z_T_Saver(args.output_path, args.exp_name) if args.save_z_t else None 313 | save_imgs = ImageSaver(args.output_path, args.exp_name, extention='.jpg') if args.save_images else None 314 | 315 | current_idx = 0 316 | start_time = time() 317 | 318 | stats = { 319 | 'n': 0, 320 | 'flipped': 0, 321 | 'bkl': [], 322 | 'l_1': [], 323 | 'pred': [], 324 | 'cf pred': [], 325 | 'target': [], 326 | 'label': [], 327 | } 328 | 329 | acc = 0 330 | n = 0 331 | classifier_scales = [float(x) for x in args.classifier_scales.split(',')] 332 | 333 | print('Starting Image Generation') 334 | for idx, (indexes, img, lab) in enumerate(loader): 335 | print(f'[Chunk {args.chunk + 1} / {args.num_chunks}] {idx} / {min(args.num_batches, len(loader))} | Time: {int(time() - start_time)}s') 336 | 337 | img = img.to(dist_util.dev()) 338 | I = (img / 2) + 0.5 339 | lab = lab.to(dist_util.dev(), dtype=torch.long) 340 | t = torch.zeros(img.size(0), device=dist_util.dev(), 341 | dtype=torch.long) 342 | 343 | # Initial Classification, no noise included 344 | with torch.no_grad(): 345 | logits = classifier(img) 346 | pred = (logits > 0).long() 347 | 348 | acc += (pred == lab).sum().item() 349 | n += lab.size(0) 350 | 351 | # as the model is binary, the target will always be the inverse of the prediction 352 | target = 1 - pred 353 | 354 | t = torch.ones_like(t) * args.start_step 355 | 356 | # add noise to the input image 357 | noise_img = diffusion.q_sample(img, t) 358 | 359 | transformed = torch.zeros_like(lab).bool() 360 | 361 | for jdx, classifier_scale in enumerate(classifier_scales): 362 | 363 | # choose the target label 364 | model_kwargs = {} 365 | model_kwargs['y'] = target[~transformed] 366 | 367 | # sample image from the noisy_img 368 | cfs, xs_t_s, zs_t_s = sample_fn( 369 | diffusion, 370 | model_fn, 371 | img[~transformed, ...].shape, 372 | args.start_step, 373 | img[~transformed, ...], 374 | t, 375 | z_t=noise_img[~transformed, ...], 376 | clip_denoised=args.clip_denoised, 377 | model_kwargs=model_kwargs, 378 | device=dist_util.dev(), 379 | class_grad_fn=clean_class_cond_fn, 380 | class_grad_kwargs={'y': target[~transformed], 381 | 'classifier': classifier, 382 | 's': classifier_scale, 383 | 'use_logits': args.use_logits}, 384 | dist_grad_fn=dist_cond_fn, 385 | dist_grad_kargs={'l1_loss': args.l1_loss, 386 | 'l2_loss': args.l2_loss, 387 | 'l_perc': vggloss}, 388 | guided_iterations=args.guided_iterations, 389 | is_x_t_sampling=False 390 | ) 391 | 392 | # evaluate the cf and check whether the model flipped the prediction 393 | with torch.no_grad(): 394 | cfsl = classifier(cfs) 395 | cfsp = cfsl > 0 396 | 397 | if jdx == 0: 398 | cf = cfs.clone().detach() 399 | x_t_s = [xp.clone().detach() for xp in xs_t_s] 400 | z_t_s = [zp.clone().detach() for zp in zs_t_s] 401 | 402 | cf[~transformed] = cfs 403 | for kdx in range(len(x_t_s)): 404 | x_t_s[kdx][~transformed] = xs_t_s[kdx] 405 | z_t_s[kdx][~transformed] = zs_t_s[kdx] 406 | transformed[~transformed] = target[~transformed] == cfsp 407 | 408 | if transformed.float().sum().item() == transformed.size(0): 409 | break 410 | 411 | if args.save_x_t: 412 | x_t_saver(x_t_s, indexes=indexes) 413 | 414 | if args.save_z_t: 415 | z_t_saver(z_t_s, indexes=indexes) 416 | 417 | with torch.no_grad(): 418 | logits_cf = classifier(cf) 419 | pred_cf = (logits_cf > 0).long() 420 | 421 | # process images 422 | cf = ((cf + 1) * 127.5).clamp(0, 255).to(torch.uint8) 423 | cf = cf.permute(0, 2, 3, 1) 424 | cf = cf.contiguous().cpu() 425 | 426 | I = (I * 255).to(torch.uint8) 427 | I = I.permute(0, 2, 3, 1) 428 | I = I.contiguous().cpu() 429 | 430 | noise_img = ((noise_img + 1) * 127.5).clamp(0, 255).to(torch.uint8) 431 | noise_img = noise_img.permute(0, 2, 3, 1) 432 | noise_img = noise_img.contiguous().cpu() 433 | 434 | # add metrics 435 | dist_cf = torch.sigmoid(logits_cf) 436 | dist_cf[target == 0] = 1 - dist_cf[target == 0] 437 | bkl = (1 - dist_cf).detach().cpu() 438 | 439 | # dists 440 | I_f = (I.to(dtype=torch.float) / 255).view(I.size(0), -1) 441 | cf_f = (cf.to(dtype=torch.float) / 255).view(I.size(0), -1) 442 | l_1 = (I_f - cf_f).abs().mean(dim=1).detach().cpu() 443 | 444 | stats['l_1'].append(l_1) 445 | stats['n'] += I.size(0) 446 | stats['bkl'].append(bkl) 447 | stats['flipped'] += (pred_cf == target).sum().item() 448 | stats['cf pred'].append(pred_cf.detach().cpu()) 449 | stats['target'].append(target.detach().cpu()) 450 | stats['label'].append(lab.detach().cpu()) 451 | stats['pred'].append(pred.detach().cpu()) 452 | 453 | if args.save_images: 454 | save_imgs(I.numpy(), cf.numpy(), noise_img.numpy(), 455 | target, lab, pred, pred_cf, 456 | bkl.numpy(), 457 | l_1, indexes=indexes.numpy()) 458 | 459 | if (idx + 1) == min(args.num_batches, len(loader)): 460 | print(f'[Chunk {args.chunk + 1} / {args.num_chunks}] {idx + 1} / {min(args.num_batches, len(loader))} | Time: {int(time() - start_time)}s') 461 | print('\nDone') 462 | break 463 | 464 | current_idx += I.size(0) 465 | 466 | # write summary for all four combinations 467 | summary = { 468 | 'class-cor': {'cf-cor': {'bkl': 0, 'l_1': 0, 'n': 0}, 469 | 'cf-inc': {'bkl': 0, 'l_1': 0, 'n': 0}, 470 | 'bkl': 0, 'l_1': 0, 'n': 0}, 471 | 'class-inc': {'cf-cor': {'bkl': 0, 'l_1': 0, 'n': 0}, 472 | 'cf-inc': {'bkl': 0, 'l_1': 0, 'n': 0}, 473 | 'bkl': 0, 'l_1': 0, 'n': 0}, 474 | 'cf-cor': {'bkl': 0, 'l_1': 0, 'n': 0}, 475 | 'cf-inc': {'bkl': 0, 'l_1': 0, 'n': 0}, 476 | 'clean acc': 100 * acc / n, 477 | 'cf acc': stats['flipped'] / n, 478 | 'bkl': 0, 'l_1': 0, 'n': 0, 'FVA': 0, 'MNAC': 0, 479 | } 480 | 481 | for k in stats.keys(): 482 | if k in ['flipped', 'n']: 483 | continue 484 | stats[k] = torch.cat(stats[k]).numpy() 485 | 486 | for k in ['bkl', 'l_1']: 487 | 488 | summary['class-cor']['cf-cor'][k] = mean(stats[k][(stats['label'] == stats['pred']) & (stats['target'] == stats['cf pred'])]) 489 | summary['class-inc']['cf-cor'][k] = mean(stats[k][(stats['label'] != stats['pred']) & (stats['target'] == stats['cf pred'])]) 490 | summary['class-cor']['cf-inc'][k] = mean(stats[k][(stats['label'] == stats['pred']) & (stats['target'] != stats['cf pred'])]) 491 | summary['class-inc']['cf-inc'][k] = mean(stats[k][(stats['label'] != stats['pred']) & (stats['target'] != stats['cf pred'])]) 492 | 493 | summary['class-cor'][k] = mean(stats[k][stats['label'] == stats['pred']]) 494 | summary['class-inc'][k] = mean(stats[k][stats['label'] != stats['pred']]) 495 | 496 | summary['cf-cor'][k] = mean(stats[k][stats['target'] == stats['cf pred']]) 497 | summary['cf-inc'][k] = mean(stats[k][stats['target'] != stats['cf pred']]) 498 | 499 | summary[k] = mean(stats[k]) 500 | 501 | summary['class-cor']['cf-cor']['n'] = len(stats[k][(stats['label'] == stats['pred']) & (stats['target'] == stats['cf pred'])]) 502 | summary['class-inc']['cf-cor']['n'] = len(stats[k][(stats['label'] != stats['pred']) & (stats['target'] == stats['cf pred'])]) 503 | summary['class-cor']['cf-inc']['n'] = len(stats[k][(stats['label'] == stats['pred']) & (stats['target'] != stats['cf pred'])]) 504 | summary['class-inc']['cf-inc']['n'] = len(stats[k][(stats['label'] != stats['pred']) & (stats['target'] != stats['cf pred'])]) 505 | 506 | summary['class-cor']['n'] = len(stats[k][stats['label'] == stats['pred']]) 507 | summary['class-inc']['n'] = len(stats[k][stats['label'] != stats['pred']]) 508 | 509 | summary['cf-cor']['n'] = len(stats[k][stats['target'] == stats['cf pred']]) 510 | summary['cf-inc']['n'] = len(stats[k][stats['target'] != stats['cf pred']]) 511 | 512 | summary['n'] = n 513 | 514 | print('ACC ON THIS SET:', 100 * acc / n) 515 | stats['acc'] = 100 * acc / n 516 | 517 | prefix = f'chunk-{args.chunk}_num-chunks-{args.num_chunks}_' if args.num_chunks != 1 else '' 518 | torch.save(stats, osp.join(args.output_path, 'Results', args.exp_name, prefix + 'stats.pth')) 519 | 520 | # save summary 521 | with open(osp.join(args.output_path, 'Results', args.exp_name, prefix + 'summary.yaml'), 'w') as f: 522 | yaml.dump(summary, f) 523 | 524 | 525 | if __name__ == '__main__': 526 | main() --------------------------------------------------------------------------------