├── ACKNOWLEDGEMENTS ├── AM ├── .DS_Store ├── diffusion.py ├── dynamics.py ├── runner.py ├── samplers.py └── util.py ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── asset ├── AGMvsFM.gif ├── cond_gen.001.png ├── cond_gen.jpg.001.jpeg ├── cond_gen.pdf ├── cond_gen2.001.png └── sampling_hop.png.001.png ├── configs ├── .DS_Store ├── afhqv2_config.py ├── cifar10_config.py ├── imagenet64_config.py └── toy_config.py ├── dataset ├── AFHQv2.py ├── StrokeData │ ├── testFig0.png │ ├── testFig0_impainting.png │ ├── testFig1.png │ └── testFig1_impainting.png ├── __init__.py ├── cifar10.py ├── imagenet64.py └── spiral.py ├── edm ├── augment.py ├── dataset.py ├── dataset_tool.py ├── distributed_util.py ├── dnnlib │ ├── __init__.py │ └── util.py ├── fid.py ├── logger.py └── torch_utils │ ├── __init__.py │ ├── distributed.py │ ├── misc.py │ ├── persistence.py │ └── training_stats.py ├── networks ├── .DS_Store ├── edm │ └── ncsnpp.py ├── get_network.py └── network.py ├── plot_util.py ├── sampling.py ├── scripts ├── AFHQv2.sh ├── cifar10.sh ├── example.sh ├── imagenet64.sh ├── release.sh └── toy.sh ├── setup ├── conda_install.sh ├── download_datasets.py ├── environments.yml ├── requirement.txt └── setup.sh └── train.py /AM/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/AM/.DS_Store -------------------------------------------------------------------------------- /AM/diffusion.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | from .samplers import SamplerWrapper 8 | from .dynamics import TVgMPC 9 | import torch.nn.functional as F 10 | 11 | class Diffusion(): 12 | def __init__(self, opt, device): 13 | self.opt = opt 14 | self.device = device 15 | self.varx = opt.varx 16 | self.varv = opt.varv 17 | dyn_kargs = { 18 | "p":opt.p, #diffusion coeffcient 19 | 'k':opt.k, # covariance of prior 20 | 'varx':opt.varx, 21 | 'varv':opt.varv, 22 | 'x_dim':opt.data_dim, 23 | 'device':opt.device, #Using sampling device 24 | 'DE_type':opt.DE_type 25 | } 26 | dynamics = TVgMPC(**dyn_kargs) 27 | 28 | ''' 29 | set up dynamics solver 30 | ''' 31 | solver_kargs = { 32 | "solver_name":opt.solver, #updated solver 33 | 'diz':opt.diz, #updated diz 34 | 't0':opt.t0, # original t0 35 | 'T':opt.T, # updated T 36 | 'interval':opt.nfe, #updated NFE 37 | 'dyn': dynamics, #updated dynamics 38 | 'device':opt.device, #Using sampling device 39 | 'snap': 10, 40 | 'local_rank':opt.local_rank, 41 | 'diz_order': opt.diz_order, 42 | 'gDDIM_r':opt.gDDIM_r, 43 | 'cond_opt':None 44 | } 45 | 46 | self.sampler = SamplerWrapper(**solver_kargs) 47 | self.dyn = TVgMPC(opt.p,opt.k,opt.varx,opt.varv,opt.data_dim,opt.device,opt.DE_type) 48 | 49 | if self.opt.local_rank ==0: 50 | print('----------using sampling method as {}'.format(opt.solver)) 51 | 52 | def reweights(self, t): 53 | reweight_type = self.opt.reweight_type 54 | dyn = self.dyn 55 | x_dim = self.opt.data_dim 56 | if reweight_type =='ones': 57 | return torch.ones_like(t) 58 | elif reweight_type=='reciprocal': 59 | weight = 1/(1-t) 60 | return torch.sqrt(weight) 61 | elif reweight_type=='reciprocalinv': 62 | weight = 1/(t) 63 | return torch.sqrt(weight) 64 | else: 65 | raise RuntimeError 66 | 67 | 68 | 69 | def mt_sample(self,x1,ts): 70 | """ return xs.shape == [batch_x, batch_t, *x_dim] """ 71 | opt = self.opt 72 | dyn = self.dyn 73 | x_dim = self.opt.data_dim 74 | joint_dim = self.opt.joint_dim 75 | t = ts.reshape(-1,*([1,]*len(x_dim))) 76 | analytic_xs, analytic_vs, analytic_fv =dyn.get_xt_vt_fv(t,x1,opt.DE_type,opt.device) 77 | 78 | analytic_ms = torch.cat([analytic_xs,analytic_vs],dim=1) 79 | label = analytic_fv 80 | inputs = analytic_ms 81 | return label, inputs 82 | 83 | -------------------------------------------------------------------------------- /AM/dynamics.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import abc 7 | import torch 8 | import numpy as np 9 | from .util import cast_shape 10 | class BaseDynamics(metaclass=abc.ABCMeta): 11 | def __init__(self, p,k,varx,varv,x_dim,device): 12 | self.p = p 13 | self.k = k 14 | self.varx = varx 15 | self.varv = varv 16 | self.device = device 17 | self.x_dim = x_dim 18 | self.normalizer = self.get_normalizer() 19 | @abc.abstractmethod 20 | def g2P10(self, t): 21 | ''' 22 | 1-0 element of the Solution of Lyapunov function P matrix. 23 | ''' 24 | raise NotImplementedError 25 | 26 | @abc.abstractmethod 27 | def g2P11(self, t): 28 | ''' 29 | 1-1 element of the Solution of Lyapunov function P matrix. 30 | ''' 31 | raise NotImplementedError 32 | 33 | @abc.abstractmethod 34 | def sigmaxx(self, t): 35 | ''' 36 | Covariance matrix xx component 37 | ''' 38 | raise NotImplementedError 39 | 40 | @abc.abstractmethod 41 | def sigmavx(self, t): 42 | ''' 43 | Covariance matrix xv component 44 | ''' 45 | raise NotImplementedError 46 | 47 | @abc.abstractmethod 48 | def sigmavv(self, t): 49 | ''' 50 | Covariance matrix vv component 51 | ''' 52 | raise NotImplementedError 53 | 54 | @abc.abstractmethod 55 | def g(self, t): 56 | ''' 57 | diffusion coefficient, can be time variant or time invariant 58 | ''' 59 | raise NotImplementedError 60 | 61 | @abc.abstractmethod 62 | def score(self, t): 63 | raise NotImplementedError 64 | 65 | @abc.abstractmethod 66 | def get_analytic_mux_muv(self, t,x0,v0,x1): 67 | ''' 68 | compute the random variance of x and v at time t given initial x0,v0 and target x1 69 | ''' 70 | raise NotImplementedError 71 | 72 | @abc.abstractmethod 73 | def mux0_muv0(self): 74 | ''' 75 | mean of position and velocity at initial boundary 76 | ''' 77 | raise NotImplementedError 78 | 79 | @abc.abstractmethod 80 | def get_normalizer(self): 81 | raise NotImplementedError 82 | 83 | def get_cov(self,t,damp=0): 84 | ''' 85 | Compute cholesky decomposition complenent Lxx,Lxv,Lvv and ell 86 | ''' 87 | x_dim = self.x_dim 88 | t = t.double() 89 | sigxx = cast_shape(self.sigmaxx(t),x_dim)+damp 90 | sigxv = cast_shape(self.sigmavx(t),x_dim) 91 | sigvv = cast_shape(self.sigmavv(t),x_dim)+damp 92 | ellt = cast_shape(-torch.sqrt(sigxx/(sigxx*sigvv-sigxv**2)),x_dim) 93 | Lxx = torch.sqrt(sigxx) 94 | Lxv = sigxv/Lxx 95 | tmp = sigvv-Lxv**2 96 | invalid_idx = torch.logical_and(tmp<0, torch.isclose(tmp,torch.zeros_like(tmp))) 97 | tmp[invalid_idx]\ 98 | = 0 99 | Lvv = torch.sqrt(tmp) 100 | 101 | return Lxx.float(),Lxv.float(),Lvv.float(), ellt.float() 102 | 103 | def get_xt_vt_fv(self,t,x1,DE_type,device,return_fv=True): 104 | ''' 105 | Compute the input and label for the network 106 | ''' 107 | # opt = self.opt 108 | joint_dim = [value*2 if idx==0 else value for idx,value in enumerate(self.x_dim)] 109 | batch_x = t.shape[0] 110 | mux0,muv0 = self.mux0_muv0(batch_x) 111 | muxt,muvt = self.get_analytic_mux_muv(t,mux0,muv0,x1=x1) 112 | Lxx,Lxv,Lvv,ell = self.get_cov(t) 113 | noise = torch.randn(batch_x, *joint_dim,device=device) 114 | assert noise.shape[0] == t.shape[0] 115 | epsxx,epsvv = torch.chunk(noise,2,dim=1) 116 | _g2P10 = self.g2P10(t) 117 | _g2P11 = self.g2P11(t) 118 | analytic_xs = muxt+Lxx*epsxx 119 | analytic_vs = muvt+(Lxv*epsxx+Lvv*epsvv) 120 | if return_fv: 121 | normalization = self.normalizer(t) 122 | analytic_fv = 4*x1*(t-1)**2- _g2P11*((Lxx/(1-t)+Lxv)*epsxx+Lvv*epsvv) 123 | 124 | score = self.score(t,ell,epsvv) if DE_type=='probODE' else 0 125 | 126 | analytic_fv = (analytic_fv+score)/normalization 127 | # =========normlaize the label to standard gaussian ================= 128 | 129 | return analytic_xs, analytic_vs, analytic_fv 130 | 131 | else: 132 | return analytic_xs, analytic_vs 133 | 134 | 135 | 136 | 137 | class TVgMPC(BaseDynamics): 138 | ''' 139 | TIVg: Time Variant g 140 | ''' 141 | def __init__(self, p,k,varx,varv,x_dim,device,DE_type): 142 | super(TVgMPC,self).__init__(p,k,varx,varv,x_dim,device) 143 | self.DE_type = DE_type 144 | def g2P10(self, t): 145 | return 6/(-1+t)**2 146 | 147 | def g2P11(self, t): 148 | return -4/(-1+t) 149 | 150 | def g(self, t): 151 | tt = 1 152 | return self.p*(tt-t) 153 | 154 | def sigmaxx(self, t): 155 | m,n = self.varx, self.varv 156 | k = self.k 157 | p = self.p 158 | tt = 1 159 | val =(t - 1)**2*(30*m*(t**3 - 3*t**2 + 3*t + 3)**2 - 60*p**2*(t - 1)**3*torch.log(1 - t) - t*(60*k*np.sqrt(m*n)*(t**5 - 6*t**4 + 15*t**3 - 15*t**2 + 9) - 30*n*t*(t**2 - 3*t + 3)**2 + p**2*(t**5*(6*tt**2 + 3*tt + 1) - 6*t**4*(6*tt**2 + 3*tt + 1) + 15*t**3*(6*tt**2 + 3*tt + 1) - 10*t**2*(9*tt**2 + 11) + 150*t - 60)))/270 160 | return val 161 | 162 | def sigmavx(self, t): 163 | m,n = self.varx, self.varv 164 | p = self.p 165 | k = self.k 166 | tt = 1 167 | val =(1/270 - t/270)*(30*k*np.sqrt(m*n)*(8*t**6 - 48*t**5 + 120*t**4 - 135*t**3 + 45*t**2 + 27*t - 9) + 150*p**2*(t - 1)**3*torch.log(1 - t) + t*(-120*m*(t**5 - 6*t**4 + 15*t**3 - 15*t**2 + 9) - 30*n*(4*t**5 - 24*t**4 + 60*t**3 - 75*t**2 + 45*t - 9) + p**2*(4*t**5*(6*tt**2 + 3*tt + 1) - 24*t**4*(6*tt**2 + 3*tt + 1) + 60*t**3*(6*tt**2 + 3*tt + 1) - 5*t**2*(81*tt**2 + 18*tt + 55) + 15*t*(9*tt**2 + 25) - 150))) 168 | return val 169 | 170 | 171 | 172 | def sigmavv(self, t): 173 | m,n = self.varx, self.varv 174 | p = self.p 175 | k = self.k 176 | tt = 1 177 | val= n*(-4*t**3 + 12*t**2 - 12*t + 3)**2/9 - 8*p**2*(t - 1)**3*torch.log(1 - t)/9 + t*(-120*k*np.sqrt(m*n)*(4*t**5 - 24*t**4 + 60*t**3 - 75*t**2 + 45*t - 9) + 240*m*t*(t**2 - 3*t + 3)**2 + p**2*(-8*t**5*(6*tt**2 + 3*tt + 1) + 48*t**4*(6*tt**2 + 3*tt + 1) - 120*t**3*(6*tt**2 + 3*tt + 1) + 5*t**2*(180*tt**2 + 72*tt + 53) - 15*t*(36*tt**2 + 9*tt + 20) + 135*tt**2 + 120))/135 178 | return val 179 | 180 | def get_normalizer(self): 181 | ''' 182 | get Compute the normlaizier for the label in order to normalize the label to some range. 183 | ''' 184 | return self.normalizer 185 | 186 | def normalizer(self,t): 187 | ''' 188 | normlaize the label to [1-->0] (predicting the x1) or standard gaussian (predicting the noise) 189 | ''' 190 | _g2P10 = self.g2P10(t) 191 | _g2P11 = self.g2P11(t) 192 | _g = self.g(t) 193 | Lxx,Lxv,Lvv,ell = self.get_cov(t) 194 | Lxx = Lxx.reshape_as(t) 195 | Lxv = Lxv.reshape_as(t) 196 | Lvv = Lvv.reshape_as(t) 197 | ell = ell.reshape_as(t) 198 | # if len(t.shape)==4: 199 | # debug() 200 | 201 | if self.DE_type=='probODE': 202 | norm =torch.sqrt(((_g2P11*((-1/(1-t)*Lxx-Lxv)))**2+(_g2P11*Lvv-0.5*_g**2*ell)**2))/(1-t) 203 | return norm 204 | else: 205 | damp=0 206 | return torch.sqrt(((_g2P11*((-1/(1-t)*Lxx-Lxv)))**2+(_g2P11*Lvv)**2)+damp)/(1-t) 207 | 208 | 209 | def get_analytic_mux_muv(self, t,x0,v0,x1): 210 | bs = x0.shape[0] 211 | if x1 is None: 212 | return self.mux0_muv0(bs) 213 | else: 214 | muv =v0*(-4*t**3/3 + 4*t**2 - 4*t + 1) + t*(-4*x0/3 + 4*x1/3)*(t**2 - 3*t + 3) 215 | mux =-x0*(t**4 - 4*t**3 + 6*t**2 - 3)/3 + t*(-v0*(t**3 - 4*t**2 + 6*t - 3) + x1*t*(t**2 - 4*t + 6))/3 216 | return mux,muv 217 | 218 | def mux0_muv0(self,bs): 219 | muv0 = torch.zeros(bs, *self.x_dim,device=self.device) 220 | mux0 = torch.zeros(bs, *self.x_dim,device=self.device) 221 | return mux0, muv0 222 | 223 | def score(self,t,ell,epsvv): 224 | _g = self.g(t) 225 | return -0.5*_g**2*ell*epsvv 226 | 227 | def get_m0(self,bs): 228 | joint_dim = [value*2 if idx==0 else value for idx,value in enumerate(self.x_dim)] 229 | 230 | mux0,muv0 = self.mux0_muv0(bs) 231 | t = torch.zeros(bs,1,device=self.device) 232 | Lxx,Lxv,Lvv,ell = self.get_cov(t) 233 | noise = torch.randn(bs, *joint_dim,device=self.device) 234 | assert noise.shape[0] == t.shape[0] 235 | epsxx,epsvv = torch.chunk(noise,2,dim=1) 236 | analytic_x0 = mux0+Lxx*epsxx 237 | analytic_v0 = muv0+(Lxv*epsxx+Lvv*epsvv) 238 | return torch.cat([analytic_x0, analytic_v0],dim=1) -------------------------------------------------------------------------------- /AM/runner.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import numpy as np 7 | import pickle 8 | import torch 9 | import pytorch_warmup as warmup 10 | import torch.nn.functional as F 11 | from torch.optim import AdamW, lr_scheduler 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from edm import dnnlib 14 | from torch_ema import ExponentialMovingAverage 15 | import torchvision.utils as tu 16 | # from .get_network import get_nn 17 | from networks.get_network import get_nn 18 | from . import util 19 | import plot_util 20 | from .diffusion import Diffusion 21 | from plot_util import norm_data, plot_plt,plot_scatter 22 | from .util import all_cat_cpu 23 | from edm.fid import calculate_fid_from_inception_stats, calculate_inception_stats 24 | from sampling import loop_saving_png 25 | 26 | 27 | def build_optimizer_sched(opt, net, log): 28 | optim_dict = {"lr": opt.lr, 'weight_decay': opt.l2_norm} 29 | optimizer = AdamW(net.parameters(), **optim_dict) 30 | log.info(f"[Opt] Built AdamW optimizer {optim_dict=}!") 31 | 32 | sched = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.num_itr) 33 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) 34 | log.info(f"[Opt] Built lr _step scheduler Cosine") 35 | 36 | 37 | if opt.load: 38 | checkpoint = torch.load(opt.load, map_location="cpu") 39 | if "optimizer" in checkpoint.keys(): 40 | optimizer.load_state_dict(checkpoint["optimizer"]) 41 | log.info(f"[Opt] Loaded optimizer ckpt {opt.load}!") 42 | else: 43 | log.warning(f"[Opt] Ckpt {opt.load} has no optimizer!") 44 | if sched is not None and "sched" in checkpoint.keys() and checkpoint["sched"] is not None: 45 | sched.load_state_dict(checkpoint["sched"]) 46 | log.info(f"[Opt] Loaded lr sched ckpt {opt.load}!") 47 | else: 48 | log.warning(f"[Opt] Ckpt {opt.load} has no lr sched!") 49 | for g in optimizer.param_groups: 50 | g['lr'] = opt.lr 51 | 52 | return optimizer, sched,warmup_scheduler 53 | 54 | 55 | 56 | class Runner(object): 57 | def __init__(self, opt, log, save_opt=True): 58 | super(Runner,self).__init__() 59 | 60 | # ===========Save opt. =========== 61 | if save_opt: 62 | opt_pkl_path = opt.ckpt_path / "options.pkl" 63 | with open(opt_pkl_path, "wb") as f: 64 | pickle.dump(opt, f) 65 | log.info("Saved options pickle to {}!".format(opt_pkl_path)) 66 | 67 | self.diffusion = Diffusion(opt, opt.device) 68 | if opt.exp!='toy': 69 | ref_file_name ={'cifar10':'https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz', 70 | 'AFHQv2': 'https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz', 71 | 'imagenet64': 'https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/imagenet-64x64.npz', 72 | }.get(opt.exp) 73 | 74 | with dnnlib.util.open_url(ref_file_name) as f: 75 | self.ref = dict(np.load(f)) 76 | 77 | self.net = get_nn(opt,self.diffusion.dyn) 78 | log.info('network size [{}]'.format(util.count_parameters(self.net))) 79 | self.ema = ExponentialMovingAverage(self.net.parameters(), decay=opt.ema) 80 | self.opt = opt 81 | self.reweight = self.diffusion.reweights 82 | self.ts_sampler = {'uniform':util.uniform_ts, 83 | 'debug':util.debug_ts, 84 | }.get(opt.t_samp) 85 | if opt.load: 86 | checkpoint = torch.load(opt.load, map_location="cpu") 87 | self.net.load_state_dict(checkpoint['net']) 88 | log.info(f"[Net] Loaded network ckpt: {opt.load}!") 89 | self.ema.load_state_dict(checkpoint["ema"]) 90 | log.info(f"[Ema] Loaded ema ckpt: {opt.load}!") 91 | 92 | self.net.to(opt.device) 93 | self.ema.to(opt.device) 94 | self.log = log 95 | self.best_res = np.inf 96 | 97 | 98 | def train(self, opt, train_loader): 99 | self.writer = util.build_log_writer(opt) 100 | log = self.log 101 | 102 | if opt.distributed: 103 | net = DDP(self.net, device_ids=[opt.device]) 104 | else: 105 | net=self.net 106 | 107 | ema = self.ema 108 | optimizer, sched, warmup = build_optimizer_sched(opt, net, log) 109 | ts_sampler, reweight = self.ts_sampler, self.reweight 110 | t0,T,device = opt.t0, opt.T, opt.device 111 | 112 | net.train() 113 | 114 | for it in range(opt.num_itr): 115 | optimizer.zero_grad(set_to_none=True) 116 | # ===== sample boundary pair ===== 117 | x1,class_cond\ 118 | = train_loader.sample() 119 | # ===== compute loss ===== 120 | _ts = ts_sampler(t0,T,x1.shape[0],device) 121 | label,mt= self.diffusion.mt_sample(x1,ts=_ts) 122 | lambdat = reweight(_ts)[:,None] 123 | lambdat = lambdat.reshape(-1,*([1,]*(len(x1.shape)-1))) 124 | pred = net(mt, _ts,cond=class_cond) 125 | label = label.reshape_as(pred) 126 | _pred = pred.detach().cpu() #for rendering loss over time 127 | _label = label.detach().cpu() #for rendering loss over time 128 | loss = F.mse_loss(lambdat*pred, lambdat*label) 129 | loss.backward() 130 | 131 | if opt.clip_grad is not None: 132 | torch.nn.utils.clip_grad_norm_(net.parameters(), opt.clip_grad) 133 | optimizer.step() 134 | ema.update() 135 | 136 | if sched is not None: 137 | with warmup.dampening(): 138 | sched.step() 139 | 140 | # # -------- logging -------- 141 | log.info("train_it {}/{} | lr:{} | loss:{}".format( 142 | 1+it, 143 | opt.num_itr, 144 | "{:.2e}".format(optimizer.param_groups[0]['lr']), 145 | "{:+.4f}".format(loss.item()), 146 | )) 147 | if it % 10 == 0: 148 | self.writer.add_scalar(it, 'loss', loss.detach()) 149 | 150 | #============monitoring the loss distribution over time======= 151 | if it% 1000 ==0 and self.writer is not None: 152 | total_idxs = torch.arange(0,opt.nfe).float() 153 | total_vals = torch.zeros_like(total_idxs).float() 154 | _lambdat = lambdat.detach().cpu() 155 | if self.opt.exp=='toy': 156 | _loss = (((_pred*_lambdat-_label*_lambdat)**2).sum(-1)).float() 157 | else: 158 | _loss = (((_pred*_lambdat-_label*_lambdat)**2).sum(-1).sum(-1).sum(-1)).float() 159 | total_vals[(_ts*opt.nfe).long()] = _loss/(_pred.numel()/_pred.shape[0]) 160 | 161 | 162 | self.writer.add_bar(it,'ts_loss',[total_idxs,total_vals]) 163 | #============monitoring the loss distribution over time======= 164 | 165 | if it == 1000 or it % 10000 == 0: 166 | # if it >=9999 and it % 5000 == 0: 167 | net.eval() 168 | results=self.evaluation(opt, it,train_loader) 169 | net.train() 170 | 171 | if results<=self.best_res: 172 | self.best_res=results 173 | name = "latest.pt".format(it) 174 | util.save_ckpt(opt,log,it,self.net,ema,optimizer,sched,name) #Using Self.net to handle DDP 175 | 176 | if it%10000==0: 177 | if opt.global_rank == 0: 178 | name = "latest_it{}.pt".format(it) 179 | util.save_ckpt(opt,log,it,self.net,ema,optimizer,sched,name) 180 | 181 | if opt.distributed: 182 | torch.distributed.barrier() 183 | 184 | self.writer.close() 185 | 186 | 187 | @torch.no_grad() 188 | def evaluation(self, opt, it, loader): 189 | log = self.log 190 | log.info(f"========== Evaluation started: iter={it} ==========") 191 | 192 | def log_image(tag, img, nrow=10): 193 | self.writer.add_image(it, tag, tu.make_grid((img+1)/2, nrow=nrow,scale_each=True)) # [1,1] -> [0,1] 194 | 195 | x1,class_cond = loader.sample() 196 | if opt.exp=='toy': 197 | sampler = self.diffusion.sampler 198 | m0 = sampler.dyn.get_m0(opt.sampling_batch).to(opt.device) 199 | ms, pred_m1,est_x1s,snap_ts = sampler.solve(self.ema,self.net,m0,cond=None) 200 | plot_util.plot_toy(opt,ms,it,pred_m1,x1) 201 | pos_traj = ms[:,:-1,0:2,...] 202 | vel_traj = ms[:,:-1,2:,...] 203 | est_x1s = est_x1s.detach().cpu().numpy() 204 | plot_util.save_toy_npy_traj(opt,'itr_{}_est_x1s'.format(it),est_x1s,n_snapshot=10) 205 | else: 206 | image_dir = os.path.join(opt.ckpt_path , 'fid_train_folder') 207 | num_loop=int((opt.train_fid_sample-1)/opt.num_proc_node/opt.n_gpu_per_node/opt.sampling_batch)+1 208 | log.info('num of loop {}, batch {}, number of gpu{}, sampling number {}'.format(num_loop, opt.sampling_batch, opt.n_gpu_per_node,opt.train_fid_sample)) 209 | 210 | ms, pred_m1, est_x1s,snap_ts = loop_saving_png( opt,\ 211 | self.diffusion.sampler,\ 212 | self.ema,\ 213 | self.net,\ 214 | log,\ 215 | image_dir,\ 216 | num_loop=num_loop,\ 217 | return_last=True) 218 | 219 | mu,sigma=calculate_inception_stats( image_path=image_dir,\ 220 | num_expected=opt.train_fid_sample,\ 221 | seed=42,\ 222 | max_batch_size=128) 223 | 224 | fid = calculate_fid_from_inception_stats(mu,sigma,self.ref['mu'], self.ref['sigma']) 225 | log.info(f"========== FID is: iter={fid} ==========") 226 | log.info(f"========== FID folder is at : {image_dir} ==========") 227 | self.writer.add_scalar(it, 'fid', fid) 228 | ########Visualizing resulting data######## 229 | num_samp=40 230 | pred_x1 = pred_m1[:,0:3,...] 231 | pred_v1 = pred_m1[:,3:,...] 232 | est_x1s = est_x1s 233 | ms = ms 234 | gt_x1 = x1 235 | gt_ts = snap_ts 236 | 237 | if opt.log_writer is not None: 238 | _pos_traj = ms[0:5,:,0:3,...] 239 | pos_traj = _pos_traj.reshape(-1,*opt.data_dim) 240 | vel_traj = ms[0:5,:,3:,...] 241 | est_x1s = est_x1s[0:5,:,:,...] 242 | vel_traj = vel_traj.reshape(-1,*opt.data_dim) 243 | est_x1s = est_x1s.reshape(-1,*opt.data_dim) 244 | log_image("image/position", (pred_x1[0:num_samp])) 245 | log_image("image/velocity", (pred_v1[0:num_samp])) 246 | log_image("image/gt", (gt_x1[0:num_samp])) 247 | log_image("image/position_traj",pos_traj,nrow=11) 248 | log_image("image/velocity_traj",vel_traj,nrow=11) 249 | log_image("image/est_x1s",est_x1s,nrow=11) 250 | else: 251 | fn_pdf = os.path.join(opt.ckpt_path, 'itr_{}_x.png'.format(it)) 252 | tu.save_image(norm_data((pred_x1+1)/2), fn_pdf, nrow = 6) 253 | fn_pdf = os.path.join(opt.ckpt_path, 'itr_{}_v.png'.format(it)) 254 | tu.save_image(norm_data((pred_v1+1)/2), fn_pdf, nrow = 6) 255 | if it==0: 256 | fn_pdf = os.path.join(opt.ckpt_path, 'itr_{}_ground_truth.png'.format(it)) 257 | tu.save_image(norm_data((gt_x1+1)/2), fn_pdf, nrow = 6) 258 | ########Visualizing resulting data######## 259 | log.info(f"========== Evaluation finished: iter={it} ==========") 260 | torch.cuda.empty_cache() 261 | results = self.best_res if opt.exp=='toy' else fid 262 | return results 263 | -------------------------------------------------------------------------------- /AM/samplers.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from tqdm import tqdm 6 | import torch 7 | from . import util 8 | import numpy as np 9 | 10 | 11 | class SamplerWrapper: 12 | def __init__(self,**kwargs): 13 | solver_name = kwargs['solver_name'] 14 | diz = kwargs['diz'] 15 | t0 = kwargs['t0'] 16 | T = kwargs['T'] 17 | interval = kwargs['interval'] 18 | device = kwargs['device'] 19 | dyn = kwargs['dyn'] 20 | diz_order = kwargs['diz_order'] 21 | cond_opt = kwargs['cond_opt'] 22 | self.kwargs = kwargs 23 | 24 | self.sampler = get_solver_fn(solver_name) 25 | tsdts = get_discretizer(diz,t0,T,interval,device,diz_order) 26 | self.kwargs['ts_dts'] = tsdts 27 | DE_type = 'probODE' 28 | self.dyn = dyn 29 | 30 | remove_keys=['diz','solver_name','t0','T','interval','device','diz_order'] 31 | if solver_name=='gDDIM': 32 | self.r = self.kwargs['gDDIM_r'] 33 | ts,dts = tsdts 34 | coef = AB_fn(dyn.normalizer,DE_type,ts,dts,self.r) 35 | self.kwargs['coef'] = coef 36 | self.kwargs['cond_opt'] = cond_opt 37 | else: 38 | remove_keys+=['gDDIM_r'] 39 | remove_keys+=['cond_opt'] 40 | 41 | for key in remove_keys: 42 | del self.kwargs[key] 43 | 44 | 45 | def solve(self,ema,net,m0,cond=None): 46 | with ema.average_parameters(): 47 | return self.sampler(m0,net,cond,**self.kwargs) 48 | 49 | 50 | def get_est_x1(dyn,t,_g,_fv,x,v,DE_type): 51 | _g2P11=dyn.g2P11(t) 52 | if DE_type=='probODE': 53 | Lxx,Lxv,Lvv, ell\ 54 | = dyn.get_cov(t) 55 | AA = (Lxx/(1-t)+Lxv).to(torch.float64) 56 | BB = (Lvv+(0.5*_g**2*ell)/_g2P11).to(torch.float64) 57 | bb = BB/Lvv 58 | aa = (AA-bb*Lxv)/Lxx 59 | cv = 4/3*t*(3+(-3+t)*t) 60 | cx = 1/3*t**2*(6+(-4+t)*t) 61 | est_x1 = (_fv+_g2P11*(aa*x+bb*v))/(4*(t-1)**2+_g2P11*(aa*cx+bb*cv)) 62 | else: 63 | est_x1 = (_fv/_g2P11+v)*(1-t)+x 64 | return est_x1 65 | 66 | def get_solver_fn(solver_name): 67 | if solver_name == 'sscs': 68 | return sscs_sampler 69 | elif solver_name == 'em': 70 | return em_sampler 71 | elif solver_name == 'gDDIM': 72 | return gDDIM_sampler 73 | else: 74 | raise NotImplementedError( 75 | 'Sampler %s is not implemened.' % solver_name) 76 | 77 | 78 | def get_discretizer(diz,t0,T,interval,device,diz_order=2): 79 | if diz =='Euler': 80 | ts = torch.linspace(t0, T, interval+1, device=device) 81 | dts = ts[1:]-ts[0:-1] 82 | ts = ts[0:-1] 83 | last_dt=torch.Tensor([0.999-T]).to(device) #For evaluate full timesteps 84 | dts = torch.cat([dts,last_dt],dim=0) 85 | 86 | elif diz =='quad': 87 | ts= torch.linspace(t0**2,T**2,interval+1) 88 | ts= torch.sqrt(ts) 89 | dts= ts[1:]-ts[0:-1] 90 | ts = ts[0:-1] 91 | ts = ts.to(device) 92 | dts = dts.to(device) 93 | 94 | elif diz =='rev-quad': 95 | order = diz_order 96 | ts= torch.linspace(t0**(1/order),T**(1/order),interval,dtype=torch.float64,device=device) 97 | ts= ts**order 98 | dts= ts[1:]-ts[0:-1] 99 | ts = ts 100 | last_dt=torch.Tensor([0.999-T]).to(device) #For evaluate full timesteps 101 | dts = torch.cat([dts,last_dt],dim=0) 102 | else: 103 | raise NotImplementedError( 104 | 'discretizer %s is not implemened.' % diz) 105 | return ts,dts 106 | 107 | 108 | def dw(x,dt): 109 | return torch.randn_like(x)*torch.sqrt(dt) 110 | 111 | 112 | def sscs_sampler(m0,drift,cond,ts_dts,dyn,snap,local_rank,return_est_x1=True): 113 | #Equivalence reduced variance 114 | def sigmaxx(p,t): 115 | return p**2*t**3*(t*(t - 5) + 10)/30 116 | def sigmavx(p,t): 117 | return p**2*t**2*(t*(t - 4) + 6)/12 118 | def sigmavv(p,t): 119 | return p**2*t*(t*(t - 3) + 3)/3 120 | 121 | def analytic_dynamics(m,t,dt): 122 | dt=dt/2 123 | delta_varxx = (sigmaxx(dyn.p*(1-t),dt)).reshape(-1,*([1,]*(len(m.shape)-1))) 124 | delta_varxv = (sigmavx(dyn.p*(1-t),dt)).reshape(-1,*([1,]*(len(m.shape)-1))) 125 | delta_varvv = (sigmavv(dyn.p*(1-t),dt)).reshape(-1,*([1,]*(len(m.shape)-1))) 126 | cholesky11 = torch.sqrt(delta_varxx) 127 | cholesky21 = (delta_varxv / cholesky11) 128 | cholesky22 = (torch.sqrt(delta_varvv - cholesky21 ** 2.)) 129 | batch_randn = torch.randn_like(m, device=m.device) 130 | batch_randn_x, batch_randn_v = torch.chunk(batch_randn, 2, dim=1) 131 | noise_x = cholesky11 * batch_randn_x 132 | noise_v = cholesky21 * batch_randn_x + cholesky22 * batch_randn_v 133 | noise = torch.cat((noise_x, noise_v), dim=1) 134 | x,v=torch.chunk(m,2,dim=1) 135 | x = x+v*dt 136 | m = torch.cat([x,v],dim=1) 137 | perturbed_data = m +noise 138 | return perturbed_data 139 | 140 | def EM_dynamics(v,dyn,fv,t,normalizer): 141 | norm = (normalizer(t)).squeeze() 142 | fv = fv*norm 143 | v = v+fv*dt 144 | return v,fv 145 | 146 | assert dyn.DE_type == 'SDE' 147 | bs = m0.shape[0] 148 | ts,dts = ts_dts 149 | m0 = m0.to(torch.float64) 150 | m = m0 151 | ms = [] 152 | x,v = torch.chunk(m0,2,dim=1) 153 | interval = ts.shape[0] 154 | snaps = np.linspace(0, interval-1, snap).astype(int) 155 | snapts = [] 156 | if local_rank == 0: 157 | _ts = tqdm(ts,desc=util.yellow("Propagating Dynamics...")) 158 | else: 159 | _ts = ts 160 | m = m0 161 | ms = [] 162 | snapts = [] 163 | est_x1s = [] 164 | normalizer=\ 165 | dyn.get_normalizer() 166 | x,v = torch.chunk(m0,2,dim=1) 167 | for idx,(t,dt) in enumerate(zip(_ts,dts)): 168 | _t= t.repeat(bs) 169 | m = analytic_dynamics(m,_t,dt) 170 | x,v = torch.chunk(m,2,dim=1) 171 | 172 | fv = drift(m.to(torch.float32),_t.to(torch.float32),cond=cond).to(torch.float32) 173 | #============EM step============= 174 | v,_fv = EM_dynamics(v,dyn,fv,t,normalizer) 175 | #============EM step============= 176 | m = torch.cat([x,v],dim=1) 177 | m = analytic_dynamics(m,_t,dt) 178 | 179 | if idx in snaps: 180 | ms.append(m[:,None,...]) 181 | snapts.append(t[None]) 182 | if return_est_x1: 183 | _g2P11=dyn.g2P11(t) 184 | est_x1 = (_fv/_g2P11+v)*(1-t)+x 185 | est_x1s.append(est_x1[:,None,...]) 186 | 187 | xT = x+v*dt 188 | mT = torch.cat([xT,v],dim=1) 189 | ms.append(mT[:,None,...]) 190 | if return_est_x1: est_x1s.append(est_x1[:,None,...]) 191 | snapts.append(t[None]) 192 | return torch.cat(ms,dim=1),mT, torch.cat(est_x1s,dim=1), torch.cat(snapts,dim=0) 193 | 194 | 195 | def em_sampler(m0,drift,cond,ts_dts,dyn,snap,local_rank,return_est_x1=True): 196 | DE_type = dyn.DE_type 197 | bs = m0.shape[0] 198 | rank = local_rank 199 | ts,dts = ts_dts 200 | interval= ts.shape[0] 201 | if rank == 0: 202 | times_horizon = zip(tqdm(ts,desc=util.blue("Propagating Dynamics..."),position=0,leave=False,colour='blue'),dts) 203 | else: 204 | times_horizon=zip(ts,dts) 205 | m = m0 206 | ms = [] 207 | x,v = torch.chunk(m0,2,dim=1) 208 | snaps = np.linspace(0, interval-1, snap).astype(int) 209 | snapts = [] 210 | normalizer=\ 211 | dyn.get_normalizer() 212 | if return_est_x1: est_x1s = [] 213 | 214 | for idx,(t,dt) in enumerate(times_horizon): 215 | _t = t.repeat(bs) 216 | _g = dyn.g(t) 217 | fv = drift(m,_t,cond) 218 | _g2P11 = dyn.g2P11(t) 219 | 220 | norm = (normalizer(t)).squeeze() 221 | fv = fv*norm 222 | m = torch.cat([x,v],dim=1) 223 | 224 | #=========dyn propagation=========== 225 | x = x+v*dt 226 | dw = dw(v,dt) if DE_type == 'SDE' else torch.zeros_like(v) 227 | v = v+fv*dt+_g*dw 228 | #=========dyn propagation=========== 229 | 230 | if idx in snaps: 231 | ms.append(m[:,None,...]) 232 | snapts.append(t[None]) 233 | if return_est_x1: 234 | est_x1 = get_est_x1(dyn,t,_g,fv,x,v,DE_type) 235 | est_x1s.append(est_x1[:,None,...]) 236 | 237 | mT = m 238 | 239 | ms.append(mT[:,None,...]) 240 | 241 | est_x1s.append(est_x1[:,None,...]) 242 | snapts.append(t[None]) 243 | return torch.cat(ms,dim=1),mT, torch.cat(est_x1s,dim=1), torch.cat(snapts,dim=0) 244 | 245 | 246 | 247 | 248 | @torch.no_grad() 249 | def gDDIM_sampler(m0,drift,cond,ts_dts,dyn,coef,snap,gDDIM_r,local_rank,return_est_x1=True,cond_opt=None): 250 | ts,dts = ts_dts 251 | conf_flag = False if cond_opt is None else True 252 | 253 | 254 | DE_type = dyn.DE_type 255 | assert DE_type == 'probODE' 256 | r = gDDIM_r 257 | rank = local_rank 258 | bs = m0.shape[0] 259 | m0 = m0.to(torch.float64) 260 | m = m0 261 | ms = [] 262 | x,v = torch.chunk(m0,2,dim=1) 263 | interval = ts.shape[0] 264 | snaps = np.linspace(0, interval-1, snap).astype(int) 265 | snapts = [] 266 | normalizer=\ 267 | dyn.get_normalizer() 268 | intgral_norm = coef 269 | 270 | if rank == 0: 271 | times_horizon = zip(tqdm(ts,desc=util.blue("Propagating Dynamics..."),position=0,leave=False,colour='blue'),dts) 272 | else: 273 | times_horizon=zip(ts,dts) 274 | 275 | if return_est_x1: est_x1s = [] 276 | prev_fv = [] 277 | 278 | 279 | if conf_flag: 280 | stroke = cond_opt.stroke 281 | stroke_type = cond_opt.stroke_type 282 | impainting = cond_opt.impainting 283 | cond_strength\ 284 | = 1.0 if impainting else 0.25 285 | cond_fn = impaint_stroke if impainting else dyn_stroke 286 | if stroke_type=='dyn-v': stroke_idx = int(cond_strength*ts.shape[0]) 287 | if stroke_type=='init-v': v = 0.9*v+0.1*stroke 288 | 289 | for idx,(t,dt) in enumerate(times_horizon): 290 | _t = t.repeat(bs) 291 | _g = dyn.g(t) 292 | normt = (normalizer(t)).squeeze() 293 | 294 | # =======Conditional generation ========== 295 | if conf_flag and idx==0: 296 | fv = drift(m.to(torch.float32),_t.to(torch.float32),cond=cond).to(torch.float32) 297 | _fv = fv*normt 298 | est_x1 = get_est_x1(dyn,t,_g,_fv,x,v,DE_type) 299 | 300 | if conf_flag and stroke_type=='dyn-v' and idx=0 and len(prev_fv)==max_order 331 | 332 | if DE_type=='ODE': 333 | coef_fv = prev_fv[jj-1]/(1-t)*coef 334 | else: 335 | coef_fx = prev_fv[jj-1]*coef[0,1] 336 | coef_fv = prev_fv[jj-1]*coef[1,1] 337 | 338 | accumulated_fv += coef_fv 339 | accumulated_fx += coef_fx 340 | 341 | 342 | phit = phi_fn((t+dt)[None,None],t[None,None])[0] 343 | x = phit[0,0]*x+phit[0,1]*v+accumulated_fx 344 | v = phit[1,0]*x+phit[1,1]*v+accumulated_fv 345 | m = torch.cat([x,v],dim=1) 346 | if len(prev_fv)=0 and i-j>=0 389 | if k!=j: 390 | prod= prod* ((t-ts[i-k])/(ts[i-j]-ts[i-k]))[...,None] 391 | #=====time coef========= 392 | phi_matrix = phi_fn(ts[i]+dts[i],t).to(device) 393 | if DE_type=='ode': 394 | return phi_matrix@z@prod*(1-t) 395 | else: 396 | return phi_matrix@z@prod 397 | return _fn_r 398 | 399 | def AB_fn(normalizer,DE_type,ts,dts,r=0): 400 | intgral_norm = {} 401 | num_monte_carlo_sample = 50000 402 | for idx,(t,dt) in enumerate(zip(ts,dts)): 403 | max_order = min(idx,r) 404 | intgral_norm[idx] = {} 405 | for jj in range(max_order+1): 406 | coef_fn = extrapolate_fn(normalizer,DE_type,ts,dts,idx,j=jj,r=max_order) 407 | coef = monte_carlo_integral(coef_fn,t,t+dt,num_monte_carlo_sample) 408 | intgral_norm[idx][jj]=coef 409 | return intgral_norm 410 | 411 | 412 | def dyn_stroke(t,dyn,stroke,m,est_x1=None): 413 | 414 | currx,currv = torch.chunk(m,2,dim=1) 415 | Sigxx,Sigxv,Sigvv\ 416 | = dyn.sigmaxx(t),dyn.sigmavx(t),dyn.sigmavv(t) 417 | noise = torch.randn(*m.shape,device=m.device) 418 | epsxx,epsvv = torch.chunk(noise,2,dim=1) 419 | muv = 0*(-4*t**3/3 + 4*t**2 - 4*t + 1) + t*(-4*0/3 + 4*stroke/3)*(t**2 - 3*t + 3) 420 | mux = -0*(t**4 - 4*t**3 + 6*t**2 - 3)/3 + t*(-0*(t**3 - 4*t**2 + 6*t - 3) + stroke*t*(t**2 - 4*t + 6))/3 421 | muv = muv+Sigxv/Sigxx*(currx-mux) 422 | Sigv = Sigvv-Sigxv**2/Sigxx 423 | v = muv+Sigv*epsvv 424 | return v 425 | 426 | def impaint_stroke(t,dyn,stroke,m,est_x1): 427 | mask = torch.ones_like(stroke) 428 | mask[stroke==-1]\ 429 | = 0 430 | invmask = 1-mask 431 | stroke = stroke*mask+invmask*est_x1 432 | return dyn_stroke(t,dyn,stroke,m,est_x1) 433 | 434 | -------------------------------------------------------------------------------- /AM/util.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from prefetch_generator import BackgroundGenerator 6 | import os 7 | import warnings 8 | try: 9 | from torch.utils.tensorboard import SummaryWriter 10 | except: 11 | warnings.warn("install your favorite tensorboard version") 12 | import wandb 13 | import termcolor 14 | import torch 15 | from torch.utils.data import DataLoader 16 | import matplotlib.pyplot as plt 17 | import edm.distributed_util as dist_util 18 | import abc 19 | import numpy as np 20 | 21 | class DataLoaderX(DataLoader): 22 | def __iter__(self): 23 | return BackgroundGenerator(super().__iter__()) 24 | def setup_loader(dataset, batch_size, num_workers=4): 25 | loader = DataLoaderX( 26 | dataset, 27 | batch_size=batch_size, 28 | pin_memory=True, 29 | shuffle=True, 30 | persistent_workers=True, 31 | num_workers=num_workers, 32 | multiprocessing_context='spawn', 33 | drop_last=True, 34 | prefetch_factor=4, 35 | ) 36 | # return loader 37 | while True: 38 | yield from loader 39 | 40 | 41 | 42 | 43 | def save_ckpt(opt,log,it,net,ema,optimizer,sched,name): 44 | torch.save({ 45 | "net": net.state_dict(), 46 | "ema": ema.state_dict(), 47 | "optimizer": optimizer.state_dict(), 48 | "sched": sched.state_dict() if sched is not None else sched, 49 | }, opt.ckpt_path / name) 50 | log.info(f"Saved latest({it=}) checkpoint to {opt.ckpt_path=}!") 51 | 52 | class BaseWriter(object): 53 | def __init__(self, opt): 54 | self.rank = opt.global_rank 55 | def add_scalar(self, step, key, val): 56 | pass # do nothing 57 | def add_image(self, step, key, image): 58 | pass # do nothing 59 | def add_bar(self, step, key, image): 60 | pass 61 | def close(self): pass 62 | 63 | class WandBWriter(BaseWriter): 64 | def __init__(self, opt): 65 | super(WandBWriter,self).__init__(opt) 66 | if self.rank == 0: 67 | assert wandb.login(key=opt.wandb_api_key) 68 | wandb.init(dir=str(opt.log_dir), project="VM", entity=opt.wandb_user, name=opt.name, config=vars(opt)) 69 | 70 | def add_scalar(self, step, key, val): 71 | if self.rank == 0: wandb.log({key: val}, step=step) 72 | 73 | def add_image(self, step, key, image): 74 | if self.rank == 0: 75 | # adopt from torchvision.utils.save_image 76 | image = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() 77 | wandb.log({key: wandb.Image(image)}, step=step) 78 | 79 | def add_bar(self,step,key,data): 80 | if self.rank == 0: 81 | fig, ax = plt.subplots() 82 | # plt.ylim([0,5]) 83 | ts,loss=data 84 | ax.bar(ts, loss) 85 | wandb.log({"plot": wandb.Image(fig)},step=step) 86 | 87 | 88 | class TensorBoardWriter(BaseWriter): 89 | def __init__(self, opt): 90 | super(TensorBoardWriter,self).__init__(opt) 91 | if self.rank == 0: 92 | run_dir = str(opt.log_dir / opt.name) 93 | os.makedirs(run_dir, exist_ok=True) 94 | self.writer=SummaryWriter(log_dir=run_dir, flush_secs=20) 95 | 96 | def add_scalar(self, global_step, key, val): 97 | if self.rank == 0: self.writer.add_scalar(key, val, global_step=global_step) 98 | 99 | def add_image(self, global_step, key, image): 100 | if self.rank == 0: 101 | image = image.mul(255).add_(0.5).clamp_(0, 255).to("cpu", torch.uint8) 102 | self.writer.add_image(key, image, global_step=global_step) 103 | 104 | def close(self): 105 | if self.rank == 0: self.writer.close() 106 | 107 | def build_log_writer(opt): 108 | if opt.log_writer == 'wandb': return WandBWriter(opt) 109 | elif opt.log_writer == 'tensorboard': return TensorBoardWriter(opt) 110 | else: return BaseWriter(opt) # do nothing 111 | 112 | def count_parameters(model): 113 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 114 | 115 | def space_indices(num_steps, count): 116 | assert count <= num_steps 117 | 118 | if count <= 1: 119 | frac_stride = 1 120 | else: 121 | frac_stride = (num_steps - 1) / (count - 1) 122 | 123 | cur_idx = 0.0 124 | taken_steps = [] 125 | for _ in range(count): 126 | taken_steps.append(round(cur_idx)) 127 | cur_idx += frac_stride 128 | 129 | return taken_steps 130 | 131 | def unsqueeze_xdim(z, xdim): 132 | bc_dim = (...,) + (None,) * len(xdim) 133 | return z[bc_dim] 134 | 135 | def merge(opt,x,v): 136 | dim=-1 if opt.exp=='toy' else -3 137 | return torch.cat([x,v], dim=-1) 138 | 139 | def flatten_dim01(x): 140 | # (dim0, dim1, *dim2) --> (dim0x1, *dim2) 141 | return x.reshape(-1, *x.shape[2:]) 142 | 143 | # convert to colored strings 144 | def red(content): return termcolor.colored(str(content),"red",attrs=["bold"]) 145 | def green(content): return termcolor.colored(str(content),"green",attrs=["bold"]) 146 | def blue(content): return termcolor.colored(str(content),"blue",attrs=["bold"]) 147 | def cyan(content): return termcolor.colored(str(content),"cyan",attrs=["bold"]) 148 | def yellow(content): return termcolor.colored(str(content),"yellow",attrs=["bold"]) 149 | def magenta(content): return termcolor.colored(str(content),"magenta",attrs=["bold"]) 150 | 151 | 152 | def count_parameters(model): 153 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 154 | 155 | def reshape_as(x,y): 156 | len_y = len(y.shape)-1 157 | return x.reshape(-1,*([1,]*len_y)) 158 | 159 | def all_cat_cpu(opt, log, t): 160 | if not opt.distributed: return t.detach().cpu() 161 | log_flag = log if opt.local_rank == 0 else None 162 | gathered_t = dist_util.all_gather(t.to(opt.device), log=log_flag) 163 | return torch.cat(gathered_t).detach().cpu() 164 | 165 | def cast_shape(x,dims): 166 | return x.reshape(-1,*([1,]*len(dims))) 167 | 168 | def uniform_ts(t0,T,n,device): 169 | _ts = uniform(n,t0,T,device) 170 | return _ts 171 | 172 | def debug_ts(t0,T,n,device): 173 | return torch.linspace(t0, T, n, device=device) 174 | 175 | def heuristic2_ts(t0,T,n,device): 176 | _ts = torch.randn(n,device=device)*0.1+T 177 | _ts = _ts.abs() 178 | invalid_item=torch.logical_or(_ts>T, _tsT, _tsT, _ts 25 | First Image 26 | Second Image 27 |

28 | 34 | 35 | --- 36 | 37 | ## Stroke-based Generative Modeling 38 | We can achieve stroke based Generative Modeling without further fine tuning and training given pretrained AGM model. 39 |
40 | 41 |
42 | 43 | --- 44 | ## Requirement 45 | For `pip` installation: 46 | ``` 47 | bash setup/setup.sh 48 | ``` 49 | For `Conda` installation: 50 | ``` 51 | conda env create -f setup/environments.yml 52 | conda activate agm 53 | bash setup/conda_install.sh 54 | ``` 55 | --- 56 | ## **Training** 57 | * **Download dataset**: You need to download the dataset and put the file under `/dataeset/`. `CIFAR-10` is download autmomatically. For `AFHQv2` and `Imagenet` we follow the same pipline as EDM. The example command line code for downloading the AFHQv2 can be found in the comment line in `setup/setup.sh`. You may have to download `Imagenet` dataset by yourself by following [EDM](https://github.com/NVlabs/edm) repo. 58 | 59 | * **Training**:Here we provide the command line for reproducing training used in our paper. You can add the argument `--log-writer wandb --wandb-user [Your-User-Name] --wandb-api-key [Your-Wandb-Key]` for monitoring the training process. 60 | 61 | **Toy** 62 | ``` 63 | python train.py --name train-toy/sde-spiral --exp toy --toy-exp spiral #SDE 64 | python train.py --name train-toy/ode-spiral --exp toy --toy-exp spiral --DE-type probODE --solver gDDIM #ODE 65 | ``` 66 | **Cifar10**: 67 | ``` 68 | python train.py --name cifar10-repo --exp cifar10 --n-gpu-per-node 8 69 | ``` 70 | **AFHQv2**: 71 | ``` 72 | python train.py --name AFHQv2-repo --exp AFHQv2 --n-gpu-per-node 8 73 | ``` 74 | **ImageNet64**: 75 | ``` 76 | python train.py --name imagenet-repo --exp imagenet64 --n-gpu-per-node 8 --exp imagenet64 --num-itr 5000000 # Unconditional Generation 77 | ``` 78 | --- 79 | ## **Sampling** 80 | Before sampling, make srue you download the checkpoint and store them in `results/Cifar10-ODE/`,`results/AFHQv2-ODE/` and `results/uncond-ImageNet64-ODE/` folder. 81 | 82 | Here we provide a short example generation command of generating 64 images on single RTX 3090 for CIFAR10, AFHQv2 and Imagenet: 83 | ``` 84 | bash scripts/example.sh 85 | ``` 86 | The corresponding generating time on single RTX 3090 is as following: 87 | | |NFE | ETA time | 88 | |----------|----------|----------| 89 | | CIFAR-10 | 20 | ~6 sec | 90 | | AFHQv2 | 20 | ~10 sec | 91 | | ImageNet | 20 | ~15 sec | 92 | 93 | ### **Toy datast Generation**: 94 | When you have trained the model, you can load it for fast sampling: 95 | ``` 96 | #SDE with 10 NFEs 97 | python train.py --name train-toy/sde-spiral-eval --exp toy --toy-exp spiral --eval --nfe 10 --ckpt train-toy/sde-spiral 98 | #ODE with 10 NFEs 99 | python train.py --name train-toy/ode-spiral-eval --exp toy --toy-exp spiral --eval --nfe 10 --DE-type probODE --solver gDDIM --ckpt train-toy/ode-spiral 100 | ``` 101 | 102 | ### **CIFAR-10 Generation and Evaluation** 103 | Using following command line to generate data. The generated images will be saved in `EVAL/cifar10-nfe[x]` 104 | ``` 105 | #NFE=5 FID=11.88 #4.5mins for sampling 50k images 106 | python sampling.py --n-gpu-per-node 1 --ckpt --ckpt Cifar10-ODE/latest.pt --pred-x1 --solver gDDIM --T 0.4 --nfe 5 --fid-save-name cifar10-nfe5 --num-sample 50000 --batch-size 1000 107 | 108 | #NFE=10 FID=4.54 109 | python sampling.py --n-gpu-per-node 1 --ckpt --ckpt Cifar10-ODE/latest.pt --pred-x1 --solver gDDIM --T 0.7 --nfe 10 --fid-save-name cifar10-nfe10 --num-sample 50000 --batch-size 1000 110 | 111 | #NFE=20 FID=2.58 112 | python sampling.py --n-gpu-per-node 1 --ckpt --ckpt Cifar10-ODE/latest.pt --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name cifar10-nfe20 --num-sample 50000 --batch-size 1000 113 | ``` 114 | Evaluate the FID using EDM evaluation: 115 | ``` 116 | # x can be in {5,10,20} 117 | torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe[x] --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 118 | ``` 119 | 120 | ### **AFHQv2 Generation and Evaluation** 121 | Using following command line to generate data. The generated images will be saved in `EVAL/AFHQv2-nfe[x]` 122 | ``` 123 | #Stroke-based generation 124 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQv2-ODE/latest.pt --pred-x1 --solver gDDIM --save-img --img-save-name stroke-AFHQv2 --nfe 100 --T 0.999 --num-sample 64 --batch-size 64 --stroke-path dataset/StrokeData/testFig0.png --stroke-type dyn-v # [you can also replace --stroke-type by init-v] 125 | 126 | #Impainting generation 127 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQv2-ODE/latest.pt --pred-x1 --solver gDDIM --save-img --img-save-name impainting-AFHQv2 --nfe 100 --T 0.999 --num-sample 64 --batch-size 64 --stroke-path dataset/StrokeData/testFig0_impainting.png --stroke-type dyn-v --impainting 128 | 129 | #NFE 20 FID=3.72 130 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQv2-ODE/latest.pt --pred-x1 --solver gDDIM --fid-save-name AFHQv2-nfe20 --nfe 20 --T 0.9 --num-sample 50000 --batch-size 250 131 | ``` 132 | Evaluate the FID using EDM evaluation: 133 | ``` 134 | torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/AFHQv2-nfe20 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz 135 | ``` 136 | 137 | ### **Imagenet64 Generation and Evaluation** 138 | Using following command line to generate data. The generated images will be saved in `EVAL/AFHQv2-nfe[x]` 139 | ``` 140 | #NFE 20 FID=10.55 141 | python sampling.py --n-gpu-per-node 1 --ckpt uncond-ImageNet64-ODE/latest.pt --pred-x1 --solver gDDIM --fid-save-name ImageNet64-nfe20 --nfe 20 --T 0.99 --num-sample 50000 --batch-size 100 142 | #NFE 30 FID=10.07 143 | python sampling.py --n-gpu-per-node 1 --ckpt uncond-ImageNet64-ODE/latest.pt --pred-x1 --solver gDDIM --fid-save-name ImageNet64-nfe30 --nfe 30 --T 0.99 --num-sample 50000 --batch-size 100 144 | #NFE 40 FID=10.10 145 | python sampling.py --n-gpu-per-node 1 --ckpt uncond-ImageNet64-ODE/latest.pt --pred-x1 --solver gDDIM --fid-save-name ImageNet64-nfe40 --nfe 40 --T 0.9 --num-sample 50000 --batch-size 100 146 | 147 | ``` 148 | Evaluate the FID using EDM evaluation (x is in [20,40]): 149 | ``` 150 | torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/ImageNet64-nfe[x] --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/imagenet-64x64.npz 151 | ``` -------------------------------------------------------------------------------- /asset/AGMvsFM.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/AGMvsFM.gif -------------------------------------------------------------------------------- /asset/cond_gen.001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/cond_gen.001.png -------------------------------------------------------------------------------- /asset/cond_gen.jpg.001.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/cond_gen.jpg.001.jpeg -------------------------------------------------------------------------------- /asset/cond_gen.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/cond_gen.pdf -------------------------------------------------------------------------------- /asset/cond_gen2.001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/cond_gen2.001.png -------------------------------------------------------------------------------- /asset/sampling_hop.png.001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/sampling_hop.png.001.png -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/configs/.DS_Store -------------------------------------------------------------------------------- /configs/afhqv2_config.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import ml_collections 6 | def get_afhqv2_default_configs(): 7 | config = ml_collections.ConfigDict() 8 | # training 9 | config.training = ml_collections.ConfigDict() 10 | config.seed = 42 11 | config.microbatch = 64 12 | config.n_gpu_per_node = 8 13 | config.lr = 1e-3 14 | config.precond = True 15 | config.reweight_type = 'reciprocal' 16 | config.t_samp = 'uniform' 17 | config.num_itr = 600000 18 | 19 | config.data_dim = [3,64,64] 20 | config.joint_dim = [6,64,64] 21 | #data 22 | config.xflip = True 23 | config.exp = 'AFHQv2' 24 | 25 | #Dynamics 26 | config.t0 = 1e-5 27 | config.T = 0.999 28 | config.dyn_type = 'TVgMPC' 29 | config.clip_grad = 1 30 | config.damp_t = 1 31 | config.p = 3 32 | config.k = 0.2 33 | config.varx = 1 34 | config.varv = 1 35 | config.DE_type = 'probODE' 36 | 37 | #Evaluation during training 38 | config.nfe = 100 #Evaluation interval during training, can be replaced 39 | config.solver = 'gDDIM' 40 | config.diz_order = 2 41 | config.diz = 'rev-quad' 42 | config.train_fid_sample = 4096 #Number of sample to evaluate FID during training 43 | 44 | model_configs = get_edm_NCSNpp_config() 45 | # model_configs=get_Unet_config() 46 | return config, model_configs 47 | 48 | def get_edm_NCSNpp_config(): 49 | config = ml_collections.ConfigDict() 50 | config.image_size = 64 51 | config.name = "SongUNet" 52 | config.embedding_type = "fourier" 53 | config.encoder_type = "residual" 54 | config.decoder_type = "standard" 55 | config.resample_filter = [1,3,3,1] 56 | config.model_channels = 128 57 | config.channel_mult = [1,2,2,2] 58 | config.channel_mult_noise = 2 59 | config.dropout = 0.25 60 | config.label_dropout = 0 61 | config.channel_mult_emb = 4 62 | config.num_blocks = 4 63 | config.attn_resolutions = [16] 64 | config.in_channels = 6 65 | config.out_channels = 3 66 | config.label_dim = 0 67 | config.augment_dim = 0 68 | return config -------------------------------------------------------------------------------- /configs/cifar10_config.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import ml_collections 6 | def get_cifar10_default_configs(): 7 | config = ml_collections.ConfigDict() 8 | # training 9 | config.training = ml_collections.ConfigDict() 10 | config.seed = 42 11 | config.microbatch = 64 12 | config.n_gpu_per_node = 8 13 | config.lr = 1e-3 14 | config.precond = True 15 | config.reweight_type = 'reciprocal' 16 | config.t_samp = 'uniform' 17 | config.num_itr = 600000 18 | 19 | config.data_dim = [3,32,32] 20 | config.joint_dim = [6,32,32] 21 | #data 22 | config.xflip = True 23 | config.exp = 'cifar10' 24 | 25 | #Dynamics 26 | config.t0 = 1e-5 27 | config.T = 0.999 28 | config.dyn_type = 'TVgMPC' 29 | # config.algo = 'DM' 30 | config.clip_grad = 1 31 | config.damp_t = 1 32 | config.p = 3 33 | config.k = 0.2 34 | config.varx = 1 35 | config.varv = 1 36 | config.DE_type = 'probODE' 37 | 38 | #Evaluation during training 39 | config.nfe = 100 #Evaluation interval during training, can be replaced 40 | config.solver = 'gDDIM' 41 | config.diz_order = 2 42 | config.diz = 'rev-quad' 43 | config.train_fid_sample = 4096 #Number of sample to evaluate FID during training 44 | 45 | model_configs = get_edm_NCSNpp_config() 46 | # model_configs=get_Unet_config() 47 | return config, model_configs 48 | 49 | def get_Unet_config(): 50 | config = ml_collections.ConfigDict() 51 | config.name = 'Unet' 52 | config.attention_resolutions = '16,8' 53 | config.in_channels = 6 54 | config.out_channel = 3 55 | config.num_head = 4 56 | config.num_res_blocks = 2 57 | config.num_channels = 128 58 | config.dropout = 0.1 59 | config.channel_mult = (1, 2, 2, 2) 60 | config.image_size = 32 61 | return config 62 | 63 | 64 | def get_NCSNpp_config(): 65 | config = ml_collections.ConfigDict() 66 | config.name = "ncsnpp" 67 | config.normalization = "GroupNorm" 68 | config.image_size = 32 69 | config.image_channels = 3 70 | config.nonlinearity = "swish" 71 | config.n_channels = 128 72 | config.ch_mult = '1,2,2,2' 73 | config.attn_resolutions = '16' 74 | config.resamp_with_conv = True 75 | config.use_fir = True 76 | config.fir_kernel = '1,3,3,1' 77 | config.skip_rescale = True 78 | config.resblock_type = "biggan" 79 | config.progressive = 'none' 80 | config.progressive_input = "residual" 81 | config.progressive_combine = "sum" 82 | config.attention_type = "ddpm" 83 | config.init_scale = 0.0 84 | config.fourier_scale = 16 85 | config.conv_size = '3' 86 | config.embedding_type = "fourier" 87 | config.n_resblocks = 8 88 | config.dropout = 0.1 89 | # config.dropout = 0.2 90 | return config 91 | 92 | def get_edm_NCSNpp_config(): 93 | config = ml_collections.ConfigDict() 94 | config.image_size = 32 95 | config.name = "SongUNet" 96 | config.embedding_type = "fourier" 97 | config.encoder_type = "residual" 98 | config.decoder_type = "standard" 99 | config.resample_filter = [1,3,3,1] 100 | config.model_channels = 128 101 | config.channel_mult = [2,2,2] 102 | config.channel_mult_noise = 2 103 | config.dropout = 0.13 104 | config.label_dropout = 0 105 | config.channel_mult_emb = 4 106 | config.num_blocks = 4 107 | config.attn_resolutions = [16] 108 | config.in_channels = 6 109 | config.out_channels = 3 110 | config.augment_dim = 0 111 | return config -------------------------------------------------------------------------------- /configs/imagenet64_config.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import ml_collections 6 | def get_imagenet64_default_configs(): 7 | config = ml_collections.ConfigDict() 8 | # basic 9 | config.training = ml_collections.ConfigDict() 10 | config.seed = 42 11 | config.microbatch = 64 12 | config.n_gpu_per_node = 8 13 | config.lr = 2e-4 14 | config.precond = True 15 | config.reweight_type = 'reciprocal' 16 | config.t_samp = 'uniform' 17 | config.num_itr = 10000000 18 | config.data_dim = [3,64,64] 19 | config.joint_dim = [6,64,64] 20 | #data 21 | config.xflip = True 22 | config.exp = 'imagenet64' 23 | 24 | #Dynamics 25 | config.t0 = 1e-5 26 | config.T = 0.999 27 | config.dyn_type = 'TVgMPC' 28 | # config.algo = 'DM' 29 | config.clip_grad = 1 30 | config.damp_t = 1 31 | config.p = 3 32 | config.k = 0.2 33 | config.varx = 1 34 | config.varv = 1 35 | config.DE_type = 'probODE' 36 | 37 | #Evaluation during training 38 | config.nfe = 100 #Evaluation interval during training, can be replaced 39 | config.solver = 'gDDIM' 40 | config.diz_order = 2 41 | config.diz = 'rev-quad' 42 | config.train_fid_sample = 4096 #Number of sample to evaluate FID during training 43 | 44 | 45 | model_configs = get_edm_ADM_config() 46 | # model_configs=get_Unet_config() 47 | return config, model_configs 48 | 49 | def get_edm_ADM_config(): 50 | config = ml_collections.ConfigDict() 51 | config.image_size = 64 52 | config.name = "ADM" 53 | config.in_channels = 6 54 | config.out_channels = 3 55 | return config 56 | -------------------------------------------------------------------------------- /configs/toy_config.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import ml_collections 6 | 7 | def get_toy_default_configs(): 8 | config = ml_collections.ConfigDict() 9 | # training 10 | # config.training = ml_collections.ConfigDict() 11 | config.seed = 42 12 | config.num_itr = 40001 13 | config.t0 = 1e-5 14 | config.debug = True 15 | config.microbatch = 2048 16 | config.nfe = 200 17 | config.DE_type = 'SDE' 18 | config.t_samp = 'uniform' 19 | config.diz = 'Euler' 20 | config.solver = 'sscs' 21 | config.exp = 'toy' 22 | config.lr = 1e-3 23 | config.dyn_type = 'TVgMPC' 24 | config.T = 0.999 25 | config.p = 3 26 | config.k = 0.2 27 | config.varx = 1 28 | config.varv = 1 29 | config.data_dim = [2] 30 | config.joint_dim = [4] 31 | config.reweight_type = 'reciprocal' 32 | 33 | model_configs=None 34 | return config, model_configs 35 | -------------------------------------------------------------------------------- /dataset/AFHQv2.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | """Streaming images and labels from datasets created with dataset_tool.py.""" 7 | # /home/iamctr/Desktop/ACDS/AM/playground/refac-CG/AM-dev/dataset/AFHQv2.py 8 | from edm.dataset import ImageFolderDataset 9 | # /home/iamctr/Desktop/ACDS/AM/playground/refac-CG/AM-dev/edm/dataset.py 10 | from AM import util 11 | try: 12 | import pyspng 13 | except ImportError: 14 | pyspng = None 15 | 16 | class AFHQv2_data(): 17 | """cifar10 dataset.""" 18 | def __init__(self, opt): 19 | self.opt = opt 20 | bs = opt.microbatch 21 | x1 = ImageFolderDataset(path='dataset/afhqv2-64x64.zip') 22 | self.loader = util.setup_loader(x1,bs,num_workers=opt.n_gpu_per_node) 23 | assert not opt.cond 24 | 25 | def sample(self): 26 | x1 = next(self.loader)[0] 27 | x1 = x1/ 127.5 - 1 28 | label = None 29 | 30 | return x1.to(self.opt.device),label 31 | -------------------------------------------------------------------------------- /dataset/StrokeData/testFig0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/dataset/StrokeData/testFig0.png -------------------------------------------------------------------------------- /dataset/StrokeData/testFig0_impainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/dataset/StrokeData/testFig0_impainting.png -------------------------------------------------------------------------------- /dataset/StrokeData/testFig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/dataset/StrokeData/testFig1.png -------------------------------------------------------------------------------- /dataset/StrokeData/testFig1_impainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/dataset/StrokeData/testFig1_impainting.png -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # -------------------------------------------------------------------------------- /dataset/cifar10.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | from torchvision import transforms 7 | import torchvision.datasets as datasets 8 | import sys 9 | sys.path.append("..") # Adds higher directory to python modules path. 10 | from AM import util 11 | def tmp_fnc(t): 12 | '''Known issue for transforms in DDP settting''' 13 | return (t * 2) - 1 14 | 15 | class cifar10_data(): 16 | """cifar10 dataset.""" 17 | def __init__(self, opt): 18 | self.opt = opt 19 | bs = opt.microbatch 20 | 21 | x1 = self.generate_x1() 22 | self.loader = util.setup_loader(x1,bs) 23 | 24 | def generate_x1(self): 25 | transforms_list = [transforms.RandomHorizontalFlip(p=0.5)] if self.opt.xflip else [] 26 | transforms_list+=[ 27 | transforms.ToTensor(), #Convert to [0,1] 28 | transforms.Lambda(tmp_fnc) #Convert to [-1,1] 29 | ] 30 | x1=datasets.CIFAR10( 31 | './dataset', 32 | train= True, 33 | download=True, 34 | transform=transforms.Compose(transforms_list) 35 | ) 36 | return x1 37 | 38 | def sample(self): 39 | x1,label = next(self.loader)[0],next(self.loader)[1] 40 | label =torch.nn.functional.one_hot(label, num_classes=10) 41 | return x1.to(self.opt.device),label.to(self.opt.device) if self.opt.cond else None 42 | -------------------------------------------------------------------------------- /dataset/imagenet64.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | """Streaming images and labels from datasets created with dataset_tool.py.""" 6 | from edm.dataset import ImageFolderDataset 7 | from AM import util 8 | try: 9 | import pyspng 10 | except ImportError: 11 | pyspng = None 12 | 13 | class imagenet64_data(): 14 | """cifar10 dataset.""" 15 | def __init__(self, opt): 16 | self.opt = opt 17 | bs = opt.microbatch 18 | x1 = ImageFolderDataset(path='dataset/imagenet-64x64.zip',use_labels=opt.cond,xflip=opt.xflip) 19 | self.loader = util.setup_loader(x1,bs,num_workers=opt.n_gpu_per_node) 20 | 21 | def sample(self): 22 | x1,label = next(self.loader)[0],next(self.loader)[1] #[bs, dims],[bs,one_hot] 23 | x1 = x1/ 127.5 - 1 24 | return x1.to(self.opt.device),label.to(self.opt.device) if self.opt.cond else None 25 | #---------------------------------------------------------------------------- 26 | # Abstract base class for datasets. 27 | -------------------------------------------------------------------------------- /dataset/spiral.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | from sklearn.datasets import make_swiss_roll 9 | 10 | class spiral_data(Dataset): 11 | """Toy Spiral dataset.""" 12 | def __init__(self, opt): 13 | n_train = opt.n_train 14 | self.opt = opt 15 | self.x1 = self.generate_x1(n_train) 16 | self.bs = opt.microbatch 17 | def generate_x1(self,n): 18 | ''' 19 | n: number of total samples 20 | ''' 21 | if self.opt.toy_exp=='gmm': 22 | WIDTH = 3 23 | BOUND = 0.5 24 | NOISE = 0.04 25 | ROTATION_MATRIX = np.array([[1., -1.], [1., 1.]]) / np.sqrt(2.) 26 | 27 | means = np.array([(x, y) for x in np.linspace(-BOUND, BOUND, WIDTH) 28 | for y in np.linspace(-BOUND, BOUND, WIDTH)]) 29 | means = means @ ROTATION_MATRIX 30 | covariance_factor = NOISE * np.eye(2) 31 | 32 | index = np.random.choice( 33 | range(WIDTH ** 2), size=n, replace=True) 34 | noise = np.random.randn(n, 2) 35 | data = means[index] + noise @ covariance_factor 36 | data=torch.from_numpy(data.astype('float32')) 37 | data=data.to(self.opt.device) 38 | elif self.opt.toy_exp=='spiral': 39 | NOISE = 0.3 40 | MULTIPLIER = 0.05 41 | OFFSETS = [[1.2, 1.2], [1.2, -1.2], [-1.2, -1.2], [-1.2, 1.2]] 42 | 43 | idx = np.random.multinomial(n, [0.2] * 5, size=1)[0] 44 | 45 | sr = [] 46 | for k in range(5): 47 | sr.append(make_swiss_roll(int(idx[k]), noise=NOISE)[ 48 | 0][:, [0, 2]].astype('float32') * MULTIPLIER) 49 | 50 | if k > 0: 51 | sr[k] += np.array(OFFSETS[k - 1]).reshape(-1, 2) 52 | 53 | data = np.concatenate(sr, axis=0)[np.random.permutation(n)] 54 | data = torch.from_numpy(data.astype('float32')) 55 | data=data.to(self.opt.device) 56 | else: 57 | raise RuntimeError 58 | 59 | return data 60 | 61 | def __len__(self): 62 | return self.x1.shape[0] 63 | 64 | def sample(self): 65 | bs = self.bs 66 | lenth = self.x1.shape[0] 67 | idx = torch.randint(lenth, (bs,),device=self.opt.device) 68 | x1 = self.x1[idx] 69 | return x1, None 70 | -------------------------------------------------------------------------------- /edm/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import zipfile 4 | import PIL.Image 5 | import json 6 | import torch 7 | from edm import dnnlib 8 | 9 | try: 10 | import pyspng 11 | except ImportError: 12 | pyspng = None 13 | 14 | class Dataset(torch.utils.data.Dataset): 15 | def __init__(self, 16 | name, # Name of the dataset. 17 | raw_shape, # Shape of the raw image data (NCHW). 18 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 19 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 20 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 21 | random_seed = 0, # Random seed to use when applying max_size. 22 | cache = False, # Cache images in CPU memory? 23 | ): 24 | self._name = name 25 | self._raw_shape = list(raw_shape) 26 | self._use_labels = use_labels 27 | self._cache = cache 28 | self._cached_images = dict() # {raw_idx: np.ndarray, ...} 29 | self._raw_labels = None 30 | self._label_shape = None 31 | 32 | # Apply max_size. 33 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 34 | if (max_size is not None) and (self._raw_idx.size > max_size): 35 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) 36 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 37 | 38 | # Apply xflip. 39 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 40 | if xflip: 41 | self._raw_idx = np.tile(self._raw_idx, 2) 42 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 43 | 44 | def _get_raw_labels(self): 45 | if self._raw_labels is None: 46 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 47 | if self._raw_labels is None: 48 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 49 | assert isinstance(self._raw_labels, np.ndarray) 50 | assert self._raw_labels.shape[0] == self._raw_shape[0] 51 | assert self._raw_labels.dtype in [np.float32, np.int64] 52 | if self._raw_labels.dtype == np.int64: 53 | assert self._raw_labels.ndim == 1 54 | assert np.all(self._raw_labels >= 0) 55 | return self._raw_labels 56 | 57 | def close(self): # to be overridden by subclass 58 | pass 59 | 60 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 61 | raise NotImplementedError 62 | 63 | def _load_raw_labels(self): # to be overridden by subclass 64 | raise NotImplementedError 65 | 66 | def __getstate__(self): 67 | return dict(self.__dict__, _raw_labels=None) 68 | 69 | def __del__(self): 70 | try: 71 | self.close() 72 | except: 73 | pass 74 | 75 | def __len__(self): 76 | return self._raw_idx.size 77 | 78 | def __getitem__(self, idx): 79 | raw_idx = self._raw_idx[idx] 80 | image = self._cached_images.get(raw_idx, None) 81 | if image is None: 82 | image = self._load_raw_image(raw_idx) 83 | if self._cache: 84 | self._cached_images[raw_idx] = image 85 | assert isinstance(image, np.ndarray) 86 | assert list(image.shape) == self.image_shape 87 | assert image.dtype == np.uint8 88 | if self._xflip[idx]: 89 | assert image.ndim == 3 # CHW 90 | image = image[:, :, ::-1] 91 | return image.copy(), self.get_label(idx) 92 | 93 | def get_label(self, idx): 94 | label = self._get_raw_labels()[self._raw_idx[idx]] 95 | if label.dtype == np.int64: 96 | onehot = np.zeros(self.label_shape, dtype=np.float32) 97 | onehot[label] = 1 98 | label = onehot 99 | return label.copy() 100 | 101 | def get_details(self, idx): 102 | d = dnnlib.EasyDict() 103 | d.raw_idx = int(self._raw_idx[idx]) 104 | d.xflip = (int(self._xflip[idx]) != 0) 105 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 106 | return d 107 | 108 | @property 109 | def name(self): 110 | return self._name 111 | 112 | @property 113 | def image_shape(self): 114 | return list(self._raw_shape[1:]) 115 | 116 | @property 117 | def num_channels(self): 118 | assert len(self.image_shape) == 3 # CHW 119 | return self.image_shape[0] 120 | 121 | @property 122 | def resolution(self): 123 | assert len(self.image_shape) == 3 # CHW 124 | assert self.image_shape[1] == self.image_shape[2] 125 | return self.image_shape[1] 126 | 127 | @property 128 | def label_shape(self): 129 | if self._label_shape is None: 130 | raw_labels = self._get_raw_labels() 131 | if raw_labels.dtype == np.int64: 132 | self._label_shape = [int(np.max(raw_labels)) + 1] 133 | else: 134 | self._label_shape = raw_labels.shape[1:] 135 | return list(self._label_shape) 136 | 137 | @property 138 | def label_dim(self): 139 | assert len(self.label_shape) == 1 140 | return self.label_shape[0] 141 | 142 | @property 143 | def has_labels(self): 144 | return any(x != 0 for x in self.label_shape) 145 | 146 | @property 147 | def has_onehot_labels(self): 148 | return self._get_raw_labels().dtype == np.int64 149 | 150 | #---------------------------------------------------------------------------- 151 | # Dataset subclass that loads images recursively from the specified directory 152 | # or ZIP file. 153 | 154 | class ImageFolderDataset(Dataset): 155 | def __init__(self, 156 | path, # Path to directory or zip. 157 | resolution = None, # Ensure specific resolution, None = highest available. 158 | use_pyspng = True, # Use pyspng if available? 159 | **super_kwargs, # Additional arguments for the Dataset base class. 160 | ): 161 | self._path = path 162 | self._use_pyspng = use_pyspng 163 | self._zipfile = None 164 | 165 | if os.path.isdir(self._path): 166 | self._type = 'dir' 167 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 168 | elif self._file_ext(self._path) == '.zip': 169 | self._type = 'zip' 170 | self._all_fnames = set(self._get_zipfile().namelist()) 171 | else: 172 | raise IOError('Path must point to a directory or zip') 173 | 174 | PIL.Image.init() 175 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 176 | if len(self._image_fnames) == 0: 177 | raise IOError('No image files found in the specified path') 178 | 179 | name = os.path.splitext(os.path.basename(self._path))[0] 180 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 181 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 182 | raise IOError('Image files do not match the specified resolution') 183 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 184 | 185 | @staticmethod 186 | def _file_ext(fname): 187 | return os.path.splitext(fname)[1].lower() 188 | 189 | def _get_zipfile(self): 190 | assert self._type == 'zip' 191 | if self._zipfile is None: 192 | self._zipfile = zipfile.ZipFile(self._path) 193 | return self._zipfile 194 | 195 | def _open_file(self, fname): 196 | if self._type == 'dir': 197 | return open(os.path.join(self._path, fname), 'rb') 198 | if self._type == 'zip': 199 | return self._get_zipfile().open(fname, 'r') 200 | return None 201 | 202 | def close(self): 203 | try: 204 | if self._zipfile is not None: 205 | self._zipfile.close() 206 | finally: 207 | self._zipfile = None 208 | 209 | def __getstate__(self): 210 | return dict(super().__getstate__(), _zipfile=None) 211 | 212 | def _load_raw_image(self, raw_idx): 213 | fname = self._image_fnames[raw_idx] 214 | with self._open_file(fname) as f: 215 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': 216 | image = pyspng.load(f.read()) 217 | else: 218 | image = np.array(PIL.Image.open(f)) 219 | if image.ndim == 2: 220 | image = image[:, :, np.newaxis] # HW => HWC 221 | image = image.transpose(2, 0, 1) # HWC => CHW 222 | return image 223 | 224 | def _load_raw_labels(self): 225 | fname = 'dataset.json' 226 | if fname not in self._all_fnames: 227 | return None 228 | with self._open_file(fname) as f: 229 | labels = json.load(f)['labels'] 230 | if labels is None: 231 | return None 232 | labels = dict(labels) 233 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 234 | labels = np.array(labels) 235 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 236 | return labels 237 | 238 | #---------------------------------------------------------------------------- 239 | -------------------------------------------------------------------------------- /edm/dataset_tool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Tool for creating ZIP/PNG based datasets.""" 9 | 10 | import functools 11 | import gzip 12 | import io 13 | import json 14 | import os 15 | import pickle 16 | import re 17 | import sys 18 | import tarfile 19 | import zipfile 20 | from pathlib import Path 21 | from typing import Callable, Optional, Tuple, Union 22 | import click 23 | import numpy as np 24 | import PIL.Image 25 | from tqdm import tqdm 26 | 27 | #---------------------------------------------------------------------------- 28 | # Parse a 'M,N' or 'MxN' integer tuple. 29 | # Example: '4x2' returns (4,2) 30 | 31 | def parse_tuple(s: str) -> Tuple[int, int]: 32 | m = re.match(r'^(\d+)[x,](\d+)$', s) 33 | if m: 34 | return int(m.group(1)), int(m.group(2)) 35 | raise click.ClickException(f'cannot parse tuple {s}') 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | def maybe_min(a: int, b: Optional[int]) -> int: 40 | if b is not None: 41 | return min(a, b) 42 | return a 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def file_ext(name: Union[str, Path]) -> str: 47 | return str(name).split('.')[-1] 48 | 49 | #---------------------------------------------------------------------------- 50 | 51 | def is_image_ext(fname: Union[str, Path]) -> bool: 52 | ext = file_ext(fname).lower() 53 | return f'.{ext}' in PIL.Image.EXTENSION 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | def open_image_folder(source_dir, *, max_images: Optional[int]): 58 | input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)] 59 | arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images} 60 | max_idx = maybe_min(len(input_images), max_images) 61 | 62 | # Load labels. 63 | labels = dict() 64 | meta_fname = os.path.join(source_dir, 'dataset.json') 65 | if os.path.isfile(meta_fname): 66 | with open(meta_fname, 'r') as file: 67 | data = json.load(file)['labels'] 68 | if data is not None: 69 | labels = {x[0]: x[1] for x in data} 70 | 71 | # No labels available => determine from top-level directory names. 72 | if len(labels) == 0: 73 | toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()} 74 | toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))} 75 | if len(toplevel_indices) > 1: 76 | labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()} 77 | 78 | def iterate_images(): 79 | for idx, fname in enumerate(input_images): 80 | img = np.array(PIL.Image.open(fname)) 81 | yield dict(img=img, label=labels.get(arch_fnames.get(fname))) 82 | if idx >= max_idx - 1: 83 | break 84 | return max_idx, iterate_images() 85 | 86 | #---------------------------------------------------------------------------- 87 | 88 | def open_image_zip(source, *, max_images: Optional[int]): 89 | with zipfile.ZipFile(source, mode='r') as z: 90 | input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)] 91 | max_idx = maybe_min(len(input_images), max_images) 92 | 93 | # Load labels. 94 | labels = dict() 95 | if 'dataset.json' in z.namelist(): 96 | with z.open('dataset.json', 'r') as file: 97 | data = json.load(file)['labels'] 98 | if data is not None: 99 | labels = {x[0]: x[1] for x in data} 100 | 101 | def iterate_images(): 102 | with zipfile.ZipFile(source, mode='r') as z: 103 | for idx, fname in enumerate(input_images): 104 | with z.open(fname, 'r') as file: 105 | img = np.array(PIL.Image.open(file)) 106 | yield dict(img=img, label=labels.get(fname)) 107 | if idx >= max_idx - 1: 108 | break 109 | return max_idx, iterate_images() 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]): 114 | import cv2 # pyright: ignore [reportMissingImports] # pip install opencv-python 115 | import lmdb # pyright: ignore [reportMissingImports] # pip install lmdb 116 | 117 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: 118 | max_idx = maybe_min(txn.stat()['entries'], max_images) 119 | 120 | def iterate_images(): 121 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: 122 | for idx, (_key, value) in enumerate(txn.cursor()): 123 | try: 124 | try: 125 | img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1) 126 | if img is None: 127 | raise IOError('cv2.imdecode failed') 128 | img = img[:, :, ::-1] # BGR => RGB 129 | except IOError: 130 | img = np.array(PIL.Image.open(io.BytesIO(value))) 131 | yield dict(img=img, label=None) 132 | if idx >= max_idx - 1: 133 | break 134 | except: 135 | print(sys.exc_info()[1]) 136 | 137 | return max_idx, iterate_images() 138 | 139 | #---------------------------------------------------------------------------- 140 | 141 | def open_cifar10(tarball: str, *, max_images: Optional[int]): 142 | images = [] 143 | labels = [] 144 | 145 | with tarfile.open(tarball, 'r:gz') as tar: 146 | for batch in range(1, 6): 147 | member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}') 148 | with tar.extractfile(member) as file: 149 | data = pickle.load(file, encoding='latin1') 150 | images.append(data['data'].reshape(-1, 3, 32, 32)) 151 | labels.append(data['labels']) 152 | 153 | images = np.concatenate(images) 154 | labels = np.concatenate(labels) 155 | images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC 156 | assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8 157 | assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64] 158 | assert np.min(images) == 0 and np.max(images) == 255 159 | assert np.min(labels) == 0 and np.max(labels) == 9 160 | 161 | max_idx = maybe_min(len(images), max_images) 162 | 163 | def iterate_images(): 164 | for idx, img in enumerate(images): 165 | yield dict(img=img, label=int(labels[idx])) 166 | if idx >= max_idx - 1: 167 | break 168 | 169 | return max_idx, iterate_images() 170 | 171 | #---------------------------------------------------------------------------- 172 | 173 | def open_mnist(images_gz: str, *, max_images: Optional[int]): 174 | labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz') 175 | assert labels_gz != images_gz 176 | images = [] 177 | labels = [] 178 | 179 | with gzip.open(images_gz, 'rb') as f: 180 | images = np.frombuffer(f.read(), np.uint8, offset=16) 181 | with gzip.open(labels_gz, 'rb') as f: 182 | labels = np.frombuffer(f.read(), np.uint8, offset=8) 183 | 184 | images = images.reshape(-1, 28, 28) 185 | images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) 186 | assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 187 | assert labels.shape == (60000,) and labels.dtype == np.uint8 188 | assert np.min(images) == 0 and np.max(images) == 255 189 | assert np.min(labels) == 0 and np.max(labels) == 9 190 | 191 | max_idx = maybe_min(len(images), max_images) 192 | 193 | def iterate_images(): 194 | for idx, img in enumerate(images): 195 | yield dict(img=img, label=int(labels[idx])) 196 | if idx >= max_idx - 1: 197 | break 198 | 199 | return max_idx, iterate_images() 200 | 201 | #---------------------------------------------------------------------------- 202 | 203 | def make_transform( 204 | transform: Optional[str], 205 | output_width: Optional[int], 206 | output_height: Optional[int] 207 | ) -> Callable[[np.ndarray], Optional[np.ndarray]]: 208 | def scale(width, height, img): 209 | w = img.shape[1] 210 | h = img.shape[0] 211 | if width == w and height == h: 212 | return img 213 | img = PIL.Image.fromarray(img) 214 | ww = width if width is not None else w 215 | hh = height if height is not None else h 216 | img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS) 217 | return np.array(img) 218 | 219 | def center_crop(width, height, img): 220 | crop = np.min(img.shape[:2]) 221 | img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] 222 | if img.ndim == 2: 223 | img = img[:, :, np.newaxis].repeat(3, axis=2) 224 | img = PIL.Image.fromarray(img, 'RGB') 225 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 226 | return np.array(img) 227 | 228 | def center_crop_wide(width, height, img): 229 | ch = int(np.round(width * img.shape[0] / img.shape[1])) 230 | if img.shape[1] < width or ch < height: 231 | return None 232 | 233 | img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] 234 | if img.ndim == 2: 235 | img = img[:, :, np.newaxis].repeat(3, axis=2) 236 | img = PIL.Image.fromarray(img, 'RGB') 237 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 238 | img = np.array(img) 239 | 240 | canvas = np.zeros([width, width, 3], dtype=np.uint8) 241 | canvas[(width - height) // 2 : (width + height) // 2, :] = img 242 | return canvas 243 | 244 | if transform is None: 245 | return functools.partial(scale, output_width, output_height) 246 | if transform == 'center-crop': 247 | if output_width is None or output_height is None: 248 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform') 249 | return functools.partial(center_crop, output_width, output_height) 250 | if transform == 'center-crop-wide': 251 | if output_width is None or output_height is None: 252 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform') 253 | return functools.partial(center_crop_wide, output_width, output_height) 254 | assert False, 'unknown transform' 255 | 256 | #---------------------------------------------------------------------------- 257 | 258 | def open_dataset(source, *, max_images: Optional[int]): 259 | if os.path.isdir(source): 260 | if source.rstrip('/').endswith('_lmdb'): 261 | return open_lmdb(source, max_images=max_images) 262 | else: 263 | return open_image_folder(source, max_images=max_images) 264 | elif os.path.isfile(source): 265 | if os.path.basename(source) == 'cifar-10-python.tar.gz': 266 | return open_cifar10(source, max_images=max_images) 267 | elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': 268 | return open_mnist(source, max_images=max_images) 269 | elif file_ext(source) == 'zip': 270 | return open_image_zip(source, max_images=max_images) 271 | else: 272 | assert False, 'unknown archive type' 273 | else: 274 | raise click.ClickException(f'Missing input file or directory: {source}') 275 | 276 | #---------------------------------------------------------------------------- 277 | 278 | def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]: 279 | dest_ext = file_ext(dest) 280 | 281 | if dest_ext == 'zip': 282 | if os.path.dirname(dest) != '': 283 | os.makedirs(os.path.dirname(dest), exist_ok=True) 284 | zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED) 285 | def zip_write_bytes(fname: str, data: Union[bytes, str]): 286 | zf.writestr(fname, data) 287 | return '', zip_write_bytes, zf.close 288 | else: 289 | # If the output folder already exists, check that is is 290 | # empty. 291 | # 292 | # Note: creating the output directory is not strictly 293 | # necessary as folder_write_bytes() also mkdirs, but it's better 294 | # to give an error message earlier in case the dest folder 295 | # somehow cannot be created. 296 | if os.path.isdir(dest) and len(os.listdir(dest)) != 0: 297 | raise click.ClickException('--dest folder must be empty') 298 | os.makedirs(dest, exist_ok=True) 299 | 300 | def folder_write_bytes(fname: str, data: Union[bytes, str]): 301 | os.makedirs(os.path.dirname(fname), exist_ok=True) 302 | with open(fname, 'wb') as fout: 303 | if isinstance(data, str): 304 | data = data.encode('utf8') 305 | fout.write(data) 306 | return dest, folder_write_bytes, lambda: None 307 | 308 | #---------------------------------------------------------------------------- 309 | 310 | @click.command() 311 | @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True) 312 | @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True) 313 | @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int) 314 | @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide'])) 315 | @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple) 316 | 317 | def main( 318 | source: str, 319 | dest: str, 320 | max_images: Optional[int], 321 | transform: Optional[str], 322 | resolution: Optional[Tuple[int, int]] 323 | ): 324 | """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch. 325 | 326 | The input dataset format is guessed from the --source argument: 327 | 328 | \b 329 | --source *_lmdb/ Load LSUN dataset 330 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset 331 | --source train-images-idx3-ubyte.gz Load MNIST dataset 332 | --source path/ Recursively load all images from path/ 333 | --source dataset.zip Recursively load all images from dataset.zip 334 | 335 | Specifying the output format and path: 336 | 337 | \b 338 | --dest /path/to/dir Save output files under /path/to/dir 339 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 340 | 341 | The output dataset format can be either an image folder or an uncompressed zip archive. 342 | Zip archives makes it easier to move datasets around file servers and clusters, and may 343 | offer better training performance on network file systems. 344 | 345 | Images within the dataset archive will be stored as uncompressed PNG. 346 | Uncompresed PNGs can be efficiently decoded in the training loop. 347 | 348 | Class labels are stored in a file called 'dataset.json' that is stored at the 349 | dataset root folder. This file has the following structure: 350 | 351 | \b 352 | { 353 | "labels": [ 354 | ["00000/img00000000.png",6], 355 | ["00000/img00000001.png",9], 356 | ... repeated for every image in the datase 357 | ["00049/img00049999.png",1] 358 | ] 359 | } 360 | 361 | If the 'dataset.json' file cannot be found, class labels are determined from 362 | top-level directory names. 363 | 364 | Image scale/crop and resolution requirements: 365 | 366 | Output images must be square-shaped and they must all have the same power-of-two 367 | dimensions. 368 | 369 | To scale arbitrary input image size to a specific width and height, use the 370 | --resolution option. Output resolution will be either the original 371 | input resolution (if resolution was not specified) or the one specified with 372 | --resolution option. 373 | 374 | Use the --transform=center-crop or --transform=center-crop-wide options to apply a 375 | center crop transform on the input image. These options should be used with the 376 | --resolution option. For example: 377 | 378 | \b 379 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\ 380 | --transform=center-crop-wide --resolution=512x384 381 | """ 382 | 383 | PIL.Image.init() 384 | 385 | if dest == '': 386 | raise click.ClickException('--dest output filename or directory must not be an empty string') 387 | 388 | num_files, input_iter = open_dataset(source, max_images=max_images) 389 | archive_root_dir, save_bytes, close_dest = open_dest(dest) 390 | 391 | if resolution is None: resolution = (None, None) 392 | transform_image = make_transform(transform, *resolution) 393 | 394 | dataset_attrs = None 395 | 396 | labels = [] 397 | for idx, image in tqdm(enumerate(input_iter), total=num_files): 398 | idx_str = f'{idx:08d}' 399 | archive_fname = f'{idx_str[:5]}/img{idx_str}.png' 400 | 401 | # Apply crop and resize. 402 | img = transform_image(image['img']) 403 | if img is None: 404 | continue 405 | 406 | # Error check to require uniform image attributes across 407 | # the whole dataset. 408 | channels = img.shape[2] if img.ndim == 3 else 1 409 | cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0], 'channels': channels} 410 | if dataset_attrs is None: 411 | dataset_attrs = cur_image_attrs 412 | width = dataset_attrs['width'] 413 | height = dataset_attrs['height'] 414 | if width != height: 415 | raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}') 416 | if dataset_attrs['channels'] not in [1, 3]: 417 | raise click.ClickException('Input images must be stored as RGB or grayscale') 418 | if width != 2 ** int(np.floor(np.log2(width))): 419 | raise click.ClickException('Image width/height after scale and crop are required to be power-of-two') 420 | elif dataset_attrs != cur_image_attrs: 421 | err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] 422 | raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err)) 423 | 424 | # Save the image as an uncompressed PNG. 425 | img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels]) 426 | image_bits = io.BytesIO() 427 | img.save(image_bits, format='png', compress_level=0, optimize=False) 428 | save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer()) 429 | labels.append([archive_fname, image['label']] if image['label'] is not None else None) 430 | 431 | metadata = {'labels': labels if all(x is not None for x in labels) else None} 432 | save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata)) 433 | close_dest() 434 | 435 | #---------------------------------------------------------------------------- 436 | 437 | if __name__ == "__main__": 438 | main() 439 | 440 | #---------------------------------------------------------------------------- 441 | -------------------------------------------------------------------------------- /edm/distributed_util.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import torch.distributed as dist 8 | from torch.multiprocessing import Process 9 | def init_processes(rank, size, fn, args): 10 | """ Initialize the distributed environment. """ 11 | os.environ['MASTER_ADDR'] = args.master_address 12 | os.environ['MASTER_PORT'] = args.port 13 | torch.cuda.set_device(args.local_rank) 14 | dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=size) 15 | fn(args) 16 | dist.barrier() 17 | cleanup() 18 | 19 | def cleanup(): 20 | dist.destroy_process_group() 21 | 22 | def average_grads(params): 23 | size = float(dist.get_world_size()) 24 | for param in params: 25 | if param.requires_grad: 26 | # _average_tensor(param.grad, size) 27 | with torch.no_grad(): 28 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 29 | param.grad.data /= size 30 | 31 | def average_params(params): 32 | size = float(dist.get_world_size()) 33 | for param in params: 34 | # _average_tensor(param, size) 35 | with torch.no_grad(): 36 | dist.all_reduce(param.data, op=dist.ReduceOp.SUM) 37 | param.data /= size 38 | 39 | def sync_params(params): 40 | """ 41 | Synchronize a sequence of Tensors across ranks from rank 0. 42 | """ 43 | for p in params: 44 | with torch.no_grad(): 45 | dist.broadcast(p, 0) 46 | 47 | def all_gather(tensor, log=None): 48 | if log: log.info("Gathering tensor across {} devices... ".format(dist.get_world_size())) 49 | gathered_tensors = [ 50 | torch.zeros_like(tensor) for _ in range(dist.get_world_size()) 51 | ] 52 | with torch.no_grad(): 53 | dist.all_gather(gathered_tensors, tensor) 54 | return gathered_tensors -------------------------------------------------------------------------------- /edm/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /edm/dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import tempfile 27 | import urllib 28 | import urllib.request 29 | import uuid 30 | 31 | from distutils.util import strtobool 32 | from typing import Any, List, Tuple, Union, Optional 33 | 34 | 35 | # Util classes 36 | # ------------------------------------------------------------------------------------------ 37 | 38 | 39 | class EasyDict(dict): 40 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 41 | 42 | def __getattr__(self, name: str) -> Any: 43 | try: 44 | return self[name] 45 | except KeyError: 46 | raise AttributeError(name) 47 | 48 | def __setattr__(self, name: str, value: Any) -> None: 49 | self[name] = value 50 | 51 | def __delattr__(self, name: str) -> None: 52 | del self[name] 53 | 54 | 55 | class Logger(object): 56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 57 | 58 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True): 59 | self.file = None 60 | 61 | if file_name is not None: 62 | self.file = open(file_name, file_mode) 63 | 64 | self.should_flush = should_flush 65 | self.stdout = sys.stdout 66 | self.stderr = sys.stderr 67 | 68 | sys.stdout = self 69 | sys.stderr = self 70 | 71 | def __enter__(self) -> "Logger": 72 | return self 73 | 74 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 75 | self.close() 76 | 77 | def write(self, text: Union[str, bytes]) -> None: 78 | """Write text to stdout (and a file) and optionally flush.""" 79 | if isinstance(text, bytes): 80 | text = text.decode() 81 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 82 | return 83 | 84 | if self.file is not None: 85 | self.file.write(text) 86 | 87 | self.stdout.write(text) 88 | 89 | if self.should_flush: 90 | self.flush() 91 | 92 | def flush(self) -> None: 93 | """Flush written text to both stdout and a file, if open.""" 94 | if self.file is not None: 95 | self.file.flush() 96 | 97 | self.stdout.flush() 98 | 99 | def close(self) -> None: 100 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 101 | self.flush() 102 | 103 | # if using multiple loggers, prevent closing in wrong order 104 | if sys.stdout is self: 105 | sys.stdout = self.stdout 106 | if sys.stderr is self: 107 | sys.stderr = self.stderr 108 | 109 | if self.file is not None: 110 | self.file.close() 111 | self.file = None 112 | 113 | 114 | # Cache directories 115 | # ------------------------------------------------------------------------------------------ 116 | 117 | _dnnlib_cache_dir = None 118 | 119 | def set_cache_dir(path: str) -> None: 120 | global _dnnlib_cache_dir 121 | _dnnlib_cache_dir = path 122 | 123 | def make_cache_dir_path(*paths: str) -> str: 124 | if _dnnlib_cache_dir is not None: 125 | return os.path.join(_dnnlib_cache_dir, *paths) 126 | if 'DNNLIB_CACHE_DIR' in os.environ: 127 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 128 | if 'HOME' in os.environ: 129 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 130 | if 'USERPROFILE' in os.environ: 131 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 132 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 133 | 134 | # Small util functions 135 | # ------------------------------------------------------------------------------------------ 136 | 137 | 138 | def format_time(seconds: Union[int, float]) -> str: 139 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 140 | s = int(np.rint(seconds)) 141 | 142 | if s < 60: 143 | return "{0}s".format(s) 144 | elif s < 60 * 60: 145 | return "{0}m {1:02}s".format(s // 60, s % 60) 146 | elif s < 24 * 60 * 60: 147 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 148 | else: 149 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 150 | 151 | 152 | def format_time_brief(seconds: Union[int, float]) -> str: 153 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 154 | s = int(np.rint(seconds)) 155 | 156 | if s < 60: 157 | return "{0}s".format(s) 158 | elif s < 60 * 60: 159 | return "{0}m {1:02}s".format(s // 60, s % 60) 160 | elif s < 24 * 60 * 60: 161 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) 162 | else: 163 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) 164 | 165 | 166 | def ask_yes_no(question: str) -> bool: 167 | """Ask the user the question until the user inputs a valid answer.""" 168 | while True: 169 | try: 170 | print("{0} [y/n]".format(question)) 171 | return strtobool(input().lower()) 172 | except ValueError: 173 | pass 174 | 175 | 176 | def tuple_product(t: Tuple) -> Any: 177 | """Calculate the product of the tuple elements.""" 178 | result = 1 179 | 180 | for v in t: 181 | result *= v 182 | 183 | return result 184 | 185 | 186 | _str_to_ctype = { 187 | "uint8": ctypes.c_ubyte, 188 | "uint16": ctypes.c_uint16, 189 | "uint32": ctypes.c_uint32, 190 | "uint64": ctypes.c_uint64, 191 | "int8": ctypes.c_byte, 192 | "int16": ctypes.c_int16, 193 | "int32": ctypes.c_int32, 194 | "int64": ctypes.c_int64, 195 | "float32": ctypes.c_float, 196 | "float64": ctypes.c_double 197 | } 198 | 199 | 200 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 201 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 202 | type_str = None 203 | 204 | if isinstance(type_obj, str): 205 | type_str = type_obj 206 | elif hasattr(type_obj, "__name__"): 207 | type_str = type_obj.__name__ 208 | elif hasattr(type_obj, "name"): 209 | type_str = type_obj.name 210 | else: 211 | raise RuntimeError("Cannot infer type name from input") 212 | 213 | assert type_str in _str_to_ctype.keys() 214 | 215 | my_dtype = np.dtype(type_str) 216 | my_ctype = _str_to_ctype[type_str] 217 | 218 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 219 | 220 | return my_dtype, my_ctype 221 | 222 | 223 | def is_pickleable(obj: Any) -> bool: 224 | try: 225 | with io.BytesIO() as stream: 226 | pickle.dump(obj, stream) 227 | return True 228 | except: 229 | return False 230 | 231 | 232 | # Functionality to import modules/objects by name, and call functions by name 233 | # ------------------------------------------------------------------------------------------ 234 | 235 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 236 | """Searches for the underlying module behind the name to some python object. 237 | Returns the module and the object name (original name with module part removed).""" 238 | 239 | # allow convenience shorthands, substitute them by full names 240 | obj_name = re.sub("^np.", "numpy.", obj_name) 241 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 242 | 243 | # list alternatives for (module_name, local_obj_name) 244 | parts = obj_name.split(".") 245 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 246 | 247 | # try each alternative in turn 248 | for module_name, local_obj_name in name_pairs: 249 | try: 250 | module = importlib.import_module(module_name) # may raise ImportError 251 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 252 | return module, local_obj_name 253 | except: 254 | pass 255 | 256 | # maybe some of the modules themselves contain errors? 257 | for module_name, _local_obj_name in name_pairs: 258 | try: 259 | importlib.import_module(module_name) # may raise ImportError 260 | except ImportError: 261 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 262 | raise 263 | 264 | # maybe the requested attribute is missing? 265 | for module_name, local_obj_name in name_pairs: 266 | try: 267 | module = importlib.import_module(module_name) # may raise ImportError 268 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 269 | except ImportError: 270 | pass 271 | 272 | # we are out of luck, but we have no idea why 273 | raise ImportError(obj_name) 274 | 275 | 276 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 277 | """Traverses the object name and returns the last (rightmost) python object.""" 278 | if obj_name == '': 279 | return module 280 | obj = module 281 | for part in obj_name.split("."): 282 | obj = getattr(obj, part) 283 | return obj 284 | 285 | 286 | def get_obj_by_name(name: str) -> Any: 287 | """Finds the python object with the given name.""" 288 | module, obj_name = get_module_from_obj_name(name) 289 | return get_obj_from_module(module, obj_name) 290 | 291 | 292 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 293 | """Finds the python object with the given name and calls it as a function.""" 294 | assert func_name is not None 295 | func_obj = get_obj_by_name(func_name) 296 | assert callable(func_obj) 297 | return func_obj(*args, **kwargs) 298 | 299 | 300 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: 301 | """Finds the python class with the given name and constructs it with the given arguments.""" 302 | return call_func_by_name(*args, func_name=class_name, **kwargs) 303 | 304 | 305 | def get_module_dir_by_obj_name(obj_name: str) -> str: 306 | """Get the directory path of the module containing the given object name.""" 307 | module, _ = get_module_from_obj_name(obj_name) 308 | return os.path.dirname(inspect.getfile(module)) 309 | 310 | 311 | def is_top_level_function(obj: Any) -> bool: 312 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 313 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 314 | 315 | 316 | def get_top_level_function_name(obj: Any) -> str: 317 | """Return the fully-qualified name of a top-level function.""" 318 | assert is_top_level_function(obj) 319 | module = obj.__module__ 320 | if module == '__main__': 321 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] 322 | return module + "." + obj.__name__ 323 | 324 | 325 | # File system helpers 326 | # ------------------------------------------------------------------------------------------ 327 | 328 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 329 | """List all files recursively in a given directory while ignoring given file and directory names. 330 | Returns list of tuples containing both absolute and relative paths.""" 331 | assert os.path.isdir(dir_path) 332 | base_name = os.path.basename(os.path.normpath(dir_path)) 333 | 334 | if ignores is None: 335 | ignores = [] 336 | 337 | result = [] 338 | 339 | for root, dirs, files in os.walk(dir_path, topdown=True): 340 | for ignore_ in ignores: 341 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 342 | 343 | # dirs need to be edited in-place 344 | for d in dirs_to_remove: 345 | dirs.remove(d) 346 | 347 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 348 | 349 | absolute_paths = [os.path.join(root, f) for f in files] 350 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 351 | 352 | if add_base_to_relative: 353 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 354 | 355 | assert len(absolute_paths) == len(relative_paths) 356 | result += zip(absolute_paths, relative_paths) 357 | 358 | return result 359 | 360 | 361 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 362 | """Takes in a list of tuples of (src, dst) paths and copies files. 363 | Will create all necessary directories.""" 364 | for file in files: 365 | target_dir_name = os.path.dirname(file[1]) 366 | 367 | # will create all intermediate-level directories 368 | if not os.path.exists(target_dir_name): 369 | os.makedirs(target_dir_name) 370 | 371 | shutil.copyfile(file[0], file[1]) 372 | 373 | 374 | # URL helpers 375 | # ------------------------------------------------------------------------------------------ 376 | 377 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 378 | """Determine whether the given object is a valid URL string.""" 379 | if not isinstance(obj, str) or not "://" in obj: 380 | return False 381 | if allow_file_urls and obj.startswith('file://'): 382 | return True 383 | try: 384 | res = requests.compat.urlparse(obj) 385 | if not res.scheme or not res.netloc or not "." in res.netloc: 386 | return False 387 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 388 | if not res.scheme or not res.netloc or not "." in res.netloc: 389 | return False 390 | except: 391 | return False 392 | return True 393 | 394 | 395 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: 396 | """Download the given URL and return a binary-mode file object to access the data.""" 397 | assert num_attempts >= 1 398 | assert not (return_filename and (not cache)) 399 | 400 | # Doesn't look like an URL scheme so interpret it as a local filename. 401 | if not re.match('^[a-z]+://', url): 402 | return url if return_filename else open(url, "rb") 403 | 404 | # Handle file URLs. This code handles unusual file:// patterns that 405 | # arise on Windows: 406 | # 407 | # file:///c:/foo.txt 408 | # 409 | # which would translate to a local '/c:/foo.txt' filename that's 410 | # invalid. Drop the forward slash for such pathnames. 411 | # 412 | # If you touch this code path, you should test it on both Linux and 413 | # Windows. 414 | # 415 | # Some internet resources suggest using urllib.request.url2pathname() but 416 | # but that converts forward slashes to backslashes and this causes 417 | # its own set of problems. 418 | if url.startswith('file://'): 419 | filename = urllib.parse.urlparse(url).path 420 | if re.match(r'^/[a-zA-Z]:', filename): 421 | filename = filename[1:] 422 | return filename if return_filename else open(filename, "rb") 423 | 424 | assert is_url(url) 425 | 426 | # Lookup from cache. 427 | if cache_dir is None: 428 | cache_dir = make_cache_dir_path('downloads') 429 | 430 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 431 | if cache: 432 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 433 | if len(cache_files) == 1: 434 | filename = cache_files[0] 435 | return filename if return_filename else open(filename, "rb") 436 | 437 | # Download. 438 | url_name = None 439 | url_data = None 440 | with requests.Session() as session: 441 | if verbose: 442 | print("Downloading %s ..." % url, end="", flush=True) 443 | for attempts_left in reversed(range(num_attempts)): 444 | try: 445 | with session.get(url) as res: 446 | res.raise_for_status() 447 | if len(res.content) == 0: 448 | raise IOError("No data received") 449 | 450 | if len(res.content) < 8192: 451 | content_str = res.content.decode("utf-8") 452 | if "download_warning" in res.headers.get("Set-Cookie", ""): 453 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 454 | if len(links) == 1: 455 | url = requests.compat.urljoin(url, links[0]) 456 | raise IOError("Google Drive virus checker nag") 457 | if "Google Drive - Quota exceeded" in content_str: 458 | raise IOError("Google Drive download quota exceeded -- please try again later") 459 | 460 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 461 | url_name = match[1] if match else url 462 | url_data = res.content 463 | if verbose: 464 | print(" done") 465 | break 466 | except KeyboardInterrupt: 467 | raise 468 | except: 469 | if not attempts_left: 470 | if verbose: 471 | print(" failed") 472 | raise 473 | if verbose: 474 | print(".", end="", flush=True) 475 | 476 | # Save to cache. 477 | if cache: 478 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 479 | safe_name = safe_name[:min(len(safe_name), 128)] 480 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 481 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 482 | os.makedirs(cache_dir, exist_ok=True) 483 | with open(temp_file, "wb") as f: 484 | f.write(url_data) 485 | os.replace(temp_file, cache_file) # atomic 486 | if return_filename: 487 | return cache_file 488 | 489 | # Return data as file object. 490 | assert not return_filename 491 | return io.BytesIO(url_data) 492 | -------------------------------------------------------------------------------- /edm/fid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Script for calculating Frechet Inception Distance (FID).""" 9 | import os 10 | import click 11 | import tqdm 12 | import pickle 13 | import numpy as np 14 | import scipy.linalg 15 | import torch 16 | import sys 17 | import os 18 | sys.path.append(os.path.abspath('edm')) 19 | import dnnlib 20 | from torch_utils import distributed as dist 21 | import zipfile 22 | import PIL.Image 23 | try: 24 | import pyspng 25 | except ImportError: 26 | pyspng = None 27 | 28 | #---------------------------------------------------------------------------- 29 | # Abstract base class for datasets. 30 | 31 | class Dataset(torch.utils.data.Dataset): 32 | def __init__(self, 33 | name, # Name of the dataset. 34 | raw_shape, # Shape of the raw image data (NCHW). 35 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 36 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 37 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 38 | random_seed = 0, # Random seed to use when applying max_size. 39 | cache = False, # Cache images in CPU memory? 40 | ): 41 | self._name = name 42 | self._raw_shape = list(raw_shape) 43 | self._use_labels = use_labels 44 | self._cache = cache 45 | self._cached_images = dict() # {raw_idx: np.ndarray, ...} 46 | self._raw_labels = None 47 | self._label_shape = None 48 | 49 | # Apply max_size. 50 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 51 | if (max_size is not None) and (self._raw_idx.size > max_size): 52 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) 53 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 54 | 55 | # Apply xflip. 56 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 57 | if xflip: 58 | self._raw_idx = np.tile(self._raw_idx, 2) 59 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 60 | 61 | def _get_raw_labels(self): 62 | if self._raw_labels is None: 63 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 64 | if self._raw_labels is None: 65 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 66 | assert isinstance(self._raw_labels, np.ndarray) 67 | assert self._raw_labels.shape[0] == self._raw_shape[0] 68 | assert self._raw_labels.dtype in [np.float32, np.int64] 69 | if self._raw_labels.dtype == np.int64: 70 | assert self._raw_labels.ndim == 1 71 | assert np.all(self._raw_labels >= 0) 72 | return self._raw_labels 73 | 74 | def close(self): # to be overridden by subclass 75 | pass 76 | 77 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 78 | raise NotImplementedError 79 | 80 | def _load_raw_labels(self): # to be overridden by subclass 81 | raise NotImplementedError 82 | 83 | def __getstate__(self): 84 | return dict(self.__dict__, _raw_labels=None) 85 | 86 | def __del__(self): 87 | try: 88 | self.close() 89 | except: 90 | pass 91 | 92 | def __len__(self): 93 | return self._raw_idx.size 94 | 95 | def __getitem__(self, idx): 96 | raw_idx = self._raw_idx[idx] 97 | image = self._cached_images.get(raw_idx, None) 98 | if image is None: 99 | image = self._load_raw_image(raw_idx) 100 | if self._cache: 101 | self._cached_images[raw_idx] = image 102 | assert isinstance(image, np.ndarray) 103 | assert list(image.shape) == self.image_shape 104 | assert image.dtype == np.uint8 105 | if self._xflip[idx]: 106 | assert image.ndim == 3 # CHW 107 | image = image[:, :, ::-1] 108 | return image.copy(), self.get_label(idx) 109 | 110 | def get_label(self, idx): 111 | label = self._get_raw_labels()[self._raw_idx[idx]] 112 | if label.dtype == np.int64: 113 | onehot = np.zeros(self.label_shape, dtype=np.float32) 114 | onehot[label] = 1 115 | label = onehot 116 | return label.copy() 117 | 118 | def get_details(self, idx): 119 | d = dnnlib.EasyDict() 120 | d.raw_idx = int(self._raw_idx[idx]) 121 | d.xflip = (int(self._xflip[idx]) != 0) 122 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 123 | return d 124 | 125 | @property 126 | def name(self): 127 | return self._name 128 | 129 | @property 130 | def image_shape(self): 131 | return list(self._raw_shape[1:]) 132 | 133 | @property 134 | def num_channels(self): 135 | assert len(self.image_shape) == 3 # CHW 136 | return self.image_shape[0] 137 | 138 | @property 139 | def resolution(self): 140 | assert len(self.image_shape) == 3 # CHW 141 | assert self.image_shape[1] == self.image_shape[2] 142 | return self.image_shape[1] 143 | 144 | @property 145 | def label_shape(self): 146 | if self._label_shape is None: 147 | raw_labels = self._get_raw_labels() 148 | if raw_labels.dtype == np.int64: 149 | self._label_shape = [int(np.max(raw_labels)) + 1] 150 | else: 151 | self._label_shape = raw_labels.shape[1:] 152 | return list(self._label_shape) 153 | 154 | @property 155 | def label_dim(self): 156 | assert len(self.label_shape) == 1 157 | return self.label_shape[0] 158 | 159 | @property 160 | def has_labels(self): 161 | return any(x != 0 for x in self.label_shape) 162 | 163 | @property 164 | def has_onehot_labels(self): 165 | return self._get_raw_labels().dtype == np.int64 166 | 167 | #---------------------------------------------------------------------------- 168 | # Dataset subclass that loads images recursively from the specified directory 169 | # or ZIP file. 170 | 171 | class ImageFolderDataset(Dataset): 172 | def __init__(self, 173 | path, # Path to directory or zip. 174 | resolution = None, # Ensure specific resolution, None = highest available. 175 | use_pyspng = True, # Use pyspng if available? 176 | **super_kwargs, # Additional arguments for the Dataset base class. 177 | ): 178 | self._path = path 179 | self._use_pyspng = use_pyspng 180 | self._zipfile = None 181 | 182 | if os.path.isdir(self._path): 183 | self._type = 'dir' 184 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 185 | elif self._file_ext(self._path) == '.zip': 186 | self._type = 'zip' 187 | self._all_fnames = set(self._get_zipfile().namelist()) 188 | else: 189 | raise IOError('Path must point to a directory or zip') 190 | 191 | PIL.Image.init() 192 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 193 | if len(self._image_fnames) == 0: 194 | raise IOError('No image files found in the specified path') 195 | 196 | name = os.path.splitext(os.path.basename(self._path))[0] 197 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 198 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 199 | raise IOError('Image files do not match the specified resolution') 200 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 201 | 202 | @staticmethod 203 | def _file_ext(fname): 204 | return os.path.splitext(fname)[1].lower() 205 | 206 | def _get_zipfile(self): 207 | assert self._type == 'zip' 208 | if self._zipfile is None: 209 | self._zipfile = zipfile.ZipFile(self._path) 210 | return self._zipfile 211 | 212 | def _open_file(self, fname): 213 | if self._type == 'dir': 214 | return open(os.path.join(self._path, fname), 'rb') 215 | if self._type == 'zip': 216 | return self._get_zipfile().open(fname, 'r') 217 | return None 218 | 219 | def close(self): 220 | try: 221 | if self._zipfile is not None: 222 | self._zipfile.close() 223 | finally: 224 | self._zipfile = None 225 | 226 | def __getstate__(self): 227 | return dict(super().__getstate__(), _zipfile=None) 228 | 229 | def _load_raw_image(self, raw_idx): 230 | fname = self._image_fnames[raw_idx] 231 | with self._open_file(fname) as f: 232 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': 233 | image = pyspng.load(f.read()) 234 | else: 235 | image = np.array(PIL.Image.open(f)) 236 | if image.ndim == 2: 237 | image = image[:, :, np.newaxis] # HW => HWC 238 | image = image.transpose(2, 0, 1) # HWC => CHW 239 | return image 240 | 241 | def _load_raw_labels(self): 242 | fname = 'dataset.json' 243 | if fname not in self._all_fnames: 244 | return None 245 | with self._open_file(fname) as f: 246 | labels = json.load(f)['labels'] 247 | if labels is None: 248 | return None 249 | labels = dict(labels) 250 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 251 | labels = np.array(labels) 252 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 253 | return labels 254 | 255 | #---------------------------------------------------------------------------- 256 | 257 | 258 | #---------------------------------------------------------------------------- 259 | 260 | def calculate_inception_stats( 261 | image_path, num_expected=None, seed=0, max_batch_size=64, 262 | num_workers=4, prefetch_factor=2, device=torch.device('cuda'), 263 | ): 264 | # Rank 0 goes first. 265 | if dist.get_rank() != 0: 266 | torch.distributed.barrier() 267 | 268 | # Load Inception-v3 model. 269 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 270 | dist.print0('Loading Inception-v3 model...') 271 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 272 | detector_kwargs = dict(return_features=True) 273 | feature_dim = 2048 274 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f: 275 | detector_net = pickle.load(f).to(device) 276 | 277 | # List images. 278 | dist.print0(f'Loading images from "{image_path}"...') 279 | dataset_obj = ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed) 280 | if num_expected is not None and len(dataset_obj) < num_expected: 281 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}') 282 | if len(dataset_obj) < 2: 283 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics') 284 | 285 | # Other ranks follow. 286 | if dist.get_rank() == 0: 287 | torch.distributed.barrier() 288 | 289 | # Divide images into batches. 290 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 291 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches) 292 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 293 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor) 294 | 295 | # Accumulate statistics. 296 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...') 297 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device) 298 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device) 299 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)): 300 | torch.distributed.barrier() 301 | if images.shape[0] == 0: 302 | continue 303 | if images.shape[1] == 1: 304 | images = images.repeat([1, 3, 1, 1]) 305 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64) 306 | mu += features.sum(0) 307 | sigma += features.T @ features 308 | 309 | # Calculate grand totals. 310 | torch.distributed.all_reduce(mu) 311 | torch.distributed.all_reduce(sigma) 312 | mu /= len(dataset_obj) 313 | sigma -= mu.ger(mu) * len(dataset_obj) 314 | sigma /= len(dataset_obj) - 1 315 | return mu.cpu().numpy(), sigma.cpu().numpy() 316 | 317 | #---------------------------------------------------------------------------- 318 | 319 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref): 320 | m = np.square(mu - mu_ref).sum() 321 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False) 322 | fid = m + np.trace(sigma + sigma_ref - s * 2) 323 | return float(np.real(fid)) 324 | 325 | #---------------------------------------------------------------------------- 326 | 327 | @click.group() 328 | def main(): 329 | """Calculate Frechet Inception Distance (FID). 330 | 331 | Examples: 332 | 333 | \b 334 | # Generate 50000 images and save them as fid-tmp/*/*.png 335 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\ 336 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 337 | 338 | \b 339 | # Calculate FID 340 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 341 | 342 | \b 343 | # Compute dataset reference statistics 344 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz 345 | """ 346 | 347 | #---------------------------------------------------------------------------- 348 | 349 | @main.command() 350 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True) 351 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True) 352 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True) 353 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True) 354 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 355 | 356 | def calc(image_path, ref_path, num_expected, seed, batch): 357 | """Calculate FID for a given set of images.""" 358 | torch.multiprocessing.set_start_method('spawn') 359 | dist.init() 360 | 361 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...') 362 | ref = None 363 | if dist.get_rank() == 0: 364 | with dnnlib.util.open_url(ref_path) as f: 365 | ref = dict(np.load(f)) 366 | 367 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch) 368 | dist.print0('Calculating FID...') 369 | if dist.get_rank() == 0: 370 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma']) 371 | print(f'{fid:g}') 372 | torch.distributed.barrier() 373 | 374 | #---------------------------------------------------------------------------- 375 | 376 | @main.command() 377 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True) 378 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True) 379 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 380 | 381 | def ref(dataset_path, dest_path, batch): 382 | """Calculate dataset reference statistics needed by 'calc'.""" 383 | torch.multiprocessing.set_start_method('spawn') 384 | dist.init() 385 | 386 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch) 387 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...') 388 | if dist.get_rank() == 0: 389 | if os.path.dirname(dest_path): 390 | os.makedirs(os.path.dirname(dest_path), exist_ok=True) 391 | np.savez(dest_path, mu=mu, sigma=sigma) 392 | 393 | torch.distributed.barrier() 394 | dist.print0('Done.') 395 | 396 | #---------------------------------------------------------------------------- 397 | 398 | if __name__ == "__main__": 399 | main() 400 | 401 | # python dataset_tool.py --source=dataset/cifar-10-python.tar.gz --dest=dataset/cifar10-32x32.zip 402 | # python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz 403 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /edm/logger.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | import time 10 | import logging 11 | from rich.console import Console 12 | from rich.logging import RichHandler 13 | 14 | def get_time(sec): 15 | h = int(sec//3600) 16 | m = int((sec//60)%60) 17 | s = int(sec%60) 18 | return h,m,s 19 | 20 | class TimeFilter(logging.Filter): 21 | 22 | def filter(self, record): 23 | try: 24 | start = self.start 25 | except AttributeError: 26 | start = self.start = time.time() 27 | 28 | time_elapsed = get_time(time.time() - start) 29 | 30 | record.relative = "{0}:{1:02d}:{2:02d}".format(*time_elapsed) 31 | 32 | # self.last = record.relativeCreated/1000.0 33 | return True 34 | 35 | class Logger(object): 36 | def __init__(self, rank=0, log_dir=".log"): 37 | # other libraries may set logging before arriving at this line. 38 | # by reloading logging, we can get rid of previous configs set by other libraries. 39 | from importlib import reload 40 | reload(logging) 41 | self.rank = rank 42 | if self.rank == 0: 43 | os.makedirs(log_dir, exist_ok=True) 44 | 45 | log_file = open(os.path.join(log_dir, "log.txt"), "w") 46 | file_console = Console(file=log_file, width=150) 47 | logging.basicConfig( 48 | level=logging.INFO, 49 | format="(%(relative)s) %(message)s", 50 | datefmt="[%X]", 51 | force=True, 52 | handlers=[ 53 | RichHandler(show_path=False), 54 | RichHandler(console=file_console, show_path=False) 55 | ], 56 | ) 57 | # https://stackoverflow.com/questions/31521859/python-logging-module-time-since-last-log 58 | log = logging.getLogger() 59 | [hndl.addFilter(TimeFilter()) for hndl in log.handlers] 60 | 61 | def info(self, string, *args): 62 | if self.rank == 0: 63 | logging.info(string, *args) 64 | 65 | def warning(self, string, *args): 66 | if self.rank == 0: 67 | logging.warning(string, *args) 68 | 69 | def error(self, string, *args): 70 | if self.rank == 0: 71 | logging.error(string, *args) 72 | -------------------------------------------------------------------------------- /edm/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /edm/torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /edm/torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import re 9 | import contextlib 10 | import numpy as np 11 | import torch 12 | import warnings 13 | try: 14 | import dnnlib 15 | except: 16 | import edm.dnnlib 17 | 18 | #---------------------------------------------------------------------------- 19 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 20 | # same constant is used multiple times. 21 | 22 | _constant_cache = dict() 23 | 24 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 25 | value = np.asarray(value) 26 | if shape is not None: 27 | shape = tuple(shape) 28 | if dtype is None: 29 | dtype = torch.get_default_dtype() 30 | if device is None: 31 | device = torch.device('cpu') 32 | if memory_format is None: 33 | memory_format = torch.contiguous_format 34 | 35 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 36 | tensor = _constant_cache.get(key, None) 37 | if tensor is None: 38 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 39 | if shape is not None: 40 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 41 | tensor = tensor.contiguous(memory_format=memory_format) 42 | _constant_cache[key] = tensor 43 | return tensor 44 | 45 | #---------------------------------------------------------------------------- 46 | # Replace NaN/Inf with specified numerical values. 47 | 48 | try: 49 | nan_to_num = torch.nan_to_num # 1.8.0a0 50 | except AttributeError: 51 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 52 | assert isinstance(input, torch.Tensor) 53 | if posinf is None: 54 | posinf = torch.finfo(input.dtype).max 55 | if neginf is None: 56 | neginf = torch.finfo(input.dtype).min 57 | assert nan == 0 58 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 59 | 60 | #---------------------------------------------------------------------------- 61 | # Symbolic assert. 62 | 63 | try: 64 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 65 | except AttributeError: 66 | symbolic_assert = torch.Assert # 1.7.0 67 | 68 | #---------------------------------------------------------------------------- 69 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 70 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 71 | 72 | @contextlib.contextmanager 73 | def suppress_tracer_warnings(): 74 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 75 | warnings.filters.insert(0, flt) 76 | yield 77 | warnings.filters.remove(flt) 78 | 79 | #---------------------------------------------------------------------------- 80 | # Assert that the shape of a tensor matches the given list of integers. 81 | # None indicates that the size of a dimension is allowed to vary. 82 | # Performs symbolic assertion when used in torch.jit.trace(). 83 | 84 | def assert_shape(tensor, ref_shape): 85 | if tensor.ndim != len(ref_shape): 86 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 87 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 88 | if ref_size is None: 89 | pass 90 | elif isinstance(ref_size, torch.Tensor): 91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 92 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 93 | elif isinstance(size, torch.Tensor): 94 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 95 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 96 | elif size != ref_size: 97 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 98 | 99 | #---------------------------------------------------------------------------- 100 | # Function decorator that calls torch.autograd.profiler.record_function(). 101 | 102 | def profiled_function(fn): 103 | def decorator(*args, **kwargs): 104 | with torch.autograd.profiler.record_function(fn.__name__): 105 | return fn(*args, **kwargs) 106 | decorator.__name__ = fn.__name__ 107 | return decorator 108 | 109 | #---------------------------------------------------------------------------- 110 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 111 | # indefinitely, shuffling items as it goes. 112 | 113 | class InfiniteSampler(torch.utils.data.Sampler): 114 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 115 | assert len(dataset) > 0 116 | assert num_replicas > 0 117 | assert 0 <= rank < num_replicas 118 | assert 0 <= window_size <= 1 119 | super().__init__(dataset) 120 | self.dataset = dataset 121 | self.rank = rank 122 | self.num_replicas = num_replicas 123 | self.shuffle = shuffle 124 | self.seed = seed 125 | self.window_size = window_size 126 | 127 | def __iter__(self): 128 | order = np.arange(len(self.dataset)) 129 | rnd = None 130 | window = 0 131 | if self.shuffle: 132 | rnd = np.random.RandomState(self.seed) 133 | rnd.shuffle(order) 134 | window = int(np.rint(order.size * self.window_size)) 135 | 136 | idx = 0 137 | while True: 138 | i = idx % order.size 139 | if idx % self.num_replicas == self.rank: 140 | yield order[i] 141 | if window >= 2: 142 | j = (i - rnd.randint(window)) % order.size 143 | order[i], order[j] = order[j], order[i] 144 | idx += 1 145 | 146 | #---------------------------------------------------------------------------- 147 | # Utilities for operating with torch.nn.Module parameters and buffers. 148 | 149 | def params_and_buffers(module): 150 | assert isinstance(module, torch.nn.Module) 151 | return list(module.parameters()) + list(module.buffers()) 152 | 153 | def named_params_and_buffers(module): 154 | assert isinstance(module, torch.nn.Module) 155 | return list(module.named_parameters()) + list(module.named_buffers()) 156 | 157 | @torch.no_grad() 158 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 159 | assert isinstance(src_module, torch.nn.Module) 160 | assert isinstance(dst_module, torch.nn.Module) 161 | src_tensors = dict(named_params_and_buffers(src_module)) 162 | for name, tensor in named_params_and_buffers(dst_module): 163 | assert (name in src_tensors) or (not require_all) 164 | if name in src_tensors: 165 | tensor.copy_(src_tensors[name]) 166 | 167 | #---------------------------------------------------------------------------- 168 | # Context manager for easily enabling/disabling DistributedDataParallel 169 | # synchronization. 170 | 171 | @contextlib.contextmanager 172 | def ddp_sync(module, sync): 173 | assert isinstance(module, torch.nn.Module) 174 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 175 | yield 176 | else: 177 | with module.no_sync(): 178 | yield 179 | 180 | #---------------------------------------------------------------------------- 181 | # Check DistributedDataParallel consistency across processes. 182 | 183 | def check_ddp_consistency(module, ignore_regex=None): 184 | assert isinstance(module, torch.nn.Module) 185 | for name, tensor in named_params_and_buffers(module): 186 | fullname = type(module).__name__ + '.' + name 187 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 188 | continue 189 | tensor = tensor.detach() 190 | if tensor.is_floating_point(): 191 | tensor = nan_to_num(tensor) 192 | other = tensor.clone() 193 | torch.distributed.broadcast(tensor=other, src=0) 194 | assert (tensor == other).all(), fullname 195 | 196 | #---------------------------------------------------------------------------- 197 | # Print summary table of module hierarchy. 198 | 199 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 200 | assert isinstance(module, torch.nn.Module) 201 | assert not isinstance(module, torch.jit.ScriptModule) 202 | assert isinstance(inputs, (tuple, list)) 203 | 204 | # Register hooks. 205 | entries = [] 206 | nesting = [0] 207 | def pre_hook(_mod, _inputs): 208 | nesting[0] += 1 209 | def post_hook(mod, _inputs, outputs): 210 | nesting[0] -= 1 211 | if nesting[0] <= max_nesting: 212 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 213 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 214 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 215 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 216 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 217 | 218 | # Run module. 219 | outputs = module(*inputs) 220 | for hook in hooks: 221 | hook.remove() 222 | 223 | # Identify unique outputs, parameters, and buffers. 224 | tensors_seen = set() 225 | for e in entries: 226 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 227 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 228 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 229 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 230 | 231 | # Filter out redundant entries. 232 | if skip_redundant: 233 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 234 | 235 | # Construct table. 236 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 237 | rows += [['---'] * len(rows[0])] 238 | param_total = 0 239 | buffer_total = 0 240 | submodule_names = {mod: name for name, mod in module.named_modules()} 241 | for e in entries: 242 | name = '' if e.mod is module else submodule_names[e.mod] 243 | param_size = sum(t.numel() for t in e.unique_params) 244 | buffer_size = sum(t.numel() for t in e.unique_buffers) 245 | output_shapes = [str(list(t.shape)) for t in e.outputs] 246 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 247 | rows += [[ 248 | name + (':0' if len(e.outputs) >= 2 else ''), 249 | str(param_size) if param_size else '-', 250 | str(buffer_size) if buffer_size else '-', 251 | (output_shapes + ['-'])[0], 252 | (output_dtypes + ['-'])[0], 253 | ]] 254 | for idx in range(1, len(e.outputs)): 255 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 256 | param_total += param_size 257 | buffer_total += buffer_size 258 | rows += [['---'] * len(rows[0])] 259 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 260 | 261 | # Print table. 262 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 263 | print() 264 | for row in rows: 265 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 266 | print() 267 | return outputs 268 | 269 | #---------------------------------------------------------------------------- 270 | -------------------------------------------------------------------------------- /edm/torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for pickling Python code alongside other data. 9 | 10 | The pickled code is automatically imported into a separate Python module 11 | during unpickling. This way, any previously exported pickles will remain 12 | usable even if the original code is no longer available, or if the current 13 | version of the code is not consistent with what was originally pickled.""" 14 | 15 | import sys 16 | import pickle 17 | import io 18 | import inspect 19 | import copy 20 | import uuid 21 | import types 22 | import sys 23 | import os 24 | sys.path.append(os.path.abspath('edm')) 25 | import dnnlib 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | _version = 6 # internal version number 30 | _decorators = set() # {decorator_class, ...} 31 | _import_hooks = [] # [hook_function, ...] 32 | _module_to_src_dict = dict() # {module: src, ...} 33 | _src_to_module_dict = dict() # {src: module, ...} 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def persistent_class(orig_class): 38 | r"""Class decorator that extends a given class to save its source code 39 | when pickled. 40 | 41 | Example: 42 | 43 | from torch_utils import persistence 44 | 45 | @persistence.persistent_class 46 | class MyNetwork(torch.nn.Module): 47 | def __init__(self, num_inputs, num_outputs): 48 | super().__init__() 49 | self.fc = MyLayer(num_inputs, num_outputs) 50 | ... 51 | 52 | @persistence.persistent_class 53 | class MyLayer(torch.nn.Module): 54 | ... 55 | 56 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 57 | source code alongside other internal state (e.g., parameters, buffers, 58 | and submodules). This way, any previously exported pickle will remain 59 | usable even if the class definitions have been modified or are no 60 | longer available. 61 | 62 | The decorator saves the source code of the entire Python module 63 | containing the decorated class. It does *not* save the source code of 64 | any imported modules. Thus, the imported modules must be available 65 | during unpickling, also including `torch_utils.persistence` itself. 66 | 67 | It is ok to call functions defined in the same module from the 68 | decorated class. However, if the decorated class depends on other 69 | classes defined in the same module, they must be decorated as well. 70 | This is illustrated in the above example in the case of `MyLayer`. 71 | 72 | It is also possible to employ the decorator just-in-time before 73 | calling the constructor. For example: 74 | 75 | cls = MyLayer 76 | if want_to_make_it_persistent: 77 | cls = persistence.persistent_class(cls) 78 | layer = cls(num_inputs, num_outputs) 79 | 80 | As an additional feature, the decorator also keeps track of the 81 | arguments that were used to construct each instance of the decorated 82 | class. The arguments can be queried via `obj.init_args` and 83 | `obj.init_kwargs`, and they are automatically pickled alongside other 84 | object state. This feature can be disabled on a per-instance basis 85 | by setting `self._record_init_args = False` in the constructor. 86 | 87 | A typical use case is to first unpickle a previous instance of a 88 | persistent class, and then upgrade it to use the latest version of 89 | the source code: 90 | 91 | with open('old_pickle.pkl', 'rb') as f: 92 | old_net = pickle.load(f) 93 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 94 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 95 | """ 96 | assert isinstance(orig_class, type) 97 | if is_persistent(orig_class): 98 | return orig_class 99 | 100 | assert orig_class.__module__ in sys.modules 101 | orig_module = sys.modules[orig_class.__module__] 102 | orig_module_src = _module_to_src(orig_module) 103 | 104 | class Decorator(orig_class): 105 | _orig_module_src = orig_module_src 106 | _orig_class_name = orig_class.__name__ 107 | 108 | def __init__(self, *args, **kwargs): 109 | super().__init__(*args, **kwargs) 110 | record_init_args = getattr(self, '_record_init_args', True) 111 | self._init_args = copy.deepcopy(args) if record_init_args else None 112 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 113 | assert orig_class.__name__ in orig_module.__dict__ 114 | _check_pickleable(self.__reduce__()) 115 | 116 | @property 117 | def init_args(self): 118 | assert self._init_args is not None 119 | return copy.deepcopy(self._init_args) 120 | 121 | @property 122 | def init_kwargs(self): 123 | assert self._init_kwargs is not None 124 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 125 | 126 | def __reduce__(self): 127 | fields = list(super().__reduce__()) 128 | fields += [None] * max(3 - len(fields), 0) 129 | if fields[0] is not _reconstruct_persistent_obj: 130 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 131 | fields[0] = _reconstruct_persistent_obj # reconstruct func 132 | fields[1] = (meta,) # reconstruct args 133 | fields[2] = None # state dict 134 | return tuple(fields) 135 | 136 | Decorator.__name__ = orig_class.__name__ 137 | Decorator.__module__ = orig_class.__module__ 138 | _decorators.add(Decorator) 139 | return Decorator 140 | 141 | #---------------------------------------------------------------------------- 142 | 143 | def is_persistent(obj): 144 | r"""Test whether the given object or class is persistent, i.e., 145 | whether it will save its source code when pickled. 146 | """ 147 | try: 148 | if obj in _decorators: 149 | return True 150 | except TypeError: 151 | pass 152 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 153 | 154 | #---------------------------------------------------------------------------- 155 | 156 | def import_hook(hook): 157 | r"""Register an import hook that is called whenever a persistent object 158 | is being unpickled. A typical use case is to patch the pickled source 159 | code to avoid errors and inconsistencies when the API of some imported 160 | module has changed. 161 | 162 | The hook should have the following signature: 163 | 164 | hook(meta) -> modified meta 165 | 166 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 167 | 168 | type: Type of the persistent object, e.g. `'class'`. 169 | version: Internal version number of `torch_utils.persistence`. 170 | module_src Original source code of the Python module. 171 | class_name: Class name in the original Python module. 172 | state: Internal state of the object. 173 | 174 | Example: 175 | 176 | @persistence.import_hook 177 | def wreck_my_network(meta): 178 | if meta.class_name == 'MyNetwork': 179 | print('MyNetwork is being imported. I will wreck it!') 180 | meta.module_src = meta.module_src.replace("True", "False") 181 | return meta 182 | """ 183 | assert callable(hook) 184 | _import_hooks.append(hook) 185 | 186 | #---------------------------------------------------------------------------- 187 | 188 | def _reconstruct_persistent_obj(meta): 189 | r"""Hook that is called internally by the `pickle` module to unpickle 190 | a persistent object. 191 | """ 192 | meta = dnnlib.EasyDict(meta) 193 | meta.state = dnnlib.EasyDict(meta.state) 194 | for hook in _import_hooks: 195 | meta = hook(meta) 196 | assert meta is not None 197 | 198 | assert meta.version == _version 199 | module = _src_to_module(meta.module_src) 200 | 201 | assert meta.type == 'class' 202 | orig_class = module.__dict__[meta.class_name] 203 | decorator_class = persistent_class(orig_class) 204 | obj = decorator_class.__new__(decorator_class) 205 | 206 | setstate = getattr(obj, '__setstate__', None) 207 | if callable(setstate): 208 | setstate(meta.state) # pylint: disable=not-callable 209 | else: 210 | obj.__dict__.update(meta.state) 211 | return obj 212 | 213 | #---------------------------------------------------------------------------- 214 | 215 | def _module_to_src(module): 216 | r"""Query the source code of a given Python module. 217 | """ 218 | src = _module_to_src_dict.get(module, None) 219 | if src is None: 220 | src = inspect.getsource(module) 221 | _module_to_src_dict[module] = src 222 | _src_to_module_dict[src] = module 223 | return src 224 | 225 | def _src_to_module(src): 226 | r"""Get or create a Python module for the given source code. 227 | """ 228 | module = _src_to_module_dict.get(src, None) 229 | # if module is None: 230 | # module_name = "_imported_module_" + uuid.uuid4().hex 231 | # module = types.ModuleType(module_name) 232 | # sys.modules[module_name] = module 233 | # _module_to_src_dict[module] = src 234 | # _src_to_module_dict[src] = module 235 | # exec(src, module.__dict__) # pylint: disable=exec-used 236 | return module 237 | 238 | #---------------------------------------------------------------------------- 239 | 240 | def _check_pickleable(obj): 241 | r"""Check that the given object is pickleable, raising an exception if 242 | it is not. This function is expected to be considerably more efficient 243 | than actually pickling the object. 244 | """ 245 | def recurse(obj): 246 | if isinstance(obj, (list, tuple, set)): 247 | return [recurse(x) for x in obj] 248 | if isinstance(obj, dict): 249 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 250 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 251 | return None # Python primitive types are pickleable. 252 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 253 | return None # NumPy arrays and PyTorch tensors are pickleable. 254 | if is_persistent(obj): 255 | return None # Persistent objects are pickleable, by virtue of the constructor check. 256 | return obj 257 | with io.BytesIO() as f: 258 | pickle.dump(recurse(obj), f) 259 | 260 | #---------------------------------------------------------------------------- 261 | -------------------------------------------------------------------------------- /edm/torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for reporting and collecting training statistics across 9 | multiple processes and devices. The interface is designed to minimize 10 | synchronization overhead as well as the amount of boilerplate in user 11 | code.""" 12 | 13 | import re 14 | import numpy as np 15 | import torch 16 | try: 17 | import edm.dnnlib 18 | except: 19 | import dnnlib 20 | 21 | from . import misc 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 26 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 27 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 28 | _rank = 0 # Rank of the current process. 29 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 30 | _sync_called = False # Has _sync() been called yet? 31 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 32 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def init_multiprocessing(rank, sync_device): 37 | r"""Initializes `torch_utils.training_stats` for collecting statistics 38 | across multiple processes. 39 | 40 | This function must be called after 41 | `torch.distributed.init_process_group()` and before `Collector.update()`. 42 | The call is not necessary if multi-process collection is not needed. 43 | 44 | Args: 45 | rank: Rank of the current process. 46 | sync_device: PyTorch device to use for inter-process 47 | communication, or None to disable multi-process 48 | collection. Typically `torch.device('cuda', rank)`. 49 | """ 50 | global _rank, _sync_device 51 | assert not _sync_called 52 | _rank = rank 53 | _sync_device = sync_device 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | @misc.profiled_function 58 | def report(name, value): 59 | r"""Broadcasts the given set of scalars to all interested instances of 60 | `Collector`, across device and process boundaries. 61 | 62 | This function is expected to be extremely cheap and can be safely 63 | called from anywhere in the training loop, loss function, or inside a 64 | `torch.nn.Module`. 65 | 66 | Warning: The current implementation expects the set of unique names to 67 | be consistent across processes. Please make sure that `report()` is 68 | called at least once for each unique name by each process, and in the 69 | same order. If a given process has no scalars to broadcast, it can do 70 | `report(name, [])` (empty list). 71 | 72 | Args: 73 | name: Arbitrary string specifying the name of the statistic. 74 | Averages are accumulated separately for each unique name. 75 | value: Arbitrary set of scalars. Can be a list, tuple, 76 | NumPy array, PyTorch tensor, or Python scalar. 77 | 78 | Returns: 79 | The same `value` that was passed in. 80 | """ 81 | if name not in _counters: 82 | _counters[name] = dict() 83 | 84 | elems = torch.as_tensor(value) 85 | if elems.numel() == 0: 86 | return value 87 | 88 | elems = elems.detach().flatten().to(_reduce_dtype) 89 | moments = torch.stack([ 90 | torch.ones_like(elems).sum(), 91 | elems.sum(), 92 | elems.square().sum(), 93 | ]) 94 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 95 | moments = moments.to(_counter_dtype) 96 | 97 | device = moments.device 98 | if device not in _counters[name]: 99 | _counters[name][device] = torch.zeros_like(moments) 100 | _counters[name][device].add_(moments) 101 | return value 102 | 103 | #---------------------------------------------------------------------------- 104 | 105 | def report0(name, value): 106 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 107 | but ignores any scalars provided by the other processes. 108 | See `report()` for further details. 109 | """ 110 | report(name, value if _rank == 0 else []) 111 | return value 112 | 113 | #---------------------------------------------------------------------------- 114 | 115 | class Collector: 116 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 117 | computes their long-term averages (mean and standard deviation) over 118 | user-defined periods of time. 119 | 120 | The averages are first collected into internal counters that are not 121 | directly visible to the user. They are then copied to the user-visible 122 | state as a result of calling `update()` and can then be queried using 123 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 124 | internal counters for the next round, so that the user-visible state 125 | effectively reflects averages collected between the last two calls to 126 | `update()`. 127 | 128 | Args: 129 | regex: Regular expression defining which statistics to 130 | collect. The default is to collect everything. 131 | keep_previous: Whether to retain the previous averages if no 132 | scalars were collected on a given round 133 | (default: True). 134 | """ 135 | def __init__(self, regex='.*', keep_previous=True): 136 | self._regex = re.compile(regex) 137 | self._keep_previous = keep_previous 138 | self._cumulative = dict() 139 | self._moments = dict() 140 | self.update() 141 | self._moments.clear() 142 | 143 | def names(self): 144 | r"""Returns the names of all statistics broadcasted so far that 145 | match the regular expression specified at construction time. 146 | """ 147 | return [name for name in _counters if self._regex.fullmatch(name)] 148 | 149 | def update(self): 150 | r"""Copies current values of the internal counters to the 151 | user-visible state and resets them for the next round. 152 | 153 | If `keep_previous=True` was specified at construction time, the 154 | operation is skipped for statistics that have received no scalars 155 | since the last update, retaining their previous averages. 156 | 157 | This method performs a number of GPU-to-CPU transfers and one 158 | `torch.distributed.all_reduce()`. It is intended to be called 159 | periodically in the main training loop, typically once every 160 | N training steps. 161 | """ 162 | if not self._keep_previous: 163 | self._moments.clear() 164 | for name, cumulative in _sync(self.names()): 165 | if name not in self._cumulative: 166 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 167 | delta = cumulative - self._cumulative[name] 168 | self._cumulative[name].copy_(cumulative) 169 | if float(delta[0]) != 0: 170 | self._moments[name] = delta 171 | 172 | def _get_delta(self, name): 173 | r"""Returns the raw moments that were accumulated for the given 174 | statistic between the last two calls to `update()`, or zero if 175 | no scalars were collected. 176 | """ 177 | assert self._regex.fullmatch(name) 178 | if name not in self._moments: 179 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 180 | return self._moments[name] 181 | 182 | def num(self, name): 183 | r"""Returns the number of scalars that were accumulated for the given 184 | statistic between the last two calls to `update()`, or zero if 185 | no scalars were collected. 186 | """ 187 | delta = self._get_delta(name) 188 | return int(delta[0]) 189 | 190 | def mean(self, name): 191 | r"""Returns the mean of the scalars that were accumulated for the 192 | given statistic between the last two calls to `update()`, or NaN if 193 | no scalars were collected. 194 | """ 195 | delta = self._get_delta(name) 196 | if int(delta[0]) == 0: 197 | return float('nan') 198 | return float(delta[1] / delta[0]) 199 | 200 | def std(self, name): 201 | r"""Returns the standard deviation of the scalars that were 202 | accumulated for the given statistic between the last two calls to 203 | `update()`, or NaN if no scalars were collected. 204 | """ 205 | delta = self._get_delta(name) 206 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 207 | return float('nan') 208 | if int(delta[0]) == 1: 209 | return float(0) 210 | mean = float(delta[1] / delta[0]) 211 | raw_var = float(delta[2] / delta[0]) 212 | return np.sqrt(max(raw_var - np.square(mean), 0)) 213 | 214 | def as_dict(self): 215 | r"""Returns the averages accumulated between the last two calls to 216 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 217 | 218 | dnnlib.EasyDict( 219 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 220 | ... 221 | ) 222 | """ 223 | stats = dnnlib.EasyDict() 224 | for name in self.names(): 225 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 226 | return stats 227 | 228 | def __getitem__(self, name): 229 | r"""Convenience getter. 230 | `collector[name]` is a synonym for `collector.mean(name)`. 231 | """ 232 | return self.mean(name) 233 | 234 | #---------------------------------------------------------------------------- 235 | 236 | def _sync(names): 237 | r"""Synchronize the global cumulative counters across devices and 238 | processes. Called internally by `Collector.update()`. 239 | """ 240 | if len(names) == 0: 241 | return [] 242 | global _sync_called 243 | _sync_called = True 244 | 245 | # Collect deltas within current rank. 246 | deltas = [] 247 | device = _sync_device if _sync_device is not None else torch.device('cpu') 248 | for name in names: 249 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 250 | for counter in _counters[name].values(): 251 | delta.add_(counter.to(device)) 252 | counter.copy_(torch.zeros_like(counter)) 253 | deltas.append(delta) 254 | deltas = torch.stack(deltas) 255 | 256 | # Sum deltas across ranks. 257 | if _sync_device is not None: 258 | torch.distributed.all_reduce(deltas) 259 | 260 | # Update cumulative values. 261 | deltas = deltas.cpu() 262 | for idx, name in enumerate(names): 263 | if name not in _cumulative: 264 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 265 | _cumulative[name].add_(deltas[idx]) 266 | 267 | # Return name-value pairs. 268 | return [(name, _cumulative[name]) for name in names] 269 | 270 | #---------------------------------------------------------------------------- 271 | # Convenience. 272 | 273 | default_collector = Collector() 274 | 275 | #---------------------------------------------------------------------------- 276 | -------------------------------------------------------------------------------- /networks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/networks/.DS_Store -------------------------------------------------------------------------------- /networks/get_network.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from networks.network import * 6 | from networks.edm.ncsnpp import SongUNet,DhariwalUNet 7 | import torch 8 | # from .util import Ltvv, Ltxx,Ltxv,reshape_as 9 | def get_nn(opt,dyn): 10 | net={ 11 | 'toy': ResNet, 12 | 'cifar10': SongUNet, 13 | 'AFHQv2': SongUNet, 14 | 'imagenet64':DhariwalUNet 15 | }.get(opt.exp) 16 | return network_wrapper(opt,net(opt),dyn) 17 | 18 | class network_wrapper(torch.nn.Module): 19 | # note: scale_by_g matters only for pre-trained model 20 | def __init__(self, opt, net,dyn): 21 | super(network_wrapper,self).__init__() 22 | self.opt = opt 23 | self.net = net 24 | self.dim = self.opt.data_dim 25 | self.varx= opt.varx 26 | self.varv= opt.varv 27 | self.dyn= dyn 28 | self.p = opt.p 29 | 30 | def get_precond(self,m,t): 31 | precond = (1-t) 32 | precond =precond.reshape(-1,*([1,]*(len(m.shape)-1))) 33 | return precond 34 | 35 | 36 | def forward(self, m, t,cond=None): 37 | t = t.squeeze() 38 | if t.dim()==0: t = t.repeat(m.shape[0]) 39 | assert t.dim()==1 and t.shape[0] == m.shape[0] 40 | precond = 1 41 | if self.opt.precond: precond =self.get_precond(m,t) 42 | out = precond*self.net(m, 1-t,class_labels=cond) #Flip the time because we are generating image at t=1. It will influence the time cond (log t). 43 | return out 44 | -------------------------------------------------------------------------------- /networks/network.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import torch.nn as nn 7 | 8 | def zero_module(module): 9 | """ 10 | Zero out the parameters of a module and return it. 11 | """ 12 | for p in module.parameters(): 13 | p.detach().zero_() 14 | return module 15 | 16 | class ResNet(nn.Module): 17 | def __init__(self, 18 | opt, 19 | input_dim=2, 20 | index_dim=1, 21 | hidden_dim=128, 22 | n_hidden_layers=20): 23 | 24 | super().__init__() 25 | 26 | self.act = nn.SiLU() 27 | self.n_hidden_layers = n_hidden_layers 28 | 29 | self.x_input = True # input is concat [x,v] or just x 30 | if self.x_input: 31 | in_dim = input_dim * 2 + index_dim 32 | else: 33 | in_dim = input_dim + index_dim 34 | out_dim = input_dim 35 | 36 | layers = [] 37 | layers.append(nn.Linear(in_dim, hidden_dim)) 38 | for _ in range(n_hidden_layers): 39 | layers.append(nn.Linear(hidden_dim + index_dim, hidden_dim)) 40 | layers.append(nn.Linear(hidden_dim + index_dim, out_dim)) 41 | 42 | self.layers = nn.ModuleList(layers) 43 | self.layers[-1] = zero_module(self.layers[-1]) 44 | def _append_time(self, h, t): 45 | time_embedding = torch.log(t) 46 | return torch.cat([h, time_embedding.reshape(-1, 1)], dim=1) 47 | 48 | def forward(self, u, t,class_labels=None): 49 | h0 = self.layers[0](self._append_time(u, t)) 50 | h = self.act(h0) 51 | 52 | for i in range(self.n_hidden_layers): 53 | h_new = self.layers[i + 1](self._append_time(h, t)) 54 | h = self.act(h + h_new) 55 | 56 | return self.layers[-1](self._append_time(h, t)) 57 | -------------------------------------------------------------------------------- /plot_util.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import matplotlib.pyplot as plt 6 | import os 7 | import numpy as np 8 | import torch 9 | def save_toy_traj(opt, fn, traj): 10 | fn_pdf = os.path.join(opt.ckpt_path, fn+'.pdf') 11 | n_snapshot=2 12 | lims = [[-4, 4],[-6, 6]] 13 | 14 | total_steps = traj.shape[1] 15 | sample_steps= np.linspace(0, total_steps-1, n_snapshot).astype(int) 16 | traj_steps = np.linspace(0, total_steps-1, 10).astype(int) 17 | if n_snapshot is None: # only store t=0 18 | plt.scatter(traj[:,0,0],traj[:,0,1], s=10) 19 | plt.xlim(*lims) 20 | plt.ylim(*lims) 21 | else: 22 | fig, axss = plt.subplots(1, 2) 23 | fig.set_size_inches(20, 10) 24 | cmap =['Blues','Reds'] 25 | colors= ['b','r'] 26 | num_samp_lines = 10 27 | random_idx = np.random.choice(traj.shape[0], num_samp_lines, replace=False) 28 | means=traj[random_idx,...] 29 | for i in range(2): 30 | ax=axss[i] 31 | ax.grid(True) 32 | ax.xaxis.set_ticklabels([]) 33 | ax.yaxis.set_ticklabels([]) 34 | _colors = np.linspace(0.5,1,len(sample_steps)) 35 | for idx,step in enumerate(sample_steps): 36 | ax.scatter(traj[:,step,2*i],traj[:,step,2*i+1], s=10, c=_colors[idx].repeat(traj.shape[0]), alpha=0.6,vmin=0, vmax=1,cmap=cmap[i]) 37 | ax.set_xlim(*lims[i]) 38 | ax.set_ylim(*lims[i]) 39 | 40 | for ii in range(num_samp_lines): 41 | ax.plot(means[ii,:,2*i],means[ii,:,2*i+1],color=colors[i],linewidth=4,alpha=0.5) 42 | ax.set_title('position' if i==0 else 'velocity',size=40) 43 | 44 | fig.suptitle('NFE = {}'.format(opt.nfe-1),size=40) 45 | fig.tight_layout() 46 | plt.savefig(fn_pdf) 47 | plt.clf() 48 | 49 | def save_snapshot_traj(opt, fn, pred,gt): 50 | fn_pdf = os.path.join(opt.ckpt_path, fn+'_static.pdf') 51 | lims = [-4, 4] 52 | gt=gt.detach().cpu().numpy() 53 | pred=pred.detach().cpu().numpy() 54 | fig, axss = plt.subplots(1, 2) 55 | fig.set_size_inches(20, 10) 56 | colors= ['b','r'] 57 | ax=axss[0] 58 | ax.scatter(pred[:,0],pred[:,1],color='steelblue',alpha=0.3,s=5) 59 | ax.set_xlim(*lims) 60 | ax.set_ylim(*lims) 61 | ax=axss[1] 62 | ax.scatter(gt[:,0],gt[:,1],color='coral',alpha=0.3,s=5) 63 | ax.set_xlim(*lims) 64 | ax.set_ylim(*lims) 65 | fig.suptitle('NFE = {}'.format(opt.nfe),size=40) 66 | fig.tight_layout() 67 | plt.savefig(fn_pdf) 68 | # np.save(fn_npy,traj) 69 | plt.clf() 70 | 71 | 72 | def norm_data(x): 73 | bs=x.shape[0] 74 | _max=torch.max(torch.max(x,dim=-1)[0],dim=-1)[0][...,None,None] 75 | _min=torch.min(torch.min(x,dim=-1)[0],dim=-1)[0][...,None,None] 76 | x=(x-_min)/(_max-_min) 77 | return x 78 | 79 | 80 | def plot_toy(opt,ms,it,pred_m1,x1): 81 | save_toy_traj(opt, 'itr_{}'.format(it), ms.detach().cpu().numpy()) 82 | save_snapshot_traj(opt, 'itr_x{}'.format(it), pred_m1[:,0:2],x1) 83 | 84 | 85 | def plot_scatter(x,ts,ax): 86 | ''' 87 | x:bs,t,dim 88 | ''' 89 | bs,interval,dim = x.shape 90 | for ii in range(interval): 91 | ax.scatter(ts[ii].repeat(bs),x[:,ii,:],s=2,color='b',alpha=0.1) 92 | 93 | def plot_plt(x,ts,ax): 94 | ''' 95 | x:bs,t,dim 96 | ''' 97 | bs,interval,dim = x.shape 98 | 99 | for ii in range(bs): 100 | ax.plot(ts,x[ii,:,0],color='r',alpha=0.1) 101 | 102 | def save_toy_npy_traj(opt, fn, traj, n_snapshot=None, direction='forward'): 103 | #form of traj: [bs, interval, x_dim=2] 104 | fn_pdf = os.path.join(opt.ckpt_path, fn+'.pdf') 105 | 106 | lims = [-5,5] 107 | 108 | if n_snapshot is None: # only store t=0 109 | plt.scatter(traj[:,0,0],traj[:,0,1], s=5) 110 | plt.xlim(*lims) 111 | plt.ylim(*lims) 112 | else: 113 | total_steps = traj.shape[1] 114 | sample_steps = np.linspace(0, total_steps-1, n_snapshot).astype(int) 115 | fig, axs = plt.subplots(1, n_snapshot) 116 | fig.set_size_inches(n_snapshot*6, 6) 117 | color = 'salmon' if direction=='forward' else 'royalblue' 118 | for ax, step in zip(axs, sample_steps): 119 | ax.scatter(traj[:,step,0],traj[:,step,1], s=1, color=color,alpha=0.2) 120 | ax.set_xlim(*lims) 121 | ax.set_ylim(*lims) 122 | ax.set_title('time = {:.2f}'.format(step/(total_steps-1)*opt.T)) 123 | fig.tight_layout() 124 | 125 | plt.savefig(fn_pdf) 126 | plt.clf() -------------------------------------------------------------------------------- /scripts/AFHQv2.sh: -------------------------------------------------------------------------------- 1 | # python sampling.py --n-gpu-per-node 8 --batch-size 1250 --ckpt AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/latest.pt --num-sample 50000 --sampling sscs --pred-x1 --nfe 20 2 | # torchrun --standalone --nproc_per_node=1 fid.py calc --images=results/AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/fid_sample_folder/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz 3 | 4 | python sampling.py --n-gpu-per-node 8 --batch-size 1250 --ckpt AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/latest.pt --num-sample 50000 --sampling sscs --pred-x1 --nfe 50 5 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=results/AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/fid_sample_folder/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz 6 | 7 | python sampling.py --n-gpu-per-node 8 --batch-size 1250 --ckpt AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/latest.pt --num-sample 50000 --sampling sscs --pred-x1 --nfe 100 8 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=results/AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/fid_sample_folder/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz 9 | 10 | python sampling.py --n-gpu-per-node 8 --batch-size 1250 --ckpt AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/latest.pt --num-sample 50000 --sampling sscs --pred-x1 --nfe 150 11 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=results/AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/fid_sample_folder/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz 12 | 13 | python sampling.py --n-gpu-per-node 8 --batch-size 1250 --ckpt AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/latest.pt --num-sample 50000 --sampling sscs --pred-x1 --nfe 500 14 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=results/AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/fid_sample_folder/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz 15 | -------------------------------------------------------------------------------- /scripts/cifar10.sh: -------------------------------------------------------------------------------- 1 | python train.py --name ode/varv1-0.99 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 1 --k 0.99 --sampling vanillaODE --reweight reciprocal --ode 2 | -------------------------------------------------------------------------------- /scripts/example.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | #Cifar10 6 | python sampling.py --n-gpu-per-node 1 --ckpt Cifar10-ODE/latest.pt --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name cifar10-nfe20 --num-sample 64 --batch-size 64 --save-img --img-save-name cifar10-nfe20 7 | #AFHQv2 8 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQv2-ODE/latest.pt --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name AFHQv2-nfe20 --num-sample 64 --batch-size 64 --save-img --img-save-name AFHQv2-nfe20 9 | 10 | #AFHQv2 Conditional Generation 11 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQv2-ODE/latest.pt --pred-x1 --solver gDDIM --save-img --img-save-name cond-AFHQv2 --nfe 100 --T 0.999 --num-sample 64 --batch-size 64 --stroke-path dataset/StrokeData/testFig0.png --stroke-type dyn-v # you can also replace it by init-v 12 | 13 | #Imagenet64 14 | #NFE 20 FID=10.55 15 | python sampling.py --n-gpu-per-node 1 --ckpt uncond-ImageNet64-ODE/latest.pt --pred-x1 --solver gDDIM --save-img --img-save-name imagenet-nfe20 --fid-save-name ImageNet64-nfe20 --nfe 20 --T 0.99 --num-sample 64 --batch-size 64 16 | 17 | -------------------------------------------------------------------------------- /scripts/imagenet64.sh: -------------------------------------------------------------------------------- 1 | python train.py --name imagenet-uniform-recip-precond-bs512-varv1-varx1-k0.2-p3-lr1e3-ADM-newlabel-probode --varv 1 --varx 1 --k 0.2 --p 3 --microbatch 64 --n-gpu-per-node 8 --lr 1e-3 --exp imagenet64 --t-samp uniform --precond --reweight-type reciprocal --probablistic-ode --sampling vanillaODE 2 | -------------------------------------------------------------------------------- /scripts/release.sh: -------------------------------------------------------------------------------- 1 | #**********Toy****************** 2 | #SSS Stochastic dynamics 3 | #train 4 | python train.py --name train-toy/sde-spiral --exp toy --toy-exp spiral 5 | #sampling 6 | python train.py --name train-toy/sde-spiral-eval --exp toy --toy-exp spiral --eval --ckpt train-toy/sde-spiral --nfe 10 7 | #Exponential Integrator ode dynamics 8 | # train 9 | python train.py --name train-toy/ode-spiral --exp toy --toy-exp spiral --DE-type probODE --solver gDDIM 10 | #Sampling 11 | python train.py --name train-toy/ode-spiral-eval --exp toy --toy-exp spiral --DE-type probODE --solver gDDIM --nfe 10 --eval --ckpt train-toy/ode-spiral/latest.pt 12 | #**************Cifar10************* 13 | #NFE=5 FID=11.88 14 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt Remote_Cifar10_ODE --pred-x1 --solver gDDIM --T 0.4 --nfe 5 --fid-save-name cifar10-nfe5 --port 6024 --num-sample 50000 --batch-size 1000 15 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe5 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 16 | 17 | #NFE=10 FID=4.54 18 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt Remote_Cifar10_ODE --pred-x1 --solver gDDIM --T 0.7 --nfe 10 --fid-save-name cifar10-nfe10 --port 6024 --num-sample 50000 --batch-size 1000 19 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe10 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 20 | 21 | #NFE=20 FID=2.58 22 | CUDA_VISIBLE_DEVICES=0 python sampling.py --n-gpu-per-node 1 --ckpt Remote_Cifar10_ODE --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name cifar10-nfe20 --port 6024 --num-sample 50000 --batch-size 1000 23 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe20 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 24 | 25 | #**************AFHQv2************* 26 | #Conditional Generation 27 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQ-cond/uuysqcynde/latest.pt --solver gDDIM --save-img --img-save-name cond-test0129 --nfe 100 --T 0.999 --num-sample 8 --batch-size 8 --stroke-path dataset/StrokeData/testFig0.png --stroke-type dyn-v 28 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQ-cond/uuysqcynde/latest.pt --solver gDDIM --save-img --img-save-name cond-test0129 --nfe 100 --T 0.999 --num-sample 8 --batch-size 8 --stroke-path dataset/StrokeData/testFig1.png --stroke-type dyn-v 29 | #Impainting Generation 30 | python sampling.py --n-gpu-per-node 1 --ckpt Remote_AFHQv2_ODE --pred-x1 --solver gDDIM --save-img --img-save-name impaint-AFHQv2 --nfe 100 --T 0.999 --num-sample 64 --batch-size 64 --stroke-type dyn-v --impainting --stroke-path dataset/StrokeData/testFig0_impainting.png 31 | python sampling.py --n-gpu-per-node 1 --ckpt Remote_AFHQv2_ODE --pred-x1 --solver gDDIM --save-img --img-save-name impaint-AFHQv2 --nfe 100 --T 0.999 --num-sample 64 --batch-size 64 --stroke-type dyn-v --impainting --stroke-path dataset/StrokeData/testFig1_impainting.png 32 | 33 | #NFE=5 34 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt Remote_Cifar10_ODE --pred-x1 --solver gDDIM --T 0.4 --nfe 5 --fid-save-name cifar10-nfe5 --port 6024 --num-sample 50000 --batch-size 1000 35 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=FID_EVAL/cifar10-nfe5 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 36 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe5 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 37 | 38 | #NFE=10 39 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt cifar10_ODE/latest.pt --pred-x1 --solver gDDIM --T 0.7 --nfe 10 --fid-save-name cifar10-nfe10 --port 6024 --num-sample 50000 --batch-size 1000 40 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe10 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 41 | 42 | #NFE=20 43 | CUDA_VISIBLE_DEVICES=0 python sampling.py --n-gpu-per-node 1 --ckpt cifar10_ODE/latest.pt --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name cifar10-nfe20 --port 6024 --num-sample 50000 --batch-size 1000 44 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe20 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 45 | 46 | #**************ImageNet************* 47 | #NFE=5 48 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt Remote_Cifar10_ODE --pred-x1 --solver gDDIM --T 0.4 --nfe 5 --fid-save-name cifar10-nfe5 --port 6024 --num-sample 50000 --batch-size 1000 49 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=FID_EVAL/cifar10-nfe5 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 50 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe5 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 51 | 52 | #NFE=10 53 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt cifar10_ODE/latest.pt --pred-x1 --solver gDDIM --T 0.7 --nfe 10 --fid-save-name cifar10-nfe10 --port 6024 --num-sample 50000 --batch-size 1000 54 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe10 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 55 | 56 | #NFE=20 57 | CUDA_VISIBLE_DEVICES=0 python sampling.py --n-gpu-per-node 1 --ckpt cifar10_ODE/latest.pt --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name cifar10-nfe20 --port 6024 --num-sample 50000 --batch-size 1000 58 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe20 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 59 | # CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe20 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 60 | 61 | # https://drive.google.com/file/d/1H92Bgz26hLajYNtcY7zPtI7y9HFLeYtD/view?usp=sharing 62 | # https://drive.google.com/file/d/1u26_iWWaBSW8hXMnAudJB90Awolabta4/view?usp=sharing -------------------------------------------------------------------------------- /scripts/toy.sh: -------------------------------------------------------------------------------- 1 | python train.py --name prob-ode/varv10-p1 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 3 --p 1 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode 2 | 3 | # python train.py --name prob-ode/varv7 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 7 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode 4 | 5 | # python train.py --name prob-ode/varv5 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 5 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode 6 | 7 | # python train.py --name prob-ode/varv3 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 3 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode 8 | 9 | # python train.py --name prob-ode/varv1 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 1 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode 10 | 11 | 12 | # python train.py --name prob-ode/varx2-varv10 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 2 13 | # --varv 10 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode 14 | 15 | # python train.py --name prob-ode/varx4-varv10 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 4 --varv 10 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode 16 | 17 | # python train.py --name prob-ode/varx4-varv6 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 4 18 | # --varv 6 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode 19 | 20 | # #Training 21 | # python train.py --name train-toy/ode-spiral --exp toy --toy-exp spiral --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --ode --sampling vanillaODE --reweight reciprocal 22 | 23 | # python train.py --name train-toy/ode-gmm --exp toy --toy-exp gmm --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --ode --sampling vanillaODE --reweight reciprocal 24 | 25 | # #Sampling 20 NFE 26 | # python train.py --name eval-toy/ode-spiral --exp toy --toy-exp spiral --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --ode --sampling vanillaODE --reweight reciprocal --ckpt train-toy/ode-spiral --eval --interval 20 27 | 28 | # python train.py --name eval-toy/ode-gmm --exp toy --toy-exp gmm --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --ode --sampling vanillaODE --reweight reciprocal --ckpt train-toy/ode-gmm --eval --interval 20 29 | 30 | # #Training 31 | # python train.py --name train-toy/sde-spiral --exp toy --toy-exp spiral --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --sampling sscs --reweight reciprocal 32 | 33 | # python train.py --name train-toy/sde-gmm --exp toy --toy-exp gmm --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --sampling sscs --reweight reciprocal 34 | 35 | # #Sampling 20 NFE 36 | # python train.py --name eval-toy/sde-spiral --exp toy --toy-exp spiral --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --sampling sscs --reweight reciprocal --ckpt train-toy/sde-spiral --eval --interval 20 37 | 38 | # python train.py --name eval-toy/sde-gmm --exp toy --toy-exp gmm --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --sampling sscs --reweight reciprocal --ckpt train-toy/sde-gmm --eval --interval 20 39 | 40 | 41 | 42 | 43 | #Source Command 44 | # python train.py --name test-ode --exp toy --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --ode --sampling vanillaODE --reweight reciprocal 45 | -------------------------------------------------------------------------------- /setup/conda_install.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | pip install --upgrade pip 6 | pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu118 7 | pip install torch_ema 8 | pip install pytorch_warmup 9 | 10 | python setup/download_datasets.py 11 | pip install -U pytorch_warmup 12 | pip install --upgrade wandb 13 | # AFHQ download 14 | URL=https://www.dropbox.com/s/vkzjokiwof5h8w6/afhq_v2.zip?dl=0 15 | ZIP_FILE=./dataset/afhq_v2.zip 16 | mkdir -p ./dataset/afhqv2 17 | # wget -N $URL -O $ZIP_FILE 18 | wget --wait 10 --random-wait --continue -N $URL -O $ZIP_FILE 19 | unzip $ZIP_FILE -d ./dataset/afhqv2 20 | rm $ZIP_FILE 21 | # AFHQ download 22 | python edm/dataset_tool.py --source=dataset/afhqv2 --dest=dataset/afhqv2-64x64.zip --resolution=64x64 -------------------------------------------------------------------------------- /setup/download_datasets.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import torchvision.datasets as datasets 6 | datasets.CIFAR10( 7 | './dataset', 8 | train= True, 9 | download=True, 10 | ) -------------------------------------------------------------------------------- /setup/environments.yml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | name: agm 6 | channels: 7 | # - pytorch 8 | - nvidia 9 | dependencies: 10 | - python>=3.8, < 3.10 # package build failures on 3.10 11 | - pip 12 | - numpy>=1.20 13 | - click>=8.0 14 | - pillow>=8.3.1 15 | - scipy>=1.7.1 16 | - psutil 17 | - requests 18 | - tqdm 19 | - imageio 20 | - matplotlib 21 | - scikit-learn 22 | - pip: 23 | - imageio-ffmpeg>=0.4.3 24 | - pyspng 25 | - easydict 26 | - ipdb 27 | # - torch_ema 28 | - prefetch_generator 29 | - wandb 30 | - termcolor 31 | - rich 32 | # - pytorch_warmup 33 | - colored_traceback 34 | - ml_collections 35 | - gdown 36 | -------------------------------------------------------------------------------- /setup/requirement.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | termcolor 4 | easydict 5 | ipdb 6 | tqdm 7 | scikit-learn 8 | imageio 9 | matplotlib 10 | tensorboard 11 | torchmetrics==0.9.3 12 | prefetch_generator 13 | colored-traceback 14 | torch-ema 15 | gdown==4.6.0 16 | clean-fid==0.1.35 17 | rich 18 | lmdb 19 | wandb 20 | Ninja 21 | ml_collections -------------------------------------------------------------------------------- /setup/setup.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | pip install --upgrade pip 6 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 7 | pip install -r setup/requirement.txt 8 | python download_datasets.py 9 | pip install -U pytorch_warmup 10 | pip install --upgrade wandb 11 | # # AFHQ download 12 | URL=https://www.dropbox.com/s/vkzjokiwof5h8w6/afhq_v2.zip?dl=0 13 | ZIP_FILE=./dataset/afhq_v2.zip 14 | mkdir -p ./dataset/afhqv2 15 | wget --wait 10 --random-wait --continue -N $URL -O $ZIP_FILE 16 | unzip $ZIP_FILE -d ./dataset/afhqv2 17 | rm $ZIP_FILE 18 | # AFHQ download 19 | python edm/dataset_tool.py --source=dataset/afhqv2 --dest=dataset/afhqv2-64x64.zip --resolution=64x64 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from __future__ import absolute_import, division, print_function, unicode_literals 6 | 7 | import os 8 | import sys 9 | import random 10 | import argparse 11 | 12 | import copy 13 | from pathlib import Path 14 | import numpy as np 15 | import torch 16 | from torch.multiprocessing import Process 17 | 18 | from edm.logger import Logger 19 | from edm.distributed_util import init_processes 20 | from dataset import spiral,cifar10, AFHQv2,imagenet64 21 | from AM import runner 22 | from configs import cifar10_config,toy_config,afhqv2_config,imagenet64_config 23 | import colored_traceback.always 24 | import torch.multiprocessing 25 | 26 | RESULT_DIR = Path("results") 27 | def set_seed(seed): 28 | # https://github.com/pytorch/pytorch/issues/7068 29 | random.seed(seed) 30 | os.environ['PYTHONHASHSEED'] = str(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 35 | torch.backends.cudnn.enabled = True 36 | torch.backends.cudnn.benchmark = True 37 | torch.backends.cuda.matmul.allow_tf32 = True 38 | # https://stackoverflow.com/questions/73125231/pytorch-dataloaders-bad-file-descriptor-and-eof-for-workers0 39 | torch.multiprocessing.set_sharing_strategy('file_system') 40 | 41 | def create_training_options(): 42 | # --------------- basic --------------- 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--seed", type=int, default=42) 45 | parser.add_argument("--name", type=str, default=None, help="experiment ID") 46 | parser.add_argument("--exp", type=str, default='toy', choices=['toy','cifar10','AFHQv2','imagenet64','cond-imagenet64'], help="experiment type") 47 | parser.add_argument("--toy-exp", type=str, default='gmm', choices=['gmm','spiral'], help="experiment type") 48 | parser.add_argument("--ckpt", type=str, default=None, help="resumed checkpoint name") 49 | parser.add_argument("--cond", action="store_true", help="whether or not use class cond") 50 | parser.add_argument("--gpu", type=int, default=None, help="set only if you wish to run on a particular device") 51 | parser.add_argument("--n-gpu-per-node", type=int, default=1, help="number of gpu on each node") 52 | parser.add_argument("--master-address", type=str, default='localhost', help="address for master") 53 | parser.add_argument("--node-rank", type=int, default=0, help="the index of node") 54 | parser.add_argument("--num-proc-node", type=int, default=1, help="The number of nodes in multi node env") 55 | parser.add_argument("--port", type=str, default='6022', help="localhost port") 56 | 57 | # --------------- Dynamics Hyperparameters --------------- 58 | parser.add_argument("--n-train", type=int, default=5000) 59 | parser.add_argument("--t0", type=float, default=1e-4, help="Number of Training sample for toy dataset") 60 | parser.add_argument("--T", type=float, default=0.999, help="Terminal Time for the dynamics") 61 | parser.add_argument("--nfe", type=int, default=1000, help="number of interval") 62 | parser.add_argument("--varx", type=float, default=1.0, help="variance of position for prior") 63 | parser.add_argument("--varv", type=float, default=1.0, help="variance of velocity for prior") 64 | parser.add_argument("--k", type=float, default=0.0, help="Correlation/Covariance of position of velocity for prior distribution") 65 | parser.add_argument("--p", type=float, default=3, help="diffusion coefficient for Time Variant value g(t)=p*(damp_t-t)") 66 | parser.add_argument("--damp-t", type=float, default=1, help="diffusion coefficient for Time Variant value g(t)=p*(damp_t-t)") 67 | parser.add_argument("--DE-type", type=str, default='probODE', choices=['probODE','SDE'],\ 68 | help="Choose the type of SDE, which includes Time Varing g Model Predictive Control (TVgMPC) ,\ 69 | Time Invariant g Model Predictive Control (TIVgMPC),Time Invariant g Flow Matching (TIVgFM)") 70 | # --------------- optimizer and loss --------------- 71 | parser.add_argument("--microbatch", type=int, default=512, help="mini batch size for gradient descent") 72 | parser.add_argument("--num-itr", type=int, default=50000, help="number of training iteration") 73 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 74 | parser.add_argument("--ema", type=float, default=0.9999, help='ema decay rate') 75 | parser.add_argument("--l2-norm", type=float, default=0, help='L2 norm for optimizer') 76 | parser.add_argument("--t-samp", type=str, default='uniform', choices=['uniform','debug'],\ 77 | help="the way to sample t during sampling") 78 | parser.add_argument("--precond", action="store_true", help="preconditioning for the network output") 79 | parser.add_argument("--clip-grad", type=float, default=None, help="whether to clip the gradient.") 80 | parser.add_argument("--xflip", action="store_true", help="Whether flip the dataset in x-horizon") 81 | parser.add_argument("--reweight-type", type=str, default='ones', choices=['ones','reciprocal','reciprocalinv'], help="How to reweight the training") 82 | # --------------- sampling and evaluating --------------- 83 | parser.add_argument("--train-fid-sample",type=int, default=None, help="number of samples used for evaluating FID") 84 | parser.add_argument("--sampling-batch", type=int, default=512, help="mini batch size for gradient descent") 85 | parser.add_argument("--eval", action="store_true", help="evaluating mode. Wont save ckpt") 86 | parser.add_argument("--clip-x1", action="store_true", help="similar to DDPM, clip the estimiated data to [-1,1]") 87 | parser.add_argument("--debug", action="store_true", help="Using single GPU to evaluate. raise this flag for fast testing") 88 | parser.add_argument("--pred-x1", action="store_true", help="Using predict x1 as the sampling output") 89 | parser.add_argument("--gDDIM-r", type=int, default=2, help="the checkpoint name from which we wish to sample") 90 | parser.add_argument("--diz-order", type=int, default=2, help="the checkpoint name from which we wish to sample") 91 | parser.add_argument("--solver", type=str, default='em', help="sampler") 92 | parser.add_argument("--diz", type=str, default='Euler', choices=['Euler','sigmoid'], help="The discretization scheme") 93 | parser.add_argument("--sanity", action="store_true", help="quick sanity check for the proposed dyanmics") 94 | # --------------- path and logging --------------- 95 | parser.add_argument("--log-dir", type=Path, default=".log", help="path to log std outputs and writer data") 96 | parser.add_argument("--log-writer", type=str, default=None, help="log writer: can be tensorbard, wandb, or None") 97 | parser.add_argument("--wandb-api-key", type=str, default=None, help="unique API key of your W&B account; see https://wandb.ai/authorize") 98 | parser.add_argument("--wandb-user", type=str, default=None, help="user name of your W&B account") 99 | 100 | 101 | default_config, model_configs = { 102 | 'toy': toy_config.get_toy_default_configs, 103 | 'cifar10': cifar10_config.get_cifar10_default_configs, 104 | 'AFHQv2': afhqv2_config.get_afhqv2_default_configs, 105 | 'imagenet64': imagenet64_config.get_imagenet64_default_configs, 106 | }.get(parser.parse_args().exp)() 107 | parser.set_defaults(**default_config) 108 | 109 | opt = parser.parse_args() 110 | opt.model_config=model_configs 111 | 112 | # ========= auto setup ========= 113 | opt.device ='cuda' if opt.gpu is None else f'cuda:{opt.gpu}' 114 | opt.distributed = opt.n_gpu_per_node > 1 115 | 116 | if opt.solver=='sscs': 117 | assert opt.DE_type=='SDE' 118 | 119 | # log ngc meta data 120 | if "NGC_JOB_ID" in os.environ.keys(): 121 | opt.ngc_job_id = os.environ["NGC_JOB_ID"] 122 | 123 | # ========= path handle ========= 124 | os.makedirs(opt.log_dir, exist_ok=True) 125 | opt.ckpt_path = RESULT_DIR / opt.name 126 | os.makedirs(opt.ckpt_path, exist_ok=True) 127 | 128 | if opt.train_fid_sample is None: 129 | opt.train_fid_sample = opt.n_gpu_per_node*opt.microbatch 130 | 131 | if opt.ckpt is not None: 132 | ckpt_file = RESULT_DIR / opt.ckpt / "latest.pt" 133 | assert ckpt_file.exists() 134 | opt.load = ckpt_file 135 | else: 136 | opt.load = None 137 | 138 | 139 | return opt 140 | 141 | def main(opt): 142 | log = Logger(opt.global_rank, opt.log_dir) 143 | log.info("=======================================================") 144 | log.info(" Accelerate Model") 145 | log.info("=======================================================") 146 | log.info("Command used:\n{}".format(" ".join(sys.argv))) 147 | log.info(f"Experiment ID: {opt.name}") 148 | 149 | # set seed: make sure each gpu has differnet seed! 150 | if opt.seed is not None: 151 | set_seed(opt.seed + opt.global_rank) 152 | 153 | # build dataset 154 | if opt.exp=='toy': 155 | train_loader = spiral.spiral_data(opt) 156 | elif opt.exp=='cifar10': 157 | train_loader = cifar10.cifar10_data(opt) 158 | if opt.cond: opt.cond_dim = 10 159 | elif opt.exp=='AFHQv2': 160 | train_loader = AFHQv2.AFHQv2_data(opt) 161 | elif opt.exp=='imagenet64': 162 | train_loader = imagenet64.imagenet64_data(opt) 163 | if opt.cond: opt.cond_dim = 1000 164 | 165 | run = runner.Runner(opt, log) 166 | if opt.eval: 167 | run.ema.copy_to() 168 | run.evaluation(opt,0,train_loader) 169 | else: 170 | run.train(opt, train_loader) 171 | log.info("Finish!") 172 | 173 | if __name__ == '__main__': 174 | opt = create_training_options() 175 | if opt.debug: 176 | opt.distributed =False 177 | opt.global_rank = 0 178 | opt.local_rank = 0 179 | opt.global_size = 1 180 | with torch.cuda.device(opt.gpu): 181 | main(opt) 182 | else: 183 | torch.multiprocessing.set_start_method('spawn') 184 | if opt.distributed: 185 | size = opt.n_gpu_per_node 186 | processes = [] 187 | 188 | for rank in range(size): 189 | opt = copy.deepcopy(opt) 190 | opt.local_rank = rank 191 | global_rank = rank + opt.node_rank * opt.n_gpu_per_node 192 | global_size = opt.num_proc_node * opt.n_gpu_per_node 193 | opt.global_rank = global_rank 194 | opt.global_size = global_size 195 | print('Node rank %d, local proc %d, global proc %d, global_size %d' % (opt.node_rank, rank, global_rank, global_size)) 196 | 197 | p = Process(target=init_processes, args=(global_rank, global_size, main, opt)) 198 | p.start() 199 | processes.append(p) 200 | 201 | for p in processes: 202 | p.join() 203 | else: 204 | torch.cuda.set_device(0) 205 | opt.global_rank = 0 206 | opt.local_rank = 0 207 | opt.global_size = 1 208 | init_processes(0, opt.n_gpu_per_node, main, opt) 209 | --------------------------------------------------------------------------------