├── EDM2_sample.py ├── LICENSE ├── README.md ├── VAR_sample.py └── assets └── grid_ddo.jpg /EDM2_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Generate random images using the given model.""" 9 | 10 | import math 11 | import os 12 | import re 13 | import warnings 14 | import click 15 | import tqdm 16 | import pickle 17 | import numpy as np 18 | import torch 19 | import PIL.Image 20 | import dnnlib 21 | from torch_utils import distributed as dist 22 | 23 | warnings.filterwarnings('ignore', '`resume_download` is deprecated') 24 | warnings.filterwarnings('ignore', 'You are using `torch.load` with `weights_only=False`') 25 | warnings.filterwarnings('ignore', '1Torch was not compiled with flash attention') 26 | 27 | 28 | #---------------------------------------------------------------------------- 29 | # DPM-Solver-v3 from the paper 30 | # "DPM-Solver-v3: Improved Diffusion ODE Solver with Empirical Model Statistics", 31 | # using a training-free version. 32 | 33 | class NoiseScheduleEDM: 34 | def marginal_log_mean_coeff(self, t): 35 | """ 36 | Compute log(alpha_t) of a given continuous-time label t in [0, T]. 37 | """ 38 | return torch.zeros_like(t).to(torch.float64) 39 | 40 | def marginal_alpha(self, t): 41 | """ 42 | Compute alpha_t of a given continuous-time label t in [0, T]. 43 | """ 44 | return torch.ones_like(t).to(torch.float64) 45 | 46 | def marginal_std(self, t): 47 | """ 48 | Compute sigma_t of a given continuous-time label t in [0, T]. 49 | """ 50 | return t.to(torch.float64) 51 | 52 | def marginal_lambda(self, t): 53 | """ 54 | Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. 55 | """ 56 | return -torch.log(t).to(torch.float64) 57 | 58 | def inverse_lambda(self, lamb): 59 | """ 60 | Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. 61 | """ 62 | return torch.exp(-lamb).to(torch.float64) 63 | 64 | 65 | def weighted_cumsumexp_trapezoid(a, x, b, cumsum=True): 66 | # ∫ b*e^a dx 67 | # Input: a,x,b: shape (N+1,...) 68 | # Output: y: shape (N+1,...) 69 | # y_0 = 0 70 | # y_n = sum_{i=1}^{n} 0.5*(x_{i}-x_{i-1})*(b_{i}*e^{a_{i}}+b_{i-1}*e^{a_{i-1}}) (n from 1 to N) 71 | 72 | assert x.shape[0] == a.shape[0] and x.ndim == a.ndim 73 | if b is not None: 74 | assert a.shape[0] == b.shape[0] and a.ndim == b.ndim 75 | 76 | a_max = np.amax(a, axis=0, keepdims=True) 77 | 78 | if b is not None: 79 | b = np.asarray(b) 80 | tmp = b * np.exp(a - a_max) 81 | else: 82 | tmp = np.exp(a - a_max) 83 | 84 | out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1]) 85 | if not cumsum: 86 | return np.sum(out, axis=0) * np.exp(a_max) 87 | out = np.cumsum(out, axis=0) 88 | out *= np.exp(a_max) 89 | return np.concatenate([np.zeros_like(out[[0]]), out], axis=0) 90 | 91 | 92 | def weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=True): 93 | 94 | assert x.shape[0] == a.shape[0] and x.ndim == a.ndim 95 | if b is not None: 96 | assert a.shape[0] == b.shape[0] and a.ndim == b.ndim 97 | 98 | a_max = torch.amax(a, dim=0, keepdims=True) 99 | 100 | if b is not None: 101 | tmp = b * torch.exp(a - a_max) 102 | else: 103 | tmp = torch.exp(a - a_max) 104 | 105 | out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1]) 106 | if not cumsum: 107 | return torch.sum(out, dim=0) * torch.exp(a_max) 108 | out = torch.cumsum(out, dim=0) 109 | out *= torch.exp(a_max) 110 | return torch.concat([torch.zeros_like(out[[0]]), out], dim=0) 111 | 112 | 113 | def index_list(lst, index): 114 | new_lst = [] 115 | for i in index: 116 | new_lst.append(lst[i]) 117 | return new_lst 118 | 119 | 120 | class DPM_Solver_v3: 121 | def __init__(self, statistics_steps, noise_schedule, steps=10, t_start=None, t_end=None, skip_type="logSNR", device="cuda"): 122 | # precompute 123 | self.device = device 124 | self.noise_schedule = noise_schedule 125 | self.steps = steps 126 | t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end 127 | t_T = self.noise_schedule.T if t_start is None else t_start 128 | assert ( 129 | t_0 > 0 and t_T > 0 130 | ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" 131 | 132 | self.statistics_steps = statistics_steps 133 | ts = noise_schedule.marginal_lambda(self.get_time_steps("logSNR", t_T, t_0, self.statistics_steps, "cpu")).numpy()[:, None, None, None] 134 | self.ts = torch.from_numpy(ts).cuda() 135 | self.lambda_T = self.ts[0].cpu().item() 136 | self.lambda_0 = self.ts[-1].cpu().item() 137 | shape = (statistics_steps + 1, 1, 1, 1) 138 | l = np.ones(shape) 139 | s = np.zeros(shape) 140 | b = np.zeros(shape) 141 | z = np.zeros_like(l) 142 | o = np.ones_like(l) 143 | L = weighted_cumsumexp_trapezoid(z, ts, l) 144 | S = weighted_cumsumexp_trapezoid(z, ts, s) 145 | 146 | I = weighted_cumsumexp_trapezoid(L + S, ts, o) 147 | B = weighted_cumsumexp_trapezoid(-S, ts, b) 148 | C = weighted_cumsumexp_trapezoid(L + S, ts, B) 149 | self.l = torch.from_numpy(l).cuda() 150 | self.s = torch.from_numpy(s).cuda() 151 | self.b = torch.from_numpy(b).cuda() 152 | self.L = torch.from_numpy(L).cuda() 153 | self.S = torch.from_numpy(S).cuda() 154 | self.I = torch.from_numpy(I).cuda() 155 | self.B = torch.from_numpy(B).cuda() 156 | self.C = torch.from_numpy(C).cuda() 157 | 158 | # precompute timesteps 159 | if skip_type == "logSNR" or skip_type == "time_uniform" or skip_type == "time_quadratic": 160 | self.timesteps = self.get_time_steps(skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) 161 | self.indexes = self.convert_to_indexes(self.timesteps) 162 | self.timesteps = self.convert_to_timesteps(self.indexes, device) 163 | elif skip_type == "edm": 164 | self.indexes, self.timesteps = self.get_timesteps_edm(N=steps, device=device) 165 | self.timesteps = self.convert_to_timesteps(self.indexes, device) 166 | else: 167 | raise ValueError(f"Unsupported timestep strategy {skip_type}") 168 | 169 | # store high-order exponential coefficients (lazy) 170 | self.exp_coeffs = {} 171 | 172 | def noise_prediction_fn(self, x, t): 173 | """ 174 | Return the noise prediction model. 175 | """ 176 | return self.model(x, t) 177 | 178 | def convert_to_indexes(self, timesteps): 179 | logSNR_steps = self.noise_schedule.marginal_lambda(timesteps) 180 | indexes = list( 181 | (self.statistics_steps * (logSNR_steps - self.lambda_T) / (self.lambda_0 - self.lambda_T)).round().cpu().numpy().astype(np.int64) 182 | ) 183 | return indexes 184 | 185 | def convert_to_timesteps(self, indexes, device): 186 | logSNR_steps = self.lambda_T + (self.lambda_0 - self.lambda_T) * torch.Tensor(indexes).to(device) / self.statistics_steps 187 | return self.noise_schedule.inverse_lambda(logSNR_steps) 188 | 189 | def get_time_steps(self, skip_type, t_T, t_0, N, device): 190 | """Compute the intermediate time steps for sampling. 191 | 192 | Args: 193 | skip_type: A `str`. The type for the spacing of the time steps. We support three types: 194 | - 'logSNR': uniform logSNR for the time steps. 195 | - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) 196 | - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) 197 | t_T: A `float`. The starting time of the sampling (default is T). 198 | t_0: A `float`. The ending time of the sampling (default is epsilon). 199 | N: A `int`. The total number of the spacing of the time steps. 200 | device: A torch device. 201 | Returns: 202 | A pytorch tensor of the time steps, with the shape (N + 1,). 203 | """ 204 | if skip_type == "logSNR": 205 | lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) 206 | lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) 207 | logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) 208 | return self.noise_schedule.inverse_lambda(logSNR_steps) 209 | elif skip_type == "time_uniform": 210 | return torch.linspace(t_T, t_0, N + 1).to(device) 211 | elif skip_type == "time_quadratic": 212 | t_order = 2 213 | t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) 214 | return t 215 | else: 216 | raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) 217 | 218 | def get_timesteps_edm(self, N, device): 219 | """Constructs the noise schedule of Karras et al. (2022).""" 220 | 221 | rho = 7.0 222 | 223 | sigma_min: float = np.exp(-self.lambda_0) 224 | sigma_max: float = np.exp(-self.lambda_T) 225 | ramp = np.linspace(0, 1, N + 1) 226 | min_inv_rho = sigma_min ** (1 / rho) 227 | max_inv_rho = sigma_max ** (1 / rho) 228 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 229 | lambdas = torch.Tensor(-np.log(sigmas)).to(device) 230 | timesteps = self.noise_schedule.inverse_lambda(lambdas) 231 | 232 | indexes = list((self.statistics_steps * (lambdas - self.lambda_T) / (self.lambda_0 - self.lambda_T)).round().cpu().numpy().astype(np.int64)) 233 | return indexes, timesteps 234 | 235 | def get_g(self, f_t, i_s, i_t): 236 | return torch.exp(self.S[i_s] - self.S[i_t]) * f_t - torch.exp(self.S[i_s]) * (self.B[i_t] - self.B[i_s]) 237 | 238 | def compute_exponential_coefficients_high_order(self, i_s, i_t, order=2): 239 | key = (i_s, i_t, order) 240 | if key in self.exp_coeffs.keys(): 241 | coeffs = self.exp_coeffs[key] 242 | else: 243 | n = order - 1 244 | a = self.L[i_s : i_t + 1] + self.S[i_s : i_t + 1] - self.L[i_s] - self.S[i_s] 245 | x = self.ts[i_s : i_t + 1] 246 | b = (self.ts[i_s : i_t + 1] - self.ts[i_s]) ** n / math.factorial(n) 247 | coeffs = weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=False) 248 | self.exp_coeffs[key] = coeffs 249 | return coeffs 250 | 251 | def compute_high_order_derivatives(self, n, lambda_0n, g_0n, pseudo=False): 252 | # return g^(1), ..., g^(n) 253 | if pseudo: 254 | D = [[] for _ in range(n + 1)] 255 | D[0] = g_0n 256 | for i in range(1, n + 1): 257 | for j in range(n - i + 1): 258 | D[i].append((D[i - 1][j] - D[i - 1][j + 1]) / (lambda_0n[j] - lambda_0n[i + j])) 259 | 260 | return [D[i][0] * math.factorial(i) for i in range(1, n + 1)] 261 | else: 262 | R = [] 263 | for i in range(1, n + 1): 264 | R.append(torch.pow(lambda_0n[1:] - lambda_0n[0], i)) 265 | R = torch.stack(R).t() 266 | B = (torch.stack(g_0n[1:]) - g_0n[0]).reshape(n, -1) 267 | shape = g_0n[0].shape 268 | solution = torch.linalg.inv(R) @ B 269 | solution = solution.reshape([n] + list(shape)) 270 | return [solution[i - 1] * math.factorial(i) for i in range(1, n + 1)] 271 | 272 | def multistep_predictor_update(self, x_lst, eps_lst, time_lst, index_lst, t, i_t, order=1, pseudo=False): 273 | # x_lst: [..., x_s] 274 | # eps_lst: [..., eps_s] 275 | # time_lst: [..., time_s] 276 | ns = self.noise_schedule 277 | n = order - 1 278 | indexes = [-i - 1 for i in range(n + 1)] 279 | x_0n = index_list(x_lst, indexes) 280 | eps_0n = index_list(eps_lst, indexes) 281 | time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda() 282 | index_0n = index_list(index_lst, indexes) 283 | lambda_0n = ns.marginal_lambda(time_0n) 284 | alpha_0n = ns.marginal_alpha(time_0n) 285 | sigma_0n = ns.marginal_std(time_0n) 286 | 287 | alpha_s, alpha_t = alpha_0n[0], ns.marginal_alpha(t) 288 | i_s = index_0n[0] 289 | x_s = x_0n[0] 290 | g_0n = [] 291 | for i in range(n + 1): 292 | f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i] 293 | g_i = self.get_g(f_i, index_0n[0], index_0n[i]) 294 | g_0n.append(g_i) 295 | g_0 = g_0n[0] 296 | x_t = ( 297 | alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s 298 | - alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0 299 | - alpha_t * torch.exp(-self.L[i_t]) * (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s])) 300 | ) 301 | if order > 1: 302 | g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo) 303 | for i in range(order - 1): 304 | x_t = ( 305 | x_t 306 | - alpha_t 307 | * torch.exp(self.L[i_s] - self.L[i_t]) 308 | * self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2) 309 | * g_d[i] 310 | ) 311 | return x_t 312 | 313 | def multistep_corrector_update(self, x_lst, eps_lst, time_lst, index_lst, order=1, pseudo=False): 314 | # x_lst: [..., x_s, x_t] 315 | # eps_lst: [..., eps_s, eps_t] 316 | # lambda_lst: [..., lambda_s, lambda_t] 317 | ns = self.noise_schedule 318 | n = order - 1 319 | indexes = [-i - 1 for i in range(n + 1)] 320 | indexes[0] = -2 321 | indexes[1] = -1 322 | x_0n = index_list(x_lst, indexes) 323 | eps_0n = index_list(eps_lst, indexes) 324 | time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda() 325 | index_0n = index_list(index_lst, indexes) 326 | lambda_0n = ns.marginal_lambda(time_0n) 327 | alpha_0n = ns.marginal_alpha(time_0n) 328 | sigma_0n = ns.marginal_std(time_0n) 329 | 330 | alpha_s, alpha_t = alpha_0n[0], alpha_0n[1] 331 | i_s, i_t = index_0n[0], index_0n[1] 332 | x_s = x_0n[0] 333 | g_0n = [] 334 | for i in range(n + 1): 335 | f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i] 336 | g_i = self.get_g(f_i, index_0n[0], index_0n[i]) 337 | g_0n.append(g_i) 338 | g_0 = g_0n[0] 339 | x_t_new = ( 340 | alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s 341 | - alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0 342 | - alpha_t * torch.exp(-self.L[i_t]) * (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s])) 343 | ) 344 | if order > 1: 345 | g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo) 346 | for i in range(order - 1): 347 | x_t_new = ( 348 | x_t_new 349 | - alpha_t 350 | * torch.exp(self.L[i_s] - self.L[i_t]) 351 | * self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2) 352 | * g_d[i] 353 | ) 354 | return x_t_new 355 | 356 | @torch.no_grad() 357 | def sample(self, model_fn, x, order, p_pseudo, use_corrector, c_pseudo, lower_order_final, return_intermediate=False): 358 | self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) 359 | steps = self.steps 360 | cached_x = [] 361 | cached_model_output = [] 362 | cached_time = [] 363 | cached_index = [] 364 | indexes, timesteps = self.indexes, self.timesteps 365 | step_p_order = 0 366 | 367 | for step in range(1, steps + 1): 368 | cached_x.append(x) 369 | cached_model_output.append(self.noise_prediction_fn(x, timesteps[step - 1])) 370 | cached_time.append(timesteps[step - 1]) 371 | cached_index.append(indexes[step - 1]) 372 | if use_corrector: 373 | step_c_order = step_p_order + c_pseudo 374 | if step_c_order > 1: 375 | x_new = self.multistep_corrector_update( 376 | cached_x, cached_model_output, cached_time, cached_index, order=step_c_order, pseudo=c_pseudo 377 | ) 378 | sigma_t = self.noise_schedule.marginal_std(cached_time[-1]) 379 | l_t = self.l[cached_index[-1]] 380 | N_old = sigma_t * cached_model_output[-1] - l_t * cached_x[-1] 381 | cached_x[-1] = x_new 382 | cached_model_output[-1] = (N_old + l_t * cached_x[-1]) / sigma_t 383 | if step < order: 384 | step_p_order = step 385 | else: 386 | step_p_order = order 387 | if lower_order_final: 388 | step_p_order = min(step_p_order, steps + 1 - step) 389 | t = timesteps[step] 390 | i_t = indexes[step] 391 | 392 | x = self.multistep_predictor_update(cached_x, cached_model_output, cached_time, cached_index, t, i_t, order=step_p_order, pseudo=p_pseudo) 393 | 394 | if return_intermediate: 395 | return x, cached_x 396 | else: 397 | return x 398 | 399 | 400 | def model_wrapper(denoier, noise_schedule): 401 | def noise_pred_fn(x, t, **kwargs): 402 | output = denoier(x, t, **kwargs) 403 | alpha_t, sigma_t = noise_schedule.marginal_alpha(t), noise_schedule.marginal_std(t) 404 | return (x - alpha_t[:, None, None, None] * output) / sigma_t[:, None, None, None] 405 | 406 | def model_fn(x, t): 407 | return noise_pred_fn(x, t).to(torch.float64) 408 | 409 | return model_fn 410 | 411 | 412 | def get_dpmv3_sampler(num_steps=32, sigma_min=0.002, sigma_max=80): 413 | ns = NoiseScheduleEDM() 414 | 415 | dpm_solver_v3 = DPM_Solver_v3(250, ns, steps=num_steps, t_start=sigma_max, t_end=sigma_min, skip_type="edm", device="cuda") 416 | 417 | def dpm_solver_v3_sampler(model_fn, z): 418 | with torch.no_grad(): 419 | x = dpm_solver_v3.sample(model_fn, z, order=2, p_pseudo=False, use_corrector=True, c_pseudo=False, lower_order_final=True) 420 | return x 421 | 422 | return dpm_solver_v3_sampler 423 | 424 | 425 | def edm_sampler( 426 | net, noise, labels=None, gnet=None, 427 | num_steps=32, sigma_min=0.002, sigma_max=80, rho=7, guidance=1, 428 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 429 | dtype=torch.float32, randn_like=torch.randn_like, sampler=None, 430 | ): 431 | # Guided denoiser. 432 | def denoise(x, t): 433 | Dx = net(x, t, labels).to(dtype) 434 | if guidance == 1: 435 | return Dx 436 | ref_Dx = gnet(x, t, labels).to(dtype) 437 | return ref_Dx.lerp(Dx, guidance) 438 | 439 | if sampler is not None: 440 | ns = NoiseScheduleEDM() 441 | noise_pred_fn = model_wrapper(denoise, ns) 442 | return sampler(noise_pred_fn, noise * sigma_max) 443 | 444 | #---------------------------------------------------------------------------- 445 | # EDM sampler from the paper 446 | # "Elucidating the Design Space of Diffusion-Based Generative Models", 447 | # extended to support guidance. 448 | 449 | # Time step discretization. 450 | step_indices = torch.arange(num_steps, dtype=dtype, device=noise.device) 451 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 452 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 453 | 454 | # Main sampling loop. 455 | x_next = noise.to(dtype) * t_steps[0] 456 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 457 | x_cur = x_next 458 | 459 | # Increase noise temporarily. 460 | if S_churn > 0 and S_min <= t_cur <= S_max: 461 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) 462 | t_hat = t_cur + gamma * t_cur 463 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 464 | else: 465 | t_hat = t_cur 466 | x_hat = x_cur 467 | 468 | # Euler step. 469 | d_cur = (x_hat - denoise(x_hat, t_hat)) / t_hat 470 | x_next = x_hat + (t_next - t_hat) * d_cur 471 | 472 | # Apply 2nd order correction. 473 | if i < num_steps - 1: 474 | d_prime = (x_next - denoise(x_next, t_next)) / t_next 475 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 476 | 477 | return x_next 478 | 479 | #---------------------------------------------------------------------------- 480 | # Wrapper for torch.Generator that allows specifying a different random seed 481 | # for each sample in a minibatch. 482 | 483 | class StackedRandomGenerator: 484 | def __init__(self, device, seeds): 485 | super().__init__() 486 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] 487 | 488 | def randn(self, size, **kwargs): 489 | assert size[0] == len(self.generators) 490 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) 491 | 492 | def randn_like(self, input): 493 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) 494 | 495 | def randint(self, *args, size, **kwargs): 496 | assert size[0] == len(self.generators) 497 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) 498 | 499 | #---------------------------------------------------------------------------- 500 | # Generate images for the given seeds in a distributed fashion. 501 | # Returns an iterable that yields 502 | # dnnlib.EasyDict(images, labels, noise, batch_idx, num_batches, indices, seeds) 503 | 504 | def generate_images( 505 | net, # Main network. Path, URL, or torch.nn.Module. 506 | gnet = None, # Guiding network. None = same as main network. 507 | encoder = None, # Instance of training.encoders.Encoder. None = load from network pickle. 508 | outdir = None, # Where to save the output images. None = do not save. 509 | subdirs = False, # Create subdirectory for every 1000 seeds? 510 | seeds = range(16, 24), # List of random seeds. 511 | class_idx = None, # Class label. None = select randomly. 512 | max_batch_size = 32, # Maximum batch size for the diffusion model. 513 | encoder_batch_size = 4, # Maximum batch size for the encoder. None = default. 514 | verbose = True, # Enable status prints? 515 | device = torch.device('cuda'), # Which compute device to use. 516 | dpmv3 = False, # Use DPM-Solver-v3 for faster inference? 517 | **sampler_kwargs, # Additional arguments for the sampler function. 518 | ): 519 | # Rank 0 goes first. 520 | if dist.get_rank() != 0: 521 | torch.distributed.barrier() 522 | 523 | # Load main network. 524 | if isinstance(net, str): 525 | if verbose: 526 | dist.print0(f'Loading main network from {net} ...') 527 | with dnnlib.util.open_url(net, verbose=(verbose and dist.get_rank() == 0)) as f: 528 | data = pickle.load(f) 529 | net = data['ema'].to(device) 530 | net.use_fp16 = True 531 | net.force_fp32 = False 532 | if encoder is None: 533 | encoder = data.get('encoder', None) 534 | if encoder is None: 535 | encoder = dnnlib.util.construct_class_by_name(class_name='training.encoders.StandardRGBEncoder') 536 | assert net is not None 537 | 538 | # Load guidance network. 539 | if isinstance(gnet, str): 540 | if verbose: 541 | dist.print0(f'Loading guiding network from {gnet} ...') 542 | with dnnlib.util.open_url(gnet, verbose=(verbose and dist.get_rank() == 0)) as f: 543 | gnet = pickle.load(f)['ema'].to(device) 544 | if gnet is None: 545 | gnet = net 546 | 547 | # Initialize encoder. 548 | assert encoder is not None 549 | if verbose: 550 | dist.print0(f'Setting up {type(encoder).__name__}...') 551 | encoder.init(device) 552 | if encoder_batch_size is not None and hasattr(encoder, 'batch_size'): 553 | encoder.batch_size = encoder_batch_size 554 | 555 | # Other ranks follow. 556 | if dist.get_rank() == 0: 557 | torch.distributed.barrier() 558 | 559 | # Divide seeds into batches. 560 | num_batches = max((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1, 1) * dist.get_world_size() 561 | rank_batches = np.array_split(np.arange(len(seeds)), num_batches)[dist.get_rank() :: dist.get_world_size()] 562 | if net.label_dim > 0: 563 | all_class_labels = torch.arange(net.label_dim, device=device, dtype=torch.int64).repeat(len(seeds) // net.label_dim).tensor_split(num_batches) 564 | rank_class_labels = all_class_labels[dist.get_rank() :: dist.get_world_size()] 565 | if verbose: 566 | dist.print0(f'Generating {len(seeds)} images...') 567 | 568 | sampler = get_dpmv3_sampler(sampler_kwargs["num_steps"], sampler_kwargs["sigma_min"], sampler_kwargs["sigma_max"]) if dpmv3 else None 569 | # Return an iterable over the batches. 570 | class ImageIterable: 571 | def __len__(self): 572 | return len(rank_batches) 573 | 574 | def __iter__(self): 575 | # Loop over batches. 576 | for batch_idx, indices in enumerate(rank_batches): 577 | r = dnnlib.EasyDict(images=None, labels=None, noise=None, batch_idx=batch_idx, num_batches=len(rank_batches), indices=indices) 578 | r.seeds = [seeds[idx] for idx in indices] 579 | if len(r.seeds) > 0: 580 | 581 | # Pick noise and labels. 582 | rnd = StackedRandomGenerator(device, r.seeds) 583 | r.noise = rnd.randn([len(r.seeds), net.img_channels, net.img_resolution, net.img_resolution], device=device) 584 | r.labels = None 585 | if net.label_dim > 0: 586 | r.labels = torch.eye(net.label_dim, device=device)[rank_class_labels[batch_idx]] 587 | if class_idx is not None: 588 | r.labels[:, :] = 0 589 | r.labels[:, class_idx] = 1 590 | 591 | # Generate images. 592 | latents = dnnlib.util.call_func_by_name(func_name=edm_sampler, net=net, noise=r.noise, 593 | labels=r.labels, gnet=gnet, randn_like=rnd.randn_like, sampler=sampler, **sampler_kwargs) 594 | r.images = encoder.decode(latents) 595 | 596 | # Save images. 597 | if outdir is not None: 598 | for seed, image in zip(r.seeds, r.images.permute(0, 2, 3, 1).cpu().numpy()): 599 | image_dir = os.path.join(outdir, f'{seed//1000*1000:06d}') if subdirs else outdir 600 | os.makedirs(image_dir, exist_ok=True) 601 | PIL.Image.fromarray(image, 'RGB').save(os.path.join(image_dir, f'{seed:06d}.png')) 602 | 603 | # Yield results. 604 | torch.distributed.barrier() # keep the ranks in sync 605 | yield r 606 | 607 | return ImageIterable() 608 | 609 | #---------------------------------------------------------------------------- 610 | # Parse a comma separated list of numbers or ranges and return a list of ints. 611 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 612 | 613 | def parse_int_list(s): 614 | if isinstance(s, list): 615 | return s 616 | ranges = [] 617 | range_re = re.compile(r'^(\d+)-(\d+)$') 618 | for p in s.split(','): 619 | m = range_re.match(p) 620 | if m: 621 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 622 | else: 623 | ranges.append(int(p)) 624 | return ranges 625 | 626 | #---------------------------------------------------------------------------- 627 | # Command line interface. 628 | 629 | @click.command() 630 | @click.option('--net', help='Main network pickle filename', metavar='PATH|URL', type=str, default=None) 631 | @click.option('--gnet', help='Guiding network pickle filename', metavar='PATH|URL', type=str, default=None) 632 | @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True) 633 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True) 634 | @click.option('--seeds', help='List of random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='16-19', show_default=True) 635 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None) 636 | @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=32, show_default=True) 637 | 638 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=32, show_default=True) 639 | @click.option('--sigma_min', help='Lowest noise level', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=0.002, show_default=True) 640 | @click.option('--sigma_max', help='Highest noise level', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=80, show_default=True) 641 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True) 642 | @click.option('--guidance', help='Guidance strength [default: 1; no guidance]', metavar='FLOAT', type=float, default=None) 643 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 644 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 645 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True) 646 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True) 647 | @click.option('--dpmv3', help='Use DPM-Solver-v3 for faster inference', is_flag=True) 648 | 649 | def cmdline(**opts): 650 | """Generate random images using the given model. 651 | """ 652 | opts = dnnlib.EasyDict(opts) 653 | 654 | # Validate options. 655 | if opts.net is None: 656 | raise click.ClickException('Please specify --net') 657 | if opts.guidance is None or opts.guidance == 1: 658 | opts.guidance = 1 659 | opts.gnet = None 660 | elif opts.gnet is None: 661 | raise click.ClickException('Please specify --gnet when using guidance') 662 | 663 | # Generate. 664 | dist.init() 665 | image_iter = generate_images(**opts) 666 | for _r in tqdm.tqdm(image_iter, unit='batch', disable=(dist.get_rank() != 0)): 667 | pass 668 | 669 | #---------------------------------------------------------------------------- 670 | 671 | if __name__ == "__main__": 672 | cmdline() 673 | 674 | #---------------------------------------------------------------------------- 675 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | 3 | Attribution-NonCommercial-ShareAlike 4.0 International 4 | 5 | ======================================================================= 6 | 7 | Creative Commons Corporation ("Creative Commons") is not a law firm and 8 | does not provide legal services or legal advice. Distribution of 9 | Creative Commons public licenses does not create a lawyer-client or 10 | other relationship. Creative Commons makes its licenses and related 11 | information available on an "as-is" basis. Creative Commons gives no 12 | warranties regarding its licenses, any material licensed under their 13 | terms and conditions, or any related information. Creative Commons 14 | disclaims all liability for damages resulting from their use to the 15 | fullest extent possible. 16 | 17 | Using Creative Commons Public Licenses 18 | 19 | Creative Commons public licenses provide a standard set of terms and 20 | conditions that creators and other rights holders may use to share 21 | original works of authorship and other material subject to copyright 22 | and certain other rights specified in the public license below. The 23 | following considerations are for informational purposes only, are not 24 | exhaustive, and do not form part of our licenses. 25 | 26 | Considerations for licensors: Our public licenses are 27 | intended for use by those authorized to give the public 28 | permission to use material in ways otherwise restricted by 29 | copyright and certain other rights. Our licenses are 30 | irrevocable. Licensors should read and understand the terms 31 | and conditions of the license they choose before applying it. 32 | Licensors should also secure all rights necessary before 33 | applying our licenses so that the public can reuse the 34 | material as expected. Licensors should clearly mark any 35 | material not subject to the license. This includes other CC- 36 | licensed material, or material used under an exception or 37 | limitation to copyright. More considerations for licensors: 38 | wiki.creativecommons.org/Considerations_for_licensors 39 | 40 | Considerations for the public: By using one of our public 41 | licenses, a licensor grants the public permission to use the 42 | licensed material under specified terms and conditions. If 43 | the licensor's permission is not necessary for any reason--for 44 | example, because of any applicable exception or limitation to 45 | copyright--then that use is not regulated by the license. Our 46 | licenses grant only permissions under copyright and certain 47 | other rights that a licensor has authority to grant. Use of 48 | the licensed material may still be restricted for other 49 | reasons, including because others have copyright or other 50 | rights in the material. A licensor may make special requests, 51 | such as asking that all changes be marked or described. 52 | Although not required by our licenses, you are encouraged to 53 | respect those requests where reasonable. More considerations 54 | for the public: 55 | wiki.creativecommons.org/Considerations_for_licensees 56 | 57 | ======================================================================= 58 | 59 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 60 | Public License 61 | 62 | By exercising the Licensed Rights (defined below), You accept and agree 63 | to be bound by the terms and conditions of this Creative Commons 64 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 65 | ("Public License"). To the extent this Public License may be 66 | interpreted as a contract, You are granted the Licensed Rights in 67 | consideration of Your acceptance of these terms and conditions, and the 68 | Licensor grants You such rights in consideration of benefits the 69 | Licensor receives from making the Licensed Material available under 70 | these terms and conditions. 71 | 72 | 73 | Section 1 -- Definitions. 74 | 75 | a. Adapted Material means material subject to Copyright and Similar 76 | Rights that is derived from or based upon the Licensed Material 77 | and in which the Licensed Material is translated, altered, 78 | arranged, transformed, or otherwise modified in a manner requiring 79 | permission under the Copyright and Similar Rights held by the 80 | Licensor. For purposes of this Public License, where the Licensed 81 | Material is a musical work, performance, or sound recording, 82 | Adapted Material is always produced where the Licensed Material is 83 | synched in timed relation with a moving image. 84 | 85 | b. Adapter's License means the license You apply to Your Copyright 86 | and Similar Rights in Your contributions to Adapted Material in 87 | accordance with the terms and conditions of this Public License. 88 | 89 | c. BY-NC-SA Compatible License means a license listed at 90 | creativecommons.org/compatiblelicenses, approved by Creative 91 | Commons as essentially the equivalent of this Public License. 92 | 93 | d. Copyright and Similar Rights means copyright and/or similar rights 94 | closely related to copyright including, without limitation, 95 | performance, broadcast, sound recording, and Sui Generis Database 96 | Rights, without regard to how the rights are labeled or 97 | categorized. For purposes of this Public License, the rights 98 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 99 | Rights. 100 | 101 | e. Effective Technological Measures means those measures that, in the 102 | absence of proper authority, may not be circumvented under laws 103 | fulfilling obligations under Article 11 of the WIPO Copyright 104 | Treaty adopted on December 20, 1996, and/or similar international 105 | agreements. 106 | 107 | f. Exceptions and Limitations means fair use, fair dealing, and/or 108 | any other exception or limitation to Copyright and Similar Rights 109 | that applies to Your use of the Licensed Material. 110 | 111 | g. License Elements means the license attributes listed in the name 112 | of a Creative Commons Public License. The License Elements of this 113 | Public License are Attribution, NonCommercial, and ShareAlike. 114 | 115 | h. Licensed Material means the artistic or literary work, database, 116 | or other material to which the Licensor applied this Public 117 | License. 118 | 119 | i. Licensed Rights means the rights granted to You subject to the 120 | terms and conditions of this Public License, which are limited to 121 | all Copyright and Similar Rights that apply to Your use of the 122 | Licensed Material and that the Licensor has authority to license. 123 | 124 | j. Licensor means the individual(s) or entity(ies) granting rights 125 | under this Public License. 126 | 127 | k. NonCommercial means not primarily intended for or directed towards 128 | commercial advantage or monetary compensation. For purposes of 129 | this Public License, the exchange of the Licensed Material for 130 | other material subject to Copyright and Similar Rights by digital 131 | file-sharing or similar means is NonCommercial provided there is 132 | no payment of monetary compensation in connection with the 133 | exchange. 134 | 135 | l. Share means to provide material to the public by any means or 136 | process that requires permission under the Licensed Rights, such 137 | as reproduction, public display, public performance, distribution, 138 | dissemination, communication, or importation, and to make material 139 | available to the public including in ways that members of the 140 | public may access the material from a place and at a time 141 | individually chosen by them. 142 | 143 | m. Sui Generis Database Rights means rights other than copyright 144 | resulting from Directive 96/9/EC of the European Parliament and of 145 | the Council of 11 March 1996 on the legal protection of databases, 146 | as amended and/or succeeded, as well as other essentially 147 | equivalent rights anywhere in the world. 148 | 149 | n. You means the individual or entity exercising the Licensed Rights 150 | under this Public License. Your has a corresponding meaning. 151 | 152 | 153 | Section 2 -- Scope. 154 | 155 | a. License grant. 156 | 157 | 1. Subject to the terms and conditions of this Public License, 158 | the Licensor hereby grants You a worldwide, royalty-free, 159 | non-sublicensable, non-exclusive, irrevocable license to 160 | exercise the Licensed Rights in the Licensed Material to: 161 | 162 | a. reproduce and Share the Licensed Material, in whole or 163 | in part, for NonCommercial purposes only; and 164 | 165 | b. produce, reproduce, and Share Adapted Material for 166 | NonCommercial purposes only. 167 | 168 | 2. Exceptions and Limitations. For the avoidance of doubt, where 169 | Exceptions and Limitations apply to Your use, this Public 170 | License does not apply, and You do not need to comply with 171 | its terms and conditions. 172 | 173 | 3. Term. The term of this Public License is specified in Section 174 | 6(a). 175 | 176 | 4. Media and formats; technical modifications allowed. The 177 | Licensor authorizes You to exercise the Licensed Rights in 178 | all media and formats whether now known or hereafter created, 179 | and to make technical modifications necessary to do so. The 180 | Licensor waives and/or agrees not to assert any right or 181 | authority to forbid You from making technical modifications 182 | necessary to exercise the Licensed Rights, including 183 | technical modifications necessary to circumvent Effective 184 | Technological Measures. For purposes of this Public License, 185 | simply making modifications authorized by this Section 2(a) 186 | (4) never produces Adapted Material. 187 | 188 | 5. Downstream recipients. 189 | 190 | a. Offer from the Licensor -- Licensed Material. Every 191 | recipient of the Licensed Material automatically 192 | receives an offer from the Licensor to exercise the 193 | Licensed Rights under the terms and conditions of this 194 | Public License. 195 | 196 | b. Additional offer from the Licensor -- Adapted Material. 197 | Every recipient of Adapted Material from You 198 | automatically receives an offer from the Licensor to 199 | exercise the Licensed Rights in the Adapted Material 200 | under the conditions of the Adapter's License You apply. 201 | 202 | c. No downstream restrictions. You may not offer or impose 203 | any additional or different terms or conditions on, or 204 | apply any Effective Technological Measures to, the 205 | Licensed Material if doing so restricts exercise of the 206 | Licensed Rights by any recipient of the Licensed 207 | Material. 208 | 209 | 6. No endorsement. Nothing in this Public License constitutes or 210 | may be construed as permission to assert or imply that You 211 | are, or that Your use of the Licensed Material is, connected 212 | with, or sponsored, endorsed, or granted official status by, 213 | the Licensor or others designated to receive attribution as 214 | provided in Section 3(a)(1)(A)(i). 215 | 216 | b. Other rights. 217 | 218 | 1. Moral rights, such as the right of integrity, are not 219 | licensed under this Public License, nor are publicity, 220 | privacy, and/or other similar personality rights; however, to 221 | the extent possible, the Licensor waives and/or agrees not to 222 | assert any such rights held by the Licensor to the limited 223 | extent necessary to allow You to exercise the Licensed 224 | Rights, but not otherwise. 225 | 226 | 2. Patent and trademark rights are not licensed under this 227 | Public License. 228 | 229 | 3. To the extent possible, the Licensor waives any right to 230 | collect royalties from You for the exercise of the Licensed 231 | Rights, whether directly or through a collecting society 232 | under any voluntary or waivable statutory or compulsory 233 | licensing scheme. In all other cases the Licensor expressly 234 | reserves any right to collect such royalties, including when 235 | the Licensed Material is used other than for NonCommercial 236 | purposes. 237 | 238 | 239 | Section 3 -- License Conditions. 240 | 241 | Your exercise of the Licensed Rights is expressly made subject to the 242 | following conditions. 243 | 244 | a. Attribution. 245 | 246 | 1. If You Share the Licensed Material (including in modified 247 | form), You must: 248 | 249 | a. retain the following if it is supplied by the Licensor 250 | with the Licensed Material: 251 | 252 | i. identification of the creator(s) of the Licensed 253 | Material and any others designated to receive 254 | attribution, in any reasonable manner requested by 255 | the Licensor (including by pseudonym if 256 | designated); 257 | 258 | ii. a copyright notice; 259 | 260 | iii. a notice that refers to this Public License; 261 | 262 | iv. a notice that refers to the disclaimer of 263 | warranties; 264 | 265 | v. a URI or hyperlink to the Licensed Material to the 266 | extent reasonably practicable; 267 | 268 | b. indicate if You modified the Licensed Material and 269 | retain an indication of any previous modifications; and 270 | 271 | c. indicate the Licensed Material is licensed under this 272 | Public License, and include the text of, or the URI or 273 | hyperlink to, this Public License. 274 | 275 | 2. You may satisfy the conditions in Section 3(a)(1) in any 276 | reasonable manner based on the medium, means, and context in 277 | which You Share the Licensed Material. For example, it may be 278 | reasonable to satisfy the conditions by providing a URI or 279 | hyperlink to a resource that includes the required 280 | information. 281 | 3. If requested by the Licensor, You must remove any of the 282 | information required by Section 3(a)(1)(A) to the extent 283 | reasonably practicable. 284 | 285 | b. ShareAlike. 286 | 287 | In addition to the conditions in Section 3(a), if You Share 288 | Adapted Material You produce, the following conditions also apply. 289 | 290 | 1. The Adapter's License You apply must be a Creative Commons 291 | license with the same License Elements, this version or 292 | later, or a BY-NC-SA Compatible License. 293 | 294 | 2. You must include the text of, or the URI or hyperlink to, the 295 | Adapter's License You apply. You may satisfy this condition 296 | in any reasonable manner based on the medium, means, and 297 | context in which You Share Adapted Material. 298 | 299 | 3. You may not offer or impose any additional or different terms 300 | or conditions on, or apply any Effective Technological 301 | Measures to, Adapted Material that restrict exercise of the 302 | rights granted under the Adapter's License You apply. 303 | 304 | 305 | Section 4 -- Sui Generis Database Rights. 306 | 307 | Where the Licensed Rights include Sui Generis Database Rights that 308 | apply to Your use of the Licensed Material: 309 | 310 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 311 | to extract, reuse, reproduce, and Share all or a substantial 312 | portion of the contents of the database for NonCommercial purposes 313 | only; 314 | 315 | b. if You include all or a substantial portion of the database 316 | contents in a database in which You have Sui Generis Database 317 | Rights, then the database in which You have Sui Generis Database 318 | Rights (but not its individual contents) is Adapted Material, 319 | including for purposes of Section 3(b); and 320 | 321 | c. You must comply with the conditions in Section 3(a) if You Share 322 | all or a substantial portion of the contents of the database. 323 | 324 | For the avoidance of doubt, this Section 4 supplements and does not 325 | replace Your obligations under this Public License where the Licensed 326 | Rights include other Copyright and Similar Rights. 327 | 328 | 329 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 330 | 331 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 332 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 333 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 334 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 335 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 336 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 337 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 338 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 339 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 340 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 341 | 342 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 343 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 344 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 345 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 346 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 347 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 348 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 349 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 350 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 351 | 352 | c. The disclaimer of warranties and limitation of liability provided 353 | above shall be interpreted in a manner that, to the extent 354 | possible, most closely approximates an absolute disclaimer and 355 | waiver of all liability. 356 | 357 | 358 | Section 6 -- Term and Termination. 359 | 360 | a. This Public License applies for the term of the Copyright and 361 | Similar Rights licensed here. However, if You fail to comply with 362 | this Public License, then Your rights under this Public License 363 | terminate automatically. 364 | 365 | b. Where Your right to use the Licensed Material has terminated under 366 | Section 6(a), it reinstates: 367 | 368 | 1. automatically as of the date the violation is cured, provided 369 | it is cured within 30 days of Your discovery of the 370 | violation; or 371 | 372 | 2. upon express reinstatement by the Licensor. 373 | 374 | For the avoidance of doubt, this Section 6(b) does not affect any 375 | right the Licensor may have to seek remedies for Your violations 376 | of this Public License. 377 | 378 | c. For the avoidance of doubt, the Licensor may also offer the 379 | Licensed Material under separate terms or conditions or stop 380 | distributing the Licensed Material at any time; however, doing so 381 | will not terminate this Public License. 382 | 383 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 384 | License. 385 | 386 | 387 | Section 7 -- Other Terms and Conditions. 388 | 389 | a. The Licensor shall not be bound by any additional or different 390 | terms or conditions communicated by You unless expressly agreed. 391 | 392 | b. Any arrangements, understandings, or agreements regarding the 393 | Licensed Material not stated herein are separate from and 394 | independent of the terms and conditions of this Public License. 395 | 396 | 397 | Section 8 -- Interpretation. 398 | 399 | a. For the avoidance of doubt, this Public License does not, and 400 | shall not be interpreted to, reduce, limit, restrict, or impose 401 | conditions on any use of the Licensed Material that could lawfully 402 | be made without permission under this Public License. 403 | 404 | b. To the extent possible, if any provision of this Public License is 405 | deemed unenforceable, it shall be automatically reformed to the 406 | minimum extent necessary to make it enforceable. If the provision 407 | cannot be reformed, it shall be severed from this Public License 408 | without affecting the enforceability of the remaining terms and 409 | conditions. 410 | 411 | c. No term or condition of this Public License will be waived and no 412 | failure to comply consented to unless expressly agreed to by the 413 | Licensor. 414 | 415 | d. Nothing in this Public License constitutes or may be interpreted 416 | as a limitation upon, or waiver of, any privileges and immunities 417 | that apply to the Licensor or You, including from the legal 418 | processes of any jurisdiction or authority. 419 | 420 | ======================================================================= 421 | 422 | Creative Commons is not a party to its public 423 | licenses. Notwithstanding, Creative Commons may elect to apply one of 424 | its public licenses to material it publishes and in those instances 425 | will be considered the "Licensor." The text of the Creative Commons 426 | public licenses is dedicated to the public domain under the CC0 Public 427 | Domain Dedication. Except for the limited purpose of indicating that 428 | material is shared under a Creative Commons public license or as 429 | otherwise permitted by the Creative Commons policies published at 430 | creativecommons.org/policies, Creative Commons does not authorize the 431 | use of the trademark "Creative Commons" or any other trademark or logo 432 | of Creative Commons without its prior written consent including, 433 | without limitation, in connection with any unauthorized modifications 434 | to any of its public licenses or any other arrangements, 435 | understandings, or agreements concerning use of licensed material. For 436 | the avoidance of doubt, this paragraph does not form part of the 437 | public licenses. 438 | 439 | Creative Commons may be contacted at creativecommons.org. 440 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DDO: A Universal Supercharger for Visual Diffusion/Autoregressive Models 🚀 SOTA on ImageNet 2 | 3 |
19 | ICML 2025 Spotlight 20 |
21 |
24 |
25 |
26 | FID=1.26 on ImageNet 512x512, without any guidance. 27 |
28 |29 | 30 | 31 | ## Introduction 32 | 33 | **Direct Discriminative Optimization (DDO)** enables **GAN-style finetuning** of likelihood-based generative models, such as **diffusion and autoregressive** models, without requiring an explicit discriminator network. By incorporating **reverse KL divergence** and **self-generated negative samples**—similar in spirit to **reinforcement learning methods used in large language models**—DDO overcomes the limitations of traditional **maximum likelihood training**, which relies on **forward KL** and often leads to **mode-covering behavior**. As a result, DDO can **substantially improve the generation quality without changing the network architecture or inference protocol**. 34 | 35 | ## Available Models 36 | We release our finetuned checkpoints in [nvidia/DirectDiscriminativeOptimization · Hugging Face](https://huggingface.co/nvidia/DirectDiscriminativeOptimization). 37 | 38 | | Model Files | Base Repository | 39 | |---------------------------------------------------------|-------------------------------------------------------------| 40 | | `edm-cifar10-uncond-vp-ddo.pkl`,`edm-cifar10-cond-vp-ddo.pkl` | [EDM](https://github.com/NVlabs/edm) | 41 | | `edm2-img64-s-ddo.pkl`,`edm2-img512-l-ddo.pkl` | [EDM2](https://github.com/NVlabs/edm2) | 42 | | `var_d16-ddo.pth`,`var_d30-ddo.pth` | [VAR](https://github.com/FoundationVision/VAR) | 43 | 44 | | Model | #Parameters | FID w/o guidance | FID w/ guidance | 45 | |---------------------------|------------------|------------------|-----------------| 46 | | `edm-cifar10-uncond-vp-ddo.pkl` | 56M | 1.38 | - | 47 | | `edm-cifar10-cond-vp-ddo.pkl` | 56M | 1.30 | - | 48 | | `edm2-img64-s-ddo.pkl` | 280M | 0.97 | - | 49 | | `edm2-img512-l-ddo.pkl` | 777M | 1.26 | 1.21 | 50 | | `var_d16-ddo.pth` | 310M | 3.12 | 2.54 | 51 | | `var_d30-ddo.pth` | 2.0B | 1.79 | 1.73 | 52 | 53 | ## Reproducing FID Results 54 | 55 | Each model type (EDM-based, EDM2-based, and VAR-based) should be used within its corresponding base repository and environment. Please ensure you have cloned the appropriate repository and set up its environment. 56 | 57 | #### EDM-based 58 | 59 | We use the EDM's original inference code. 60 | 61 | - Generate samples 62 | 63 | ```bash 64 | # Generate 50000 images and save them as fid-tmp/*/*.png 65 | 66 | # For CIFAR-10 (unconditional) 67 | torchrun --nproc_per_node=8 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs --network=https://huggingface.co/nvidia/DirectDiscriminativeOptimization/resolve/main/edm-cifar10-uncond-vp-ddo.pkl 68 | 69 | # For CIFAR-10 (conditional) 70 | torchrun --nproc_per_node=8 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs --network=https://huggingface.co/nvidia/DirectDiscriminativeOptimization/resolve/main/edm-cifar10-cond-vp-ddo.pkl 71 | ``` 72 | 73 | - Calculate FID 74 | 75 | ```bash 76 | torchrun --nproc_per_node=8 fid.py calc --images=fid-tmp --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 77 | ``` 78 | 79 | #### EDM2-based 80 | 81 | We provide `EDM2_sample.py`, which performs class rebalancing and integrates [DPM-Solver-v3](https://arxiv.org/abs/2310.13268) for accelerated sampling. Please place this script inside the EDM2 repository. 82 | 83 | - Generate samples 84 | 85 | ```bash 86 | # Generate 50000 images and save them as fid-tmp/*/*.png 87 | 88 | # For ImageNet-64 89 | torchrun --nproc_per_node=8 EDM2_sample.py --outdir=fid-tmp --seeds=0-49999 --subdirs --net=https://huggingface.co/nvidia/DirectDiscriminativeOptimization/resolve/main/edm2-img64-s-ddo.pkl 90 | 91 | # For ImageNet 512x512 92 | torchrun --nproc_per_node=8 EDM2_sample.py --outdir=fid-tmp --seeds=0-49999 --subdirs --net=https://huggingface.co/nvidia/DirectDiscriminativeOptimization/resolve/main/edm2-img512-l-ddo.pkl 93 | 94 | # For ImageNet 512x512 (with guidance) 95 | torchrun --nproc_per_node=8 EDM2_sample.py --outdir=fid-tmp --seeds=0-49999 --subdirs --net=https://huggingface.co/nvidia/DirectDiscriminativeOptimization/resolve/main/edm2-img512-l-ddo.pkl --gnet=https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/edm2-img512-xs-0134217-0.165.pkl --guidance=1.1 96 | 97 | # For ImageNet 512x512 (DPM-Solver-v3) 98 | torchrun --nproc_per_node=8 EDM2_sample.py --dpmv3 --steps=25 --outdir=fid-tmp --seeds=0-49999 --subdirs --net=https://huggingface.co/nvidia/DirectDiscriminativeOptimization/resolve/main/edm2-img512-l-ddo.pkl 99 | 100 | # For ImageNet 512x512 (with guidance, DPM-Solver-v3) 101 | torchrun --nproc_per_node=8 EDM2_sample.py --dpmv3 --steps=25 --outdir=fid-tmp --seeds=0-49999 --subdirs --net=https://huggingface.co/nvidia/DirectDiscriminativeOptimization/resolve/main/edm2-img512-l-ddo.pkl --gnet=https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/edm2-img512-xs-0134217-0.165.pkl --guidance=1.1 102 | ``` 103 | 104 | - Calculate FID 105 | 106 | ```bash 107 | # For ImageNet-64 108 | torchrun --nproc_per_node=8 calculate_metrics.py calc --metrics=fid --images=fid-tmp --ref=https://nvlabs-fi-cdn.nvidia.com/edm2/dataset-refs/img64.pkl 109 | 110 | # For ImageNet 512x512 111 | torchrun --nproc_per_node=8 calculate_metrics.py calc --metrics=fid --images=fid-tmp --ref=https://nvlabs-fi-cdn.nvidia.com/edm2/dataset-refs/img512.pkl 112 | ``` 113 | 114 | #### VAR-based 115 | 116 | We provide `VAR_sample.py`, an extended version of VAR's original sampling demo. Please place this script in the VAR repository. 117 | 118 | - Generate samples 119 | 120 | ```bash 121 | # Generate 50000 images, save both .png and compressed .npz files in samples/ 122 | 123 | # For ImageNet 256x256 124 | # $DEPTH in {16,30} is supported 125 | # $CFG can be set to 1.0 (guidance-free) or 1.4 (best FID) 126 | torchrun --nproc_per_node=8 VAR_sample.py --depth $DEPTH --cfg $CFG 127 | ``` 128 | 129 | - Calculate FID 130 | 131 | Please use the [OpenAI's FID evaluation toolkit](https://github.com/openai/guided-diffusion/tree/main/evaluations) and reference stats file [VIRTUAL_imagenet256_labeled.npz](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) to evaluate FID, IS, precision, and recall. 132 | 133 | ## License 134 | 135 | Copyright © 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 136 | 137 | All materials, including source code and pre-trained models, are licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-nc-sa/4.0/). 138 | 139 | ## Citation 140 | If our work assists your research, feel free to give us a star ⭐ or cite us using: 141 | ``` 142 | @article{zheng2025direct, 143 | title={Direct Discriminative Optimization: Your Likelihood-Based Visual Generative Model is Secretly a GAN Discriminator}, 144 | author={Zheng, Kaiwen and Chen, Yongxin and Chen, Huayu and He, Guande and Liu, Ming-Yu and Zhu, Jun and Zhang, Qinsheng}, 145 | journal={arXiv preprint arXiv:2503.01103}, 146 | year={2025} 147 | } 148 | ``` 149 | -------------------------------------------------------------------------------- /VAR_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | ################## 1. Download checkpoints and build models 9 | import os 10 | import torch 11 | import random 12 | from tqdm import tqdm 13 | import numpy as np 14 | import PIL.Image as PImage 15 | 16 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) # disable default parameter init for faster speed 17 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) # disable default parameter init for faster speed 18 | from models import build_vae_var 19 | from PIL import Image 20 | import argparse 21 | 22 | import dist 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--cfg", type=float, default=1.0) 27 | parser.add_argument("--depth", type=int, default=16) 28 | parser.add_argument("--sample_dir", type=str, default="./samples") 29 | parser.add_argument("--trick", type=bool, default=False) 30 | 31 | args = parser.parse_args() 32 | 33 | MODEL_DEPTH = args.depth # TODO: =====> please specify MODEL_DEPTH <===== 34 | assert MODEL_DEPTH in {16, 20, 24, 30} 35 | 36 | dist.initialize() 37 | 38 | vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}-ddo.pth' 39 | 40 | # download checkpoint 41 | if dist.get_rank() == 0: 42 | if not os.path.exists(vae_ckpt): os.system(f'wget https://huggingface.co/FoundationVision/var/resolve/main/{vae_ckpt}') 43 | if not os.path.exists(var_ckpt): os.system(f'wget https://huggingface.co/nvidia/DirectDiscriminativeOptimization/resolve/main/{var_ckpt}') 44 | 45 | torch.distributed.barrier() 46 | 47 | # build vae, var 48 | patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16) 49 | device = "cuda" if torch.cuda.is_available() else "cpu" 50 | if "vae" not in globals() or "var" not in globals(): 51 | vae, var = build_vae_var( 52 | V=4096, 53 | Cvae=32, 54 | ch=160, 55 | share_quant_resi=4, # hard-coded VQVAE hyperparameters 56 | device=device, 57 | patch_nums=patch_nums, 58 | num_classes=1000, 59 | depth=MODEL_DEPTH, 60 | shared_aln=False, 61 | ) 62 | 63 | # load checkpoints 64 | vae.load_state_dict(torch.load(vae_ckpt, map_location="cpu"), strict=True) 65 | var.load_state_dict(torch.load(var_ckpt, map_location="cpu"), strict=True) 66 | vae.eval(), var.eval() 67 | for p in vae.parameters(): 68 | p.requires_grad_(False) 69 | for p in var.parameters(): 70 | p.requires_grad_(False) 71 | print(f"prepare finished.") 72 | 73 | ############################# 2. Sample with classifier-free guidance 74 | 75 | # set args 76 | seed = 1 # @param {type:"number"} 77 | cfg = args.cfg # @param {type:"slider", min:1, max:10, step:0.1} 78 | more_smooth = False # True for more smooth output 79 | 80 | # seed 81 | torch.manual_seed(seed) 82 | random.seed(seed) 83 | np.random.seed(seed) 84 | torch.backends.cudnn.deterministic = True 85 | torch.backends.cudnn.benchmark = False 86 | 87 | # run faster 88 | tf32 = True 89 | torch.backends.cudnn.allow_tf32 = bool(tf32) 90 | torch.backends.cuda.matmul.allow_tf32 = bool(tf32) 91 | torch.set_float32_matmul_precision("high" if tf32 else "highest") 92 | 93 | path_parts = var_ckpt.replace(".pth", "").replace(".pt", "").split("/") 94 | ckpt_string_name = f"{path_parts[-1]}" 95 | 96 | folder_name = f"d{MODEL_DEPTH}-{ckpt_string_name}-" f"cfg-{args.cfg}-seed-{seed}" 97 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" 98 | os.makedirs(sample_folder_dir, exist_ok=True) 99 | 100 | total_classes = 1000 101 | rank_classes = np.array_split(np.arange(total_classes), dist.get_world_size())[dist.get_rank()] 102 | 103 | # sample 104 | B = 25 105 | for img_cls in tqdm(rank_classes, disable=(dist.get_rank() != 0)): 106 | for i in range(50 // B): 107 | label_B = torch.tensor([img_cls] * B, device=device) 108 | with torch.inference_mode(): 109 | with torch.autocast("cuda", enabled=True, dtype=torch.float16, cache_enabled=True): # using bfloat16 can be faster 110 | recon_B3HW = var.autoregressive_infer_cfg( 111 | B=B, 112 | label_B=label_B, 113 | cfg=cfg, 114 | top_k=900 if args.trick else 0, 115 | top_p=0.96 if args.trick else 0, 116 | more_smooth=more_smooth, 117 | g_seed=int(seed + img_cls * (50 // B) + i), 118 | ) 119 | bchw = recon_B3HW.permute(0, 2, 3, 1).mul_(255).cpu().numpy() 120 | bchw = bchw.astype(np.uint8) 121 | for j in range(B): 122 | img = PImage.fromarray(bchw[j]) 123 | img.save(f"{sample_folder_dir}/{(img_cls * 50 + i * B + j):06d}.png") 124 | 125 | 126 | def create_npz_from_sample_folder(sample_dir, num=50_000): 127 | """ 128 | Builds a single .npz file from a folder of .png samples. 129 | """ 130 | samples, label = [], [] 131 | for i in tqdm(range(num), desc="Building .npz file from samples"): 132 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 133 | sample_np = np.asarray(sample_pil).astype(np.uint8) 134 | samples.append(sample_np) 135 | label.append(i // 50) 136 | samples = np.stack(samples) 137 | label = np.asarray(label) 138 | p = np.random.permutation(num) 139 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 140 | npz_path = f"{sample_dir}.npz" 141 | # np.savez(npz_path, samples=samples[p], label=label[p]) 142 | np.savez(npz_path, arr_0=samples[p]) 143 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 144 | return npz_path 145 | 146 | 147 | torch.distributed.barrier() 148 | if dist.get_rank() == 0: 149 | create_npz_from_sample_folder(sample_folder_dir) 150 | -------------------------------------------------------------------------------- /assets/grid_ddo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DDO/19fbf23d5131ee4724abf1fed5f5caf62adbf551/assets/grid_ddo.jpg --------------------------------------------------------------------------------