├── ACKNOWLEDGEMENTS
├── AM
├── .DS_Store
├── diffusion.py
├── dynamics.py
├── runner.py
├── samplers.py
└── util.py
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── asset
├── AGMvsFM.gif
├── cond_gen.001.png
├── cond_gen.jpg.001.jpeg
├── cond_gen.pdf
├── cond_gen2.001.png
└── sampling_hop.png.001.png
├── configs
├── .DS_Store
├── afhqv2_config.py
├── cifar10_config.py
├── imagenet64_config.py
└── toy_config.py
├── dataset
├── AFHQv2.py
├── StrokeData
│ ├── testFig0.png
│ ├── testFig0_impainting.png
│ ├── testFig1.png
│ └── testFig1_impainting.png
├── __init__.py
├── cifar10.py
├── imagenet64.py
└── spiral.py
├── edm
├── augment.py
├── dataset.py
├── dataset_tool.py
├── distributed_util.py
├── dnnlib
│ ├── __init__.py
│ └── util.py
├── fid.py
├── logger.py
└── torch_utils
│ ├── __init__.py
│ ├── distributed.py
│ ├── misc.py
│ ├── persistence.py
│ └── training_stats.py
├── networks
├── .DS_Store
├── edm
│ └── ncsnpp.py
├── get_network.py
└── network.py
├── plot_util.py
├── sampling.py
├── scripts
├── AFHQv2.sh
├── cifar10.sh
├── example.sh
├── imagenet64.sh
├── release.sh
└── toy.sh
├── setup
├── conda_install.sh
├── download_datasets.py
├── environments.yml
├── requirement.txt
└── setup.sh
└── train.py
/AM/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/AM/.DS_Store
--------------------------------------------------------------------------------
/AM/diffusion.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import torch
7 | from .samplers import SamplerWrapper
8 | from .dynamics import TVgMPC
9 | import torch.nn.functional as F
10 |
11 | class Diffusion():
12 | def __init__(self, opt, device):
13 | self.opt = opt
14 | self.device = device
15 | self.varx = opt.varx
16 | self.varv = opt.varv
17 | dyn_kargs = {
18 | "p":opt.p, #diffusion coeffcient
19 | 'k':opt.k, # covariance of prior
20 | 'varx':opt.varx,
21 | 'varv':opt.varv,
22 | 'x_dim':opt.data_dim,
23 | 'device':opt.device, #Using sampling device
24 | 'DE_type':opt.DE_type
25 | }
26 | dynamics = TVgMPC(**dyn_kargs)
27 |
28 | '''
29 | set up dynamics solver
30 | '''
31 | solver_kargs = {
32 | "solver_name":opt.solver, #updated solver
33 | 'diz':opt.diz, #updated diz
34 | 't0':opt.t0, # original t0
35 | 'T':opt.T, # updated T
36 | 'interval':opt.nfe, #updated NFE
37 | 'dyn': dynamics, #updated dynamics
38 | 'device':opt.device, #Using sampling device
39 | 'snap': 10,
40 | 'local_rank':opt.local_rank,
41 | 'diz_order': opt.diz_order,
42 | 'gDDIM_r':opt.gDDIM_r,
43 | 'cond_opt':None
44 | }
45 |
46 | self.sampler = SamplerWrapper(**solver_kargs)
47 | self.dyn = TVgMPC(opt.p,opt.k,opt.varx,opt.varv,opt.data_dim,opt.device,opt.DE_type)
48 |
49 | if self.opt.local_rank ==0:
50 | print('----------using sampling method as {}'.format(opt.solver))
51 |
52 | def reweights(self, t):
53 | reweight_type = self.opt.reweight_type
54 | dyn = self.dyn
55 | x_dim = self.opt.data_dim
56 | if reweight_type =='ones':
57 | return torch.ones_like(t)
58 | elif reweight_type=='reciprocal':
59 | weight = 1/(1-t)
60 | return torch.sqrt(weight)
61 | elif reweight_type=='reciprocalinv':
62 | weight = 1/(t)
63 | return torch.sqrt(weight)
64 | else:
65 | raise RuntimeError
66 |
67 |
68 |
69 | def mt_sample(self,x1,ts):
70 | """ return xs.shape == [batch_x, batch_t, *x_dim] """
71 | opt = self.opt
72 | dyn = self.dyn
73 | x_dim = self.opt.data_dim
74 | joint_dim = self.opt.joint_dim
75 | t = ts.reshape(-1,*([1,]*len(x_dim)))
76 | analytic_xs, analytic_vs, analytic_fv =dyn.get_xt_vt_fv(t,x1,opt.DE_type,opt.device)
77 |
78 | analytic_ms = torch.cat([analytic_xs,analytic_vs],dim=1)
79 | label = analytic_fv
80 | inputs = analytic_ms
81 | return label, inputs
82 |
83 |
--------------------------------------------------------------------------------
/AM/dynamics.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import abc
7 | import torch
8 | import numpy as np
9 | from .util import cast_shape
10 | class BaseDynamics(metaclass=abc.ABCMeta):
11 | def __init__(self, p,k,varx,varv,x_dim,device):
12 | self.p = p
13 | self.k = k
14 | self.varx = varx
15 | self.varv = varv
16 | self.device = device
17 | self.x_dim = x_dim
18 | self.normalizer = self.get_normalizer()
19 | @abc.abstractmethod
20 | def g2P10(self, t):
21 | '''
22 | 1-0 element of the Solution of Lyapunov function P matrix.
23 | '''
24 | raise NotImplementedError
25 |
26 | @abc.abstractmethod
27 | def g2P11(self, t):
28 | '''
29 | 1-1 element of the Solution of Lyapunov function P matrix.
30 | '''
31 | raise NotImplementedError
32 |
33 | @abc.abstractmethod
34 | def sigmaxx(self, t):
35 | '''
36 | Covariance matrix xx component
37 | '''
38 | raise NotImplementedError
39 |
40 | @abc.abstractmethod
41 | def sigmavx(self, t):
42 | '''
43 | Covariance matrix xv component
44 | '''
45 | raise NotImplementedError
46 |
47 | @abc.abstractmethod
48 | def sigmavv(self, t):
49 | '''
50 | Covariance matrix vv component
51 | '''
52 | raise NotImplementedError
53 |
54 | @abc.abstractmethod
55 | def g(self, t):
56 | '''
57 | diffusion coefficient, can be time variant or time invariant
58 | '''
59 | raise NotImplementedError
60 |
61 | @abc.abstractmethod
62 | def score(self, t):
63 | raise NotImplementedError
64 |
65 | @abc.abstractmethod
66 | def get_analytic_mux_muv(self, t,x0,v0,x1):
67 | '''
68 | compute the random variance of x and v at time t given initial x0,v0 and target x1
69 | '''
70 | raise NotImplementedError
71 |
72 | @abc.abstractmethod
73 | def mux0_muv0(self):
74 | '''
75 | mean of position and velocity at initial boundary
76 | '''
77 | raise NotImplementedError
78 |
79 | @abc.abstractmethod
80 | def get_normalizer(self):
81 | raise NotImplementedError
82 |
83 | def get_cov(self,t,damp=0):
84 | '''
85 | Compute cholesky decomposition complenent Lxx,Lxv,Lvv and ell
86 | '''
87 | x_dim = self.x_dim
88 | t = t.double()
89 | sigxx = cast_shape(self.sigmaxx(t),x_dim)+damp
90 | sigxv = cast_shape(self.sigmavx(t),x_dim)
91 | sigvv = cast_shape(self.sigmavv(t),x_dim)+damp
92 | ellt = cast_shape(-torch.sqrt(sigxx/(sigxx*sigvv-sigxv**2)),x_dim)
93 | Lxx = torch.sqrt(sigxx)
94 | Lxv = sigxv/Lxx
95 | tmp = sigvv-Lxv**2
96 | invalid_idx = torch.logical_and(tmp<0, torch.isclose(tmp,torch.zeros_like(tmp)))
97 | tmp[invalid_idx]\
98 | = 0
99 | Lvv = torch.sqrt(tmp)
100 |
101 | return Lxx.float(),Lxv.float(),Lvv.float(), ellt.float()
102 |
103 | def get_xt_vt_fv(self,t,x1,DE_type,device,return_fv=True):
104 | '''
105 | Compute the input and label for the network
106 | '''
107 | # opt = self.opt
108 | joint_dim = [value*2 if idx==0 else value for idx,value in enumerate(self.x_dim)]
109 | batch_x = t.shape[0]
110 | mux0,muv0 = self.mux0_muv0(batch_x)
111 | muxt,muvt = self.get_analytic_mux_muv(t,mux0,muv0,x1=x1)
112 | Lxx,Lxv,Lvv,ell = self.get_cov(t)
113 | noise = torch.randn(batch_x, *joint_dim,device=device)
114 | assert noise.shape[0] == t.shape[0]
115 | epsxx,epsvv = torch.chunk(noise,2,dim=1)
116 | _g2P10 = self.g2P10(t)
117 | _g2P11 = self.g2P11(t)
118 | analytic_xs = muxt+Lxx*epsxx
119 | analytic_vs = muvt+(Lxv*epsxx+Lvv*epsvv)
120 | if return_fv:
121 | normalization = self.normalizer(t)
122 | analytic_fv = 4*x1*(t-1)**2- _g2P11*((Lxx/(1-t)+Lxv)*epsxx+Lvv*epsvv)
123 |
124 | score = self.score(t,ell,epsvv) if DE_type=='probODE' else 0
125 |
126 | analytic_fv = (analytic_fv+score)/normalization
127 | # =========normlaize the label to standard gaussian =================
128 |
129 | return analytic_xs, analytic_vs, analytic_fv
130 |
131 | else:
132 | return analytic_xs, analytic_vs
133 |
134 |
135 |
136 |
137 | class TVgMPC(BaseDynamics):
138 | '''
139 | TIVg: Time Variant g
140 | '''
141 | def __init__(self, p,k,varx,varv,x_dim,device,DE_type):
142 | super(TVgMPC,self).__init__(p,k,varx,varv,x_dim,device)
143 | self.DE_type = DE_type
144 | def g2P10(self, t):
145 | return 6/(-1+t)**2
146 |
147 | def g2P11(self, t):
148 | return -4/(-1+t)
149 |
150 | def g(self, t):
151 | tt = 1
152 | return self.p*(tt-t)
153 |
154 | def sigmaxx(self, t):
155 | m,n = self.varx, self.varv
156 | k = self.k
157 | p = self.p
158 | tt = 1
159 | val =(t - 1)**2*(30*m*(t**3 - 3*t**2 + 3*t + 3)**2 - 60*p**2*(t - 1)**3*torch.log(1 - t) - t*(60*k*np.sqrt(m*n)*(t**5 - 6*t**4 + 15*t**3 - 15*t**2 + 9) - 30*n*t*(t**2 - 3*t + 3)**2 + p**2*(t**5*(6*tt**2 + 3*tt + 1) - 6*t**4*(6*tt**2 + 3*tt + 1) + 15*t**3*(6*tt**2 + 3*tt + 1) - 10*t**2*(9*tt**2 + 11) + 150*t - 60)))/270
160 | return val
161 |
162 | def sigmavx(self, t):
163 | m,n = self.varx, self.varv
164 | p = self.p
165 | k = self.k
166 | tt = 1
167 | val =(1/270 - t/270)*(30*k*np.sqrt(m*n)*(8*t**6 - 48*t**5 + 120*t**4 - 135*t**3 + 45*t**2 + 27*t - 9) + 150*p**2*(t - 1)**3*torch.log(1 - t) + t*(-120*m*(t**5 - 6*t**4 + 15*t**3 - 15*t**2 + 9) - 30*n*(4*t**5 - 24*t**4 + 60*t**3 - 75*t**2 + 45*t - 9) + p**2*(4*t**5*(6*tt**2 + 3*tt + 1) - 24*t**4*(6*tt**2 + 3*tt + 1) + 60*t**3*(6*tt**2 + 3*tt + 1) - 5*t**2*(81*tt**2 + 18*tt + 55) + 15*t*(9*tt**2 + 25) - 150)))
168 | return val
169 |
170 |
171 |
172 | def sigmavv(self, t):
173 | m,n = self.varx, self.varv
174 | p = self.p
175 | k = self.k
176 | tt = 1
177 | val= n*(-4*t**3 + 12*t**2 - 12*t + 3)**2/9 - 8*p**2*(t - 1)**3*torch.log(1 - t)/9 + t*(-120*k*np.sqrt(m*n)*(4*t**5 - 24*t**4 + 60*t**3 - 75*t**2 + 45*t - 9) + 240*m*t*(t**2 - 3*t + 3)**2 + p**2*(-8*t**5*(6*tt**2 + 3*tt + 1) + 48*t**4*(6*tt**2 + 3*tt + 1) - 120*t**3*(6*tt**2 + 3*tt + 1) + 5*t**2*(180*tt**2 + 72*tt + 53) - 15*t*(36*tt**2 + 9*tt + 20) + 135*tt**2 + 120))/135
178 | return val
179 |
180 | def get_normalizer(self):
181 | '''
182 | get Compute the normlaizier for the label in order to normalize the label to some range.
183 | '''
184 | return self.normalizer
185 |
186 | def normalizer(self,t):
187 | '''
188 | normlaize the label to [1-->0] (predicting the x1) or standard gaussian (predicting the noise)
189 | '''
190 | _g2P10 = self.g2P10(t)
191 | _g2P11 = self.g2P11(t)
192 | _g = self.g(t)
193 | Lxx,Lxv,Lvv,ell = self.get_cov(t)
194 | Lxx = Lxx.reshape_as(t)
195 | Lxv = Lxv.reshape_as(t)
196 | Lvv = Lvv.reshape_as(t)
197 | ell = ell.reshape_as(t)
198 | # if len(t.shape)==4:
199 | # debug()
200 |
201 | if self.DE_type=='probODE':
202 | norm =torch.sqrt(((_g2P11*((-1/(1-t)*Lxx-Lxv)))**2+(_g2P11*Lvv-0.5*_g**2*ell)**2))/(1-t)
203 | return norm
204 | else:
205 | damp=0
206 | return torch.sqrt(((_g2P11*((-1/(1-t)*Lxx-Lxv)))**2+(_g2P11*Lvv)**2)+damp)/(1-t)
207 |
208 |
209 | def get_analytic_mux_muv(self, t,x0,v0,x1):
210 | bs = x0.shape[0]
211 | if x1 is None:
212 | return self.mux0_muv0(bs)
213 | else:
214 | muv =v0*(-4*t**3/3 + 4*t**2 - 4*t + 1) + t*(-4*x0/3 + 4*x1/3)*(t**2 - 3*t + 3)
215 | mux =-x0*(t**4 - 4*t**3 + 6*t**2 - 3)/3 + t*(-v0*(t**3 - 4*t**2 + 6*t - 3) + x1*t*(t**2 - 4*t + 6))/3
216 | return mux,muv
217 |
218 | def mux0_muv0(self,bs):
219 | muv0 = torch.zeros(bs, *self.x_dim,device=self.device)
220 | mux0 = torch.zeros(bs, *self.x_dim,device=self.device)
221 | return mux0, muv0
222 |
223 | def score(self,t,ell,epsvv):
224 | _g = self.g(t)
225 | return -0.5*_g**2*ell*epsvv
226 |
227 | def get_m0(self,bs):
228 | joint_dim = [value*2 if idx==0 else value for idx,value in enumerate(self.x_dim)]
229 |
230 | mux0,muv0 = self.mux0_muv0(bs)
231 | t = torch.zeros(bs,1,device=self.device)
232 | Lxx,Lxv,Lvv,ell = self.get_cov(t)
233 | noise = torch.randn(bs, *joint_dim,device=self.device)
234 | assert noise.shape[0] == t.shape[0]
235 | epsxx,epsvv = torch.chunk(noise,2,dim=1)
236 | analytic_x0 = mux0+Lxx*epsxx
237 | analytic_v0 = muv0+(Lxv*epsxx+Lvv*epsvv)
238 | return torch.cat([analytic_x0, analytic_v0],dim=1)
--------------------------------------------------------------------------------
/AM/runner.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import os
6 | import numpy as np
7 | import pickle
8 | import torch
9 | import pytorch_warmup as warmup
10 | import torch.nn.functional as F
11 | from torch.optim import AdamW, lr_scheduler
12 | from torch.nn.parallel import DistributedDataParallel as DDP
13 | from edm import dnnlib
14 | from torch_ema import ExponentialMovingAverage
15 | import torchvision.utils as tu
16 | # from .get_network import get_nn
17 | from networks.get_network import get_nn
18 | from . import util
19 | import plot_util
20 | from .diffusion import Diffusion
21 | from plot_util import norm_data, plot_plt,plot_scatter
22 | from .util import all_cat_cpu
23 | from edm.fid import calculate_fid_from_inception_stats, calculate_inception_stats
24 | from sampling import loop_saving_png
25 |
26 |
27 | def build_optimizer_sched(opt, net, log):
28 | optim_dict = {"lr": opt.lr, 'weight_decay': opt.l2_norm}
29 | optimizer = AdamW(net.parameters(), **optim_dict)
30 | log.info(f"[Opt] Built AdamW optimizer {optim_dict=}!")
31 |
32 | sched = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.num_itr)
33 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
34 | log.info(f"[Opt] Built lr _step scheduler Cosine")
35 |
36 |
37 | if opt.load:
38 | checkpoint = torch.load(opt.load, map_location="cpu")
39 | if "optimizer" in checkpoint.keys():
40 | optimizer.load_state_dict(checkpoint["optimizer"])
41 | log.info(f"[Opt] Loaded optimizer ckpt {opt.load}!")
42 | else:
43 | log.warning(f"[Opt] Ckpt {opt.load} has no optimizer!")
44 | if sched is not None and "sched" in checkpoint.keys() and checkpoint["sched"] is not None:
45 | sched.load_state_dict(checkpoint["sched"])
46 | log.info(f"[Opt] Loaded lr sched ckpt {opt.load}!")
47 | else:
48 | log.warning(f"[Opt] Ckpt {opt.load} has no lr sched!")
49 | for g in optimizer.param_groups:
50 | g['lr'] = opt.lr
51 |
52 | return optimizer, sched,warmup_scheduler
53 |
54 |
55 |
56 | class Runner(object):
57 | def __init__(self, opt, log, save_opt=True):
58 | super(Runner,self).__init__()
59 |
60 | # ===========Save opt. ===========
61 | if save_opt:
62 | opt_pkl_path = opt.ckpt_path / "options.pkl"
63 | with open(opt_pkl_path, "wb") as f:
64 | pickle.dump(opt, f)
65 | log.info("Saved options pickle to {}!".format(opt_pkl_path))
66 |
67 | self.diffusion = Diffusion(opt, opt.device)
68 | if opt.exp!='toy':
69 | ref_file_name ={'cifar10':'https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz',
70 | 'AFHQv2': 'https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz',
71 | 'imagenet64': 'https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/imagenet-64x64.npz',
72 | }.get(opt.exp)
73 |
74 | with dnnlib.util.open_url(ref_file_name) as f:
75 | self.ref = dict(np.load(f))
76 |
77 | self.net = get_nn(opt,self.diffusion.dyn)
78 | log.info('network size [{}]'.format(util.count_parameters(self.net)))
79 | self.ema = ExponentialMovingAverage(self.net.parameters(), decay=opt.ema)
80 | self.opt = opt
81 | self.reweight = self.diffusion.reweights
82 | self.ts_sampler = {'uniform':util.uniform_ts,
83 | 'debug':util.debug_ts,
84 | }.get(opt.t_samp)
85 | if opt.load:
86 | checkpoint = torch.load(opt.load, map_location="cpu")
87 | self.net.load_state_dict(checkpoint['net'])
88 | log.info(f"[Net] Loaded network ckpt: {opt.load}!")
89 | self.ema.load_state_dict(checkpoint["ema"])
90 | log.info(f"[Ema] Loaded ema ckpt: {opt.load}!")
91 |
92 | self.net.to(opt.device)
93 | self.ema.to(opt.device)
94 | self.log = log
95 | self.best_res = np.inf
96 |
97 |
98 | def train(self, opt, train_loader):
99 | self.writer = util.build_log_writer(opt)
100 | log = self.log
101 |
102 | if opt.distributed:
103 | net = DDP(self.net, device_ids=[opt.device])
104 | else:
105 | net=self.net
106 |
107 | ema = self.ema
108 | optimizer, sched, warmup = build_optimizer_sched(opt, net, log)
109 | ts_sampler, reweight = self.ts_sampler, self.reweight
110 | t0,T,device = opt.t0, opt.T, opt.device
111 |
112 | net.train()
113 |
114 | for it in range(opt.num_itr):
115 | optimizer.zero_grad(set_to_none=True)
116 | # ===== sample boundary pair =====
117 | x1,class_cond\
118 | = train_loader.sample()
119 | # ===== compute loss =====
120 | _ts = ts_sampler(t0,T,x1.shape[0],device)
121 | label,mt= self.diffusion.mt_sample(x1,ts=_ts)
122 | lambdat = reweight(_ts)[:,None]
123 | lambdat = lambdat.reshape(-1,*([1,]*(len(x1.shape)-1)))
124 | pred = net(mt, _ts,cond=class_cond)
125 | label = label.reshape_as(pred)
126 | _pred = pred.detach().cpu() #for rendering loss over time
127 | _label = label.detach().cpu() #for rendering loss over time
128 | loss = F.mse_loss(lambdat*pred, lambdat*label)
129 | loss.backward()
130 |
131 | if opt.clip_grad is not None:
132 | torch.nn.utils.clip_grad_norm_(net.parameters(), opt.clip_grad)
133 | optimizer.step()
134 | ema.update()
135 |
136 | if sched is not None:
137 | with warmup.dampening():
138 | sched.step()
139 |
140 | # # -------- logging --------
141 | log.info("train_it {}/{} | lr:{} | loss:{}".format(
142 | 1+it,
143 | opt.num_itr,
144 | "{:.2e}".format(optimizer.param_groups[0]['lr']),
145 | "{:+.4f}".format(loss.item()),
146 | ))
147 | if it % 10 == 0:
148 | self.writer.add_scalar(it, 'loss', loss.detach())
149 |
150 | #============monitoring the loss distribution over time=======
151 | if it% 1000 ==0 and self.writer is not None:
152 | total_idxs = torch.arange(0,opt.nfe).float()
153 | total_vals = torch.zeros_like(total_idxs).float()
154 | _lambdat = lambdat.detach().cpu()
155 | if self.opt.exp=='toy':
156 | _loss = (((_pred*_lambdat-_label*_lambdat)**2).sum(-1)).float()
157 | else:
158 | _loss = (((_pred*_lambdat-_label*_lambdat)**2).sum(-1).sum(-1).sum(-1)).float()
159 | total_vals[(_ts*opt.nfe).long()] = _loss/(_pred.numel()/_pred.shape[0])
160 |
161 |
162 | self.writer.add_bar(it,'ts_loss',[total_idxs,total_vals])
163 | #============monitoring the loss distribution over time=======
164 |
165 | if it == 1000 or it % 10000 == 0:
166 | # if it >=9999 and it % 5000 == 0:
167 | net.eval()
168 | results=self.evaluation(opt, it,train_loader)
169 | net.train()
170 |
171 | if results<=self.best_res:
172 | self.best_res=results
173 | name = "latest.pt".format(it)
174 | util.save_ckpt(opt,log,it,self.net,ema,optimizer,sched,name) #Using Self.net to handle DDP
175 |
176 | if it%10000==0:
177 | if opt.global_rank == 0:
178 | name = "latest_it{}.pt".format(it)
179 | util.save_ckpt(opt,log,it,self.net,ema,optimizer,sched,name)
180 |
181 | if opt.distributed:
182 | torch.distributed.barrier()
183 |
184 | self.writer.close()
185 |
186 |
187 | @torch.no_grad()
188 | def evaluation(self, opt, it, loader):
189 | log = self.log
190 | log.info(f"========== Evaluation started: iter={it} ==========")
191 |
192 | def log_image(tag, img, nrow=10):
193 | self.writer.add_image(it, tag, tu.make_grid((img+1)/2, nrow=nrow,scale_each=True)) # [1,1] -> [0,1]
194 |
195 | x1,class_cond = loader.sample()
196 | if opt.exp=='toy':
197 | sampler = self.diffusion.sampler
198 | m0 = sampler.dyn.get_m0(opt.sampling_batch).to(opt.device)
199 | ms, pred_m1,est_x1s,snap_ts = sampler.solve(self.ema,self.net,m0,cond=None)
200 | plot_util.plot_toy(opt,ms,it,pred_m1,x1)
201 | pos_traj = ms[:,:-1,0:2,...]
202 | vel_traj = ms[:,:-1,2:,...]
203 | est_x1s = est_x1s.detach().cpu().numpy()
204 | plot_util.save_toy_npy_traj(opt,'itr_{}_est_x1s'.format(it),est_x1s,n_snapshot=10)
205 | else:
206 | image_dir = os.path.join(opt.ckpt_path , 'fid_train_folder')
207 | num_loop=int((opt.train_fid_sample-1)/opt.num_proc_node/opt.n_gpu_per_node/opt.sampling_batch)+1
208 | log.info('num of loop {}, batch {}, number of gpu{}, sampling number {}'.format(num_loop, opt.sampling_batch, opt.n_gpu_per_node,opt.train_fid_sample))
209 |
210 | ms, pred_m1, est_x1s,snap_ts = loop_saving_png( opt,\
211 | self.diffusion.sampler,\
212 | self.ema,\
213 | self.net,\
214 | log,\
215 | image_dir,\
216 | num_loop=num_loop,\
217 | return_last=True)
218 |
219 | mu,sigma=calculate_inception_stats( image_path=image_dir,\
220 | num_expected=opt.train_fid_sample,\
221 | seed=42,\
222 | max_batch_size=128)
223 |
224 | fid = calculate_fid_from_inception_stats(mu,sigma,self.ref['mu'], self.ref['sigma'])
225 | log.info(f"========== FID is: iter={fid} ==========")
226 | log.info(f"========== FID folder is at : {image_dir} ==========")
227 | self.writer.add_scalar(it, 'fid', fid)
228 | ########Visualizing resulting data########
229 | num_samp=40
230 | pred_x1 = pred_m1[:,0:3,...]
231 | pred_v1 = pred_m1[:,3:,...]
232 | est_x1s = est_x1s
233 | ms = ms
234 | gt_x1 = x1
235 | gt_ts = snap_ts
236 |
237 | if opt.log_writer is not None:
238 | _pos_traj = ms[0:5,:,0:3,...]
239 | pos_traj = _pos_traj.reshape(-1,*opt.data_dim)
240 | vel_traj = ms[0:5,:,3:,...]
241 | est_x1s = est_x1s[0:5,:,:,...]
242 | vel_traj = vel_traj.reshape(-1,*opt.data_dim)
243 | est_x1s = est_x1s.reshape(-1,*opt.data_dim)
244 | log_image("image/position", (pred_x1[0:num_samp]))
245 | log_image("image/velocity", (pred_v1[0:num_samp]))
246 | log_image("image/gt", (gt_x1[0:num_samp]))
247 | log_image("image/position_traj",pos_traj,nrow=11)
248 | log_image("image/velocity_traj",vel_traj,nrow=11)
249 | log_image("image/est_x1s",est_x1s,nrow=11)
250 | else:
251 | fn_pdf = os.path.join(opt.ckpt_path, 'itr_{}_x.png'.format(it))
252 | tu.save_image(norm_data((pred_x1+1)/2), fn_pdf, nrow = 6)
253 | fn_pdf = os.path.join(opt.ckpt_path, 'itr_{}_v.png'.format(it))
254 | tu.save_image(norm_data((pred_v1+1)/2), fn_pdf, nrow = 6)
255 | if it==0:
256 | fn_pdf = os.path.join(opt.ckpt_path, 'itr_{}_ground_truth.png'.format(it))
257 | tu.save_image(norm_data((gt_x1+1)/2), fn_pdf, nrow = 6)
258 | ########Visualizing resulting data########
259 | log.info(f"========== Evaluation finished: iter={it} ==========")
260 | torch.cuda.empty_cache()
261 | results = self.best_res if opt.exp=='toy' else fid
262 | return results
263 |
--------------------------------------------------------------------------------
/AM/samplers.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | from tqdm import tqdm
6 | import torch
7 | from . import util
8 | import numpy as np
9 |
10 |
11 | class SamplerWrapper:
12 | def __init__(self,**kwargs):
13 | solver_name = kwargs['solver_name']
14 | diz = kwargs['diz']
15 | t0 = kwargs['t0']
16 | T = kwargs['T']
17 | interval = kwargs['interval']
18 | device = kwargs['device']
19 | dyn = kwargs['dyn']
20 | diz_order = kwargs['diz_order']
21 | cond_opt = kwargs['cond_opt']
22 | self.kwargs = kwargs
23 |
24 | self.sampler = get_solver_fn(solver_name)
25 | tsdts = get_discretizer(diz,t0,T,interval,device,diz_order)
26 | self.kwargs['ts_dts'] = tsdts
27 | DE_type = 'probODE'
28 | self.dyn = dyn
29 |
30 | remove_keys=['diz','solver_name','t0','T','interval','device','diz_order']
31 | if solver_name=='gDDIM':
32 | self.r = self.kwargs['gDDIM_r']
33 | ts,dts = tsdts
34 | coef = AB_fn(dyn.normalizer,DE_type,ts,dts,self.r)
35 | self.kwargs['coef'] = coef
36 | self.kwargs['cond_opt'] = cond_opt
37 | else:
38 | remove_keys+=['gDDIM_r']
39 | remove_keys+=['cond_opt']
40 |
41 | for key in remove_keys:
42 | del self.kwargs[key]
43 |
44 |
45 | def solve(self,ema,net,m0,cond=None):
46 | with ema.average_parameters():
47 | return self.sampler(m0,net,cond,**self.kwargs)
48 |
49 |
50 | def get_est_x1(dyn,t,_g,_fv,x,v,DE_type):
51 | _g2P11=dyn.g2P11(t)
52 | if DE_type=='probODE':
53 | Lxx,Lxv,Lvv, ell\
54 | = dyn.get_cov(t)
55 | AA = (Lxx/(1-t)+Lxv).to(torch.float64)
56 | BB = (Lvv+(0.5*_g**2*ell)/_g2P11).to(torch.float64)
57 | bb = BB/Lvv
58 | aa = (AA-bb*Lxv)/Lxx
59 | cv = 4/3*t*(3+(-3+t)*t)
60 | cx = 1/3*t**2*(6+(-4+t)*t)
61 | est_x1 = (_fv+_g2P11*(aa*x+bb*v))/(4*(t-1)**2+_g2P11*(aa*cx+bb*cv))
62 | else:
63 | est_x1 = (_fv/_g2P11+v)*(1-t)+x
64 | return est_x1
65 |
66 | def get_solver_fn(solver_name):
67 | if solver_name == 'sscs':
68 | return sscs_sampler
69 | elif solver_name == 'em':
70 | return em_sampler
71 | elif solver_name == 'gDDIM':
72 | return gDDIM_sampler
73 | else:
74 | raise NotImplementedError(
75 | 'Sampler %s is not implemened.' % solver_name)
76 |
77 |
78 | def get_discretizer(diz,t0,T,interval,device,diz_order=2):
79 | if diz =='Euler':
80 | ts = torch.linspace(t0, T, interval+1, device=device)
81 | dts = ts[1:]-ts[0:-1]
82 | ts = ts[0:-1]
83 | last_dt=torch.Tensor([0.999-T]).to(device) #For evaluate full timesteps
84 | dts = torch.cat([dts,last_dt],dim=0)
85 |
86 | elif diz =='quad':
87 | ts= torch.linspace(t0**2,T**2,interval+1)
88 | ts= torch.sqrt(ts)
89 | dts= ts[1:]-ts[0:-1]
90 | ts = ts[0:-1]
91 | ts = ts.to(device)
92 | dts = dts.to(device)
93 |
94 | elif diz =='rev-quad':
95 | order = diz_order
96 | ts= torch.linspace(t0**(1/order),T**(1/order),interval,dtype=torch.float64,device=device)
97 | ts= ts**order
98 | dts= ts[1:]-ts[0:-1]
99 | ts = ts
100 | last_dt=torch.Tensor([0.999-T]).to(device) #For evaluate full timesteps
101 | dts = torch.cat([dts,last_dt],dim=0)
102 | else:
103 | raise NotImplementedError(
104 | 'discretizer %s is not implemened.' % diz)
105 | return ts,dts
106 |
107 |
108 | def dw(x,dt):
109 | return torch.randn_like(x)*torch.sqrt(dt)
110 |
111 |
112 | def sscs_sampler(m0,drift,cond,ts_dts,dyn,snap,local_rank,return_est_x1=True):
113 | #Equivalence reduced variance
114 | def sigmaxx(p,t):
115 | return p**2*t**3*(t*(t - 5) + 10)/30
116 | def sigmavx(p,t):
117 | return p**2*t**2*(t*(t - 4) + 6)/12
118 | def sigmavv(p,t):
119 | return p**2*t*(t*(t - 3) + 3)/3
120 |
121 | def analytic_dynamics(m,t,dt):
122 | dt=dt/2
123 | delta_varxx = (sigmaxx(dyn.p*(1-t),dt)).reshape(-1,*([1,]*(len(m.shape)-1)))
124 | delta_varxv = (sigmavx(dyn.p*(1-t),dt)).reshape(-1,*([1,]*(len(m.shape)-1)))
125 | delta_varvv = (sigmavv(dyn.p*(1-t),dt)).reshape(-1,*([1,]*(len(m.shape)-1)))
126 | cholesky11 = torch.sqrt(delta_varxx)
127 | cholesky21 = (delta_varxv / cholesky11)
128 | cholesky22 = (torch.sqrt(delta_varvv - cholesky21 ** 2.))
129 | batch_randn = torch.randn_like(m, device=m.device)
130 | batch_randn_x, batch_randn_v = torch.chunk(batch_randn, 2, dim=1)
131 | noise_x = cholesky11 * batch_randn_x
132 | noise_v = cholesky21 * batch_randn_x + cholesky22 * batch_randn_v
133 | noise = torch.cat((noise_x, noise_v), dim=1)
134 | x,v=torch.chunk(m,2,dim=1)
135 | x = x+v*dt
136 | m = torch.cat([x,v],dim=1)
137 | perturbed_data = m +noise
138 | return perturbed_data
139 |
140 | def EM_dynamics(v,dyn,fv,t,normalizer):
141 | norm = (normalizer(t)).squeeze()
142 | fv = fv*norm
143 | v = v+fv*dt
144 | return v,fv
145 |
146 | assert dyn.DE_type == 'SDE'
147 | bs = m0.shape[0]
148 | ts,dts = ts_dts
149 | m0 = m0.to(torch.float64)
150 | m = m0
151 | ms = []
152 | x,v = torch.chunk(m0,2,dim=1)
153 | interval = ts.shape[0]
154 | snaps = np.linspace(0, interval-1, snap).astype(int)
155 | snapts = []
156 | if local_rank == 0:
157 | _ts = tqdm(ts,desc=util.yellow("Propagating Dynamics..."))
158 | else:
159 | _ts = ts
160 | m = m0
161 | ms = []
162 | snapts = []
163 | est_x1s = []
164 | normalizer=\
165 | dyn.get_normalizer()
166 | x,v = torch.chunk(m0,2,dim=1)
167 | for idx,(t,dt) in enumerate(zip(_ts,dts)):
168 | _t= t.repeat(bs)
169 | m = analytic_dynamics(m,_t,dt)
170 | x,v = torch.chunk(m,2,dim=1)
171 |
172 | fv = drift(m.to(torch.float32),_t.to(torch.float32),cond=cond).to(torch.float32)
173 | #============EM step=============
174 | v,_fv = EM_dynamics(v,dyn,fv,t,normalizer)
175 | #============EM step=============
176 | m = torch.cat([x,v],dim=1)
177 | m = analytic_dynamics(m,_t,dt)
178 |
179 | if idx in snaps:
180 | ms.append(m[:,None,...])
181 | snapts.append(t[None])
182 | if return_est_x1:
183 | _g2P11=dyn.g2P11(t)
184 | est_x1 = (_fv/_g2P11+v)*(1-t)+x
185 | est_x1s.append(est_x1[:,None,...])
186 |
187 | xT = x+v*dt
188 | mT = torch.cat([xT,v],dim=1)
189 | ms.append(mT[:,None,...])
190 | if return_est_x1: est_x1s.append(est_x1[:,None,...])
191 | snapts.append(t[None])
192 | return torch.cat(ms,dim=1),mT, torch.cat(est_x1s,dim=1), torch.cat(snapts,dim=0)
193 |
194 |
195 | def em_sampler(m0,drift,cond,ts_dts,dyn,snap,local_rank,return_est_x1=True):
196 | DE_type = dyn.DE_type
197 | bs = m0.shape[0]
198 | rank = local_rank
199 | ts,dts = ts_dts
200 | interval= ts.shape[0]
201 | if rank == 0:
202 | times_horizon = zip(tqdm(ts,desc=util.blue("Propagating Dynamics..."),position=0,leave=False,colour='blue'),dts)
203 | else:
204 | times_horizon=zip(ts,dts)
205 | m = m0
206 | ms = []
207 | x,v = torch.chunk(m0,2,dim=1)
208 | snaps = np.linspace(0, interval-1, snap).astype(int)
209 | snapts = []
210 | normalizer=\
211 | dyn.get_normalizer()
212 | if return_est_x1: est_x1s = []
213 |
214 | for idx,(t,dt) in enumerate(times_horizon):
215 | _t = t.repeat(bs)
216 | _g = dyn.g(t)
217 | fv = drift(m,_t,cond)
218 | _g2P11 = dyn.g2P11(t)
219 |
220 | norm = (normalizer(t)).squeeze()
221 | fv = fv*norm
222 | m = torch.cat([x,v],dim=1)
223 |
224 | #=========dyn propagation===========
225 | x = x+v*dt
226 | dw = dw(v,dt) if DE_type == 'SDE' else torch.zeros_like(v)
227 | v = v+fv*dt+_g*dw
228 | #=========dyn propagation===========
229 |
230 | if idx in snaps:
231 | ms.append(m[:,None,...])
232 | snapts.append(t[None])
233 | if return_est_x1:
234 | est_x1 = get_est_x1(dyn,t,_g,fv,x,v,DE_type)
235 | est_x1s.append(est_x1[:,None,...])
236 |
237 | mT = m
238 |
239 | ms.append(mT[:,None,...])
240 |
241 | est_x1s.append(est_x1[:,None,...])
242 | snapts.append(t[None])
243 | return torch.cat(ms,dim=1),mT, torch.cat(est_x1s,dim=1), torch.cat(snapts,dim=0)
244 |
245 |
246 |
247 |
248 | @torch.no_grad()
249 | def gDDIM_sampler(m0,drift,cond,ts_dts,dyn,coef,snap,gDDIM_r,local_rank,return_est_x1=True,cond_opt=None):
250 | ts,dts = ts_dts
251 | conf_flag = False if cond_opt is None else True
252 |
253 |
254 | DE_type = dyn.DE_type
255 | assert DE_type == 'probODE'
256 | r = gDDIM_r
257 | rank = local_rank
258 | bs = m0.shape[0]
259 | m0 = m0.to(torch.float64)
260 | m = m0
261 | ms = []
262 | x,v = torch.chunk(m0,2,dim=1)
263 | interval = ts.shape[0]
264 | snaps = np.linspace(0, interval-1, snap).astype(int)
265 | snapts = []
266 | normalizer=\
267 | dyn.get_normalizer()
268 | intgral_norm = coef
269 |
270 | if rank == 0:
271 | times_horizon = zip(tqdm(ts,desc=util.blue("Propagating Dynamics..."),position=0,leave=False,colour='blue'),dts)
272 | else:
273 | times_horizon=zip(ts,dts)
274 |
275 | if return_est_x1: est_x1s = []
276 | prev_fv = []
277 |
278 |
279 | if conf_flag:
280 | stroke = cond_opt.stroke
281 | stroke_type = cond_opt.stroke_type
282 | impainting = cond_opt.impainting
283 | cond_strength\
284 | = 1.0 if impainting else 0.25
285 | cond_fn = impaint_stroke if impainting else dyn_stroke
286 | if stroke_type=='dyn-v': stroke_idx = int(cond_strength*ts.shape[0])
287 | if stroke_type=='init-v': v = 0.9*v+0.1*stroke
288 |
289 | for idx,(t,dt) in enumerate(times_horizon):
290 | _t = t.repeat(bs)
291 | _g = dyn.g(t)
292 | normt = (normalizer(t)).squeeze()
293 |
294 | # =======Conditional generation ==========
295 | if conf_flag and idx==0:
296 | fv = drift(m.to(torch.float32),_t.to(torch.float32),cond=cond).to(torch.float32)
297 | _fv = fv*normt
298 | est_x1 = get_est_x1(dyn,t,_g,_fv,x,v,DE_type)
299 |
300 | if conf_flag and stroke_type=='dyn-v' and idx=0 and len(prev_fv)==max_order
331 |
332 | if DE_type=='ODE':
333 | coef_fv = prev_fv[jj-1]/(1-t)*coef
334 | else:
335 | coef_fx = prev_fv[jj-1]*coef[0,1]
336 | coef_fv = prev_fv[jj-1]*coef[1,1]
337 |
338 | accumulated_fv += coef_fv
339 | accumulated_fx += coef_fx
340 |
341 |
342 | phit = phi_fn((t+dt)[None,None],t[None,None])[0]
343 | x = phit[0,0]*x+phit[0,1]*v+accumulated_fx
344 | v = phit[1,0]*x+phit[1,1]*v+accumulated_fv
345 | m = torch.cat([x,v],dim=1)
346 | if len(prev_fv)=0 and i-j>=0
389 | if k!=j:
390 | prod= prod* ((t-ts[i-k])/(ts[i-j]-ts[i-k]))[...,None]
391 | #=====time coef=========
392 | phi_matrix = phi_fn(ts[i]+dts[i],t).to(device)
393 | if DE_type=='ode':
394 | return phi_matrix@z@prod*(1-t)
395 | else:
396 | return phi_matrix@z@prod
397 | return _fn_r
398 |
399 | def AB_fn(normalizer,DE_type,ts,dts,r=0):
400 | intgral_norm = {}
401 | num_monte_carlo_sample = 50000
402 | for idx,(t,dt) in enumerate(zip(ts,dts)):
403 | max_order = min(idx,r)
404 | intgral_norm[idx] = {}
405 | for jj in range(max_order+1):
406 | coef_fn = extrapolate_fn(normalizer,DE_type,ts,dts,idx,j=jj,r=max_order)
407 | coef = monte_carlo_integral(coef_fn,t,t+dt,num_monte_carlo_sample)
408 | intgral_norm[idx][jj]=coef
409 | return intgral_norm
410 |
411 |
412 | def dyn_stroke(t,dyn,stroke,m,est_x1=None):
413 |
414 | currx,currv = torch.chunk(m,2,dim=1)
415 | Sigxx,Sigxv,Sigvv\
416 | = dyn.sigmaxx(t),dyn.sigmavx(t),dyn.sigmavv(t)
417 | noise = torch.randn(*m.shape,device=m.device)
418 | epsxx,epsvv = torch.chunk(noise,2,dim=1)
419 | muv = 0*(-4*t**3/3 + 4*t**2 - 4*t + 1) + t*(-4*0/3 + 4*stroke/3)*(t**2 - 3*t + 3)
420 | mux = -0*(t**4 - 4*t**3 + 6*t**2 - 3)/3 + t*(-0*(t**3 - 4*t**2 + 6*t - 3) + stroke*t*(t**2 - 4*t + 6))/3
421 | muv = muv+Sigxv/Sigxx*(currx-mux)
422 | Sigv = Sigvv-Sigxv**2/Sigxx
423 | v = muv+Sigv*epsvv
424 | return v
425 |
426 | def impaint_stroke(t,dyn,stroke,m,est_x1):
427 | mask = torch.ones_like(stroke)
428 | mask[stroke==-1]\
429 | = 0
430 | invmask = 1-mask
431 | stroke = stroke*mask+invmask*est_x1
432 | return dyn_stroke(t,dyn,stroke,m,est_x1)
433 |
434 |
--------------------------------------------------------------------------------
/AM/util.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | from prefetch_generator import BackgroundGenerator
6 | import os
7 | import warnings
8 | try:
9 | from torch.utils.tensorboard import SummaryWriter
10 | except:
11 | warnings.warn("install your favorite tensorboard version")
12 | import wandb
13 | import termcolor
14 | import torch
15 | from torch.utils.data import DataLoader
16 | import matplotlib.pyplot as plt
17 | import edm.distributed_util as dist_util
18 | import abc
19 | import numpy as np
20 |
21 | class DataLoaderX(DataLoader):
22 | def __iter__(self):
23 | return BackgroundGenerator(super().__iter__())
24 | def setup_loader(dataset, batch_size, num_workers=4):
25 | loader = DataLoaderX(
26 | dataset,
27 | batch_size=batch_size,
28 | pin_memory=True,
29 | shuffle=True,
30 | persistent_workers=True,
31 | num_workers=num_workers,
32 | multiprocessing_context='spawn',
33 | drop_last=True,
34 | prefetch_factor=4,
35 | )
36 | # return loader
37 | while True:
38 | yield from loader
39 |
40 |
41 |
42 |
43 | def save_ckpt(opt,log,it,net,ema,optimizer,sched,name):
44 | torch.save({
45 | "net": net.state_dict(),
46 | "ema": ema.state_dict(),
47 | "optimizer": optimizer.state_dict(),
48 | "sched": sched.state_dict() if sched is not None else sched,
49 | }, opt.ckpt_path / name)
50 | log.info(f"Saved latest({it=}) checkpoint to {opt.ckpt_path=}!")
51 |
52 | class BaseWriter(object):
53 | def __init__(self, opt):
54 | self.rank = opt.global_rank
55 | def add_scalar(self, step, key, val):
56 | pass # do nothing
57 | def add_image(self, step, key, image):
58 | pass # do nothing
59 | def add_bar(self, step, key, image):
60 | pass
61 | def close(self): pass
62 |
63 | class WandBWriter(BaseWriter):
64 | def __init__(self, opt):
65 | super(WandBWriter,self).__init__(opt)
66 | if self.rank == 0:
67 | assert wandb.login(key=opt.wandb_api_key)
68 | wandb.init(dir=str(opt.log_dir), project="VM", entity=opt.wandb_user, name=opt.name, config=vars(opt))
69 |
70 | def add_scalar(self, step, key, val):
71 | if self.rank == 0: wandb.log({key: val}, step=step)
72 |
73 | def add_image(self, step, key, image):
74 | if self.rank == 0:
75 | # adopt from torchvision.utils.save_image
76 | image = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
77 | wandb.log({key: wandb.Image(image)}, step=step)
78 |
79 | def add_bar(self,step,key,data):
80 | if self.rank == 0:
81 | fig, ax = plt.subplots()
82 | # plt.ylim([0,5])
83 | ts,loss=data
84 | ax.bar(ts, loss)
85 | wandb.log({"plot": wandb.Image(fig)},step=step)
86 |
87 |
88 | class TensorBoardWriter(BaseWriter):
89 | def __init__(self, opt):
90 | super(TensorBoardWriter,self).__init__(opt)
91 | if self.rank == 0:
92 | run_dir = str(opt.log_dir / opt.name)
93 | os.makedirs(run_dir, exist_ok=True)
94 | self.writer=SummaryWriter(log_dir=run_dir, flush_secs=20)
95 |
96 | def add_scalar(self, global_step, key, val):
97 | if self.rank == 0: self.writer.add_scalar(key, val, global_step=global_step)
98 |
99 | def add_image(self, global_step, key, image):
100 | if self.rank == 0:
101 | image = image.mul(255).add_(0.5).clamp_(0, 255).to("cpu", torch.uint8)
102 | self.writer.add_image(key, image, global_step=global_step)
103 |
104 | def close(self):
105 | if self.rank == 0: self.writer.close()
106 |
107 | def build_log_writer(opt):
108 | if opt.log_writer == 'wandb': return WandBWriter(opt)
109 | elif opt.log_writer == 'tensorboard': return TensorBoardWriter(opt)
110 | else: return BaseWriter(opt) # do nothing
111 |
112 | def count_parameters(model):
113 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
114 |
115 | def space_indices(num_steps, count):
116 | assert count <= num_steps
117 |
118 | if count <= 1:
119 | frac_stride = 1
120 | else:
121 | frac_stride = (num_steps - 1) / (count - 1)
122 |
123 | cur_idx = 0.0
124 | taken_steps = []
125 | for _ in range(count):
126 | taken_steps.append(round(cur_idx))
127 | cur_idx += frac_stride
128 |
129 | return taken_steps
130 |
131 | def unsqueeze_xdim(z, xdim):
132 | bc_dim = (...,) + (None,) * len(xdim)
133 | return z[bc_dim]
134 |
135 | def merge(opt,x,v):
136 | dim=-1 if opt.exp=='toy' else -3
137 | return torch.cat([x,v], dim=-1)
138 |
139 | def flatten_dim01(x):
140 | # (dim0, dim1, *dim2) --> (dim0x1, *dim2)
141 | return x.reshape(-1, *x.shape[2:])
142 |
143 | # convert to colored strings
144 | def red(content): return termcolor.colored(str(content),"red",attrs=["bold"])
145 | def green(content): return termcolor.colored(str(content),"green",attrs=["bold"])
146 | def blue(content): return termcolor.colored(str(content),"blue",attrs=["bold"])
147 | def cyan(content): return termcolor.colored(str(content),"cyan",attrs=["bold"])
148 | def yellow(content): return termcolor.colored(str(content),"yellow",attrs=["bold"])
149 | def magenta(content): return termcolor.colored(str(content),"magenta",attrs=["bold"])
150 |
151 |
152 | def count_parameters(model):
153 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
154 |
155 | def reshape_as(x,y):
156 | len_y = len(y.shape)-1
157 | return x.reshape(-1,*([1,]*len_y))
158 |
159 | def all_cat_cpu(opt, log, t):
160 | if not opt.distributed: return t.detach().cpu()
161 | log_flag = log if opt.local_rank == 0 else None
162 | gathered_t = dist_util.all_gather(t.to(opt.device), log=log_flag)
163 | return torch.cat(gathered_t).detach().cpu()
164 |
165 | def cast_shape(x,dims):
166 | return x.reshape(-1,*([1,]*len(dims)))
167 |
168 | def uniform_ts(t0,T,n,device):
169 | _ts = uniform(n,t0,T,device)
170 | return _ts
171 |
172 | def debug_ts(t0,T,n,device):
173 | return torch.linspace(t0, T, n, device=device)
174 |
175 | def heuristic2_ts(t0,T,n,device):
176 | _ts = torch.randn(n,device=device)*0.1+T
177 | _ts = _ts.abs()
178 | invalid_item=torch.logical_or(_ts>T, _tsT, _tsT, _ts
25 |
26 |
27 |
28 |
34 |
35 | ---
36 |
37 | ## Stroke-based Generative Modeling
38 | We can achieve stroke based Generative Modeling without further fine tuning and training given pretrained AGM model.
39 |
40 |

41 |
42 |
43 | ---
44 | ## Requirement
45 | For `pip` installation:
46 | ```
47 | bash setup/setup.sh
48 | ```
49 | For `Conda` installation:
50 | ```
51 | conda env create -f setup/environments.yml
52 | conda activate agm
53 | bash setup/conda_install.sh
54 | ```
55 | ---
56 | ## **Training**
57 | * **Download dataset**: You need to download the dataset and put the file under `/dataeset/`. `CIFAR-10` is download autmomatically. For `AFHQv2` and `Imagenet` we follow the same pipline as EDM. The example command line code for downloading the AFHQv2 can be found in the comment line in `setup/setup.sh`. You may have to download `Imagenet` dataset by yourself by following [EDM](https://github.com/NVlabs/edm) repo.
58 |
59 | * **Training**:Here we provide the command line for reproducing training used in our paper. You can add the argument `--log-writer wandb --wandb-user [Your-User-Name] --wandb-api-key [Your-Wandb-Key]` for monitoring the training process.
60 |
61 | **Toy**
62 | ```
63 | python train.py --name train-toy/sde-spiral --exp toy --toy-exp spiral #SDE
64 | python train.py --name train-toy/ode-spiral --exp toy --toy-exp spiral --DE-type probODE --solver gDDIM #ODE
65 | ```
66 | **Cifar10**:
67 | ```
68 | python train.py --name cifar10-repo --exp cifar10 --n-gpu-per-node 8
69 | ```
70 | **AFHQv2**:
71 | ```
72 | python train.py --name AFHQv2-repo --exp AFHQv2 --n-gpu-per-node 8
73 | ```
74 | **ImageNet64**:
75 | ```
76 | python train.py --name imagenet-repo --exp imagenet64 --n-gpu-per-node 8 --exp imagenet64 --num-itr 5000000 # Unconditional Generation
77 | ```
78 | ---
79 | ## **Sampling**
80 | Before sampling, make srue you download the checkpoint and store them in `results/Cifar10-ODE/`,`results/AFHQv2-ODE/` and `results/uncond-ImageNet64-ODE/` folder.
81 |
82 | Here we provide a short example generation command of generating 64 images on single RTX 3090 for CIFAR10, AFHQv2 and Imagenet:
83 | ```
84 | bash scripts/example.sh
85 | ```
86 | The corresponding generating time on single RTX 3090 is as following:
87 | | |NFE | ETA time |
88 | |----------|----------|----------|
89 | | CIFAR-10 | 20 | ~6 sec |
90 | | AFHQv2 | 20 | ~10 sec |
91 | | ImageNet | 20 | ~15 sec |
92 |
93 | ### **Toy datast Generation**:
94 | When you have trained the model, you can load it for fast sampling:
95 | ```
96 | #SDE with 10 NFEs
97 | python train.py --name train-toy/sde-spiral-eval --exp toy --toy-exp spiral --eval --nfe 10 --ckpt train-toy/sde-spiral
98 | #ODE with 10 NFEs
99 | python train.py --name train-toy/ode-spiral-eval --exp toy --toy-exp spiral --eval --nfe 10 --DE-type probODE --solver gDDIM --ckpt train-toy/ode-spiral
100 | ```
101 |
102 | ### **CIFAR-10 Generation and Evaluation**
103 | Using following command line to generate data. The generated images will be saved in `EVAL/cifar10-nfe[x]`
104 | ```
105 | #NFE=5 FID=11.88 #4.5mins for sampling 50k images
106 | python sampling.py --n-gpu-per-node 1 --ckpt --ckpt Cifar10-ODE/latest.pt --pred-x1 --solver gDDIM --T 0.4 --nfe 5 --fid-save-name cifar10-nfe5 --num-sample 50000 --batch-size 1000
107 |
108 | #NFE=10 FID=4.54
109 | python sampling.py --n-gpu-per-node 1 --ckpt --ckpt Cifar10-ODE/latest.pt --pred-x1 --solver gDDIM --T 0.7 --nfe 10 --fid-save-name cifar10-nfe10 --num-sample 50000 --batch-size 1000
110 |
111 | #NFE=20 FID=2.58
112 | python sampling.py --n-gpu-per-node 1 --ckpt --ckpt Cifar10-ODE/latest.pt --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name cifar10-nfe20 --num-sample 50000 --batch-size 1000
113 | ```
114 | Evaluate the FID using EDM evaluation:
115 | ```
116 | # x can be in {5,10,20}
117 | torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe[x] --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
118 | ```
119 |
120 | ### **AFHQv2 Generation and Evaluation**
121 | Using following command line to generate data. The generated images will be saved in `EVAL/AFHQv2-nfe[x]`
122 | ```
123 | #Stroke-based generation
124 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQv2-ODE/latest.pt --pred-x1 --solver gDDIM --save-img --img-save-name stroke-AFHQv2 --nfe 100 --T 0.999 --num-sample 64 --batch-size 64 --stroke-path dataset/StrokeData/testFig0.png --stroke-type dyn-v # [you can also replace --stroke-type by init-v]
125 |
126 | #Impainting generation
127 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQv2-ODE/latest.pt --pred-x1 --solver gDDIM --save-img --img-save-name impainting-AFHQv2 --nfe 100 --T 0.999 --num-sample 64 --batch-size 64 --stroke-path dataset/StrokeData/testFig0_impainting.png --stroke-type dyn-v --impainting
128 |
129 | #NFE 20 FID=3.72
130 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQv2-ODE/latest.pt --pred-x1 --solver gDDIM --fid-save-name AFHQv2-nfe20 --nfe 20 --T 0.9 --num-sample 50000 --batch-size 250
131 | ```
132 | Evaluate the FID using EDM evaluation:
133 | ```
134 | torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/AFHQv2-nfe20 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz
135 | ```
136 |
137 | ### **Imagenet64 Generation and Evaluation**
138 | Using following command line to generate data. The generated images will be saved in `EVAL/AFHQv2-nfe[x]`
139 | ```
140 | #NFE 20 FID=10.55
141 | python sampling.py --n-gpu-per-node 1 --ckpt uncond-ImageNet64-ODE/latest.pt --pred-x1 --solver gDDIM --fid-save-name ImageNet64-nfe20 --nfe 20 --T 0.99 --num-sample 50000 --batch-size 100
142 | #NFE 30 FID=10.07
143 | python sampling.py --n-gpu-per-node 1 --ckpt uncond-ImageNet64-ODE/latest.pt --pred-x1 --solver gDDIM --fid-save-name ImageNet64-nfe30 --nfe 30 --T 0.99 --num-sample 50000 --batch-size 100
144 | #NFE 40 FID=10.10
145 | python sampling.py --n-gpu-per-node 1 --ckpt uncond-ImageNet64-ODE/latest.pt --pred-x1 --solver gDDIM --fid-save-name ImageNet64-nfe40 --nfe 40 --T 0.9 --num-sample 50000 --batch-size 100
146 |
147 | ```
148 | Evaluate the FID using EDM evaluation (x is in [20,40]):
149 | ```
150 | torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/ImageNet64-nfe[x] --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/imagenet-64x64.npz
151 | ```
--------------------------------------------------------------------------------
/asset/AGMvsFM.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/AGMvsFM.gif
--------------------------------------------------------------------------------
/asset/cond_gen.001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/cond_gen.001.png
--------------------------------------------------------------------------------
/asset/cond_gen.jpg.001.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/cond_gen.jpg.001.jpeg
--------------------------------------------------------------------------------
/asset/cond_gen.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/cond_gen.pdf
--------------------------------------------------------------------------------
/asset/cond_gen2.001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/cond_gen2.001.png
--------------------------------------------------------------------------------
/asset/sampling_hop.png.001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/asset/sampling_hop.png.001.png
--------------------------------------------------------------------------------
/configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/configs/.DS_Store
--------------------------------------------------------------------------------
/configs/afhqv2_config.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import ml_collections
6 | def get_afhqv2_default_configs():
7 | config = ml_collections.ConfigDict()
8 | # training
9 | config.training = ml_collections.ConfigDict()
10 | config.seed = 42
11 | config.microbatch = 64
12 | config.n_gpu_per_node = 8
13 | config.lr = 1e-3
14 | config.precond = True
15 | config.reweight_type = 'reciprocal'
16 | config.t_samp = 'uniform'
17 | config.num_itr = 600000
18 |
19 | config.data_dim = [3,64,64]
20 | config.joint_dim = [6,64,64]
21 | #data
22 | config.xflip = True
23 | config.exp = 'AFHQv2'
24 |
25 | #Dynamics
26 | config.t0 = 1e-5
27 | config.T = 0.999
28 | config.dyn_type = 'TVgMPC'
29 | config.clip_grad = 1
30 | config.damp_t = 1
31 | config.p = 3
32 | config.k = 0.2
33 | config.varx = 1
34 | config.varv = 1
35 | config.DE_type = 'probODE'
36 |
37 | #Evaluation during training
38 | config.nfe = 100 #Evaluation interval during training, can be replaced
39 | config.solver = 'gDDIM'
40 | config.diz_order = 2
41 | config.diz = 'rev-quad'
42 | config.train_fid_sample = 4096 #Number of sample to evaluate FID during training
43 |
44 | model_configs = get_edm_NCSNpp_config()
45 | # model_configs=get_Unet_config()
46 | return config, model_configs
47 |
48 | def get_edm_NCSNpp_config():
49 | config = ml_collections.ConfigDict()
50 | config.image_size = 64
51 | config.name = "SongUNet"
52 | config.embedding_type = "fourier"
53 | config.encoder_type = "residual"
54 | config.decoder_type = "standard"
55 | config.resample_filter = [1,3,3,1]
56 | config.model_channels = 128
57 | config.channel_mult = [1,2,2,2]
58 | config.channel_mult_noise = 2
59 | config.dropout = 0.25
60 | config.label_dropout = 0
61 | config.channel_mult_emb = 4
62 | config.num_blocks = 4
63 | config.attn_resolutions = [16]
64 | config.in_channels = 6
65 | config.out_channels = 3
66 | config.label_dim = 0
67 | config.augment_dim = 0
68 | return config
--------------------------------------------------------------------------------
/configs/cifar10_config.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import ml_collections
6 | def get_cifar10_default_configs():
7 | config = ml_collections.ConfigDict()
8 | # training
9 | config.training = ml_collections.ConfigDict()
10 | config.seed = 42
11 | config.microbatch = 64
12 | config.n_gpu_per_node = 8
13 | config.lr = 1e-3
14 | config.precond = True
15 | config.reweight_type = 'reciprocal'
16 | config.t_samp = 'uniform'
17 | config.num_itr = 600000
18 |
19 | config.data_dim = [3,32,32]
20 | config.joint_dim = [6,32,32]
21 | #data
22 | config.xflip = True
23 | config.exp = 'cifar10'
24 |
25 | #Dynamics
26 | config.t0 = 1e-5
27 | config.T = 0.999
28 | config.dyn_type = 'TVgMPC'
29 | # config.algo = 'DM'
30 | config.clip_grad = 1
31 | config.damp_t = 1
32 | config.p = 3
33 | config.k = 0.2
34 | config.varx = 1
35 | config.varv = 1
36 | config.DE_type = 'probODE'
37 |
38 | #Evaluation during training
39 | config.nfe = 100 #Evaluation interval during training, can be replaced
40 | config.solver = 'gDDIM'
41 | config.diz_order = 2
42 | config.diz = 'rev-quad'
43 | config.train_fid_sample = 4096 #Number of sample to evaluate FID during training
44 |
45 | model_configs = get_edm_NCSNpp_config()
46 | # model_configs=get_Unet_config()
47 | return config, model_configs
48 |
49 | def get_Unet_config():
50 | config = ml_collections.ConfigDict()
51 | config.name = 'Unet'
52 | config.attention_resolutions = '16,8'
53 | config.in_channels = 6
54 | config.out_channel = 3
55 | config.num_head = 4
56 | config.num_res_blocks = 2
57 | config.num_channels = 128
58 | config.dropout = 0.1
59 | config.channel_mult = (1, 2, 2, 2)
60 | config.image_size = 32
61 | return config
62 |
63 |
64 | def get_NCSNpp_config():
65 | config = ml_collections.ConfigDict()
66 | config.name = "ncsnpp"
67 | config.normalization = "GroupNorm"
68 | config.image_size = 32
69 | config.image_channels = 3
70 | config.nonlinearity = "swish"
71 | config.n_channels = 128
72 | config.ch_mult = '1,2,2,2'
73 | config.attn_resolutions = '16'
74 | config.resamp_with_conv = True
75 | config.use_fir = True
76 | config.fir_kernel = '1,3,3,1'
77 | config.skip_rescale = True
78 | config.resblock_type = "biggan"
79 | config.progressive = 'none'
80 | config.progressive_input = "residual"
81 | config.progressive_combine = "sum"
82 | config.attention_type = "ddpm"
83 | config.init_scale = 0.0
84 | config.fourier_scale = 16
85 | config.conv_size = '3'
86 | config.embedding_type = "fourier"
87 | config.n_resblocks = 8
88 | config.dropout = 0.1
89 | # config.dropout = 0.2
90 | return config
91 |
92 | def get_edm_NCSNpp_config():
93 | config = ml_collections.ConfigDict()
94 | config.image_size = 32
95 | config.name = "SongUNet"
96 | config.embedding_type = "fourier"
97 | config.encoder_type = "residual"
98 | config.decoder_type = "standard"
99 | config.resample_filter = [1,3,3,1]
100 | config.model_channels = 128
101 | config.channel_mult = [2,2,2]
102 | config.channel_mult_noise = 2
103 | config.dropout = 0.13
104 | config.label_dropout = 0
105 | config.channel_mult_emb = 4
106 | config.num_blocks = 4
107 | config.attn_resolutions = [16]
108 | config.in_channels = 6
109 | config.out_channels = 3
110 | config.augment_dim = 0
111 | return config
--------------------------------------------------------------------------------
/configs/imagenet64_config.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import ml_collections
6 | def get_imagenet64_default_configs():
7 | config = ml_collections.ConfigDict()
8 | # basic
9 | config.training = ml_collections.ConfigDict()
10 | config.seed = 42
11 | config.microbatch = 64
12 | config.n_gpu_per_node = 8
13 | config.lr = 2e-4
14 | config.precond = True
15 | config.reweight_type = 'reciprocal'
16 | config.t_samp = 'uniform'
17 | config.num_itr = 10000000
18 | config.data_dim = [3,64,64]
19 | config.joint_dim = [6,64,64]
20 | #data
21 | config.xflip = True
22 | config.exp = 'imagenet64'
23 |
24 | #Dynamics
25 | config.t0 = 1e-5
26 | config.T = 0.999
27 | config.dyn_type = 'TVgMPC'
28 | # config.algo = 'DM'
29 | config.clip_grad = 1
30 | config.damp_t = 1
31 | config.p = 3
32 | config.k = 0.2
33 | config.varx = 1
34 | config.varv = 1
35 | config.DE_type = 'probODE'
36 |
37 | #Evaluation during training
38 | config.nfe = 100 #Evaluation interval during training, can be replaced
39 | config.solver = 'gDDIM'
40 | config.diz_order = 2
41 | config.diz = 'rev-quad'
42 | config.train_fid_sample = 4096 #Number of sample to evaluate FID during training
43 |
44 |
45 | model_configs = get_edm_ADM_config()
46 | # model_configs=get_Unet_config()
47 | return config, model_configs
48 |
49 | def get_edm_ADM_config():
50 | config = ml_collections.ConfigDict()
51 | config.image_size = 64
52 | config.name = "ADM"
53 | config.in_channels = 6
54 | config.out_channels = 3
55 | return config
56 |
--------------------------------------------------------------------------------
/configs/toy_config.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import ml_collections
6 |
7 | def get_toy_default_configs():
8 | config = ml_collections.ConfigDict()
9 | # training
10 | # config.training = ml_collections.ConfigDict()
11 | config.seed = 42
12 | config.num_itr = 40001
13 | config.t0 = 1e-5
14 | config.debug = True
15 | config.microbatch = 2048
16 | config.nfe = 200
17 | config.DE_type = 'SDE'
18 | config.t_samp = 'uniform'
19 | config.diz = 'Euler'
20 | config.solver = 'sscs'
21 | config.exp = 'toy'
22 | config.lr = 1e-3
23 | config.dyn_type = 'TVgMPC'
24 | config.T = 0.999
25 | config.p = 3
26 | config.k = 0.2
27 | config.varx = 1
28 | config.varv = 1
29 | config.data_dim = [2]
30 | config.joint_dim = [4]
31 | config.reweight_type = 'reciprocal'
32 |
33 | model_configs=None
34 | return config, model_configs
35 |
--------------------------------------------------------------------------------
/dataset/AFHQv2.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | """Streaming images and labels from datasets created with dataset_tool.py."""
7 | # /home/iamctr/Desktop/ACDS/AM/playground/refac-CG/AM-dev/dataset/AFHQv2.py
8 | from edm.dataset import ImageFolderDataset
9 | # /home/iamctr/Desktop/ACDS/AM/playground/refac-CG/AM-dev/edm/dataset.py
10 | from AM import util
11 | try:
12 | import pyspng
13 | except ImportError:
14 | pyspng = None
15 |
16 | class AFHQv2_data():
17 | """cifar10 dataset."""
18 | def __init__(self, opt):
19 | self.opt = opt
20 | bs = opt.microbatch
21 | x1 = ImageFolderDataset(path='dataset/afhqv2-64x64.zip')
22 | self.loader = util.setup_loader(x1,bs,num_workers=opt.n_gpu_per_node)
23 | assert not opt.cond
24 |
25 | def sample(self):
26 | x1 = next(self.loader)[0]
27 | x1 = x1/ 127.5 - 1
28 | label = None
29 |
30 | return x1.to(self.opt.device),label
31 |
--------------------------------------------------------------------------------
/dataset/StrokeData/testFig0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/dataset/StrokeData/testFig0.png
--------------------------------------------------------------------------------
/dataset/StrokeData/testFig0_impainting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/dataset/StrokeData/testFig0_impainting.png
--------------------------------------------------------------------------------
/dataset/StrokeData/testFig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/dataset/StrokeData/testFig1.png
--------------------------------------------------------------------------------
/dataset/StrokeData/testFig1_impainting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/dataset/StrokeData/testFig1_impainting.png
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
--------------------------------------------------------------------------------
/dataset/cifar10.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import torch
6 | from torchvision import transforms
7 | import torchvision.datasets as datasets
8 | import sys
9 | sys.path.append("..") # Adds higher directory to python modules path.
10 | from AM import util
11 | def tmp_fnc(t):
12 | '''Known issue for transforms in DDP settting'''
13 | return (t * 2) - 1
14 |
15 | class cifar10_data():
16 | """cifar10 dataset."""
17 | def __init__(self, opt):
18 | self.opt = opt
19 | bs = opt.microbatch
20 |
21 | x1 = self.generate_x1()
22 | self.loader = util.setup_loader(x1,bs)
23 |
24 | def generate_x1(self):
25 | transforms_list = [transforms.RandomHorizontalFlip(p=0.5)] if self.opt.xflip else []
26 | transforms_list+=[
27 | transforms.ToTensor(), #Convert to [0,1]
28 | transforms.Lambda(tmp_fnc) #Convert to [-1,1]
29 | ]
30 | x1=datasets.CIFAR10(
31 | './dataset',
32 | train= True,
33 | download=True,
34 | transform=transforms.Compose(transforms_list)
35 | )
36 | return x1
37 |
38 | def sample(self):
39 | x1,label = next(self.loader)[0],next(self.loader)[1]
40 | label =torch.nn.functional.one_hot(label, num_classes=10)
41 | return x1.to(self.opt.device),label.to(self.opt.device) if self.opt.cond else None
42 |
--------------------------------------------------------------------------------
/dataset/imagenet64.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | """Streaming images and labels from datasets created with dataset_tool.py."""
6 | from edm.dataset import ImageFolderDataset
7 | from AM import util
8 | try:
9 | import pyspng
10 | except ImportError:
11 | pyspng = None
12 |
13 | class imagenet64_data():
14 | """cifar10 dataset."""
15 | def __init__(self, opt):
16 | self.opt = opt
17 | bs = opt.microbatch
18 | x1 = ImageFolderDataset(path='dataset/imagenet-64x64.zip',use_labels=opt.cond,xflip=opt.xflip)
19 | self.loader = util.setup_loader(x1,bs,num_workers=opt.n_gpu_per_node)
20 |
21 | def sample(self):
22 | x1,label = next(self.loader)[0],next(self.loader)[1] #[bs, dims],[bs,one_hot]
23 | x1 = x1/ 127.5 - 1
24 | return x1.to(self.opt.device),label.to(self.opt.device) if self.opt.cond else None
25 | #----------------------------------------------------------------------------
26 | # Abstract base class for datasets.
27 |
--------------------------------------------------------------------------------
/dataset/spiral.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import torch
6 | import numpy as np
7 | from torch.utils.data import Dataset
8 | from sklearn.datasets import make_swiss_roll
9 |
10 | class spiral_data(Dataset):
11 | """Toy Spiral dataset."""
12 | def __init__(self, opt):
13 | n_train = opt.n_train
14 | self.opt = opt
15 | self.x1 = self.generate_x1(n_train)
16 | self.bs = opt.microbatch
17 | def generate_x1(self,n):
18 | '''
19 | n: number of total samples
20 | '''
21 | if self.opt.toy_exp=='gmm':
22 | WIDTH = 3
23 | BOUND = 0.5
24 | NOISE = 0.04
25 | ROTATION_MATRIX = np.array([[1., -1.], [1., 1.]]) / np.sqrt(2.)
26 |
27 | means = np.array([(x, y) for x in np.linspace(-BOUND, BOUND, WIDTH)
28 | for y in np.linspace(-BOUND, BOUND, WIDTH)])
29 | means = means @ ROTATION_MATRIX
30 | covariance_factor = NOISE * np.eye(2)
31 |
32 | index = np.random.choice(
33 | range(WIDTH ** 2), size=n, replace=True)
34 | noise = np.random.randn(n, 2)
35 | data = means[index] + noise @ covariance_factor
36 | data=torch.from_numpy(data.astype('float32'))
37 | data=data.to(self.opt.device)
38 | elif self.opt.toy_exp=='spiral':
39 | NOISE = 0.3
40 | MULTIPLIER = 0.05
41 | OFFSETS = [[1.2, 1.2], [1.2, -1.2], [-1.2, -1.2], [-1.2, 1.2]]
42 |
43 | idx = np.random.multinomial(n, [0.2] * 5, size=1)[0]
44 |
45 | sr = []
46 | for k in range(5):
47 | sr.append(make_swiss_roll(int(idx[k]), noise=NOISE)[
48 | 0][:, [0, 2]].astype('float32') * MULTIPLIER)
49 |
50 | if k > 0:
51 | sr[k] += np.array(OFFSETS[k - 1]).reshape(-1, 2)
52 |
53 | data = np.concatenate(sr, axis=0)[np.random.permutation(n)]
54 | data = torch.from_numpy(data.astype('float32'))
55 | data=data.to(self.opt.device)
56 | else:
57 | raise RuntimeError
58 |
59 | return data
60 |
61 | def __len__(self):
62 | return self.x1.shape[0]
63 |
64 | def sample(self):
65 | bs = self.bs
66 | lenth = self.x1.shape[0]
67 | idx = torch.randint(lenth, (bs,),device=self.opt.device)
68 | x1 = self.x1[idx]
69 | return x1, None
70 |
--------------------------------------------------------------------------------
/edm/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import zipfile
4 | import PIL.Image
5 | import json
6 | import torch
7 | from edm import dnnlib
8 |
9 | try:
10 | import pyspng
11 | except ImportError:
12 | pyspng = None
13 |
14 | class Dataset(torch.utils.data.Dataset):
15 | def __init__(self,
16 | name, # Name of the dataset.
17 | raw_shape, # Shape of the raw image data (NCHW).
18 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
19 | use_labels = False, # Enable conditioning labels? False = label dimension is zero.
20 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
21 | random_seed = 0, # Random seed to use when applying max_size.
22 | cache = False, # Cache images in CPU memory?
23 | ):
24 | self._name = name
25 | self._raw_shape = list(raw_shape)
26 | self._use_labels = use_labels
27 | self._cache = cache
28 | self._cached_images = dict() # {raw_idx: np.ndarray, ...}
29 | self._raw_labels = None
30 | self._label_shape = None
31 |
32 | # Apply max_size.
33 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
34 | if (max_size is not None) and (self._raw_idx.size > max_size):
35 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
36 | self._raw_idx = np.sort(self._raw_idx[:max_size])
37 |
38 | # Apply xflip.
39 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
40 | if xflip:
41 | self._raw_idx = np.tile(self._raw_idx, 2)
42 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
43 |
44 | def _get_raw_labels(self):
45 | if self._raw_labels is None:
46 | self._raw_labels = self._load_raw_labels() if self._use_labels else None
47 | if self._raw_labels is None:
48 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
49 | assert isinstance(self._raw_labels, np.ndarray)
50 | assert self._raw_labels.shape[0] == self._raw_shape[0]
51 | assert self._raw_labels.dtype in [np.float32, np.int64]
52 | if self._raw_labels.dtype == np.int64:
53 | assert self._raw_labels.ndim == 1
54 | assert np.all(self._raw_labels >= 0)
55 | return self._raw_labels
56 |
57 | def close(self): # to be overridden by subclass
58 | pass
59 |
60 | def _load_raw_image(self, raw_idx): # to be overridden by subclass
61 | raise NotImplementedError
62 |
63 | def _load_raw_labels(self): # to be overridden by subclass
64 | raise NotImplementedError
65 |
66 | def __getstate__(self):
67 | return dict(self.__dict__, _raw_labels=None)
68 |
69 | def __del__(self):
70 | try:
71 | self.close()
72 | except:
73 | pass
74 |
75 | def __len__(self):
76 | return self._raw_idx.size
77 |
78 | def __getitem__(self, idx):
79 | raw_idx = self._raw_idx[idx]
80 | image = self._cached_images.get(raw_idx, None)
81 | if image is None:
82 | image = self._load_raw_image(raw_idx)
83 | if self._cache:
84 | self._cached_images[raw_idx] = image
85 | assert isinstance(image, np.ndarray)
86 | assert list(image.shape) == self.image_shape
87 | assert image.dtype == np.uint8
88 | if self._xflip[idx]:
89 | assert image.ndim == 3 # CHW
90 | image = image[:, :, ::-1]
91 | return image.copy(), self.get_label(idx)
92 |
93 | def get_label(self, idx):
94 | label = self._get_raw_labels()[self._raw_idx[idx]]
95 | if label.dtype == np.int64:
96 | onehot = np.zeros(self.label_shape, dtype=np.float32)
97 | onehot[label] = 1
98 | label = onehot
99 | return label.copy()
100 |
101 | def get_details(self, idx):
102 | d = dnnlib.EasyDict()
103 | d.raw_idx = int(self._raw_idx[idx])
104 | d.xflip = (int(self._xflip[idx]) != 0)
105 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
106 | return d
107 |
108 | @property
109 | def name(self):
110 | return self._name
111 |
112 | @property
113 | def image_shape(self):
114 | return list(self._raw_shape[1:])
115 |
116 | @property
117 | def num_channels(self):
118 | assert len(self.image_shape) == 3 # CHW
119 | return self.image_shape[0]
120 |
121 | @property
122 | def resolution(self):
123 | assert len(self.image_shape) == 3 # CHW
124 | assert self.image_shape[1] == self.image_shape[2]
125 | return self.image_shape[1]
126 |
127 | @property
128 | def label_shape(self):
129 | if self._label_shape is None:
130 | raw_labels = self._get_raw_labels()
131 | if raw_labels.dtype == np.int64:
132 | self._label_shape = [int(np.max(raw_labels)) + 1]
133 | else:
134 | self._label_shape = raw_labels.shape[1:]
135 | return list(self._label_shape)
136 |
137 | @property
138 | def label_dim(self):
139 | assert len(self.label_shape) == 1
140 | return self.label_shape[0]
141 |
142 | @property
143 | def has_labels(self):
144 | return any(x != 0 for x in self.label_shape)
145 |
146 | @property
147 | def has_onehot_labels(self):
148 | return self._get_raw_labels().dtype == np.int64
149 |
150 | #----------------------------------------------------------------------------
151 | # Dataset subclass that loads images recursively from the specified directory
152 | # or ZIP file.
153 |
154 | class ImageFolderDataset(Dataset):
155 | def __init__(self,
156 | path, # Path to directory or zip.
157 | resolution = None, # Ensure specific resolution, None = highest available.
158 | use_pyspng = True, # Use pyspng if available?
159 | **super_kwargs, # Additional arguments for the Dataset base class.
160 | ):
161 | self._path = path
162 | self._use_pyspng = use_pyspng
163 | self._zipfile = None
164 |
165 | if os.path.isdir(self._path):
166 | self._type = 'dir'
167 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
168 | elif self._file_ext(self._path) == '.zip':
169 | self._type = 'zip'
170 | self._all_fnames = set(self._get_zipfile().namelist())
171 | else:
172 | raise IOError('Path must point to a directory or zip')
173 |
174 | PIL.Image.init()
175 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
176 | if len(self._image_fnames) == 0:
177 | raise IOError('No image files found in the specified path')
178 |
179 | name = os.path.splitext(os.path.basename(self._path))[0]
180 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
181 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
182 | raise IOError('Image files do not match the specified resolution')
183 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
184 |
185 | @staticmethod
186 | def _file_ext(fname):
187 | return os.path.splitext(fname)[1].lower()
188 |
189 | def _get_zipfile(self):
190 | assert self._type == 'zip'
191 | if self._zipfile is None:
192 | self._zipfile = zipfile.ZipFile(self._path)
193 | return self._zipfile
194 |
195 | def _open_file(self, fname):
196 | if self._type == 'dir':
197 | return open(os.path.join(self._path, fname), 'rb')
198 | if self._type == 'zip':
199 | return self._get_zipfile().open(fname, 'r')
200 | return None
201 |
202 | def close(self):
203 | try:
204 | if self._zipfile is not None:
205 | self._zipfile.close()
206 | finally:
207 | self._zipfile = None
208 |
209 | def __getstate__(self):
210 | return dict(super().__getstate__(), _zipfile=None)
211 |
212 | def _load_raw_image(self, raw_idx):
213 | fname = self._image_fnames[raw_idx]
214 | with self._open_file(fname) as f:
215 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png':
216 | image = pyspng.load(f.read())
217 | else:
218 | image = np.array(PIL.Image.open(f))
219 | if image.ndim == 2:
220 | image = image[:, :, np.newaxis] # HW => HWC
221 | image = image.transpose(2, 0, 1) # HWC => CHW
222 | return image
223 |
224 | def _load_raw_labels(self):
225 | fname = 'dataset.json'
226 | if fname not in self._all_fnames:
227 | return None
228 | with self._open_file(fname) as f:
229 | labels = json.load(f)['labels']
230 | if labels is None:
231 | return None
232 | labels = dict(labels)
233 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
234 | labels = np.array(labels)
235 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
236 | return labels
237 |
238 | #----------------------------------------------------------------------------
239 |
--------------------------------------------------------------------------------
/edm/dataset_tool.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Tool for creating ZIP/PNG based datasets."""
9 |
10 | import functools
11 | import gzip
12 | import io
13 | import json
14 | import os
15 | import pickle
16 | import re
17 | import sys
18 | import tarfile
19 | import zipfile
20 | from pathlib import Path
21 | from typing import Callable, Optional, Tuple, Union
22 | import click
23 | import numpy as np
24 | import PIL.Image
25 | from tqdm import tqdm
26 |
27 | #----------------------------------------------------------------------------
28 | # Parse a 'M,N' or 'MxN' integer tuple.
29 | # Example: '4x2' returns (4,2)
30 |
31 | def parse_tuple(s: str) -> Tuple[int, int]:
32 | m = re.match(r'^(\d+)[x,](\d+)$', s)
33 | if m:
34 | return int(m.group(1)), int(m.group(2))
35 | raise click.ClickException(f'cannot parse tuple {s}')
36 |
37 | #----------------------------------------------------------------------------
38 |
39 | def maybe_min(a: int, b: Optional[int]) -> int:
40 | if b is not None:
41 | return min(a, b)
42 | return a
43 |
44 | #----------------------------------------------------------------------------
45 |
46 | def file_ext(name: Union[str, Path]) -> str:
47 | return str(name).split('.')[-1]
48 |
49 | #----------------------------------------------------------------------------
50 |
51 | def is_image_ext(fname: Union[str, Path]) -> bool:
52 | ext = file_ext(fname).lower()
53 | return f'.{ext}' in PIL.Image.EXTENSION
54 |
55 | #----------------------------------------------------------------------------
56 |
57 | def open_image_folder(source_dir, *, max_images: Optional[int]):
58 | input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
59 | arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
60 | max_idx = maybe_min(len(input_images), max_images)
61 |
62 | # Load labels.
63 | labels = dict()
64 | meta_fname = os.path.join(source_dir, 'dataset.json')
65 | if os.path.isfile(meta_fname):
66 | with open(meta_fname, 'r') as file:
67 | data = json.load(file)['labels']
68 | if data is not None:
69 | labels = {x[0]: x[1] for x in data}
70 |
71 | # No labels available => determine from top-level directory names.
72 | if len(labels) == 0:
73 | toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
74 | toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
75 | if len(toplevel_indices) > 1:
76 | labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
77 |
78 | def iterate_images():
79 | for idx, fname in enumerate(input_images):
80 | img = np.array(PIL.Image.open(fname))
81 | yield dict(img=img, label=labels.get(arch_fnames.get(fname)))
82 | if idx >= max_idx - 1:
83 | break
84 | return max_idx, iterate_images()
85 |
86 | #----------------------------------------------------------------------------
87 |
88 | def open_image_zip(source, *, max_images: Optional[int]):
89 | with zipfile.ZipFile(source, mode='r') as z:
90 | input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
91 | max_idx = maybe_min(len(input_images), max_images)
92 |
93 | # Load labels.
94 | labels = dict()
95 | if 'dataset.json' in z.namelist():
96 | with z.open('dataset.json', 'r') as file:
97 | data = json.load(file)['labels']
98 | if data is not None:
99 | labels = {x[0]: x[1] for x in data}
100 |
101 | def iterate_images():
102 | with zipfile.ZipFile(source, mode='r') as z:
103 | for idx, fname in enumerate(input_images):
104 | with z.open(fname, 'r') as file:
105 | img = np.array(PIL.Image.open(file))
106 | yield dict(img=img, label=labels.get(fname))
107 | if idx >= max_idx - 1:
108 | break
109 | return max_idx, iterate_images()
110 |
111 | #----------------------------------------------------------------------------
112 |
113 | def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
114 | import cv2 # pyright: ignore [reportMissingImports] # pip install opencv-python
115 | import lmdb # pyright: ignore [reportMissingImports] # pip install lmdb
116 |
117 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
118 | max_idx = maybe_min(txn.stat()['entries'], max_images)
119 |
120 | def iterate_images():
121 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
122 | for idx, (_key, value) in enumerate(txn.cursor()):
123 | try:
124 | try:
125 | img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
126 | if img is None:
127 | raise IOError('cv2.imdecode failed')
128 | img = img[:, :, ::-1] # BGR => RGB
129 | except IOError:
130 | img = np.array(PIL.Image.open(io.BytesIO(value)))
131 | yield dict(img=img, label=None)
132 | if idx >= max_idx - 1:
133 | break
134 | except:
135 | print(sys.exc_info()[1])
136 |
137 | return max_idx, iterate_images()
138 |
139 | #----------------------------------------------------------------------------
140 |
141 | def open_cifar10(tarball: str, *, max_images: Optional[int]):
142 | images = []
143 | labels = []
144 |
145 | with tarfile.open(tarball, 'r:gz') as tar:
146 | for batch in range(1, 6):
147 | member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
148 | with tar.extractfile(member) as file:
149 | data = pickle.load(file, encoding='latin1')
150 | images.append(data['data'].reshape(-1, 3, 32, 32))
151 | labels.append(data['labels'])
152 |
153 | images = np.concatenate(images)
154 | labels = np.concatenate(labels)
155 | images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
156 | assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
157 | assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
158 | assert np.min(images) == 0 and np.max(images) == 255
159 | assert np.min(labels) == 0 and np.max(labels) == 9
160 |
161 | max_idx = maybe_min(len(images), max_images)
162 |
163 | def iterate_images():
164 | for idx, img in enumerate(images):
165 | yield dict(img=img, label=int(labels[idx]))
166 | if idx >= max_idx - 1:
167 | break
168 |
169 | return max_idx, iterate_images()
170 |
171 | #----------------------------------------------------------------------------
172 |
173 | def open_mnist(images_gz: str, *, max_images: Optional[int]):
174 | labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
175 | assert labels_gz != images_gz
176 | images = []
177 | labels = []
178 |
179 | with gzip.open(images_gz, 'rb') as f:
180 | images = np.frombuffer(f.read(), np.uint8, offset=16)
181 | with gzip.open(labels_gz, 'rb') as f:
182 | labels = np.frombuffer(f.read(), np.uint8, offset=8)
183 |
184 | images = images.reshape(-1, 28, 28)
185 | images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
186 | assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
187 | assert labels.shape == (60000,) and labels.dtype == np.uint8
188 | assert np.min(images) == 0 and np.max(images) == 255
189 | assert np.min(labels) == 0 and np.max(labels) == 9
190 |
191 | max_idx = maybe_min(len(images), max_images)
192 |
193 | def iterate_images():
194 | for idx, img in enumerate(images):
195 | yield dict(img=img, label=int(labels[idx]))
196 | if idx >= max_idx - 1:
197 | break
198 |
199 | return max_idx, iterate_images()
200 |
201 | #----------------------------------------------------------------------------
202 |
203 | def make_transform(
204 | transform: Optional[str],
205 | output_width: Optional[int],
206 | output_height: Optional[int]
207 | ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
208 | def scale(width, height, img):
209 | w = img.shape[1]
210 | h = img.shape[0]
211 | if width == w and height == h:
212 | return img
213 | img = PIL.Image.fromarray(img)
214 | ww = width if width is not None else w
215 | hh = height if height is not None else h
216 | img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
217 | return np.array(img)
218 |
219 | def center_crop(width, height, img):
220 | crop = np.min(img.shape[:2])
221 | img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
222 | if img.ndim == 2:
223 | img = img[:, :, np.newaxis].repeat(3, axis=2)
224 | img = PIL.Image.fromarray(img, 'RGB')
225 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
226 | return np.array(img)
227 |
228 | def center_crop_wide(width, height, img):
229 | ch = int(np.round(width * img.shape[0] / img.shape[1]))
230 | if img.shape[1] < width or ch < height:
231 | return None
232 |
233 | img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
234 | if img.ndim == 2:
235 | img = img[:, :, np.newaxis].repeat(3, axis=2)
236 | img = PIL.Image.fromarray(img, 'RGB')
237 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
238 | img = np.array(img)
239 |
240 | canvas = np.zeros([width, width, 3], dtype=np.uint8)
241 | canvas[(width - height) // 2 : (width + height) // 2, :] = img
242 | return canvas
243 |
244 | if transform is None:
245 | return functools.partial(scale, output_width, output_height)
246 | if transform == 'center-crop':
247 | if output_width is None or output_height is None:
248 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
249 | return functools.partial(center_crop, output_width, output_height)
250 | if transform == 'center-crop-wide':
251 | if output_width is None or output_height is None:
252 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
253 | return functools.partial(center_crop_wide, output_width, output_height)
254 | assert False, 'unknown transform'
255 |
256 | #----------------------------------------------------------------------------
257 |
258 | def open_dataset(source, *, max_images: Optional[int]):
259 | if os.path.isdir(source):
260 | if source.rstrip('/').endswith('_lmdb'):
261 | return open_lmdb(source, max_images=max_images)
262 | else:
263 | return open_image_folder(source, max_images=max_images)
264 | elif os.path.isfile(source):
265 | if os.path.basename(source) == 'cifar-10-python.tar.gz':
266 | return open_cifar10(source, max_images=max_images)
267 | elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
268 | return open_mnist(source, max_images=max_images)
269 | elif file_ext(source) == 'zip':
270 | return open_image_zip(source, max_images=max_images)
271 | else:
272 | assert False, 'unknown archive type'
273 | else:
274 | raise click.ClickException(f'Missing input file or directory: {source}')
275 |
276 | #----------------------------------------------------------------------------
277 |
278 | def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
279 | dest_ext = file_ext(dest)
280 |
281 | if dest_ext == 'zip':
282 | if os.path.dirname(dest) != '':
283 | os.makedirs(os.path.dirname(dest), exist_ok=True)
284 | zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
285 | def zip_write_bytes(fname: str, data: Union[bytes, str]):
286 | zf.writestr(fname, data)
287 | return '', zip_write_bytes, zf.close
288 | else:
289 | # If the output folder already exists, check that is is
290 | # empty.
291 | #
292 | # Note: creating the output directory is not strictly
293 | # necessary as folder_write_bytes() also mkdirs, but it's better
294 | # to give an error message earlier in case the dest folder
295 | # somehow cannot be created.
296 | if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
297 | raise click.ClickException('--dest folder must be empty')
298 | os.makedirs(dest, exist_ok=True)
299 |
300 | def folder_write_bytes(fname: str, data: Union[bytes, str]):
301 | os.makedirs(os.path.dirname(fname), exist_ok=True)
302 | with open(fname, 'wb') as fout:
303 | if isinstance(data, str):
304 | data = data.encode('utf8')
305 | fout.write(data)
306 | return dest, folder_write_bytes, lambda: None
307 |
308 | #----------------------------------------------------------------------------
309 |
310 | @click.command()
311 | @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
312 | @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
313 | @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
314 | @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide']))
315 | @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple)
316 |
317 | def main(
318 | source: str,
319 | dest: str,
320 | max_images: Optional[int],
321 | transform: Optional[str],
322 | resolution: Optional[Tuple[int, int]]
323 | ):
324 | """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
325 |
326 | The input dataset format is guessed from the --source argument:
327 |
328 | \b
329 | --source *_lmdb/ Load LSUN dataset
330 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset
331 | --source train-images-idx3-ubyte.gz Load MNIST dataset
332 | --source path/ Recursively load all images from path/
333 | --source dataset.zip Recursively load all images from dataset.zip
334 |
335 | Specifying the output format and path:
336 |
337 | \b
338 | --dest /path/to/dir Save output files under /path/to/dir
339 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
340 |
341 | The output dataset format can be either an image folder or an uncompressed zip archive.
342 | Zip archives makes it easier to move datasets around file servers and clusters, and may
343 | offer better training performance on network file systems.
344 |
345 | Images within the dataset archive will be stored as uncompressed PNG.
346 | Uncompresed PNGs can be efficiently decoded in the training loop.
347 |
348 | Class labels are stored in a file called 'dataset.json' that is stored at the
349 | dataset root folder. This file has the following structure:
350 |
351 | \b
352 | {
353 | "labels": [
354 | ["00000/img00000000.png",6],
355 | ["00000/img00000001.png",9],
356 | ... repeated for every image in the datase
357 | ["00049/img00049999.png",1]
358 | ]
359 | }
360 |
361 | If the 'dataset.json' file cannot be found, class labels are determined from
362 | top-level directory names.
363 |
364 | Image scale/crop and resolution requirements:
365 |
366 | Output images must be square-shaped and they must all have the same power-of-two
367 | dimensions.
368 |
369 | To scale arbitrary input image size to a specific width and height, use the
370 | --resolution option. Output resolution will be either the original
371 | input resolution (if resolution was not specified) or the one specified with
372 | --resolution option.
373 |
374 | Use the --transform=center-crop or --transform=center-crop-wide options to apply a
375 | center crop transform on the input image. These options should be used with the
376 | --resolution option. For example:
377 |
378 | \b
379 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
380 | --transform=center-crop-wide --resolution=512x384
381 | """
382 |
383 | PIL.Image.init()
384 |
385 | if dest == '':
386 | raise click.ClickException('--dest output filename or directory must not be an empty string')
387 |
388 | num_files, input_iter = open_dataset(source, max_images=max_images)
389 | archive_root_dir, save_bytes, close_dest = open_dest(dest)
390 |
391 | if resolution is None: resolution = (None, None)
392 | transform_image = make_transform(transform, *resolution)
393 |
394 | dataset_attrs = None
395 |
396 | labels = []
397 | for idx, image in tqdm(enumerate(input_iter), total=num_files):
398 | idx_str = f'{idx:08d}'
399 | archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
400 |
401 | # Apply crop and resize.
402 | img = transform_image(image['img'])
403 | if img is None:
404 | continue
405 |
406 | # Error check to require uniform image attributes across
407 | # the whole dataset.
408 | channels = img.shape[2] if img.ndim == 3 else 1
409 | cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0], 'channels': channels}
410 | if dataset_attrs is None:
411 | dataset_attrs = cur_image_attrs
412 | width = dataset_attrs['width']
413 | height = dataset_attrs['height']
414 | if width != height:
415 | raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
416 | if dataset_attrs['channels'] not in [1, 3]:
417 | raise click.ClickException('Input images must be stored as RGB or grayscale')
418 | if width != 2 ** int(np.floor(np.log2(width))):
419 | raise click.ClickException('Image width/height after scale and crop are required to be power-of-two')
420 | elif dataset_attrs != cur_image_attrs:
421 | err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
422 | raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
423 |
424 | # Save the image as an uncompressed PNG.
425 | img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels])
426 | image_bits = io.BytesIO()
427 | img.save(image_bits, format='png', compress_level=0, optimize=False)
428 | save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
429 | labels.append([archive_fname, image['label']] if image['label'] is not None else None)
430 |
431 | metadata = {'labels': labels if all(x is not None for x in labels) else None}
432 | save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
433 | close_dest()
434 |
435 | #----------------------------------------------------------------------------
436 |
437 | if __name__ == "__main__":
438 | main()
439 |
440 | #----------------------------------------------------------------------------
441 |
--------------------------------------------------------------------------------
/edm/distributed_util.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 |
7 | import torch.distributed as dist
8 | from torch.multiprocessing import Process
9 | def init_processes(rank, size, fn, args):
10 | """ Initialize the distributed environment. """
11 | os.environ['MASTER_ADDR'] = args.master_address
12 | os.environ['MASTER_PORT'] = args.port
13 | torch.cuda.set_device(args.local_rank)
14 | dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=size)
15 | fn(args)
16 | dist.barrier()
17 | cleanup()
18 |
19 | def cleanup():
20 | dist.destroy_process_group()
21 |
22 | def average_grads(params):
23 | size = float(dist.get_world_size())
24 | for param in params:
25 | if param.requires_grad:
26 | # _average_tensor(param.grad, size)
27 | with torch.no_grad():
28 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
29 | param.grad.data /= size
30 |
31 | def average_params(params):
32 | size = float(dist.get_world_size())
33 | for param in params:
34 | # _average_tensor(param, size)
35 | with torch.no_grad():
36 | dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
37 | param.data /= size
38 |
39 | def sync_params(params):
40 | """
41 | Synchronize a sequence of Tensors across ranks from rank 0.
42 | """
43 | for p in params:
44 | with torch.no_grad():
45 | dist.broadcast(p, 0)
46 |
47 | def all_gather(tensor, log=None):
48 | if log: log.info("Gathering tensor across {} devices... ".format(dist.get_world_size()))
49 | gathered_tensors = [
50 | torch.zeros_like(tensor) for _ in range(dist.get_world_size())
51 | ]
52 | with torch.no_grad():
53 | dist.all_gather(gathered_tensors, tensor)
54 | return gathered_tensors
--------------------------------------------------------------------------------
/edm/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | from .util import EasyDict, make_cache_dir_path
9 |
--------------------------------------------------------------------------------
/edm/dnnlib/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Miscellaneous utility classes and functions."""
9 |
10 | import ctypes
11 | import fnmatch
12 | import importlib
13 | import inspect
14 | import numpy as np
15 | import os
16 | import shutil
17 | import sys
18 | import types
19 | import io
20 | import pickle
21 | import re
22 | import requests
23 | import html
24 | import hashlib
25 | import glob
26 | import tempfile
27 | import urllib
28 | import urllib.request
29 | import uuid
30 |
31 | from distutils.util import strtobool
32 | from typing import Any, List, Tuple, Union, Optional
33 |
34 |
35 | # Util classes
36 | # ------------------------------------------------------------------------------------------
37 |
38 |
39 | class EasyDict(dict):
40 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
41 |
42 | def __getattr__(self, name: str) -> Any:
43 | try:
44 | return self[name]
45 | except KeyError:
46 | raise AttributeError(name)
47 |
48 | def __setattr__(self, name: str, value: Any) -> None:
49 | self[name] = value
50 |
51 | def __delattr__(self, name: str) -> None:
52 | del self[name]
53 |
54 |
55 | class Logger(object):
56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
57 |
58 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
59 | self.file = None
60 |
61 | if file_name is not None:
62 | self.file = open(file_name, file_mode)
63 |
64 | self.should_flush = should_flush
65 | self.stdout = sys.stdout
66 | self.stderr = sys.stderr
67 |
68 | sys.stdout = self
69 | sys.stderr = self
70 |
71 | def __enter__(self) -> "Logger":
72 | return self
73 |
74 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
75 | self.close()
76 |
77 | def write(self, text: Union[str, bytes]) -> None:
78 | """Write text to stdout (and a file) and optionally flush."""
79 | if isinstance(text, bytes):
80 | text = text.decode()
81 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
82 | return
83 |
84 | if self.file is not None:
85 | self.file.write(text)
86 |
87 | self.stdout.write(text)
88 |
89 | if self.should_flush:
90 | self.flush()
91 |
92 | def flush(self) -> None:
93 | """Flush written text to both stdout and a file, if open."""
94 | if self.file is not None:
95 | self.file.flush()
96 |
97 | self.stdout.flush()
98 |
99 | def close(self) -> None:
100 | """Flush, close possible files, and remove stdout/stderr mirroring."""
101 | self.flush()
102 |
103 | # if using multiple loggers, prevent closing in wrong order
104 | if sys.stdout is self:
105 | sys.stdout = self.stdout
106 | if sys.stderr is self:
107 | sys.stderr = self.stderr
108 |
109 | if self.file is not None:
110 | self.file.close()
111 | self.file = None
112 |
113 |
114 | # Cache directories
115 | # ------------------------------------------------------------------------------------------
116 |
117 | _dnnlib_cache_dir = None
118 |
119 | def set_cache_dir(path: str) -> None:
120 | global _dnnlib_cache_dir
121 | _dnnlib_cache_dir = path
122 |
123 | def make_cache_dir_path(*paths: str) -> str:
124 | if _dnnlib_cache_dir is not None:
125 | return os.path.join(_dnnlib_cache_dir, *paths)
126 | if 'DNNLIB_CACHE_DIR' in os.environ:
127 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
128 | if 'HOME' in os.environ:
129 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
130 | if 'USERPROFILE' in os.environ:
131 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
132 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
133 |
134 | # Small util functions
135 | # ------------------------------------------------------------------------------------------
136 |
137 |
138 | def format_time(seconds: Union[int, float]) -> str:
139 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
140 | s = int(np.rint(seconds))
141 |
142 | if s < 60:
143 | return "{0}s".format(s)
144 | elif s < 60 * 60:
145 | return "{0}m {1:02}s".format(s // 60, s % 60)
146 | elif s < 24 * 60 * 60:
147 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
148 | else:
149 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
150 |
151 |
152 | def format_time_brief(seconds: Union[int, float]) -> str:
153 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
154 | s = int(np.rint(seconds))
155 |
156 | if s < 60:
157 | return "{0}s".format(s)
158 | elif s < 60 * 60:
159 | return "{0}m {1:02}s".format(s // 60, s % 60)
160 | elif s < 24 * 60 * 60:
161 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
162 | else:
163 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
164 |
165 |
166 | def ask_yes_no(question: str) -> bool:
167 | """Ask the user the question until the user inputs a valid answer."""
168 | while True:
169 | try:
170 | print("{0} [y/n]".format(question))
171 | return strtobool(input().lower())
172 | except ValueError:
173 | pass
174 |
175 |
176 | def tuple_product(t: Tuple) -> Any:
177 | """Calculate the product of the tuple elements."""
178 | result = 1
179 |
180 | for v in t:
181 | result *= v
182 |
183 | return result
184 |
185 |
186 | _str_to_ctype = {
187 | "uint8": ctypes.c_ubyte,
188 | "uint16": ctypes.c_uint16,
189 | "uint32": ctypes.c_uint32,
190 | "uint64": ctypes.c_uint64,
191 | "int8": ctypes.c_byte,
192 | "int16": ctypes.c_int16,
193 | "int32": ctypes.c_int32,
194 | "int64": ctypes.c_int64,
195 | "float32": ctypes.c_float,
196 | "float64": ctypes.c_double
197 | }
198 |
199 |
200 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
201 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
202 | type_str = None
203 |
204 | if isinstance(type_obj, str):
205 | type_str = type_obj
206 | elif hasattr(type_obj, "__name__"):
207 | type_str = type_obj.__name__
208 | elif hasattr(type_obj, "name"):
209 | type_str = type_obj.name
210 | else:
211 | raise RuntimeError("Cannot infer type name from input")
212 |
213 | assert type_str in _str_to_ctype.keys()
214 |
215 | my_dtype = np.dtype(type_str)
216 | my_ctype = _str_to_ctype[type_str]
217 |
218 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
219 |
220 | return my_dtype, my_ctype
221 |
222 |
223 | def is_pickleable(obj: Any) -> bool:
224 | try:
225 | with io.BytesIO() as stream:
226 | pickle.dump(obj, stream)
227 | return True
228 | except:
229 | return False
230 |
231 |
232 | # Functionality to import modules/objects by name, and call functions by name
233 | # ------------------------------------------------------------------------------------------
234 |
235 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
236 | """Searches for the underlying module behind the name to some python object.
237 | Returns the module and the object name (original name with module part removed)."""
238 |
239 | # allow convenience shorthands, substitute them by full names
240 | obj_name = re.sub("^np.", "numpy.", obj_name)
241 | obj_name = re.sub("^tf.", "tensorflow.", obj_name)
242 |
243 | # list alternatives for (module_name, local_obj_name)
244 | parts = obj_name.split(".")
245 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
246 |
247 | # try each alternative in turn
248 | for module_name, local_obj_name in name_pairs:
249 | try:
250 | module = importlib.import_module(module_name) # may raise ImportError
251 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
252 | return module, local_obj_name
253 | except:
254 | pass
255 |
256 | # maybe some of the modules themselves contain errors?
257 | for module_name, _local_obj_name in name_pairs:
258 | try:
259 | importlib.import_module(module_name) # may raise ImportError
260 | except ImportError:
261 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
262 | raise
263 |
264 | # maybe the requested attribute is missing?
265 | for module_name, local_obj_name in name_pairs:
266 | try:
267 | module = importlib.import_module(module_name) # may raise ImportError
268 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
269 | except ImportError:
270 | pass
271 |
272 | # we are out of luck, but we have no idea why
273 | raise ImportError(obj_name)
274 |
275 |
276 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
277 | """Traverses the object name and returns the last (rightmost) python object."""
278 | if obj_name == '':
279 | return module
280 | obj = module
281 | for part in obj_name.split("."):
282 | obj = getattr(obj, part)
283 | return obj
284 |
285 |
286 | def get_obj_by_name(name: str) -> Any:
287 | """Finds the python object with the given name."""
288 | module, obj_name = get_module_from_obj_name(name)
289 | return get_obj_from_module(module, obj_name)
290 |
291 |
292 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
293 | """Finds the python object with the given name and calls it as a function."""
294 | assert func_name is not None
295 | func_obj = get_obj_by_name(func_name)
296 | assert callable(func_obj)
297 | return func_obj(*args, **kwargs)
298 |
299 |
300 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
301 | """Finds the python class with the given name and constructs it with the given arguments."""
302 | return call_func_by_name(*args, func_name=class_name, **kwargs)
303 |
304 |
305 | def get_module_dir_by_obj_name(obj_name: str) -> str:
306 | """Get the directory path of the module containing the given object name."""
307 | module, _ = get_module_from_obj_name(obj_name)
308 | return os.path.dirname(inspect.getfile(module))
309 |
310 |
311 | def is_top_level_function(obj: Any) -> bool:
312 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
313 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
314 |
315 |
316 | def get_top_level_function_name(obj: Any) -> str:
317 | """Return the fully-qualified name of a top-level function."""
318 | assert is_top_level_function(obj)
319 | module = obj.__module__
320 | if module == '__main__':
321 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
322 | return module + "." + obj.__name__
323 |
324 |
325 | # File system helpers
326 | # ------------------------------------------------------------------------------------------
327 |
328 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
329 | """List all files recursively in a given directory while ignoring given file and directory names.
330 | Returns list of tuples containing both absolute and relative paths."""
331 | assert os.path.isdir(dir_path)
332 | base_name = os.path.basename(os.path.normpath(dir_path))
333 |
334 | if ignores is None:
335 | ignores = []
336 |
337 | result = []
338 |
339 | for root, dirs, files in os.walk(dir_path, topdown=True):
340 | for ignore_ in ignores:
341 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
342 |
343 | # dirs need to be edited in-place
344 | for d in dirs_to_remove:
345 | dirs.remove(d)
346 |
347 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
348 |
349 | absolute_paths = [os.path.join(root, f) for f in files]
350 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
351 |
352 | if add_base_to_relative:
353 | relative_paths = [os.path.join(base_name, p) for p in relative_paths]
354 |
355 | assert len(absolute_paths) == len(relative_paths)
356 | result += zip(absolute_paths, relative_paths)
357 |
358 | return result
359 |
360 |
361 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
362 | """Takes in a list of tuples of (src, dst) paths and copies files.
363 | Will create all necessary directories."""
364 | for file in files:
365 | target_dir_name = os.path.dirname(file[1])
366 |
367 | # will create all intermediate-level directories
368 | if not os.path.exists(target_dir_name):
369 | os.makedirs(target_dir_name)
370 |
371 | shutil.copyfile(file[0], file[1])
372 |
373 |
374 | # URL helpers
375 | # ------------------------------------------------------------------------------------------
376 |
377 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
378 | """Determine whether the given object is a valid URL string."""
379 | if not isinstance(obj, str) or not "://" in obj:
380 | return False
381 | if allow_file_urls and obj.startswith('file://'):
382 | return True
383 | try:
384 | res = requests.compat.urlparse(obj)
385 | if not res.scheme or not res.netloc or not "." in res.netloc:
386 | return False
387 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
388 | if not res.scheme or not res.netloc or not "." in res.netloc:
389 | return False
390 | except:
391 | return False
392 | return True
393 |
394 |
395 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
396 | """Download the given URL and return a binary-mode file object to access the data."""
397 | assert num_attempts >= 1
398 | assert not (return_filename and (not cache))
399 |
400 | # Doesn't look like an URL scheme so interpret it as a local filename.
401 | if not re.match('^[a-z]+://', url):
402 | return url if return_filename else open(url, "rb")
403 |
404 | # Handle file URLs. This code handles unusual file:// patterns that
405 | # arise on Windows:
406 | #
407 | # file:///c:/foo.txt
408 | #
409 | # which would translate to a local '/c:/foo.txt' filename that's
410 | # invalid. Drop the forward slash for such pathnames.
411 | #
412 | # If you touch this code path, you should test it on both Linux and
413 | # Windows.
414 | #
415 | # Some internet resources suggest using urllib.request.url2pathname() but
416 | # but that converts forward slashes to backslashes and this causes
417 | # its own set of problems.
418 | if url.startswith('file://'):
419 | filename = urllib.parse.urlparse(url).path
420 | if re.match(r'^/[a-zA-Z]:', filename):
421 | filename = filename[1:]
422 | return filename if return_filename else open(filename, "rb")
423 |
424 | assert is_url(url)
425 |
426 | # Lookup from cache.
427 | if cache_dir is None:
428 | cache_dir = make_cache_dir_path('downloads')
429 |
430 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
431 | if cache:
432 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
433 | if len(cache_files) == 1:
434 | filename = cache_files[0]
435 | return filename if return_filename else open(filename, "rb")
436 |
437 | # Download.
438 | url_name = None
439 | url_data = None
440 | with requests.Session() as session:
441 | if verbose:
442 | print("Downloading %s ..." % url, end="", flush=True)
443 | for attempts_left in reversed(range(num_attempts)):
444 | try:
445 | with session.get(url) as res:
446 | res.raise_for_status()
447 | if len(res.content) == 0:
448 | raise IOError("No data received")
449 |
450 | if len(res.content) < 8192:
451 | content_str = res.content.decode("utf-8")
452 | if "download_warning" in res.headers.get("Set-Cookie", ""):
453 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
454 | if len(links) == 1:
455 | url = requests.compat.urljoin(url, links[0])
456 | raise IOError("Google Drive virus checker nag")
457 | if "Google Drive - Quota exceeded" in content_str:
458 | raise IOError("Google Drive download quota exceeded -- please try again later")
459 |
460 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
461 | url_name = match[1] if match else url
462 | url_data = res.content
463 | if verbose:
464 | print(" done")
465 | break
466 | except KeyboardInterrupt:
467 | raise
468 | except:
469 | if not attempts_left:
470 | if verbose:
471 | print(" failed")
472 | raise
473 | if verbose:
474 | print(".", end="", flush=True)
475 |
476 | # Save to cache.
477 | if cache:
478 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
479 | safe_name = safe_name[:min(len(safe_name), 128)]
480 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482 | os.makedirs(cache_dir, exist_ok=True)
483 | with open(temp_file, "wb") as f:
484 | f.write(url_data)
485 | os.replace(temp_file, cache_file) # atomic
486 | if return_filename:
487 | return cache_file
488 |
489 | # Return data as file object.
490 | assert not return_filename
491 | return io.BytesIO(url_data)
492 |
--------------------------------------------------------------------------------
/edm/fid.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Script for calculating Frechet Inception Distance (FID)."""
9 | import os
10 | import click
11 | import tqdm
12 | import pickle
13 | import numpy as np
14 | import scipy.linalg
15 | import torch
16 | import sys
17 | import os
18 | sys.path.append(os.path.abspath('edm'))
19 | import dnnlib
20 | from torch_utils import distributed as dist
21 | import zipfile
22 | import PIL.Image
23 | try:
24 | import pyspng
25 | except ImportError:
26 | pyspng = None
27 |
28 | #----------------------------------------------------------------------------
29 | # Abstract base class for datasets.
30 |
31 | class Dataset(torch.utils.data.Dataset):
32 | def __init__(self,
33 | name, # Name of the dataset.
34 | raw_shape, # Shape of the raw image data (NCHW).
35 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
36 | use_labels = False, # Enable conditioning labels? False = label dimension is zero.
37 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
38 | random_seed = 0, # Random seed to use when applying max_size.
39 | cache = False, # Cache images in CPU memory?
40 | ):
41 | self._name = name
42 | self._raw_shape = list(raw_shape)
43 | self._use_labels = use_labels
44 | self._cache = cache
45 | self._cached_images = dict() # {raw_idx: np.ndarray, ...}
46 | self._raw_labels = None
47 | self._label_shape = None
48 |
49 | # Apply max_size.
50 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
51 | if (max_size is not None) and (self._raw_idx.size > max_size):
52 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
53 | self._raw_idx = np.sort(self._raw_idx[:max_size])
54 |
55 | # Apply xflip.
56 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
57 | if xflip:
58 | self._raw_idx = np.tile(self._raw_idx, 2)
59 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
60 |
61 | def _get_raw_labels(self):
62 | if self._raw_labels is None:
63 | self._raw_labels = self._load_raw_labels() if self._use_labels else None
64 | if self._raw_labels is None:
65 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
66 | assert isinstance(self._raw_labels, np.ndarray)
67 | assert self._raw_labels.shape[0] == self._raw_shape[0]
68 | assert self._raw_labels.dtype in [np.float32, np.int64]
69 | if self._raw_labels.dtype == np.int64:
70 | assert self._raw_labels.ndim == 1
71 | assert np.all(self._raw_labels >= 0)
72 | return self._raw_labels
73 |
74 | def close(self): # to be overridden by subclass
75 | pass
76 |
77 | def _load_raw_image(self, raw_idx): # to be overridden by subclass
78 | raise NotImplementedError
79 |
80 | def _load_raw_labels(self): # to be overridden by subclass
81 | raise NotImplementedError
82 |
83 | def __getstate__(self):
84 | return dict(self.__dict__, _raw_labels=None)
85 |
86 | def __del__(self):
87 | try:
88 | self.close()
89 | except:
90 | pass
91 |
92 | def __len__(self):
93 | return self._raw_idx.size
94 |
95 | def __getitem__(self, idx):
96 | raw_idx = self._raw_idx[idx]
97 | image = self._cached_images.get(raw_idx, None)
98 | if image is None:
99 | image = self._load_raw_image(raw_idx)
100 | if self._cache:
101 | self._cached_images[raw_idx] = image
102 | assert isinstance(image, np.ndarray)
103 | assert list(image.shape) == self.image_shape
104 | assert image.dtype == np.uint8
105 | if self._xflip[idx]:
106 | assert image.ndim == 3 # CHW
107 | image = image[:, :, ::-1]
108 | return image.copy(), self.get_label(idx)
109 |
110 | def get_label(self, idx):
111 | label = self._get_raw_labels()[self._raw_idx[idx]]
112 | if label.dtype == np.int64:
113 | onehot = np.zeros(self.label_shape, dtype=np.float32)
114 | onehot[label] = 1
115 | label = onehot
116 | return label.copy()
117 |
118 | def get_details(self, idx):
119 | d = dnnlib.EasyDict()
120 | d.raw_idx = int(self._raw_idx[idx])
121 | d.xflip = (int(self._xflip[idx]) != 0)
122 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
123 | return d
124 |
125 | @property
126 | def name(self):
127 | return self._name
128 |
129 | @property
130 | def image_shape(self):
131 | return list(self._raw_shape[1:])
132 |
133 | @property
134 | def num_channels(self):
135 | assert len(self.image_shape) == 3 # CHW
136 | return self.image_shape[0]
137 |
138 | @property
139 | def resolution(self):
140 | assert len(self.image_shape) == 3 # CHW
141 | assert self.image_shape[1] == self.image_shape[2]
142 | return self.image_shape[1]
143 |
144 | @property
145 | def label_shape(self):
146 | if self._label_shape is None:
147 | raw_labels = self._get_raw_labels()
148 | if raw_labels.dtype == np.int64:
149 | self._label_shape = [int(np.max(raw_labels)) + 1]
150 | else:
151 | self._label_shape = raw_labels.shape[1:]
152 | return list(self._label_shape)
153 |
154 | @property
155 | def label_dim(self):
156 | assert len(self.label_shape) == 1
157 | return self.label_shape[0]
158 |
159 | @property
160 | def has_labels(self):
161 | return any(x != 0 for x in self.label_shape)
162 |
163 | @property
164 | def has_onehot_labels(self):
165 | return self._get_raw_labels().dtype == np.int64
166 |
167 | #----------------------------------------------------------------------------
168 | # Dataset subclass that loads images recursively from the specified directory
169 | # or ZIP file.
170 |
171 | class ImageFolderDataset(Dataset):
172 | def __init__(self,
173 | path, # Path to directory or zip.
174 | resolution = None, # Ensure specific resolution, None = highest available.
175 | use_pyspng = True, # Use pyspng if available?
176 | **super_kwargs, # Additional arguments for the Dataset base class.
177 | ):
178 | self._path = path
179 | self._use_pyspng = use_pyspng
180 | self._zipfile = None
181 |
182 | if os.path.isdir(self._path):
183 | self._type = 'dir'
184 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
185 | elif self._file_ext(self._path) == '.zip':
186 | self._type = 'zip'
187 | self._all_fnames = set(self._get_zipfile().namelist())
188 | else:
189 | raise IOError('Path must point to a directory or zip')
190 |
191 | PIL.Image.init()
192 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
193 | if len(self._image_fnames) == 0:
194 | raise IOError('No image files found in the specified path')
195 |
196 | name = os.path.splitext(os.path.basename(self._path))[0]
197 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
198 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
199 | raise IOError('Image files do not match the specified resolution')
200 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
201 |
202 | @staticmethod
203 | def _file_ext(fname):
204 | return os.path.splitext(fname)[1].lower()
205 |
206 | def _get_zipfile(self):
207 | assert self._type == 'zip'
208 | if self._zipfile is None:
209 | self._zipfile = zipfile.ZipFile(self._path)
210 | return self._zipfile
211 |
212 | def _open_file(self, fname):
213 | if self._type == 'dir':
214 | return open(os.path.join(self._path, fname), 'rb')
215 | if self._type == 'zip':
216 | return self._get_zipfile().open(fname, 'r')
217 | return None
218 |
219 | def close(self):
220 | try:
221 | if self._zipfile is not None:
222 | self._zipfile.close()
223 | finally:
224 | self._zipfile = None
225 |
226 | def __getstate__(self):
227 | return dict(super().__getstate__(), _zipfile=None)
228 |
229 | def _load_raw_image(self, raw_idx):
230 | fname = self._image_fnames[raw_idx]
231 | with self._open_file(fname) as f:
232 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png':
233 | image = pyspng.load(f.read())
234 | else:
235 | image = np.array(PIL.Image.open(f))
236 | if image.ndim == 2:
237 | image = image[:, :, np.newaxis] # HW => HWC
238 | image = image.transpose(2, 0, 1) # HWC => CHW
239 | return image
240 |
241 | def _load_raw_labels(self):
242 | fname = 'dataset.json'
243 | if fname not in self._all_fnames:
244 | return None
245 | with self._open_file(fname) as f:
246 | labels = json.load(f)['labels']
247 | if labels is None:
248 | return None
249 | labels = dict(labels)
250 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
251 | labels = np.array(labels)
252 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
253 | return labels
254 |
255 | #----------------------------------------------------------------------------
256 |
257 |
258 | #----------------------------------------------------------------------------
259 |
260 | def calculate_inception_stats(
261 | image_path, num_expected=None, seed=0, max_batch_size=64,
262 | num_workers=4, prefetch_factor=2, device=torch.device('cuda'),
263 | ):
264 | # Rank 0 goes first.
265 | if dist.get_rank() != 0:
266 | torch.distributed.barrier()
267 |
268 | # Load Inception-v3 model.
269 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
270 | dist.print0('Loading Inception-v3 model...')
271 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
272 | detector_kwargs = dict(return_features=True)
273 | feature_dim = 2048
274 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f:
275 | detector_net = pickle.load(f).to(device)
276 |
277 | # List images.
278 | dist.print0(f'Loading images from "{image_path}"...')
279 | dataset_obj = ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
280 | if num_expected is not None and len(dataset_obj) < num_expected:
281 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
282 | if len(dataset_obj) < 2:
283 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')
284 |
285 | # Other ranks follow.
286 | if dist.get_rank() == 0:
287 | torch.distributed.barrier()
288 |
289 | # Divide images into batches.
290 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
291 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
292 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
293 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
294 |
295 | # Accumulate statistics.
296 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...')
297 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
298 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
299 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
300 | torch.distributed.barrier()
301 | if images.shape[0] == 0:
302 | continue
303 | if images.shape[1] == 1:
304 | images = images.repeat([1, 3, 1, 1])
305 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
306 | mu += features.sum(0)
307 | sigma += features.T @ features
308 |
309 | # Calculate grand totals.
310 | torch.distributed.all_reduce(mu)
311 | torch.distributed.all_reduce(sigma)
312 | mu /= len(dataset_obj)
313 | sigma -= mu.ger(mu) * len(dataset_obj)
314 | sigma /= len(dataset_obj) - 1
315 | return mu.cpu().numpy(), sigma.cpu().numpy()
316 |
317 | #----------------------------------------------------------------------------
318 |
319 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
320 | m = np.square(mu - mu_ref).sum()
321 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
322 | fid = m + np.trace(sigma + sigma_ref - s * 2)
323 | return float(np.real(fid))
324 |
325 | #----------------------------------------------------------------------------
326 |
327 | @click.group()
328 | def main():
329 | """Calculate Frechet Inception Distance (FID).
330 |
331 | Examples:
332 |
333 | \b
334 | # Generate 50000 images and save them as fid-tmp/*/*.png
335 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\
336 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
337 |
338 | \b
339 | # Calculate FID
340 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
341 |
342 | \b
343 | # Compute dataset reference statistics
344 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
345 | """
346 |
347 | #----------------------------------------------------------------------------
348 |
349 | @main.command()
350 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True)
351 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True)
352 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True)
353 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True)
354 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
355 |
356 | def calc(image_path, ref_path, num_expected, seed, batch):
357 | """Calculate FID for a given set of images."""
358 | torch.multiprocessing.set_start_method('spawn')
359 | dist.init()
360 |
361 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...')
362 | ref = None
363 | if dist.get_rank() == 0:
364 | with dnnlib.util.open_url(ref_path) as f:
365 | ref = dict(np.load(f))
366 |
367 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch)
368 | dist.print0('Calculating FID...')
369 | if dist.get_rank() == 0:
370 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma'])
371 | print(f'{fid:g}')
372 | torch.distributed.barrier()
373 |
374 | #----------------------------------------------------------------------------
375 |
376 | @main.command()
377 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True)
378 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True)
379 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
380 |
381 | def ref(dataset_path, dest_path, batch):
382 | """Calculate dataset reference statistics needed by 'calc'."""
383 | torch.multiprocessing.set_start_method('spawn')
384 | dist.init()
385 |
386 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch)
387 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...')
388 | if dist.get_rank() == 0:
389 | if os.path.dirname(dest_path):
390 | os.makedirs(os.path.dirname(dest_path), exist_ok=True)
391 | np.savez(dest_path, mu=mu, sigma=sigma)
392 |
393 | torch.distributed.barrier()
394 | dist.print0('Done.')
395 |
396 | #----------------------------------------------------------------------------
397 |
398 | if __name__ == "__main__":
399 | main()
400 |
401 | # python dataset_tool.py --source=dataset/cifar-10-python.tar.gz --dest=dataset/cifar10-32x32.zip
402 | # python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz
403 | #----------------------------------------------------------------------------
--------------------------------------------------------------------------------
/edm/logger.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for I2SB. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import os
9 | import time
10 | import logging
11 | from rich.console import Console
12 | from rich.logging import RichHandler
13 |
14 | def get_time(sec):
15 | h = int(sec//3600)
16 | m = int((sec//60)%60)
17 | s = int(sec%60)
18 | return h,m,s
19 |
20 | class TimeFilter(logging.Filter):
21 |
22 | def filter(self, record):
23 | try:
24 | start = self.start
25 | except AttributeError:
26 | start = self.start = time.time()
27 |
28 | time_elapsed = get_time(time.time() - start)
29 |
30 | record.relative = "{0}:{1:02d}:{2:02d}".format(*time_elapsed)
31 |
32 | # self.last = record.relativeCreated/1000.0
33 | return True
34 |
35 | class Logger(object):
36 | def __init__(self, rank=0, log_dir=".log"):
37 | # other libraries may set logging before arriving at this line.
38 | # by reloading logging, we can get rid of previous configs set by other libraries.
39 | from importlib import reload
40 | reload(logging)
41 | self.rank = rank
42 | if self.rank == 0:
43 | os.makedirs(log_dir, exist_ok=True)
44 |
45 | log_file = open(os.path.join(log_dir, "log.txt"), "w")
46 | file_console = Console(file=log_file, width=150)
47 | logging.basicConfig(
48 | level=logging.INFO,
49 | format="(%(relative)s) %(message)s",
50 | datefmt="[%X]",
51 | force=True,
52 | handlers=[
53 | RichHandler(show_path=False),
54 | RichHandler(console=file_console, show_path=False)
55 | ],
56 | )
57 | # https://stackoverflow.com/questions/31521859/python-logging-module-time-since-last-log
58 | log = logging.getLogger()
59 | [hndl.addFilter(TimeFilter()) for hndl in log.handlers]
60 |
61 | def info(self, string, *args):
62 | if self.rank == 0:
63 | logging.info(string, *args)
64 |
65 | def warning(self, string, *args):
66 | if self.rank == 0:
67 | logging.warning(string, *args)
68 |
69 | def error(self, string, *args):
70 | if self.rank == 0:
71 | logging.error(string, *args)
72 |
--------------------------------------------------------------------------------
/edm/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/edm/torch_utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import os
9 | import torch
10 | from . import training_stats
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | def init():
15 | if 'MASTER_ADDR' not in os.environ:
16 | os.environ['MASTER_ADDR'] = 'localhost'
17 | if 'MASTER_PORT' not in os.environ:
18 | os.environ['MASTER_PORT'] = '29500'
19 | if 'RANK' not in os.environ:
20 | os.environ['RANK'] = '0'
21 | if 'LOCAL_RANK' not in os.environ:
22 | os.environ['LOCAL_RANK'] = '0'
23 | if 'WORLD_SIZE' not in os.environ:
24 | os.environ['WORLD_SIZE'] = '1'
25 |
26 | backend = 'gloo' if os.name == 'nt' else 'nccl'
27 | torch.distributed.init_process_group(backend=backend, init_method='env://')
28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
29 |
30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None
31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def get_rank():
36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | def get_world_size():
41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | def should_stop():
46 | return False
47 |
48 | #----------------------------------------------------------------------------
49 |
50 | def update_progress(cur, total):
51 | _ = cur, total
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def print0(*args, **kwargs):
56 | if get_rank() == 0:
57 | print(*args, **kwargs)
58 |
59 | #----------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/edm/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import re
9 | import contextlib
10 | import numpy as np
11 | import torch
12 | import warnings
13 | try:
14 | import dnnlib
15 | except:
16 | import edm.dnnlib
17 |
18 | #----------------------------------------------------------------------------
19 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
20 | # same constant is used multiple times.
21 |
22 | _constant_cache = dict()
23 |
24 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
25 | value = np.asarray(value)
26 | if shape is not None:
27 | shape = tuple(shape)
28 | if dtype is None:
29 | dtype = torch.get_default_dtype()
30 | if device is None:
31 | device = torch.device('cpu')
32 | if memory_format is None:
33 | memory_format = torch.contiguous_format
34 |
35 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
36 | tensor = _constant_cache.get(key, None)
37 | if tensor is None:
38 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
39 | if shape is not None:
40 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
41 | tensor = tensor.contiguous(memory_format=memory_format)
42 | _constant_cache[key] = tensor
43 | return tensor
44 |
45 | #----------------------------------------------------------------------------
46 | # Replace NaN/Inf with specified numerical values.
47 |
48 | try:
49 | nan_to_num = torch.nan_to_num # 1.8.0a0
50 | except AttributeError:
51 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
52 | assert isinstance(input, torch.Tensor)
53 | if posinf is None:
54 | posinf = torch.finfo(input.dtype).max
55 | if neginf is None:
56 | neginf = torch.finfo(input.dtype).min
57 | assert nan == 0
58 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
59 |
60 | #----------------------------------------------------------------------------
61 | # Symbolic assert.
62 |
63 | try:
64 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
65 | except AttributeError:
66 | symbolic_assert = torch.Assert # 1.7.0
67 |
68 | #----------------------------------------------------------------------------
69 | # Context manager to temporarily suppress known warnings in torch.jit.trace().
70 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
71 |
72 | @contextlib.contextmanager
73 | def suppress_tracer_warnings():
74 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
75 | warnings.filters.insert(0, flt)
76 | yield
77 | warnings.filters.remove(flt)
78 |
79 | #----------------------------------------------------------------------------
80 | # Assert that the shape of a tensor matches the given list of integers.
81 | # None indicates that the size of a dimension is allowed to vary.
82 | # Performs symbolic assertion when used in torch.jit.trace().
83 |
84 | def assert_shape(tensor, ref_shape):
85 | if tensor.ndim != len(ref_shape):
86 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
87 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
88 | if ref_size is None:
89 | pass
90 | elif isinstance(ref_size, torch.Tensor):
91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
92 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
93 | elif isinstance(size, torch.Tensor):
94 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
95 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
96 | elif size != ref_size:
97 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
98 |
99 | #----------------------------------------------------------------------------
100 | # Function decorator that calls torch.autograd.profiler.record_function().
101 |
102 | def profiled_function(fn):
103 | def decorator(*args, **kwargs):
104 | with torch.autograd.profiler.record_function(fn.__name__):
105 | return fn(*args, **kwargs)
106 | decorator.__name__ = fn.__name__
107 | return decorator
108 |
109 | #----------------------------------------------------------------------------
110 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
111 | # indefinitely, shuffling items as it goes.
112 |
113 | class InfiniteSampler(torch.utils.data.Sampler):
114 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
115 | assert len(dataset) > 0
116 | assert num_replicas > 0
117 | assert 0 <= rank < num_replicas
118 | assert 0 <= window_size <= 1
119 | super().__init__(dataset)
120 | self.dataset = dataset
121 | self.rank = rank
122 | self.num_replicas = num_replicas
123 | self.shuffle = shuffle
124 | self.seed = seed
125 | self.window_size = window_size
126 |
127 | def __iter__(self):
128 | order = np.arange(len(self.dataset))
129 | rnd = None
130 | window = 0
131 | if self.shuffle:
132 | rnd = np.random.RandomState(self.seed)
133 | rnd.shuffle(order)
134 | window = int(np.rint(order.size * self.window_size))
135 |
136 | idx = 0
137 | while True:
138 | i = idx % order.size
139 | if idx % self.num_replicas == self.rank:
140 | yield order[i]
141 | if window >= 2:
142 | j = (i - rnd.randint(window)) % order.size
143 | order[i], order[j] = order[j], order[i]
144 | idx += 1
145 |
146 | #----------------------------------------------------------------------------
147 | # Utilities for operating with torch.nn.Module parameters and buffers.
148 |
149 | def params_and_buffers(module):
150 | assert isinstance(module, torch.nn.Module)
151 | return list(module.parameters()) + list(module.buffers())
152 |
153 | def named_params_and_buffers(module):
154 | assert isinstance(module, torch.nn.Module)
155 | return list(module.named_parameters()) + list(module.named_buffers())
156 |
157 | @torch.no_grad()
158 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
159 | assert isinstance(src_module, torch.nn.Module)
160 | assert isinstance(dst_module, torch.nn.Module)
161 | src_tensors = dict(named_params_and_buffers(src_module))
162 | for name, tensor in named_params_and_buffers(dst_module):
163 | assert (name in src_tensors) or (not require_all)
164 | if name in src_tensors:
165 | tensor.copy_(src_tensors[name])
166 |
167 | #----------------------------------------------------------------------------
168 | # Context manager for easily enabling/disabling DistributedDataParallel
169 | # synchronization.
170 |
171 | @contextlib.contextmanager
172 | def ddp_sync(module, sync):
173 | assert isinstance(module, torch.nn.Module)
174 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
175 | yield
176 | else:
177 | with module.no_sync():
178 | yield
179 |
180 | #----------------------------------------------------------------------------
181 | # Check DistributedDataParallel consistency across processes.
182 |
183 | def check_ddp_consistency(module, ignore_regex=None):
184 | assert isinstance(module, torch.nn.Module)
185 | for name, tensor in named_params_and_buffers(module):
186 | fullname = type(module).__name__ + '.' + name
187 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
188 | continue
189 | tensor = tensor.detach()
190 | if tensor.is_floating_point():
191 | tensor = nan_to_num(tensor)
192 | other = tensor.clone()
193 | torch.distributed.broadcast(tensor=other, src=0)
194 | assert (tensor == other).all(), fullname
195 |
196 | #----------------------------------------------------------------------------
197 | # Print summary table of module hierarchy.
198 |
199 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
200 | assert isinstance(module, torch.nn.Module)
201 | assert not isinstance(module, torch.jit.ScriptModule)
202 | assert isinstance(inputs, (tuple, list))
203 |
204 | # Register hooks.
205 | entries = []
206 | nesting = [0]
207 | def pre_hook(_mod, _inputs):
208 | nesting[0] += 1
209 | def post_hook(mod, _inputs, outputs):
210 | nesting[0] -= 1
211 | if nesting[0] <= max_nesting:
212 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
213 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
214 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
215 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
216 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
217 |
218 | # Run module.
219 | outputs = module(*inputs)
220 | for hook in hooks:
221 | hook.remove()
222 |
223 | # Identify unique outputs, parameters, and buffers.
224 | tensors_seen = set()
225 | for e in entries:
226 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
227 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
228 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
229 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
230 |
231 | # Filter out redundant entries.
232 | if skip_redundant:
233 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
234 |
235 | # Construct table.
236 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
237 | rows += [['---'] * len(rows[0])]
238 | param_total = 0
239 | buffer_total = 0
240 | submodule_names = {mod: name for name, mod in module.named_modules()}
241 | for e in entries:
242 | name = '' if e.mod is module else submodule_names[e.mod]
243 | param_size = sum(t.numel() for t in e.unique_params)
244 | buffer_size = sum(t.numel() for t in e.unique_buffers)
245 | output_shapes = [str(list(t.shape)) for t in e.outputs]
246 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
247 | rows += [[
248 | name + (':0' if len(e.outputs) >= 2 else ''),
249 | str(param_size) if param_size else '-',
250 | str(buffer_size) if buffer_size else '-',
251 | (output_shapes + ['-'])[0],
252 | (output_dtypes + ['-'])[0],
253 | ]]
254 | for idx in range(1, len(e.outputs)):
255 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
256 | param_total += param_size
257 | buffer_total += buffer_size
258 | rows += [['---'] * len(rows[0])]
259 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
260 |
261 | # Print table.
262 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
263 | print()
264 | for row in rows:
265 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
266 | print()
267 | return outputs
268 |
269 | #----------------------------------------------------------------------------
270 |
--------------------------------------------------------------------------------
/edm/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for pickling Python code alongside other data.
9 |
10 | The pickled code is automatically imported into a separate Python module
11 | during unpickling. This way, any previously exported pickles will remain
12 | usable even if the original code is no longer available, or if the current
13 | version of the code is not consistent with what was originally pickled."""
14 |
15 | import sys
16 | import pickle
17 | import io
18 | import inspect
19 | import copy
20 | import uuid
21 | import types
22 | import sys
23 | import os
24 | sys.path.append(os.path.abspath('edm'))
25 | import dnnlib
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | _version = 6 # internal version number
30 | _decorators = set() # {decorator_class, ...}
31 | _import_hooks = [] # [hook_function, ...]
32 | _module_to_src_dict = dict() # {module: src, ...}
33 | _src_to_module_dict = dict() # {src: module, ...}
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | def persistent_class(orig_class):
38 | r"""Class decorator that extends a given class to save its source code
39 | when pickled.
40 |
41 | Example:
42 |
43 | from torch_utils import persistence
44 |
45 | @persistence.persistent_class
46 | class MyNetwork(torch.nn.Module):
47 | def __init__(self, num_inputs, num_outputs):
48 | super().__init__()
49 | self.fc = MyLayer(num_inputs, num_outputs)
50 | ...
51 |
52 | @persistence.persistent_class
53 | class MyLayer(torch.nn.Module):
54 | ...
55 |
56 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
57 | source code alongside other internal state (e.g., parameters, buffers,
58 | and submodules). This way, any previously exported pickle will remain
59 | usable even if the class definitions have been modified or are no
60 | longer available.
61 |
62 | The decorator saves the source code of the entire Python module
63 | containing the decorated class. It does *not* save the source code of
64 | any imported modules. Thus, the imported modules must be available
65 | during unpickling, also including `torch_utils.persistence` itself.
66 |
67 | It is ok to call functions defined in the same module from the
68 | decorated class. However, if the decorated class depends on other
69 | classes defined in the same module, they must be decorated as well.
70 | This is illustrated in the above example in the case of `MyLayer`.
71 |
72 | It is also possible to employ the decorator just-in-time before
73 | calling the constructor. For example:
74 |
75 | cls = MyLayer
76 | if want_to_make_it_persistent:
77 | cls = persistence.persistent_class(cls)
78 | layer = cls(num_inputs, num_outputs)
79 |
80 | As an additional feature, the decorator also keeps track of the
81 | arguments that were used to construct each instance of the decorated
82 | class. The arguments can be queried via `obj.init_args` and
83 | `obj.init_kwargs`, and they are automatically pickled alongside other
84 | object state. This feature can be disabled on a per-instance basis
85 | by setting `self._record_init_args = False` in the constructor.
86 |
87 | A typical use case is to first unpickle a previous instance of a
88 | persistent class, and then upgrade it to use the latest version of
89 | the source code:
90 |
91 | with open('old_pickle.pkl', 'rb') as f:
92 | old_net = pickle.load(f)
93 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
94 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
95 | """
96 | assert isinstance(orig_class, type)
97 | if is_persistent(orig_class):
98 | return orig_class
99 |
100 | assert orig_class.__module__ in sys.modules
101 | orig_module = sys.modules[orig_class.__module__]
102 | orig_module_src = _module_to_src(orig_module)
103 |
104 | class Decorator(orig_class):
105 | _orig_module_src = orig_module_src
106 | _orig_class_name = orig_class.__name__
107 |
108 | def __init__(self, *args, **kwargs):
109 | super().__init__(*args, **kwargs)
110 | record_init_args = getattr(self, '_record_init_args', True)
111 | self._init_args = copy.deepcopy(args) if record_init_args else None
112 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
113 | assert orig_class.__name__ in orig_module.__dict__
114 | _check_pickleable(self.__reduce__())
115 |
116 | @property
117 | def init_args(self):
118 | assert self._init_args is not None
119 | return copy.deepcopy(self._init_args)
120 |
121 | @property
122 | def init_kwargs(self):
123 | assert self._init_kwargs is not None
124 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
125 |
126 | def __reduce__(self):
127 | fields = list(super().__reduce__())
128 | fields += [None] * max(3 - len(fields), 0)
129 | if fields[0] is not _reconstruct_persistent_obj:
130 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
131 | fields[0] = _reconstruct_persistent_obj # reconstruct func
132 | fields[1] = (meta,) # reconstruct args
133 | fields[2] = None # state dict
134 | return tuple(fields)
135 |
136 | Decorator.__name__ = orig_class.__name__
137 | Decorator.__module__ = orig_class.__module__
138 | _decorators.add(Decorator)
139 | return Decorator
140 |
141 | #----------------------------------------------------------------------------
142 |
143 | def is_persistent(obj):
144 | r"""Test whether the given object or class is persistent, i.e.,
145 | whether it will save its source code when pickled.
146 | """
147 | try:
148 | if obj in _decorators:
149 | return True
150 | except TypeError:
151 | pass
152 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
153 |
154 | #----------------------------------------------------------------------------
155 |
156 | def import_hook(hook):
157 | r"""Register an import hook that is called whenever a persistent object
158 | is being unpickled. A typical use case is to patch the pickled source
159 | code to avoid errors and inconsistencies when the API of some imported
160 | module has changed.
161 |
162 | The hook should have the following signature:
163 |
164 | hook(meta) -> modified meta
165 |
166 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
167 |
168 | type: Type of the persistent object, e.g. `'class'`.
169 | version: Internal version number of `torch_utils.persistence`.
170 | module_src Original source code of the Python module.
171 | class_name: Class name in the original Python module.
172 | state: Internal state of the object.
173 |
174 | Example:
175 |
176 | @persistence.import_hook
177 | def wreck_my_network(meta):
178 | if meta.class_name == 'MyNetwork':
179 | print('MyNetwork is being imported. I will wreck it!')
180 | meta.module_src = meta.module_src.replace("True", "False")
181 | return meta
182 | """
183 | assert callable(hook)
184 | _import_hooks.append(hook)
185 |
186 | #----------------------------------------------------------------------------
187 |
188 | def _reconstruct_persistent_obj(meta):
189 | r"""Hook that is called internally by the `pickle` module to unpickle
190 | a persistent object.
191 | """
192 | meta = dnnlib.EasyDict(meta)
193 | meta.state = dnnlib.EasyDict(meta.state)
194 | for hook in _import_hooks:
195 | meta = hook(meta)
196 | assert meta is not None
197 |
198 | assert meta.version == _version
199 | module = _src_to_module(meta.module_src)
200 |
201 | assert meta.type == 'class'
202 | orig_class = module.__dict__[meta.class_name]
203 | decorator_class = persistent_class(orig_class)
204 | obj = decorator_class.__new__(decorator_class)
205 |
206 | setstate = getattr(obj, '__setstate__', None)
207 | if callable(setstate):
208 | setstate(meta.state) # pylint: disable=not-callable
209 | else:
210 | obj.__dict__.update(meta.state)
211 | return obj
212 |
213 | #----------------------------------------------------------------------------
214 |
215 | def _module_to_src(module):
216 | r"""Query the source code of a given Python module.
217 | """
218 | src = _module_to_src_dict.get(module, None)
219 | if src is None:
220 | src = inspect.getsource(module)
221 | _module_to_src_dict[module] = src
222 | _src_to_module_dict[src] = module
223 | return src
224 |
225 | def _src_to_module(src):
226 | r"""Get or create a Python module for the given source code.
227 | """
228 | module = _src_to_module_dict.get(src, None)
229 | # if module is None:
230 | # module_name = "_imported_module_" + uuid.uuid4().hex
231 | # module = types.ModuleType(module_name)
232 | # sys.modules[module_name] = module
233 | # _module_to_src_dict[module] = src
234 | # _src_to_module_dict[src] = module
235 | # exec(src, module.__dict__) # pylint: disable=exec-used
236 | return module
237 |
238 | #----------------------------------------------------------------------------
239 |
240 | def _check_pickleable(obj):
241 | r"""Check that the given object is pickleable, raising an exception if
242 | it is not. This function is expected to be considerably more efficient
243 | than actually pickling the object.
244 | """
245 | def recurse(obj):
246 | if isinstance(obj, (list, tuple, set)):
247 | return [recurse(x) for x in obj]
248 | if isinstance(obj, dict):
249 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
250 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
251 | return None # Python primitive types are pickleable.
252 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
253 | return None # NumPy arrays and PyTorch tensors are pickleable.
254 | if is_persistent(obj):
255 | return None # Persistent objects are pickleable, by virtue of the constructor check.
256 | return obj
257 | with io.BytesIO() as f:
258 | pickle.dump(recurse(obj), f)
259 |
260 | #----------------------------------------------------------------------------
261 |
--------------------------------------------------------------------------------
/edm/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for reporting and collecting training statistics across
9 | multiple processes and devices. The interface is designed to minimize
10 | synchronization overhead as well as the amount of boilerplate in user
11 | code."""
12 |
13 | import re
14 | import numpy as np
15 | import torch
16 | try:
17 | import edm.dnnlib
18 | except:
19 | import dnnlib
20 |
21 | from . import misc
22 |
23 | #----------------------------------------------------------------------------
24 |
25 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
26 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
27 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
28 | _rank = 0 # Rank of the current process.
29 | _sync_device = None # Device to use for multiprocess communication. None = single-process.
30 | _sync_called = False # Has _sync() been called yet?
31 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
32 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | def init_multiprocessing(rank, sync_device):
37 | r"""Initializes `torch_utils.training_stats` for collecting statistics
38 | across multiple processes.
39 |
40 | This function must be called after
41 | `torch.distributed.init_process_group()` and before `Collector.update()`.
42 | The call is not necessary if multi-process collection is not needed.
43 |
44 | Args:
45 | rank: Rank of the current process.
46 | sync_device: PyTorch device to use for inter-process
47 | communication, or None to disable multi-process
48 | collection. Typically `torch.device('cuda', rank)`.
49 | """
50 | global _rank, _sync_device
51 | assert not _sync_called
52 | _rank = rank
53 | _sync_device = sync_device
54 |
55 | #----------------------------------------------------------------------------
56 |
57 | @misc.profiled_function
58 | def report(name, value):
59 | r"""Broadcasts the given set of scalars to all interested instances of
60 | `Collector`, across device and process boundaries.
61 |
62 | This function is expected to be extremely cheap and can be safely
63 | called from anywhere in the training loop, loss function, or inside a
64 | `torch.nn.Module`.
65 |
66 | Warning: The current implementation expects the set of unique names to
67 | be consistent across processes. Please make sure that `report()` is
68 | called at least once for each unique name by each process, and in the
69 | same order. If a given process has no scalars to broadcast, it can do
70 | `report(name, [])` (empty list).
71 |
72 | Args:
73 | name: Arbitrary string specifying the name of the statistic.
74 | Averages are accumulated separately for each unique name.
75 | value: Arbitrary set of scalars. Can be a list, tuple,
76 | NumPy array, PyTorch tensor, or Python scalar.
77 |
78 | Returns:
79 | The same `value` that was passed in.
80 | """
81 | if name not in _counters:
82 | _counters[name] = dict()
83 |
84 | elems = torch.as_tensor(value)
85 | if elems.numel() == 0:
86 | return value
87 |
88 | elems = elems.detach().flatten().to(_reduce_dtype)
89 | moments = torch.stack([
90 | torch.ones_like(elems).sum(),
91 | elems.sum(),
92 | elems.square().sum(),
93 | ])
94 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
95 | moments = moments.to(_counter_dtype)
96 |
97 | device = moments.device
98 | if device not in _counters[name]:
99 | _counters[name][device] = torch.zeros_like(moments)
100 | _counters[name][device].add_(moments)
101 | return value
102 |
103 | #----------------------------------------------------------------------------
104 |
105 | def report0(name, value):
106 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
107 | but ignores any scalars provided by the other processes.
108 | See `report()` for further details.
109 | """
110 | report(name, value if _rank == 0 else [])
111 | return value
112 |
113 | #----------------------------------------------------------------------------
114 |
115 | class Collector:
116 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
117 | computes their long-term averages (mean and standard deviation) over
118 | user-defined periods of time.
119 |
120 | The averages are first collected into internal counters that are not
121 | directly visible to the user. They are then copied to the user-visible
122 | state as a result of calling `update()` and can then be queried using
123 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
124 | internal counters for the next round, so that the user-visible state
125 | effectively reflects averages collected between the last two calls to
126 | `update()`.
127 |
128 | Args:
129 | regex: Regular expression defining which statistics to
130 | collect. The default is to collect everything.
131 | keep_previous: Whether to retain the previous averages if no
132 | scalars were collected on a given round
133 | (default: True).
134 | """
135 | def __init__(self, regex='.*', keep_previous=True):
136 | self._regex = re.compile(regex)
137 | self._keep_previous = keep_previous
138 | self._cumulative = dict()
139 | self._moments = dict()
140 | self.update()
141 | self._moments.clear()
142 |
143 | def names(self):
144 | r"""Returns the names of all statistics broadcasted so far that
145 | match the regular expression specified at construction time.
146 | """
147 | return [name for name in _counters if self._regex.fullmatch(name)]
148 |
149 | def update(self):
150 | r"""Copies current values of the internal counters to the
151 | user-visible state and resets them for the next round.
152 |
153 | If `keep_previous=True` was specified at construction time, the
154 | operation is skipped for statistics that have received no scalars
155 | since the last update, retaining their previous averages.
156 |
157 | This method performs a number of GPU-to-CPU transfers and one
158 | `torch.distributed.all_reduce()`. It is intended to be called
159 | periodically in the main training loop, typically once every
160 | N training steps.
161 | """
162 | if not self._keep_previous:
163 | self._moments.clear()
164 | for name, cumulative in _sync(self.names()):
165 | if name not in self._cumulative:
166 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
167 | delta = cumulative - self._cumulative[name]
168 | self._cumulative[name].copy_(cumulative)
169 | if float(delta[0]) != 0:
170 | self._moments[name] = delta
171 |
172 | def _get_delta(self, name):
173 | r"""Returns the raw moments that were accumulated for the given
174 | statistic between the last two calls to `update()`, or zero if
175 | no scalars were collected.
176 | """
177 | assert self._regex.fullmatch(name)
178 | if name not in self._moments:
179 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
180 | return self._moments[name]
181 |
182 | def num(self, name):
183 | r"""Returns the number of scalars that were accumulated for the given
184 | statistic between the last two calls to `update()`, or zero if
185 | no scalars were collected.
186 | """
187 | delta = self._get_delta(name)
188 | return int(delta[0])
189 |
190 | def mean(self, name):
191 | r"""Returns the mean of the scalars that were accumulated for the
192 | given statistic between the last two calls to `update()`, or NaN if
193 | no scalars were collected.
194 | """
195 | delta = self._get_delta(name)
196 | if int(delta[0]) == 0:
197 | return float('nan')
198 | return float(delta[1] / delta[0])
199 |
200 | def std(self, name):
201 | r"""Returns the standard deviation of the scalars that were
202 | accumulated for the given statistic between the last two calls to
203 | `update()`, or NaN if no scalars were collected.
204 | """
205 | delta = self._get_delta(name)
206 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
207 | return float('nan')
208 | if int(delta[0]) == 1:
209 | return float(0)
210 | mean = float(delta[1] / delta[0])
211 | raw_var = float(delta[2] / delta[0])
212 | return np.sqrt(max(raw_var - np.square(mean), 0))
213 |
214 | def as_dict(self):
215 | r"""Returns the averages accumulated between the last two calls to
216 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
217 |
218 | dnnlib.EasyDict(
219 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
220 | ...
221 | )
222 | """
223 | stats = dnnlib.EasyDict()
224 | for name in self.names():
225 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
226 | return stats
227 |
228 | def __getitem__(self, name):
229 | r"""Convenience getter.
230 | `collector[name]` is a synonym for `collector.mean(name)`.
231 | """
232 | return self.mean(name)
233 |
234 | #----------------------------------------------------------------------------
235 |
236 | def _sync(names):
237 | r"""Synchronize the global cumulative counters across devices and
238 | processes. Called internally by `Collector.update()`.
239 | """
240 | if len(names) == 0:
241 | return []
242 | global _sync_called
243 | _sync_called = True
244 |
245 | # Collect deltas within current rank.
246 | deltas = []
247 | device = _sync_device if _sync_device is not None else torch.device('cpu')
248 | for name in names:
249 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
250 | for counter in _counters[name].values():
251 | delta.add_(counter.to(device))
252 | counter.copy_(torch.zeros_like(counter))
253 | deltas.append(delta)
254 | deltas = torch.stack(deltas)
255 |
256 | # Sum deltas across ranks.
257 | if _sync_device is not None:
258 | torch.distributed.all_reduce(deltas)
259 |
260 | # Update cumulative values.
261 | deltas = deltas.cpu()
262 | for idx, name in enumerate(names):
263 | if name not in _cumulative:
264 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
265 | _cumulative[name].add_(deltas[idx])
266 |
267 | # Return name-value pairs.
268 | return [(name, _cumulative[name]) for name in names]
269 |
270 | #----------------------------------------------------------------------------
271 | # Convenience.
272 |
273 | default_collector = Collector()
274 |
275 | #----------------------------------------------------------------------------
276 |
--------------------------------------------------------------------------------
/networks/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-agm/de57e5fcb9b0526de5ffebf67d6932601c1db054/networks/.DS_Store
--------------------------------------------------------------------------------
/networks/get_network.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | from networks.network import *
6 | from networks.edm.ncsnpp import SongUNet,DhariwalUNet
7 | import torch
8 | # from .util import Ltvv, Ltxx,Ltxv,reshape_as
9 | def get_nn(opt,dyn):
10 | net={
11 | 'toy': ResNet,
12 | 'cifar10': SongUNet,
13 | 'AFHQv2': SongUNet,
14 | 'imagenet64':DhariwalUNet
15 | }.get(opt.exp)
16 | return network_wrapper(opt,net(opt),dyn)
17 |
18 | class network_wrapper(torch.nn.Module):
19 | # note: scale_by_g matters only for pre-trained model
20 | def __init__(self, opt, net,dyn):
21 | super(network_wrapper,self).__init__()
22 | self.opt = opt
23 | self.net = net
24 | self.dim = self.opt.data_dim
25 | self.varx= opt.varx
26 | self.varv= opt.varv
27 | self.dyn= dyn
28 | self.p = opt.p
29 |
30 | def get_precond(self,m,t):
31 | precond = (1-t)
32 | precond =precond.reshape(-1,*([1,]*(len(m.shape)-1)))
33 | return precond
34 |
35 |
36 | def forward(self, m, t,cond=None):
37 | t = t.squeeze()
38 | if t.dim()==0: t = t.repeat(m.shape[0])
39 | assert t.dim()==1 and t.shape[0] == m.shape[0]
40 | precond = 1
41 | if self.opt.precond: precond =self.get_precond(m,t)
42 | out = precond*self.net(m, 1-t,class_labels=cond) #Flip the time because we are generating image at t=1. It will influence the time cond (log t).
43 | return out
44 |
--------------------------------------------------------------------------------
/networks/network.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import torch
6 | import torch.nn as nn
7 |
8 | def zero_module(module):
9 | """
10 | Zero out the parameters of a module and return it.
11 | """
12 | for p in module.parameters():
13 | p.detach().zero_()
14 | return module
15 |
16 | class ResNet(nn.Module):
17 | def __init__(self,
18 | opt,
19 | input_dim=2,
20 | index_dim=1,
21 | hidden_dim=128,
22 | n_hidden_layers=20):
23 |
24 | super().__init__()
25 |
26 | self.act = nn.SiLU()
27 | self.n_hidden_layers = n_hidden_layers
28 |
29 | self.x_input = True # input is concat [x,v] or just x
30 | if self.x_input:
31 | in_dim = input_dim * 2 + index_dim
32 | else:
33 | in_dim = input_dim + index_dim
34 | out_dim = input_dim
35 |
36 | layers = []
37 | layers.append(nn.Linear(in_dim, hidden_dim))
38 | for _ in range(n_hidden_layers):
39 | layers.append(nn.Linear(hidden_dim + index_dim, hidden_dim))
40 | layers.append(nn.Linear(hidden_dim + index_dim, out_dim))
41 |
42 | self.layers = nn.ModuleList(layers)
43 | self.layers[-1] = zero_module(self.layers[-1])
44 | def _append_time(self, h, t):
45 | time_embedding = torch.log(t)
46 | return torch.cat([h, time_embedding.reshape(-1, 1)], dim=1)
47 |
48 | def forward(self, u, t,class_labels=None):
49 | h0 = self.layers[0](self._append_time(u, t))
50 | h = self.act(h0)
51 |
52 | for i in range(self.n_hidden_layers):
53 | h_new = self.layers[i + 1](self._append_time(h, t))
54 | h = self.act(h + h_new)
55 |
56 | return self.layers[-1](self._append_time(h, t))
57 |
--------------------------------------------------------------------------------
/plot_util.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import matplotlib.pyplot as plt
6 | import os
7 | import numpy as np
8 | import torch
9 | def save_toy_traj(opt, fn, traj):
10 | fn_pdf = os.path.join(opt.ckpt_path, fn+'.pdf')
11 | n_snapshot=2
12 | lims = [[-4, 4],[-6, 6]]
13 |
14 | total_steps = traj.shape[1]
15 | sample_steps= np.linspace(0, total_steps-1, n_snapshot).astype(int)
16 | traj_steps = np.linspace(0, total_steps-1, 10).astype(int)
17 | if n_snapshot is None: # only store t=0
18 | plt.scatter(traj[:,0,0],traj[:,0,1], s=10)
19 | plt.xlim(*lims)
20 | plt.ylim(*lims)
21 | else:
22 | fig, axss = plt.subplots(1, 2)
23 | fig.set_size_inches(20, 10)
24 | cmap =['Blues','Reds']
25 | colors= ['b','r']
26 | num_samp_lines = 10
27 | random_idx = np.random.choice(traj.shape[0], num_samp_lines, replace=False)
28 | means=traj[random_idx,...]
29 | for i in range(2):
30 | ax=axss[i]
31 | ax.grid(True)
32 | ax.xaxis.set_ticklabels([])
33 | ax.yaxis.set_ticklabels([])
34 | _colors = np.linspace(0.5,1,len(sample_steps))
35 | for idx,step in enumerate(sample_steps):
36 | ax.scatter(traj[:,step,2*i],traj[:,step,2*i+1], s=10, c=_colors[idx].repeat(traj.shape[0]), alpha=0.6,vmin=0, vmax=1,cmap=cmap[i])
37 | ax.set_xlim(*lims[i])
38 | ax.set_ylim(*lims[i])
39 |
40 | for ii in range(num_samp_lines):
41 | ax.plot(means[ii,:,2*i],means[ii,:,2*i+1],color=colors[i],linewidth=4,alpha=0.5)
42 | ax.set_title('position' if i==0 else 'velocity',size=40)
43 |
44 | fig.suptitle('NFE = {}'.format(opt.nfe-1),size=40)
45 | fig.tight_layout()
46 | plt.savefig(fn_pdf)
47 | plt.clf()
48 |
49 | def save_snapshot_traj(opt, fn, pred,gt):
50 | fn_pdf = os.path.join(opt.ckpt_path, fn+'_static.pdf')
51 | lims = [-4, 4]
52 | gt=gt.detach().cpu().numpy()
53 | pred=pred.detach().cpu().numpy()
54 | fig, axss = plt.subplots(1, 2)
55 | fig.set_size_inches(20, 10)
56 | colors= ['b','r']
57 | ax=axss[0]
58 | ax.scatter(pred[:,0],pred[:,1],color='steelblue',alpha=0.3,s=5)
59 | ax.set_xlim(*lims)
60 | ax.set_ylim(*lims)
61 | ax=axss[1]
62 | ax.scatter(gt[:,0],gt[:,1],color='coral',alpha=0.3,s=5)
63 | ax.set_xlim(*lims)
64 | ax.set_ylim(*lims)
65 | fig.suptitle('NFE = {}'.format(opt.nfe),size=40)
66 | fig.tight_layout()
67 | plt.savefig(fn_pdf)
68 | # np.save(fn_npy,traj)
69 | plt.clf()
70 |
71 |
72 | def norm_data(x):
73 | bs=x.shape[0]
74 | _max=torch.max(torch.max(x,dim=-1)[0],dim=-1)[0][...,None,None]
75 | _min=torch.min(torch.min(x,dim=-1)[0],dim=-1)[0][...,None,None]
76 | x=(x-_min)/(_max-_min)
77 | return x
78 |
79 |
80 | def plot_toy(opt,ms,it,pred_m1,x1):
81 | save_toy_traj(opt, 'itr_{}'.format(it), ms.detach().cpu().numpy())
82 | save_snapshot_traj(opt, 'itr_x{}'.format(it), pred_m1[:,0:2],x1)
83 |
84 |
85 | def plot_scatter(x,ts,ax):
86 | '''
87 | x:bs,t,dim
88 | '''
89 | bs,interval,dim = x.shape
90 | for ii in range(interval):
91 | ax.scatter(ts[ii].repeat(bs),x[:,ii,:],s=2,color='b',alpha=0.1)
92 |
93 | def plot_plt(x,ts,ax):
94 | '''
95 | x:bs,t,dim
96 | '''
97 | bs,interval,dim = x.shape
98 |
99 | for ii in range(bs):
100 | ax.plot(ts,x[ii,:,0],color='r',alpha=0.1)
101 |
102 | def save_toy_npy_traj(opt, fn, traj, n_snapshot=None, direction='forward'):
103 | #form of traj: [bs, interval, x_dim=2]
104 | fn_pdf = os.path.join(opt.ckpt_path, fn+'.pdf')
105 |
106 | lims = [-5,5]
107 |
108 | if n_snapshot is None: # only store t=0
109 | plt.scatter(traj[:,0,0],traj[:,0,1], s=5)
110 | plt.xlim(*lims)
111 | plt.ylim(*lims)
112 | else:
113 | total_steps = traj.shape[1]
114 | sample_steps = np.linspace(0, total_steps-1, n_snapshot).astype(int)
115 | fig, axs = plt.subplots(1, n_snapshot)
116 | fig.set_size_inches(n_snapshot*6, 6)
117 | color = 'salmon' if direction=='forward' else 'royalblue'
118 | for ax, step in zip(axs, sample_steps):
119 | ax.scatter(traj[:,step,0],traj[:,step,1], s=1, color=color,alpha=0.2)
120 | ax.set_xlim(*lims)
121 | ax.set_ylim(*lims)
122 | ax.set_title('time = {:.2f}'.format(step/(total_steps-1)*opt.T))
123 | fig.tight_layout()
124 |
125 | plt.savefig(fn_pdf)
126 | plt.clf()
--------------------------------------------------------------------------------
/scripts/AFHQv2.sh:
--------------------------------------------------------------------------------
1 | # python sampling.py --n-gpu-per-node 8 --batch-size 1250 --ckpt AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/latest.pt --num-sample 50000 --sampling sscs --pred-x1 --nfe 20
2 | # torchrun --standalone --nproc_per_node=1 fid.py calc --images=results/AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/fid_sample_folder/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz
3 |
4 | python sampling.py --n-gpu-per-node 8 --batch-size 1250 --ckpt AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/latest.pt --num-sample 50000 --sampling sscs --pred-x1 --nfe 50
5 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=results/AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/fid_sample_folder/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz
6 |
7 | python sampling.py --n-gpu-per-node 8 --batch-size 1250 --ckpt AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/latest.pt --num-sample 50000 --sampling sscs --pred-x1 --nfe 100
8 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=results/AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/fid_sample_folder/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz
9 |
10 | python sampling.py --n-gpu-per-node 8 --batch-size 1250 --ckpt AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/latest.pt --num-sample 50000 --sampling sscs --pred-x1 --nfe 150
11 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=results/AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/fid_sample_folder/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz
12 |
13 | python sampling.py --n-gpu-per-node 8 --batch-size 1250 --ckpt AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/latest.pt --num-sample 50000 --sampling sscs --pred-x1 --nfe 500
14 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=results/AFHQv2-uniform-recip-precond-bs512-varv12-k0.8-p3-lr5e4-EDM-newlabel/fid_sample_folder/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/afhqv2-64x64.npz
15 |
--------------------------------------------------------------------------------
/scripts/cifar10.sh:
--------------------------------------------------------------------------------
1 | python train.py --name ode/varv1-0.99 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 1 --k 0.99 --sampling vanillaODE --reweight reciprocal --ode
2 |
--------------------------------------------------------------------------------
/scripts/example.sh:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | #Cifar10
6 | python sampling.py --n-gpu-per-node 1 --ckpt Cifar10-ODE/latest.pt --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name cifar10-nfe20 --num-sample 64 --batch-size 64 --save-img --img-save-name cifar10-nfe20
7 | #AFHQv2
8 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQv2-ODE/latest.pt --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name AFHQv2-nfe20 --num-sample 64 --batch-size 64 --save-img --img-save-name AFHQv2-nfe20
9 |
10 | #AFHQv2 Conditional Generation
11 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQv2-ODE/latest.pt --pred-x1 --solver gDDIM --save-img --img-save-name cond-AFHQv2 --nfe 100 --T 0.999 --num-sample 64 --batch-size 64 --stroke-path dataset/StrokeData/testFig0.png --stroke-type dyn-v # you can also replace it by init-v
12 |
13 | #Imagenet64
14 | #NFE 20 FID=10.55
15 | python sampling.py --n-gpu-per-node 1 --ckpt uncond-ImageNet64-ODE/latest.pt --pred-x1 --solver gDDIM --save-img --img-save-name imagenet-nfe20 --fid-save-name ImageNet64-nfe20 --nfe 20 --T 0.99 --num-sample 64 --batch-size 64
16 |
17 |
--------------------------------------------------------------------------------
/scripts/imagenet64.sh:
--------------------------------------------------------------------------------
1 | python train.py --name imagenet-uniform-recip-precond-bs512-varv1-varx1-k0.2-p3-lr1e3-ADM-newlabel-probode --varv 1 --varx 1 --k 0.2 --p 3 --microbatch 64 --n-gpu-per-node 8 --lr 1e-3 --exp imagenet64 --t-samp uniform --precond --reweight-type reciprocal --probablistic-ode --sampling vanillaODE
2 |
--------------------------------------------------------------------------------
/scripts/release.sh:
--------------------------------------------------------------------------------
1 | #**********Toy******************
2 | #SSS Stochastic dynamics
3 | #train
4 | python train.py --name train-toy/sde-spiral --exp toy --toy-exp spiral
5 | #sampling
6 | python train.py --name train-toy/sde-spiral-eval --exp toy --toy-exp spiral --eval --ckpt train-toy/sde-spiral --nfe 10
7 | #Exponential Integrator ode dynamics
8 | # train
9 | python train.py --name train-toy/ode-spiral --exp toy --toy-exp spiral --DE-type probODE --solver gDDIM
10 | #Sampling
11 | python train.py --name train-toy/ode-spiral-eval --exp toy --toy-exp spiral --DE-type probODE --solver gDDIM --nfe 10 --eval --ckpt train-toy/ode-spiral/latest.pt
12 | #**************Cifar10*************
13 | #NFE=5 FID=11.88
14 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt Remote_Cifar10_ODE --pred-x1 --solver gDDIM --T 0.4 --nfe 5 --fid-save-name cifar10-nfe5 --port 6024 --num-sample 50000 --batch-size 1000
15 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe5 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
16 |
17 | #NFE=10 FID=4.54
18 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt Remote_Cifar10_ODE --pred-x1 --solver gDDIM --T 0.7 --nfe 10 --fid-save-name cifar10-nfe10 --port 6024 --num-sample 50000 --batch-size 1000
19 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe10 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
20 |
21 | #NFE=20 FID=2.58
22 | CUDA_VISIBLE_DEVICES=0 python sampling.py --n-gpu-per-node 1 --ckpt Remote_Cifar10_ODE --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name cifar10-nfe20 --port 6024 --num-sample 50000 --batch-size 1000
23 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe20 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
24 |
25 | #**************AFHQv2*************
26 | #Conditional Generation
27 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQ-cond/uuysqcynde/latest.pt --solver gDDIM --save-img --img-save-name cond-test0129 --nfe 100 --T 0.999 --num-sample 8 --batch-size 8 --stroke-path dataset/StrokeData/testFig0.png --stroke-type dyn-v
28 | python sampling.py --n-gpu-per-node 1 --ckpt AFHQ-cond/uuysqcynde/latest.pt --solver gDDIM --save-img --img-save-name cond-test0129 --nfe 100 --T 0.999 --num-sample 8 --batch-size 8 --stroke-path dataset/StrokeData/testFig1.png --stroke-type dyn-v
29 | #Impainting Generation
30 | python sampling.py --n-gpu-per-node 1 --ckpt Remote_AFHQv2_ODE --pred-x1 --solver gDDIM --save-img --img-save-name impaint-AFHQv2 --nfe 100 --T 0.999 --num-sample 64 --batch-size 64 --stroke-type dyn-v --impainting --stroke-path dataset/StrokeData/testFig0_impainting.png
31 | python sampling.py --n-gpu-per-node 1 --ckpt Remote_AFHQv2_ODE --pred-x1 --solver gDDIM --save-img --img-save-name impaint-AFHQv2 --nfe 100 --T 0.999 --num-sample 64 --batch-size 64 --stroke-type dyn-v --impainting --stroke-path dataset/StrokeData/testFig1_impainting.png
32 |
33 | #NFE=5
34 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt Remote_Cifar10_ODE --pred-x1 --solver gDDIM --T 0.4 --nfe 5 --fid-save-name cifar10-nfe5 --port 6024 --num-sample 50000 --batch-size 1000
35 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=FID_EVAL/cifar10-nfe5 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
36 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe5 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
37 |
38 | #NFE=10
39 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt cifar10_ODE/latest.pt --pred-x1 --solver gDDIM --T 0.7 --nfe 10 --fid-save-name cifar10-nfe10 --port 6024 --num-sample 50000 --batch-size 1000
40 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe10 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
41 |
42 | #NFE=20
43 | CUDA_VISIBLE_DEVICES=0 python sampling.py --n-gpu-per-node 1 --ckpt cifar10_ODE/latest.pt --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name cifar10-nfe20 --port 6024 --num-sample 50000 --batch-size 1000
44 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe20 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
45 |
46 | #**************ImageNet*************
47 | #NFE=5
48 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt Remote_Cifar10_ODE --pred-x1 --solver gDDIM --T 0.4 --nfe 5 --fid-save-name cifar10-nfe5 --port 6024 --num-sample 50000 --batch-size 1000
49 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=FID_EVAL/cifar10-nfe5 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
50 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe5 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
51 |
52 | #NFE=10
53 | CUDA_VISIBLE_DEVICES=1 python sampling.py --n-gpu-per-node 1 --ckpt cifar10_ODE/latest.pt --pred-x1 --solver gDDIM --T 0.7 --nfe 10 --fid-save-name cifar10-nfe10 --port 6024 --num-sample 50000 --batch-size 1000
54 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe10 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
55 |
56 | #NFE=20
57 | CUDA_VISIBLE_DEVICES=0 python sampling.py --n-gpu-per-node 1 --ckpt cifar10_ODE/latest.pt --pred-x1 --solver gDDIM --T 0.9 --nfe 20 --fid-save-name cifar10-nfe20 --port 6024 --num-sample 50000 --batch-size 1000
58 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe20 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
59 | # CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 edm/fid.py calc --images=FID_EVAL/cifar10-nfe20 --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
60 |
61 | # https://drive.google.com/file/d/1H92Bgz26hLajYNtcY7zPtI7y9HFLeYtD/view?usp=sharing
62 | # https://drive.google.com/file/d/1u26_iWWaBSW8hXMnAudJB90Awolabta4/view?usp=sharing
--------------------------------------------------------------------------------
/scripts/toy.sh:
--------------------------------------------------------------------------------
1 | python train.py --name prob-ode/varv10-p1 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 3 --p 1 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode
2 |
3 | # python train.py --name prob-ode/varv7 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 7 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode
4 |
5 | # python train.py --name prob-ode/varv5 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 5 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode
6 |
7 | # python train.py --name prob-ode/varv3 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 3 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode
8 |
9 | # python train.py --name prob-ode/varv1 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 1 --varv 1 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode
10 |
11 |
12 | # python train.py --name prob-ode/varx2-varv10 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 2
13 | # --varv 10 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode
14 |
15 | # python train.py --name prob-ode/varx4-varv10 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 4 --varv 10 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode
16 |
17 | # python train.py --name prob-ode/varx4-varv6 --exp toy --toy-exp spiral --reweight-type reciprocal --t-samp uniform --varx 4
18 | # --varv 6 --k 0.8 --sampling vanillaODE --reweight reciprocal --probablistic-ode
19 |
20 | # #Training
21 | # python train.py --name train-toy/ode-spiral --exp toy --toy-exp spiral --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --ode --sampling vanillaODE --reweight reciprocal
22 |
23 | # python train.py --name train-toy/ode-gmm --exp toy --toy-exp gmm --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --ode --sampling vanillaODE --reweight reciprocal
24 |
25 | # #Sampling 20 NFE
26 | # python train.py --name eval-toy/ode-spiral --exp toy --toy-exp spiral --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --ode --sampling vanillaODE --reweight reciprocal --ckpt train-toy/ode-spiral --eval --interval 20
27 |
28 | # python train.py --name eval-toy/ode-gmm --exp toy --toy-exp gmm --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --ode --sampling vanillaODE --reweight reciprocal --ckpt train-toy/ode-gmm --eval --interval 20
29 |
30 | # #Training
31 | # python train.py --name train-toy/sde-spiral --exp toy --toy-exp spiral --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --sampling sscs --reweight reciprocal
32 |
33 | # python train.py --name train-toy/sde-gmm --exp toy --toy-exp gmm --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --sampling sscs --reweight reciprocal
34 |
35 | # #Sampling 20 NFE
36 | # python train.py --name eval-toy/sde-spiral --exp toy --toy-exp spiral --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --sampling sscs --reweight reciprocal --ckpt train-toy/sde-spiral --eval --interval 20
37 |
38 | # python train.py --name eval-toy/sde-gmm --exp toy --toy-exp gmm --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --sampling sscs --reweight reciprocal --ckpt train-toy/sde-gmm --eval --interval 20
39 |
40 |
41 |
42 |
43 | #Source Command
44 | # python train.py --name test-ode --exp toy --reweight-type ones --t-samp uniform --varx 1 --varv 10 --k 0.8 --ode --sampling vanillaODE --reweight reciprocal
45 |
--------------------------------------------------------------------------------
/setup/conda_install.sh:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | pip install --upgrade pip
6 | pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu118
7 | pip install torch_ema
8 | pip install pytorch_warmup
9 |
10 | python setup/download_datasets.py
11 | pip install -U pytorch_warmup
12 | pip install --upgrade wandb
13 | # AFHQ download
14 | URL=https://www.dropbox.com/s/vkzjokiwof5h8w6/afhq_v2.zip?dl=0
15 | ZIP_FILE=./dataset/afhq_v2.zip
16 | mkdir -p ./dataset/afhqv2
17 | # wget -N $URL -O $ZIP_FILE
18 | wget --wait 10 --random-wait --continue -N $URL -O $ZIP_FILE
19 | unzip $ZIP_FILE -d ./dataset/afhqv2
20 | rm $ZIP_FILE
21 | # AFHQ download
22 | python edm/dataset_tool.py --source=dataset/afhqv2 --dest=dataset/afhqv2-64x64.zip --resolution=64x64
--------------------------------------------------------------------------------
/setup/download_datasets.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import torchvision.datasets as datasets
6 | datasets.CIFAR10(
7 | './dataset',
8 | train= True,
9 | download=True,
10 | )
--------------------------------------------------------------------------------
/setup/environments.yml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | name: agm
6 | channels:
7 | # - pytorch
8 | - nvidia
9 | dependencies:
10 | - python>=3.8, < 3.10 # package build failures on 3.10
11 | - pip
12 | - numpy>=1.20
13 | - click>=8.0
14 | - pillow>=8.3.1
15 | - scipy>=1.7.1
16 | - psutil
17 | - requests
18 | - tqdm
19 | - imageio
20 | - matplotlib
21 | - scikit-learn
22 | - pip:
23 | - imageio-ffmpeg>=0.4.3
24 | - pyspng
25 | - easydict
26 | - ipdb
27 | # - torch_ema
28 | - prefetch_generator
29 | - wandb
30 | - termcolor
31 | - rich
32 | # - pytorch_warmup
33 | - colored_traceback
34 | - ml_collections
35 | - gdown
36 |
--------------------------------------------------------------------------------
/setup/requirement.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scipy
3 | termcolor
4 | easydict
5 | ipdb
6 | tqdm
7 | scikit-learn
8 | imageio
9 | matplotlib
10 | tensorboard
11 | torchmetrics==0.9.3
12 | prefetch_generator
13 | colored-traceback
14 | torch-ema
15 | gdown==4.6.0
16 | clean-fid==0.1.35
17 | rich
18 | lmdb
19 | wandb
20 | Ninja
21 | ml_collections
--------------------------------------------------------------------------------
/setup/setup.sh:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | pip install --upgrade pip
6 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
7 | pip install -r setup/requirement.txt
8 | python download_datasets.py
9 | pip install -U pytorch_warmup
10 | pip install --upgrade wandb
11 | # # AFHQ download
12 | URL=https://www.dropbox.com/s/vkzjokiwof5h8w6/afhq_v2.zip?dl=0
13 | ZIP_FILE=./dataset/afhq_v2.zip
14 | mkdir -p ./dataset/afhqv2
15 | wget --wait 10 --random-wait --continue -N $URL -O $ZIP_FILE
16 | unzip $ZIP_FILE -d ./dataset/afhqv2
17 | rm $ZIP_FILE
18 | # AFHQ download
19 | python edm/dataset_tool.py --source=dataset/afhqv2 --dest=dataset/afhqv2-64x64.zip --resolution=64x64
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | from __future__ import absolute_import, division, print_function, unicode_literals
6 |
7 | import os
8 | import sys
9 | import random
10 | import argparse
11 |
12 | import copy
13 | from pathlib import Path
14 | import numpy as np
15 | import torch
16 | from torch.multiprocessing import Process
17 |
18 | from edm.logger import Logger
19 | from edm.distributed_util import init_processes
20 | from dataset import spiral,cifar10, AFHQv2,imagenet64
21 | from AM import runner
22 | from configs import cifar10_config,toy_config,afhqv2_config,imagenet64_config
23 | import colored_traceback.always
24 | import torch.multiprocessing
25 |
26 | RESULT_DIR = Path("results")
27 | def set_seed(seed):
28 | # https://github.com/pytorch/pytorch/issues/7068
29 | random.seed(seed)
30 | os.environ['PYTHONHASHSEED'] = str(seed)
31 | np.random.seed(seed)
32 | torch.manual_seed(seed)
33 | torch.cuda.manual_seed(seed)
34 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
35 | torch.backends.cudnn.enabled = True
36 | torch.backends.cudnn.benchmark = True
37 | torch.backends.cuda.matmul.allow_tf32 = True
38 | # https://stackoverflow.com/questions/73125231/pytorch-dataloaders-bad-file-descriptor-and-eof-for-workers0
39 | torch.multiprocessing.set_sharing_strategy('file_system')
40 |
41 | def create_training_options():
42 | # --------------- basic ---------------
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument("--seed", type=int, default=42)
45 | parser.add_argument("--name", type=str, default=None, help="experiment ID")
46 | parser.add_argument("--exp", type=str, default='toy', choices=['toy','cifar10','AFHQv2','imagenet64','cond-imagenet64'], help="experiment type")
47 | parser.add_argument("--toy-exp", type=str, default='gmm', choices=['gmm','spiral'], help="experiment type")
48 | parser.add_argument("--ckpt", type=str, default=None, help="resumed checkpoint name")
49 | parser.add_argument("--cond", action="store_true", help="whether or not use class cond")
50 | parser.add_argument("--gpu", type=int, default=None, help="set only if you wish to run on a particular device")
51 | parser.add_argument("--n-gpu-per-node", type=int, default=1, help="number of gpu on each node")
52 | parser.add_argument("--master-address", type=str, default='localhost', help="address for master")
53 | parser.add_argument("--node-rank", type=int, default=0, help="the index of node")
54 | parser.add_argument("--num-proc-node", type=int, default=1, help="The number of nodes in multi node env")
55 | parser.add_argument("--port", type=str, default='6022', help="localhost port")
56 |
57 | # --------------- Dynamics Hyperparameters ---------------
58 | parser.add_argument("--n-train", type=int, default=5000)
59 | parser.add_argument("--t0", type=float, default=1e-4, help="Number of Training sample for toy dataset")
60 | parser.add_argument("--T", type=float, default=0.999, help="Terminal Time for the dynamics")
61 | parser.add_argument("--nfe", type=int, default=1000, help="number of interval")
62 | parser.add_argument("--varx", type=float, default=1.0, help="variance of position for prior")
63 | parser.add_argument("--varv", type=float, default=1.0, help="variance of velocity for prior")
64 | parser.add_argument("--k", type=float, default=0.0, help="Correlation/Covariance of position of velocity for prior distribution")
65 | parser.add_argument("--p", type=float, default=3, help="diffusion coefficient for Time Variant value g(t)=p*(damp_t-t)")
66 | parser.add_argument("--damp-t", type=float, default=1, help="diffusion coefficient for Time Variant value g(t)=p*(damp_t-t)")
67 | parser.add_argument("--DE-type", type=str, default='probODE', choices=['probODE','SDE'],\
68 | help="Choose the type of SDE, which includes Time Varing g Model Predictive Control (TVgMPC) ,\
69 | Time Invariant g Model Predictive Control (TIVgMPC),Time Invariant g Flow Matching (TIVgFM)")
70 | # --------------- optimizer and loss ---------------
71 | parser.add_argument("--microbatch", type=int, default=512, help="mini batch size for gradient descent")
72 | parser.add_argument("--num-itr", type=int, default=50000, help="number of training iteration")
73 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
74 | parser.add_argument("--ema", type=float, default=0.9999, help='ema decay rate')
75 | parser.add_argument("--l2-norm", type=float, default=0, help='L2 norm for optimizer')
76 | parser.add_argument("--t-samp", type=str, default='uniform', choices=['uniform','debug'],\
77 | help="the way to sample t during sampling")
78 | parser.add_argument("--precond", action="store_true", help="preconditioning for the network output")
79 | parser.add_argument("--clip-grad", type=float, default=None, help="whether to clip the gradient.")
80 | parser.add_argument("--xflip", action="store_true", help="Whether flip the dataset in x-horizon")
81 | parser.add_argument("--reweight-type", type=str, default='ones', choices=['ones','reciprocal','reciprocalinv'], help="How to reweight the training")
82 | # --------------- sampling and evaluating ---------------
83 | parser.add_argument("--train-fid-sample",type=int, default=None, help="number of samples used for evaluating FID")
84 | parser.add_argument("--sampling-batch", type=int, default=512, help="mini batch size for gradient descent")
85 | parser.add_argument("--eval", action="store_true", help="evaluating mode. Wont save ckpt")
86 | parser.add_argument("--clip-x1", action="store_true", help="similar to DDPM, clip the estimiated data to [-1,1]")
87 | parser.add_argument("--debug", action="store_true", help="Using single GPU to evaluate. raise this flag for fast testing")
88 | parser.add_argument("--pred-x1", action="store_true", help="Using predict x1 as the sampling output")
89 | parser.add_argument("--gDDIM-r", type=int, default=2, help="the checkpoint name from which we wish to sample")
90 | parser.add_argument("--diz-order", type=int, default=2, help="the checkpoint name from which we wish to sample")
91 | parser.add_argument("--solver", type=str, default='em', help="sampler")
92 | parser.add_argument("--diz", type=str, default='Euler', choices=['Euler','sigmoid'], help="The discretization scheme")
93 | parser.add_argument("--sanity", action="store_true", help="quick sanity check for the proposed dyanmics")
94 | # --------------- path and logging ---------------
95 | parser.add_argument("--log-dir", type=Path, default=".log", help="path to log std outputs and writer data")
96 | parser.add_argument("--log-writer", type=str, default=None, help="log writer: can be tensorbard, wandb, or None")
97 | parser.add_argument("--wandb-api-key", type=str, default=None, help="unique API key of your W&B account; see https://wandb.ai/authorize")
98 | parser.add_argument("--wandb-user", type=str, default=None, help="user name of your W&B account")
99 |
100 |
101 | default_config, model_configs = {
102 | 'toy': toy_config.get_toy_default_configs,
103 | 'cifar10': cifar10_config.get_cifar10_default_configs,
104 | 'AFHQv2': afhqv2_config.get_afhqv2_default_configs,
105 | 'imagenet64': imagenet64_config.get_imagenet64_default_configs,
106 | }.get(parser.parse_args().exp)()
107 | parser.set_defaults(**default_config)
108 |
109 | opt = parser.parse_args()
110 | opt.model_config=model_configs
111 |
112 | # ========= auto setup =========
113 | opt.device ='cuda' if opt.gpu is None else f'cuda:{opt.gpu}'
114 | opt.distributed = opt.n_gpu_per_node > 1
115 |
116 | if opt.solver=='sscs':
117 | assert opt.DE_type=='SDE'
118 |
119 | # log ngc meta data
120 | if "NGC_JOB_ID" in os.environ.keys():
121 | opt.ngc_job_id = os.environ["NGC_JOB_ID"]
122 |
123 | # ========= path handle =========
124 | os.makedirs(opt.log_dir, exist_ok=True)
125 | opt.ckpt_path = RESULT_DIR / opt.name
126 | os.makedirs(opt.ckpt_path, exist_ok=True)
127 |
128 | if opt.train_fid_sample is None:
129 | opt.train_fid_sample = opt.n_gpu_per_node*opt.microbatch
130 |
131 | if opt.ckpt is not None:
132 | ckpt_file = RESULT_DIR / opt.ckpt / "latest.pt"
133 | assert ckpt_file.exists()
134 | opt.load = ckpt_file
135 | else:
136 | opt.load = None
137 |
138 |
139 | return opt
140 |
141 | def main(opt):
142 | log = Logger(opt.global_rank, opt.log_dir)
143 | log.info("=======================================================")
144 | log.info(" Accelerate Model")
145 | log.info("=======================================================")
146 | log.info("Command used:\n{}".format(" ".join(sys.argv)))
147 | log.info(f"Experiment ID: {opt.name}")
148 |
149 | # set seed: make sure each gpu has differnet seed!
150 | if opt.seed is not None:
151 | set_seed(opt.seed + opt.global_rank)
152 |
153 | # build dataset
154 | if opt.exp=='toy':
155 | train_loader = spiral.spiral_data(opt)
156 | elif opt.exp=='cifar10':
157 | train_loader = cifar10.cifar10_data(opt)
158 | if opt.cond: opt.cond_dim = 10
159 | elif opt.exp=='AFHQv2':
160 | train_loader = AFHQv2.AFHQv2_data(opt)
161 | elif opt.exp=='imagenet64':
162 | train_loader = imagenet64.imagenet64_data(opt)
163 | if opt.cond: opt.cond_dim = 1000
164 |
165 | run = runner.Runner(opt, log)
166 | if opt.eval:
167 | run.ema.copy_to()
168 | run.evaluation(opt,0,train_loader)
169 | else:
170 | run.train(opt, train_loader)
171 | log.info("Finish!")
172 |
173 | if __name__ == '__main__':
174 | opt = create_training_options()
175 | if opt.debug:
176 | opt.distributed =False
177 | opt.global_rank = 0
178 | opt.local_rank = 0
179 | opt.global_size = 1
180 | with torch.cuda.device(opt.gpu):
181 | main(opt)
182 | else:
183 | torch.multiprocessing.set_start_method('spawn')
184 | if opt.distributed:
185 | size = opt.n_gpu_per_node
186 | processes = []
187 |
188 | for rank in range(size):
189 | opt = copy.deepcopy(opt)
190 | opt.local_rank = rank
191 | global_rank = rank + opt.node_rank * opt.n_gpu_per_node
192 | global_size = opt.num_proc_node * opt.n_gpu_per_node
193 | opt.global_rank = global_rank
194 | opt.global_size = global_size
195 | print('Node rank %d, local proc %d, global proc %d, global_size %d' % (opt.node_rank, rank, global_rank, global_size))
196 |
197 | p = Process(target=init_processes, args=(global_rank, global_size, main, opt))
198 | p.start()
199 | processes.append(p)
200 |
201 | for p in processes:
202 | p.join()
203 | else:
204 | torch.cuda.set_device(0)
205 | opt.global_rank = 0
206 | opt.local_rank = 0
207 | opt.global_size = 1
208 | init_processes(0, opt.n_gpu_per_node, main, opt)
209 |
--------------------------------------------------------------------------------