├── .gitignore ├── README.md ├── diffusion_tf ├── __init__.py ├── diffusion_utils.py ├── diffusion_utils_2.py ├── models │ ├── __init__.py │ └── unet.py ├── nn.py ├── tpu_utils │ ├── __init__.py │ ├── classifier_metrics_numpy.py │ ├── datasets.py │ ├── simple_eval_worker.py │ ├── tpu_summaries.py │ └── tpu_utils.py └── utils.py ├── requirements.txt ├── resources └── samples.png └── scripts ├── __init__.py ├── run_celebahq.py ├── run_cifar.py └── run_lsun.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Denoising Diffusion Probabilistic Models 2 | 3 | Jonathan Ho, Ajay Jain, Pieter Abbeel 4 | 5 | Paper: https://arxiv.org/abs/2006.11239 6 | 7 | Website: https://hojonathanho.github.io/diffusion 8 | 9 | ![Samples generated by our model](resources/samples.png) 10 | 11 | Experiments run on Google Cloud TPU v3-8. 12 | Requires TensorFlow 1.15 and Python 3.5, and these dependencies for CPU instances (see `requirements.txt`): 13 | ``` 14 | pip3 install fire 15 | pip3 install scipy 16 | pip3 install pillow 17 | pip3 install tensorflow-probability==0.8 18 | pip3 install tensorflow-gan==0.0.0.dev0 19 | pip3 install tensorflow-datasets==2.1.0 20 | ``` 21 | 22 | The training and evaluation scripts are in the `scripts/` subdirectory. 23 | The commands to run training and evaluation are in comments at the top of the scripts. 24 | Data is stored in GCS buckets. The scripts are written to assume that the bucket names are of the form `gs://mybucketprefix-us-central1`; i.e. some prefix followed by the region. 25 | The prefix should be passed into the scripts using the `--bucket_name_prefix` flag. 26 | 27 | Models and samples can be found at: https://www.dropbox.com/sh/pm6tn31da21yrx4/AABWKZnBzIROmDjGxpB6vn6Ja 28 | 29 | ## Citation 30 | If you find our work relevant to your research, please cite: 31 | ``` 32 | @article{ho2020denoising, 33 | title={Denoising Diffusion Probabilistic Models}, 34 | author={Jonathan Ho and Ajay Jain and Pieter Abbeel}, 35 | year={2020}, 36 | journal={arXiv preprint arxiv:2006.11239} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /diffusion_tf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hojonathanho/diffusion/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/__init__.py -------------------------------------------------------------------------------- /diffusion_tf/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow.compat.v1 as tf 3 | 4 | from . import nn 5 | 6 | 7 | def normal_kl(mean1, logvar1, mean2, logvar2): 8 | """ 9 | KL divergence between normal distributions parameterized by mean and log-variance. 10 | """ 11 | return 0.5 * (-1.0 + logvar2 - logvar1 + tf.exp(logvar1 - logvar2) 12 | + tf.squared_difference(mean1, mean2) * tf.exp(-logvar2)) 13 | 14 | 15 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 16 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 17 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 18 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 19 | return betas 20 | 21 | 22 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 23 | if beta_schedule == 'quad': 24 | betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2 25 | elif beta_schedule == 'linear': 26 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 27 | elif beta_schedule == 'warmup10': 28 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 29 | elif beta_schedule == 'warmup50': 30 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 31 | elif beta_schedule == 'const': 32 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 33 | elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 34 | betas = 1. / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) 35 | else: 36 | raise NotImplementedError(beta_schedule) 37 | assert betas.shape == (num_diffusion_timesteps,) 38 | return betas 39 | 40 | 41 | def noise_like(shape, noise_fn=tf.random_normal, repeat=False, dtype=tf.float32): 42 | repeat_noise = lambda: tf.repeat(noise_fn(shape=(1, *shape[1:]), dtype=dtype), repeats=shape[0], axis=0) 43 | noise = lambda: noise_fn(shape=shape, dtype=dtype) 44 | return repeat_noise() if repeat else noise() 45 | 46 | 47 | class GaussianDiffusion: 48 | """ 49 | Contains utilities for the diffusion model. 50 | """ 51 | 52 | def __init__(self, *, betas, loss_type, tf_dtype=tf.float32): 53 | self.loss_type = loss_type 54 | 55 | assert isinstance(betas, np.ndarray) 56 | self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy 57 | assert (betas > 0).all() and (betas <= 1).all() 58 | timesteps, = betas.shape 59 | self.num_timesteps = int(timesteps) 60 | 61 | alphas = 1. - betas 62 | alphas_cumprod = np.cumprod(alphas, axis=0) 63 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 64 | assert alphas_cumprod_prev.shape == (timesteps,) 65 | 66 | self.betas = tf.constant(betas, dtype=tf_dtype) 67 | self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf_dtype) 68 | self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf_dtype) 69 | 70 | # calculations for diffusion q(x_t | x_{t-1}) and others 71 | self.sqrt_alphas_cumprod = tf.constant(np.sqrt(alphas_cumprod), dtype=tf_dtype) 72 | self.sqrt_one_minus_alphas_cumprod = tf.constant(np.sqrt(1. - alphas_cumprod), dtype=tf_dtype) 73 | self.log_one_minus_alphas_cumprod = tf.constant(np.log(1. - alphas_cumprod), dtype=tf_dtype) 74 | self.sqrt_recip_alphas_cumprod = tf.constant(np.sqrt(1. / alphas_cumprod), dtype=tf_dtype) 75 | self.sqrt_recipm1_alphas_cumprod = tf.constant(np.sqrt(1. / alphas_cumprod - 1), dtype=tf_dtype) 76 | 77 | # calculations for posterior q(x_{t-1} | x_t, x_0) 78 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 79 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 80 | self.posterior_variance = tf.constant(posterior_variance, dtype=tf_dtype) 81 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 82 | self.posterior_log_variance_clipped = tf.constant(np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf_dtype) 83 | self.posterior_mean_coef1 = tf.constant( 84 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod), dtype=tf_dtype) 85 | self.posterior_mean_coef2 = tf.constant( 86 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod), dtype=tf_dtype) 87 | 88 | @staticmethod 89 | def _extract(a, t, x_shape): 90 | """ 91 | Extract some coefficients at specified timesteps, 92 | then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. 93 | """ 94 | bs, = t.shape 95 | assert x_shape[0] == bs 96 | out = tf.gather(a, t) 97 | assert out.shape == [bs] 98 | return tf.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) 99 | 100 | def q_mean_variance(self, x_start, t): 101 | mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 102 | variance = self._extract(1. - self.alphas_cumprod, t, x_start.shape) 103 | log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) 104 | return mean, variance, log_variance 105 | 106 | def q_sample(self, x_start, t, noise=None): 107 | """ 108 | Diffuse the data (t == 0 means diffused for 1 step) 109 | """ 110 | if noise is None: 111 | noise = tf.random_normal(shape=x_start.shape) 112 | assert noise.shape == x_start.shape 113 | return ( 114 | self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 115 | self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 116 | ) 117 | 118 | def predict_start_from_noise(self, x_t, t, noise): 119 | assert x_t.shape == noise.shape 120 | return ( 121 | self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 122 | self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 123 | ) 124 | 125 | def q_posterior(self, x_start, x_t, t): 126 | """ 127 | Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) 128 | """ 129 | assert x_start.shape == x_t.shape 130 | posterior_mean = ( 131 | self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 132 | self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 133 | ) 134 | posterior_variance = self._extract(self.posterior_variance, t, x_t.shape) 135 | posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape) 136 | assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == 137 | x_start.shape[0]) 138 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 139 | 140 | def p_losses(self, denoise_fn, x_start, t, noise=None): 141 | """ 142 | Training loss calculation 143 | """ 144 | B, H, W, C = x_start.shape.as_list() 145 | assert t.shape == [B] 146 | 147 | if noise is None: 148 | noise = tf.random_normal(shape=x_start.shape, dtype=x_start.dtype) 149 | assert noise.shape == x_start.shape and noise.dtype == x_start.dtype 150 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 151 | x_recon = denoise_fn(x_noisy, t) 152 | assert x_noisy.shape == x_start.shape 153 | assert x_recon.shape[:3] == [B, H, W] and len(x_recon.shape) == 4 154 | 155 | if self.loss_type == 'noisepred': 156 | # predict the noise instead of x_start. seems to be weighted naturally like SNR 157 | assert x_recon.shape == x_start.shape 158 | losses = nn.meanflat(tf.squared_difference(noise, x_recon)) 159 | else: 160 | raise NotImplementedError(self.loss_type) 161 | 162 | assert losses.shape == [B] 163 | return losses 164 | 165 | def p_mean_variance(self, denoise_fn, *, x, t, clip_denoised: bool): 166 | if self.loss_type == 'noisepred': 167 | x_recon = self.predict_start_from_noise(x, t=t, noise=denoise_fn(x, t)) 168 | else: 169 | raise NotImplementedError(self.loss_type) 170 | 171 | if clip_denoised: 172 | x_recon = tf.clip_by_value(x_recon, -1., 1.) 173 | 174 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 175 | assert model_mean.shape == x_recon.shape == x.shape 176 | assert posterior_variance.shape == posterior_log_variance.shape == [x.shape[0], 1, 1, 1] 177 | return model_mean, posterior_variance, posterior_log_variance 178 | 179 | def p_sample(self, denoise_fn, *, x, t, noise_fn, clip_denoised=True, repeat_noise=False): 180 | """ 181 | Sample from the model 182 | """ 183 | model_mean, _, model_log_variance = self.p_mean_variance(denoise_fn, x=x, t=t, clip_denoised=clip_denoised) 184 | noise = noise_like(x.shape, noise_fn, repeat_noise) 185 | assert noise.shape == x.shape 186 | # no noise when t == 0 187 | nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [x.shape[0]] + [1] * (len(x.shape) - 1)) 188 | return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise 189 | 190 | def p_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal): 191 | """ 192 | Generate samples 193 | """ 194 | i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32) 195 | assert isinstance(shape, (tuple, list)) 196 | img_0 = noise_fn(shape=shape, dtype=tf.float32) 197 | _, img_final = tf.while_loop( 198 | cond=lambda i_, _: tf.greater_equal(i_, 0), 199 | body=lambda i_, img_: [ 200 | i_ - 1, 201 | self.p_sample(denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn) 202 | ], 203 | loop_vars=[i_0, img_0], 204 | shape_invariants=[i_0.shape, img_0.shape], 205 | back_prop=False 206 | ) 207 | assert img_final.shape == shape 208 | return img_final 209 | 210 | def p_sample_loop_trajectory(self, denoise_fn, *, shape, noise_fn=tf.random_normal, repeat_noise_steps=-1): 211 | """ 212 | Generate samples, returning intermediate images 213 | Useful for visualizing how denoised images evolve over time 214 | Args: 215 | repeat_noise_steps (int): Number of denoising timesteps in which the same noise 216 | is used across the batch. If >= 0, the initial noise is the same for all batch elemements. 217 | """ 218 | i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32) 219 | assert isinstance(shape, (tuple, list)) 220 | img_0 = noise_like(shape, noise_fn, repeat_noise_steps >= 0) 221 | times = tf.Variable([i_0]) 222 | imgs = tf.Variable([img_0]) 223 | # Steps with repeated noise 224 | times, imgs = tf.while_loop( 225 | cond=lambda times_, _: tf.less_equal(self.num_timesteps - times_[-1], repeat_noise_steps), 226 | body=lambda times_, imgs_: [ 227 | tf.concat([times_, [times_[-1] - 1]], 0), 228 | tf.concat([imgs_, [self.p_sample(denoise_fn=denoise_fn, 229 | x=imgs_[-1], 230 | t=tf.fill([shape[0]], times_[-1]), 231 | noise_fn=noise_fn, 232 | repeat_noise=True)]], 0) 233 | ], 234 | loop_vars=[times, imgs], 235 | shape_invariants=[tf.TensorShape([None, *i_0.shape]), 236 | tf.TensorShape([None, *img_0.shape])], 237 | back_prop=False 238 | ) 239 | # Steps with different noise for each batch element 240 | times, imgs = tf.while_loop( 241 | cond=lambda times_, _: tf.greater_equal(times_[-1], 0), 242 | body=lambda times_, imgs_: [ 243 | tf.concat([times_, [times_[-1] - 1]], 0), 244 | tf.concat([imgs_, [self.p_sample(denoise_fn=denoise_fn, 245 | x=imgs_[-1], 246 | t=tf.fill([shape[0]], times_[-1]), 247 | noise_fn=noise_fn, 248 | repeat_noise=False)]], 0) 249 | ], 250 | loop_vars=[times, imgs], 251 | shape_invariants=[tf.TensorShape([None, *i_0.shape]), 252 | tf.TensorShape([None, *img_0.shape])], 253 | back_prop=False 254 | ) 255 | assert imgs[-1].shape == shape 256 | return times, imgs 257 | 258 | def interpolate(self, denoise_fn, *, shape, noise_fn=tf.random_normal): 259 | """ 260 | Interpolate between images. 261 | t == 0 means diffuse images for 1 timestep before mixing. 262 | """ 263 | assert isinstance(shape, (tuple, list)) 264 | 265 | # Placeholders for real samples to interpolate 266 | x1 = tf.placeholder(tf.float32, shape) 267 | x2 = tf.placeholder(tf.float32, shape) 268 | # lam == 0.5 averages diffused images. 269 | lam = tf.placeholder(tf.float32, shape=()) 270 | t = tf.placeholder(tf.int32, shape=()) 271 | 272 | # Add noise via forward diffusion 273 | # TODO: use the same noise for both endpoints? 274 | # t_batched = tf.constant([t] * x1.shape[0], dtype=tf.int32) 275 | t_batched = tf.stack([t] * x1.shape[0]) 276 | xt1 = self.q_sample(x1, t=t_batched) 277 | xt2 = self.q_sample(x2, t=t_batched) 278 | 279 | # Mix latents 280 | # Linear interpolation 281 | xt_interp = (1 - lam) * xt1 + lam * xt2 282 | # Constant variance interpolation 283 | # xt_interp = tf.sqrt(1 - lam * lam) * xt1 + lam * xt2 284 | 285 | # Reverse diffusion (similar to self.p_sample_loop) 286 | # t = tf.constant(t, dtype=tf.int32) 287 | _, x_interp = tf.while_loop( 288 | cond=lambda i_, _: tf.greater_equal(i_, 0), 289 | body=lambda i_, img_: [ 290 | i_ - 1, 291 | self.p_sample(denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn) 292 | ], 293 | loop_vars=[t, xt_interp], 294 | shape_invariants=[t.shape, xt_interp.shape], 295 | back_prop=False 296 | ) 297 | assert x_interp.shape == shape 298 | 299 | return x1, x2, lam, x_interp, t 300 | -------------------------------------------------------------------------------- /diffusion_tf/diffusion_utils_2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow.compat.v1 as tf 3 | 4 | from . import nn 5 | from . import utils 6 | 7 | 8 | def normal_kl(mean1, logvar1, mean2, logvar2): 9 | """ 10 | KL divergence between normal distributions parameterized by mean and log-variance. 11 | """ 12 | return 0.5 * (-1.0 + logvar2 - logvar1 + tf.exp(logvar1 - logvar2) 13 | + tf.squared_difference(mean1, mean2) * tf.exp(-logvar2)) 14 | 15 | 16 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 17 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 18 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 19 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 20 | return betas 21 | 22 | 23 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 24 | if beta_schedule == 'quad': 25 | betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2 26 | elif beta_schedule == 'linear': 27 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 28 | elif beta_schedule == 'warmup10': 29 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 30 | elif beta_schedule == 'warmup50': 31 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 32 | elif beta_schedule == 'const': 33 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 34 | elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 35 | betas = 1. / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) 36 | else: 37 | raise NotImplementedError(beta_schedule) 38 | assert betas.shape == (num_diffusion_timesteps,) 39 | return betas 40 | 41 | 42 | class GaussianDiffusion2: 43 | """ 44 | Contains utilities for the diffusion model. 45 | 46 | Arguments: 47 | - what the network predicts (x_{t-1}, x_0, or epsilon) 48 | - which loss function (kl or unweighted MSE) 49 | - what is the variance of p(x_{t-1}|x_t) (learned, fixed to beta, or fixed to weighted beta) 50 | - what type of decoder, and how to weight its loss? is its variance learned too? 51 | """ 52 | 53 | def __init__(self, *, betas, model_mean_type, model_var_type, loss_type): 54 | self.model_mean_type = model_mean_type # xprev, xstart, eps 55 | self.model_var_type = model_var_type # learned, fixedsmall, fixedlarge 56 | self.loss_type = loss_type # kl, mse 57 | 58 | assert isinstance(betas, np.ndarray) 59 | self.betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy 60 | assert (betas > 0).all() and (betas <= 1).all() 61 | timesteps, = betas.shape 62 | self.num_timesteps = int(timesteps) 63 | 64 | alphas = 1. - betas 65 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 66 | self.alphas_cumprod_prev = np.append(1., self.alphas_cumprod[:-1]) 67 | assert self.alphas_cumprod_prev.shape == (timesteps,) 68 | 69 | # calculations for diffusion q(x_t | x_{t-1}) and others 70 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 71 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1. - self.alphas_cumprod) 72 | self.log_one_minus_alphas_cumprod = np.log(1. - self.alphas_cumprod) 73 | self.sqrt_recip_alphas_cumprod = np.sqrt(1. / self.alphas_cumprod) 74 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1. / self.alphas_cumprod - 1) 75 | 76 | # calculations for posterior q(x_{t-1} | x_t, x_0) 77 | self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) 78 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 79 | self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) 80 | self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) 81 | self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1. - self.alphas_cumprod) 82 | 83 | @staticmethod 84 | def _extract(a, t, x_shape): 85 | """ 86 | Extract some coefficients at specified timesteps, 87 | then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. 88 | """ 89 | bs, = t.shape 90 | assert x_shape[0] == bs 91 | out = tf.gather(tf.convert_to_tensor(a, dtype=tf.float32), t) 92 | assert out.shape == [bs] 93 | return tf.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) 94 | 95 | def q_mean_variance(self, x_start, t): 96 | mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 97 | variance = self._extract(1. - self.alphas_cumprod, t, x_start.shape) 98 | log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) 99 | return mean, variance, log_variance 100 | 101 | def q_sample(self, x_start, t, noise=None): 102 | """ 103 | Diffuse the data (t == 0 means diffused for 1 step) 104 | """ 105 | if noise is None: 106 | noise = tf.random_normal(shape=x_start.shape) 107 | assert noise.shape == x_start.shape 108 | return ( 109 | self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 110 | self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 111 | ) 112 | 113 | def q_posterior_mean_variance(self, x_start, x_t, t): 114 | """ 115 | Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) 116 | """ 117 | assert x_start.shape == x_t.shape 118 | posterior_mean = ( 119 | self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 120 | self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 121 | ) 122 | posterior_variance = self._extract(self.posterior_variance, t, x_t.shape) 123 | posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape) 124 | assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == 125 | x_start.shape[0]) 126 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 127 | 128 | def p_mean_variance(self, denoise_fn, *, x, t, clip_denoised: bool, return_pred_xstart: bool): 129 | B, H, W, C = x.shape 130 | assert t.shape == [B] 131 | model_output = denoise_fn(x, t) 132 | 133 | # Learned or fixed variance? 134 | if self.model_var_type == 'learned': 135 | assert model_output.shape == [B, H, W, C * 2] 136 | model_output, model_log_variance = tf.split(model_output, 2, axis=-1) 137 | model_variance = tf.exp(model_log_variance) 138 | elif self.model_var_type in ['fixedsmall', 'fixedlarge']: 139 | # below: only log_variance is used in the KL computations 140 | model_variance, model_log_variance = { 141 | # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood 142 | 'fixedlarge': (self.betas, np.log(np.append(self.posterior_variance[1], self.betas[1:]))), 143 | 'fixedsmall': (self.posterior_variance, self.posterior_log_variance_clipped), 144 | }[self.model_var_type] 145 | model_variance = self._extract(model_variance, t, x.shape) * tf.ones(x.shape.as_list()) 146 | model_log_variance = self._extract(model_log_variance, t, x.shape) * tf.ones(x.shape.as_list()) 147 | else: 148 | raise NotImplementedError(self.model_var_type) 149 | 150 | # Mean parameterization 151 | _maybe_clip = lambda x_: (tf.clip_by_value(x_, -1., 1.) if clip_denoised else x_) 152 | if self.model_mean_type == 'xprev': # the model predicts x_{t-1} 153 | pred_xstart = _maybe_clip(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) 154 | model_mean = model_output 155 | elif self.model_mean_type == 'xstart': # the model predicts x_0 156 | pred_xstart = _maybe_clip(model_output) 157 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 158 | elif self.model_mean_type == 'eps': # the model predicts epsilon 159 | pred_xstart = _maybe_clip(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) 160 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 161 | else: 162 | raise NotImplementedError(self.model_mean_type) 163 | 164 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 165 | if return_pred_xstart: 166 | return model_mean, model_variance, model_log_variance, pred_xstart 167 | else: 168 | return model_mean, model_variance, model_log_variance 169 | 170 | def _predict_xstart_from_eps(self, x_t, t, eps): 171 | assert x_t.shape == eps.shape 172 | return ( 173 | self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 174 | self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 175 | ) 176 | 177 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 178 | assert x_t.shape == xprev.shape 179 | return ( # (xprev - coef2*x_t) / coef1 180 | self._extract(1. / self.posterior_mean_coef1, t, x_t.shape) * xprev - 181 | self._extract(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t 182 | ) 183 | 184 | # === Sampling === 185 | 186 | def p_sample(self, denoise_fn, *, x, t, noise_fn, clip_denoised=True, return_pred_xstart: bool): 187 | """ 188 | Sample from the model 189 | """ 190 | model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( 191 | denoise_fn, x=x, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) 192 | noise = noise_fn(shape=x.shape, dtype=x.dtype) 193 | assert noise.shape == x.shape 194 | # no noise when t == 0 195 | nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [x.shape[0]] + [1] * (len(x.shape) - 1)) 196 | sample = model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise 197 | assert sample.shape == pred_xstart.shape 198 | return (sample, pred_xstart) if return_pred_xstart else sample 199 | 200 | def p_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal): 201 | """ 202 | Generate samples 203 | """ 204 | assert isinstance(shape, (tuple, list)) 205 | i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32) 206 | img_0 = noise_fn(shape=shape, dtype=tf.float32) 207 | _, img_final = tf.while_loop( 208 | cond=lambda i_, _: tf.greater_equal(i_, 0), 209 | body=lambda i_, img_: [ 210 | i_ - 1, 211 | self.p_sample( 212 | denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn, return_pred_xstart=False) 213 | ], 214 | loop_vars=[i_0, img_0], 215 | shape_invariants=[i_0.shape, img_0.shape], 216 | back_prop=False 217 | ) 218 | assert img_final.shape == shape 219 | return img_final 220 | 221 | def p_sample_loop_progressive(self, denoise_fn, *, shape, noise_fn=tf.random_normal, include_xstartpred_freq=50): 222 | """ 223 | Generate samples and keep track of prediction of x0 224 | """ 225 | assert isinstance(shape, (tuple, list)) 226 | i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32) 227 | img_0 = noise_fn(shape=shape, dtype=tf.float32) # [B, H, W, C] 228 | 229 | num_recorded_xstartpred = self.num_timesteps // include_xstartpred_freq 230 | xstartpreds_0 = tf.zeros([shape[0], num_recorded_xstartpred, *shape[1:]], dtype=tf.float32) # [B, N, H, W, C] 231 | 232 | def _loop_body(i_, img_, xstartpreds_): 233 | # Sample p(x_{t-1} | x_t) as usual 234 | sample, pred_xstart = self.p_sample( 235 | denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn, return_pred_xstart=True) 236 | assert sample.shape == pred_xstart.shape == shape 237 | # Keep track of prediction of x0 238 | insert_mask = tf.equal(tf.floordiv(i_, include_xstartpred_freq), 239 | tf.range(num_recorded_xstartpred, dtype=tf.int32)) 240 | insert_mask = tf.reshape(tf.cast(insert_mask, dtype=tf.float32), 241 | [1, num_recorded_xstartpred, *([1] * len(shape[1:]))]) # [1, N, 1, 1, 1] 242 | new_xstartpreds = insert_mask * pred_xstart[:, None, ...] + (1. - insert_mask) * xstartpreds_ 243 | return [i_ - 1, sample, new_xstartpreds] 244 | 245 | _, img_final, xstartpreds_final = tf.while_loop( 246 | cond=lambda i_, img_, xstartpreds_: tf.greater_equal(i_, 0), 247 | body=_loop_body, 248 | loop_vars=[i_0, img_0, xstartpreds_0], 249 | shape_invariants=[i_0.shape, img_0.shape, xstartpreds_0.shape], 250 | back_prop=False 251 | ) 252 | assert img_final.shape == shape and xstartpreds_final.shape == xstartpreds_0.shape 253 | return img_final, xstartpreds_final # xstart predictions should agree with img_final at step 0 254 | 255 | # === Log likelihood calculation === 256 | 257 | def _vb_terms_bpd(self, denoise_fn, x_start, x_t, t, *, clip_denoised: bool, return_pred_xstart: bool): 258 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) 259 | model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( 260 | denoise_fn, x=x_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) 261 | kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) 262 | kl = nn.meanflat(kl) / np.log(2.) 263 | 264 | decoder_nll = -utils.discretized_gaussian_log_likelihood( 265 | x_start, means=model_mean, log_scales=0.5 * model_log_variance) 266 | assert decoder_nll.shape == x_start.shape 267 | decoder_nll = nn.meanflat(decoder_nll) / np.log(2.) 268 | 269 | # At the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 270 | assert kl.shape == decoder_nll.shape == t.shape == [x_start.shape[0]] 271 | output = tf.where(tf.equal(t, 0), decoder_nll, kl) 272 | return (output, pred_xstart) if return_pred_xstart else output 273 | 274 | def training_losses(self, denoise_fn, x_start, t, noise=None): 275 | """ 276 | Training loss calculation 277 | """ 278 | 279 | # Add noise to data 280 | assert t.shape == [x_start.shape[0]] 281 | if noise is None: 282 | noise = tf.random_normal(shape=x_start.shape, dtype=x_start.dtype) 283 | assert noise.shape == x_start.shape and noise.dtype == x_start.dtype 284 | x_t = self.q_sample(x_start=x_start, t=t, noise=noise) 285 | 286 | # Calculate the loss 287 | if self.loss_type == 'kl': # the variational bound 288 | losses = self._vb_terms_bpd( 289 | denoise_fn=denoise_fn, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, return_pred_xstart=False) 290 | elif self.loss_type == 'mse': # unweighted MSE 291 | assert self.model_var_type != 'learned' 292 | target = { 293 | 'xprev': self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], 294 | 'xstart': x_start, 295 | 'eps': noise 296 | }[self.model_mean_type] 297 | model_output = denoise_fn(x_t, t) 298 | assert model_output.shape == target.shape == x_start.shape 299 | losses = nn.meanflat(tf.squared_difference(target, model_output)) 300 | else: 301 | raise NotImplementedError(self.loss_type) 302 | 303 | assert losses.shape == t.shape 304 | return losses 305 | 306 | def _prior_bpd(self, x_start): 307 | B, T = x_start.shape[0], self.num_timesteps 308 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=tf.fill([B], tf.constant(T - 1, dtype=tf.int32))) 309 | kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0., logvar2=0.) 310 | assert kl_prior.shape == x_start.shape 311 | return nn.meanflat(kl_prior) / np.log(2.) 312 | 313 | def calc_bpd_loop(self, denoise_fn, x_start, *, clip_denoised=True): 314 | (B, H, W, C), T = x_start.shape, self.num_timesteps 315 | 316 | def _loop_body(t_, cur_vals_bt_, cur_mse_bt_): 317 | assert t_.shape == [] 318 | t_b = tf.fill([B], t_) 319 | # Calculate VLB term at the current timestep 320 | new_vals_b, pred_xstart = self._vb_terms_bpd( 321 | denoise_fn, x_start=x_start, x_t=self.q_sample(x_start=x_start, t=t_b), t=t_b, 322 | clip_denoised=clip_denoised, return_pred_xstart=True) 323 | # MSE for progressive prediction loss 324 | assert pred_xstart.shape == x_start.shape 325 | new_mse_b = nn.meanflat(tf.squared_difference(pred_xstart, x_start)) 326 | assert new_vals_b.shape == new_mse_b.shape == [B] 327 | # Insert the calculated term into the tensor of all terms 328 | mask_bt = tf.cast(tf.equal(t_b[:, None], tf.range(T)[None, :]), dtype=tf.float32) 329 | new_vals_bt = cur_vals_bt_ * (1. - mask_bt) + new_vals_b[:, None] * mask_bt 330 | new_mse_bt = cur_mse_bt_ * (1. - mask_bt) + new_mse_b[:, None] * mask_bt 331 | assert mask_bt.shape == cur_vals_bt_.shape == new_vals_bt.shape == [B, T] 332 | return t_ - 1, new_vals_bt, new_mse_bt 333 | 334 | t_0 = tf.constant(T - 1, dtype=tf.int32) 335 | terms_0 = tf.zeros([B, T]) 336 | mse_0 = tf.zeros([B, T]) 337 | _, terms_bpd_bt, mse_bt = tf.while_loop( # Note that this can be implemented with tf.map_fn instead 338 | cond=lambda t_, cur_vals_bt_, cur_mse_bt_: tf.greater_equal(t_, 0), 339 | body=_loop_body, 340 | loop_vars=[t_0, terms_0, mse_0], 341 | shape_invariants=[t_0.shape, terms_0.shape, mse_0.shape], 342 | back_prop=False 343 | ) 344 | prior_bpd_b = self._prior_bpd(x_start) 345 | total_bpd_b = tf.reduce_sum(terms_bpd_bt, axis=1) + prior_bpd_b 346 | assert terms_bpd_bt.shape == mse_bt.shape == [B, T] and total_bpd_b.shape == prior_bpd_b.shape == [B] 347 | return total_bpd_b, terms_bpd_bt, prior_bpd_b, mse_bt 348 | -------------------------------------------------------------------------------- /diffusion_tf/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hojonathanho/diffusion/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/__init__.py -------------------------------------------------------------------------------- /diffusion_tf/models/unet.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | import tensorflow.contrib as tf_contrib 3 | 4 | from .. import nn 5 | 6 | 7 | def nonlinearity(x): 8 | return tf.nn.swish(x) 9 | 10 | 11 | def normalize(x, *, temb, name): 12 | return tf_contrib.layers.group_norm(x, scope=name) 13 | 14 | 15 | def upsample(x, *, name, with_conv): 16 | with tf.variable_scope(name): 17 | B, H, W, C = x.shape 18 | x = tf.image.resize(x, size=[H * 2, W * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True) 19 | assert x.shape == [B, H * 2, W * 2, C] 20 | if with_conv: 21 | x = nn.conv2d(x, name='conv', num_units=C, filter_size=3, stride=1) 22 | assert x.shape == [B, H * 2, W * 2, C] 23 | return x 24 | 25 | 26 | def downsample(x, *, name, with_conv): 27 | with tf.variable_scope(name): 28 | B, H, W, C = x.shape 29 | if with_conv: 30 | x = nn.conv2d(x, name='conv', num_units=C, filter_size=3, stride=2) 31 | else: 32 | x = tf.nn.avg_pool(x, 2, 2, 'SAME') 33 | assert x.shape == [B, H // 2, W // 2, C] 34 | return x 35 | 36 | 37 | def resnet_block(x, *, temb, name, out_ch=None, conv_shortcut=False, dropout): 38 | B, H, W, C = x.shape 39 | if out_ch is None: 40 | out_ch = C 41 | 42 | with tf.variable_scope(name): 43 | h = x 44 | 45 | h = nonlinearity(normalize(h, temb=temb, name='norm1')) 46 | h = nn.conv2d(h, name='conv1', num_units=out_ch) 47 | 48 | # add in timestep embedding 49 | h += nn.dense(nonlinearity(temb), name='temb_proj', num_units=out_ch)[:, None, None, :] 50 | 51 | h = nonlinearity(normalize(h, temb=temb, name='norm2')) 52 | h = tf.nn.dropout(h, rate=dropout) 53 | h = nn.conv2d(h, name='conv2', num_units=out_ch, init_scale=0.) 54 | 55 | if C != out_ch: 56 | if conv_shortcut: 57 | x = nn.conv2d(x, name='conv_shortcut', num_units=out_ch) 58 | else: 59 | x = nn.nin(x, name='nin_shortcut', num_units=out_ch) 60 | 61 | assert x.shape == h.shape 62 | print('{}: x={} temb={}'.format(tf.get_default_graph().get_name_scope(), x.shape, temb.shape)) 63 | return x + h 64 | 65 | 66 | def attn_block(x, *, name, temb): 67 | B, H, W, C = x.shape 68 | with tf.variable_scope(name): 69 | h = normalize(x, temb=temb, name='norm') 70 | q = nn.nin(h, name='q', num_units=C) 71 | k = nn.nin(h, name='k', num_units=C) 72 | v = nn.nin(h, name='v', num_units=C) 73 | 74 | w = tf.einsum('bhwc,bHWc->bhwHW', q, k) * (int(C) ** (-0.5)) 75 | w = tf.reshape(w, [B, H, W, H * W]) 76 | w = tf.nn.softmax(w, -1) 77 | w = tf.reshape(w, [B, H, W, H, W]) 78 | 79 | h = tf.einsum('bhwHW,bHWc->bhwc', w, v) 80 | h = nn.nin(h, name='proj_out', num_units=C, init_scale=0.) 81 | 82 | assert h.shape == x.shape 83 | print(tf.get_default_graph().get_name_scope(), x.shape) 84 | return x + h 85 | 86 | 87 | def model(x, *, t, y, name, num_classes, reuse=tf.AUTO_REUSE, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, 88 | attn_resolutions, dropout=0., resamp_with_conv=True): 89 | B, S, _, _ = x.shape 90 | assert x.dtype == tf.float32 and x.shape[2] == S 91 | assert t.dtype in [tf.int32, tf.int64] 92 | num_resolutions = len(ch_mult) 93 | 94 | assert num_classes == 1 and y is None, 'not supported' 95 | del y 96 | 97 | with tf.variable_scope(name, reuse=reuse): 98 | # Timestep embedding 99 | with tf.variable_scope('temb'): 100 | temb = nn.get_timestep_embedding(t, ch) 101 | temb = nn.dense(temb, name='dense0', num_units=ch * 4) 102 | temb = nn.dense(nonlinearity(temb), name='dense1', num_units=ch * 4) 103 | assert temb.shape == [B, ch * 4] 104 | 105 | # Downsampling 106 | hs = [nn.conv2d(x, name='conv_in', num_units=ch)] 107 | for i_level in range(num_resolutions): 108 | with tf.variable_scope('down_{}'.format(i_level)): 109 | # Residual blocks for this resolution 110 | for i_block in range(num_res_blocks): 111 | h = resnet_block( 112 | hs[-1], name='block_{}'.format(i_block), temb=temb, out_ch=ch * ch_mult[i_level], dropout=dropout) 113 | if h.shape[1] in attn_resolutions: 114 | h = attn_block(h, name='attn_{}'.format(i_block), temb=temb) 115 | hs.append(h) 116 | # Downsample 117 | if i_level != num_resolutions - 1: 118 | hs.append(downsample(hs[-1], name='downsample', with_conv=resamp_with_conv)) 119 | 120 | # Middle 121 | with tf.variable_scope('mid'): 122 | h = hs[-1] 123 | h = resnet_block(h, temb=temb, name='block_1', dropout=dropout) 124 | h = attn_block(h, name='attn_1'.format(i_block), temb=temb) 125 | h = resnet_block(h, temb=temb, name='block_2', dropout=dropout) 126 | 127 | # Upsampling 128 | for i_level in reversed(range(num_resolutions)): 129 | with tf.variable_scope('up_{}'.format(i_level)): 130 | # Residual blocks for this resolution 131 | for i_block in range(num_res_blocks + 1): 132 | h = resnet_block(tf.concat([h, hs.pop()], axis=-1), name='block_{}'.format(i_block), 133 | temb=temb, out_ch=ch * ch_mult[i_level], dropout=dropout) 134 | if h.shape[1] in attn_resolutions: 135 | h = attn_block(h, name='attn_{}'.format(i_block), temb=temb) 136 | # Upsample 137 | if i_level != 0: 138 | h = upsample(h, name='upsample', with_conv=resamp_with_conv) 139 | assert not hs 140 | 141 | # End 142 | h = nonlinearity(normalize(h, temb=temb, name='norm_out')) 143 | h = nn.conv2d(h, name='conv_out', num_units=out_ch, init_scale=0.) 144 | assert h.shape == x.shape[:3] + [out_ch] 145 | return h 146 | -------------------------------------------------------------------------------- /diffusion_tf/nn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import string 3 | 4 | import tensorflow.compat.v1 as tf 5 | 6 | # ===== Neural network building defaults ===== 7 | DEFAULT_DTYPE = tf.float32 8 | 9 | 10 | def default_init(scale): 11 | return tf.initializers.variance_scaling(scale=1e-10 if scale == 0 else scale, mode='fan_avg', distribution='uniform') 12 | 13 | 14 | # ===== Utilities ===== 15 | 16 | def _wrapped_print(x, *args, **kwargs): 17 | print_op = tf.print(*args, **kwargs) 18 | with tf.control_dependencies([print_op]): 19 | return tf.identity(x) 20 | 21 | 22 | def debug_print(x, name): 23 | return _wrapped_print(x, name, tf.reduce_mean(x), tf.math.reduce_std(x), tf.reduce_min(x), tf.reduce_max(x)) 24 | 25 | 26 | def flatten(x): 27 | return tf.reshape(x, [int(x.shape[0]), -1]) 28 | 29 | 30 | def sumflat(x): 31 | return tf.reduce_sum(x, axis=list(range(1, len(x.shape)))) 32 | 33 | 34 | def meanflat(x): 35 | return tf.reduce_mean(x, axis=list(range(1, len(x.shape)))) 36 | 37 | 38 | # ===== Neural network layers ===== 39 | 40 | def _einsum(a, b, c, x, y): 41 | einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) 42 | return tf.einsum(einsum_str, x, y) 43 | 44 | 45 | def contract_inner(x, y): 46 | """tensordot(x, y, 1).""" 47 | x_chars = list(string.ascii_lowercase[:len(x.shape)]) 48 | y_chars = list(string.ascii_uppercase[:len(y.shape)]) 49 | assert len(x_chars) == len(x.shape) and len(y_chars) == len(y.shape) 50 | y_chars[0] = x_chars[-1] # first axis of y and last of x get summed 51 | out_chars = x_chars[:-1] + y_chars[1:] 52 | return _einsum(x_chars, y_chars, out_chars, x, y) 53 | 54 | 55 | def nin(x, *, name, num_units, init_scale=1.): 56 | with tf.variable_scope(name): 57 | in_dim = int(x.shape[-1]) 58 | W = tf.get_variable('W', shape=[in_dim, num_units], initializer=default_init(scale=init_scale), dtype=DEFAULT_DTYPE) 59 | b = tf.get_variable('b', shape=[num_units], initializer=tf.constant_initializer(0.), dtype=DEFAULT_DTYPE) 60 | y = contract_inner(x, W) + b 61 | assert y.shape == x.shape[:-1] + [num_units] 62 | return y 63 | 64 | 65 | def dense(x, *, name, num_units, init_scale=1., bias=True): 66 | with tf.variable_scope(name): 67 | _, in_dim = x.shape 68 | W = tf.get_variable('W', shape=[in_dim, num_units], initializer=default_init(scale=init_scale), dtype=DEFAULT_DTYPE) 69 | z = tf.matmul(x, W) 70 | if not bias: 71 | return z 72 | b = tf.get_variable('b', shape=[num_units], initializer=tf.constant_initializer(0.), dtype=DEFAULT_DTYPE) 73 | return z + b 74 | 75 | 76 | def conv2d(x, *, name, num_units, filter_size=(3, 3), stride=1, dilation=None, pad='SAME', init_scale=1., bias=True): 77 | with tf.variable_scope(name): 78 | assert x.shape.ndims == 4 79 | if isinstance(filter_size, int): 80 | filter_size = (filter_size, filter_size) 81 | W = tf.get_variable('W', shape=[*filter_size, int(x.shape[-1]), num_units], 82 | initializer=default_init(scale=init_scale), dtype=DEFAULT_DTYPE) 83 | z = tf.nn.conv2d(x, W, strides=stride, padding=pad, dilations=dilation) 84 | if not bias: 85 | return z 86 | b = tf.get_variable('b', shape=[num_units], initializer=tf.constant_initializer(0.), dtype=DEFAULT_DTYPE) 87 | return z + b 88 | 89 | 90 | def get_timestep_embedding(timesteps, embedding_dim: int): 91 | """ 92 | From Fairseq. 93 | Build sinusoidal embeddings. 94 | This matches the implementation in tensor2tensor, but differs slightly 95 | from the description in Section 3.5 of "Attention Is All You Need". 96 | """ 97 | assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 98 | 99 | half_dim = embedding_dim // 2 100 | emb = math.log(10000) / (half_dim - 1) 101 | emb = tf.exp(tf.range(half_dim, dtype=DEFAULT_DTYPE) * -emb) 102 | # emb = tf.range(num_embeddings, dtype=DEFAULT_DTYPE)[:, None] * emb[None, :] 103 | emb = tf.cast(timesteps, dtype=DEFAULT_DTYPE)[:, None] * emb[None, :] 104 | emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=1) 105 | if embedding_dim % 2 == 1: # zero pad 106 | # emb = tf.concat([emb, tf.zeros([num_embeddings, 1])], axis=1) 107 | emb = tf.pad(emb, [[0, 0], [0, 1]]) 108 | assert emb.shape == [timesteps.shape[0], embedding_dim] 109 | return emb 110 | -------------------------------------------------------------------------------- /diffusion_tf/tpu_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hojonathanho/diffusion/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/tpu_utils/__init__.py -------------------------------------------------------------------------------- /diffusion_tf/tpu_utils/classifier_metrics_numpy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Direct NumPy port of tfgan.eval.classifier_metrics 3 | """ 4 | 5 | import numpy as np 6 | import scipy.special 7 | 8 | 9 | def log_softmax(x, axis): 10 | return x - scipy.special.logsumexp(x, axis=axis, keepdims=True) 11 | 12 | 13 | def kl_divergence(p, p_logits, q): 14 | assert len(p.shape) == len(p_logits.shape) == 2 15 | assert len(q.shape) == 1 16 | return np.sum(p * (log_softmax(p_logits, axis=1) - np.log(q)[None, :]), axis=1) 17 | 18 | 19 | def _symmetric_matrix_square_root(mat, eps=1e-10): 20 | """Compute square root of a symmetric matrix. 21 | 22 | Note that this is different from an elementwise square root. We want to 23 | compute M' where M' = sqrt(mat) such that M' * M' = mat. 24 | 25 | Also note that this method **only** works for symmetric matrices. 26 | 27 | Args: 28 | mat: Matrix to take the square root of. 29 | eps: Small epsilon such that any element less than eps will not be square 30 | rooted to guard against numerical instability. 31 | 32 | Returns: 33 | Matrix square root of mat. 34 | """ 35 | u, s, vt = np.linalg.svd(mat) 36 | # sqrt is unstable around 0, just use 0 in such case 37 | si = np.where(s < eps, s, np.sqrt(s)) 38 | return u.dot(np.diag(si)).dot(vt) 39 | 40 | 41 | def trace_sqrt_product(sigma, sigma_v): 42 | """Find the trace of the positive sqrt of product of covariance matrices. 43 | 44 | '_symmetric_matrix_square_root' only works for symmetric matrices, so we 45 | cannot just take _symmetric_matrix_square_root(sigma * sigma_v). 46 | ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily). 47 | 48 | Let sigma = A A so A = sqrt(sigma), and sigma_v = B B. 49 | We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B)) 50 | Note the following properties: 51 | (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1) 52 | => eigenvalues(A A B B) = eigenvalues (A B B A) 53 | (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2)) 54 | => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A)) 55 | (iii) forall M: trace(M) = sum(eigenvalues(M)) 56 | => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v))) 57 | = sum(sqrt(eigenvalues(A B B A))) 58 | = sum(eigenvalues(sqrt(A B B A))) 59 | = trace(sqrt(A B B A)) 60 | = trace(sqrt(A sigma_v A)) 61 | A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can** 62 | use the _symmetric_matrix_square_root function to find the roots of these 63 | matrices. 64 | 65 | Args: 66 | sigma: a square, symmetric, real, positive semi-definite covariance matrix 67 | sigma_v: same as sigma 68 | 69 | Returns: 70 | The trace of the positive square root of sigma*sigma_v 71 | """ 72 | 73 | # Note sqrt_sigma is called "A" in the proof above 74 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 75 | 76 | # This is sqrt(A sigma_v A) above 77 | sqrt_a_sigmav_a = sqrt_sigma.dot(sigma_v.dot(sqrt_sigma)) 78 | 79 | return np.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 80 | 81 | 82 | def classifier_score_from_logits(logits): 83 | """Classifier score for evaluating a generative model from logits. 84 | 85 | This method computes the classifier score for a set of logits. This can be 86 | used independently of the classifier_score() method, especially in the case 87 | of using large batches during evaluation where we would like precompute all 88 | of the logits before computing the classifier score. 89 | 90 | This technique is described in detail in https://arxiv.org/abs/1606.03498. In 91 | summary, this function calculates: 92 | 93 | exp( E[ KL(p(y|x) || p(y)) ] ) 94 | 95 | which captures how different the network's classification prediction is from 96 | the prior distribution over classes. 97 | 98 | Args: 99 | logits: Precomputed 2D tensor of logits that will be used to compute the 100 | classifier score. 101 | 102 | Returns: 103 | The classifier score. A floating-point scalar of the same type as the output 104 | of `logits`. 105 | """ 106 | assert len(logits.shape) == 2 107 | 108 | # Use maximum precision for best results. 109 | logits_dtype = logits.dtype 110 | if logits_dtype != np.float64: 111 | logits = logits.astype(np.float64) 112 | 113 | p = scipy.special.softmax(logits, axis=1) 114 | q = np.mean(p, axis=0) 115 | kl = kl_divergence(p, logits, q) 116 | assert len(kl.shape) == 1 117 | log_score = np.mean(kl) 118 | final_score = np.exp(log_score) 119 | 120 | if logits_dtype != np.float64: 121 | final_score = final_score.astype(logits_dtype) 122 | 123 | return final_score 124 | 125 | 126 | def frechet_classifier_distance_from_activations(real_activations, 127 | generated_activations): 128 | """Classifier distance for evaluating a generative model. 129 | 130 | This methods computes the Frechet classifier distance from activations of 131 | real images and generated images. This can be used independently of the 132 | frechet_classifier_distance() method, especially in the case of using large 133 | batches during evaluation where we would like precompute all of the 134 | activations before computing the classifier distance. 135 | 136 | This technique is described in detail in https://arxiv.org/abs/1706.08500. 137 | Given two Gaussian distribution with means m and m_w and covariance matrices 138 | C and C_w, this function calculates 139 | 140 | |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) 141 | 142 | which captures how different the distributions of real images and generated 143 | images (or more accurately, their visual features) are. Note that unlike the 144 | Inception score, this is a true distance and utilizes information about real 145 | world images. 146 | 147 | Note that when computed using sample means and sample covariance matrices, 148 | Frechet distance is biased. It is more biased for small sample sizes. (e.g. 149 | even if the two distributions are the same, for a small sample size, the 150 | expected Frechet distance is large). It is important to use the same 151 | sample size to compute frechet classifier distance when comparing two 152 | generative models. 153 | 154 | Args: 155 | real_activations: 2D Tensor containing activations of real data. Shape is 156 | [batch_size, activation_size]. 157 | generated_activations: 2D Tensor containing activations of generated data. 158 | Shape is [batch_size, activation_size]. 159 | 160 | Returns: 161 | The Frechet Inception distance. A floating-point scalar of the same type 162 | as the output of the activations. 163 | 164 | """ 165 | assert len(real_activations.shape) == len(generated_activations.shape) == 2 166 | 167 | activations_dtype = real_activations.dtype 168 | if activations_dtype != np.float64: 169 | real_activations = real_activations.astype(np.float64) 170 | generated_activations = generated_activations.astype(np.float64) 171 | 172 | # Compute mean and covariance matrices of activations. 173 | m = np.mean(real_activations, 0) 174 | m_w = np.mean(generated_activations, 0) 175 | num_examples_real = float(real_activations.shape[0]) 176 | num_examples_generated = float(generated_activations.shape[0]) 177 | 178 | # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T 179 | real_centered = real_activations - m 180 | sigma = real_centered.T.dot(real_centered) / (num_examples_real - 1) 181 | 182 | gen_centered = generated_activations - m_w 183 | sigma_w = gen_centered.T.dot(gen_centered) / (num_examples_generated - 1) 184 | 185 | # Find the Tr(sqrt(sigma sigma_w)) component of FID 186 | sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) 187 | 188 | # Compute the two components of FID. 189 | 190 | # First the covariance component. 191 | # Here, note that trace(A + B) = trace(A) + trace(B) 192 | trace = np.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component 193 | 194 | # Next the distance between means. 195 | mean = np.sum(np.square(m - m_w)) # Equivalent to L2 but more stable. 196 | fid = trace + mean 197 | if activations_dtype != np.float64: 198 | fid = fid.astype(activations_dtype) 199 | 200 | return fid 201 | 202 | 203 | def test_all(): 204 | """ 205 | Test against tfgan.eval.classifier_metrics 206 | """ 207 | 208 | import tensorflow.compat.v1 as tf 209 | import tensorflow_gan as tfgan 210 | 211 | rand = np.random.RandomState(1234) 212 | logits = rand.randn(64, 1008) 213 | asdf1, asdf2 = rand.randn(64, 2048), rand.rand(256, 2048) 214 | with tf.Session() as sess: 215 | assert np.allclose( 216 | sess.run(tfgan.eval.classifier_score_from_logits(tf.convert_to_tensor(logits))), 217 | classifier_score_from_logits(logits)) 218 | assert np.allclose( 219 | sess.run(tfgan.eval.frechet_classifier_distance_from_activations( 220 | tf.convert_to_tensor(asdf1), tf.convert_to_tensor(asdf2))), 221 | frechet_classifier_distance_from_activations(asdf1, asdf2)) 222 | print('all ok') 223 | 224 | 225 | if __name__ == '__main__': 226 | test_all() 227 | -------------------------------------------------------------------------------- /diffusion_tf/tpu_utils/datasets.py: -------------------------------------------------------------------------------- 1 | """Dataset loading utilities. 2 | 3 | All images are scaled to [0, 255] instead of [0, 1] 4 | """ 5 | 6 | import functools 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | import tensorflow_datasets as tfds 11 | 12 | 13 | def pack(image, label): 14 | label = tf.cast(label, tf.int32) 15 | return {'image': image, 'label': label} 16 | 17 | 18 | class SimpleDataset: 19 | DATASET_NAMES = ('cifar10', 'celebahq256') 20 | 21 | def __init__(self, name, tfds_data_dir): 22 | self._name = name 23 | self._data_dir = tfds_data_dir 24 | self._img_size = {'cifar10': 32, 'celebahq256': 256}[name] 25 | self._img_shape = [self._img_size, self._img_size, 3] 26 | self._tfds_name = { 27 | 'cifar10': 'cifar10:3.0.0', 28 | 'celebahq256': 'celeb_a_hq/256:2.0.0', 29 | }[name] 30 | self.num_train_examples, self.num_eval_examples = { 31 | 'cifar10': (50000, 10000), 32 | 'celebahq256': (30000, 0), 33 | }[name] 34 | self.num_classes = 1 # unconditional 35 | self.eval_split_name = { 36 | 'cifar10': 'test', 37 | 'celebahq256': None, 38 | }[name] 39 | 40 | @property 41 | def image_shape(self): 42 | """Returns a tuple with the image shape.""" 43 | return tuple(self._img_shape) 44 | 45 | def _proc_and_batch(self, ds, batch_size): 46 | def _process_data(x_): 47 | img_ = tf.cast(x_['image'], tf.int32) 48 | img_.set_shape(self._img_shape) 49 | return pack(image=img_, label=tf.constant(0, dtype=tf.int32)) 50 | 51 | ds = ds.map(_process_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) 52 | ds = ds.batch(batch_size, drop_remainder=True) 53 | ds = ds.prefetch(tf.data.experimental.AUTOTUNE) 54 | return ds 55 | 56 | def train_input_fn(self, params): 57 | ds = tfds.load(self._tfds_name, split='train', shuffle_files=True, data_dir=self._data_dir) 58 | ds = ds.repeat() 59 | ds = ds.shuffle(50000) 60 | return self._proc_and_batch(ds, params['batch_size']) 61 | 62 | def train_one_pass_input_fn(self, params): 63 | ds = tfds.load(self._tfds_name, split='train', shuffle_files=False, data_dir=self._data_dir) 64 | return self._proc_and_batch(ds, params['batch_size']) 65 | 66 | def eval_input_fn(self, params): 67 | if self.eval_split_name is None: 68 | return None 69 | ds = tfds.load(self._tfds_name, split=self.eval_split_name, shuffle_files=False, data_dir=self._data_dir) 70 | return self._proc_and_batch(ds, params['batch_size']) 71 | 72 | 73 | class LsunDataset: 74 | def __init__(self, 75 | tfr_file, # Path to tfrecord file. 76 | resolution=256, # Dataset resolution. 77 | max_images=None, # Maximum number of images to use, None = use all images. 78 | shuffle_mb=4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling. 79 | buffer_mb=256, # Read buffer size (megabytes). 80 | ): 81 | """Adapted from https://github.com/NVlabs/stylegan2/blob/master/training/dataset.py. 82 | Use StyleGAN2 dataset_tool.py to generate tf record files. 83 | """ 84 | self.tfr_file = tfr_file 85 | self.dtype = 'int32' 86 | self.max_images = max_images 87 | self.buffer_mb = buffer_mb 88 | self.num_classes = 1 # unconditional 89 | 90 | # Determine shape and resolution. 91 | self.resolution = resolution 92 | self.resolution_log2 = int(np.log2(self.resolution)) 93 | self.image_shape = [self.resolution, self.resolution, 3] 94 | 95 | def _train_input_fn(self, params, one_pass: bool): 96 | # Build TF expressions. 97 | dset = tf.data.TFRecordDataset(self.tfr_file, compression_type='', buffer_size=self.buffer_mb<<20) 98 | if self.max_images is not None: 99 | dset = dset.take(self.max_images) 100 | if not one_pass: 101 | dset = dset.repeat() 102 | dset = dset.map(self._parse_tfrecord_tf, num_parallel_calls=tf.data.experimental.AUTOTUNE) 103 | # Shuffle and prefetch 104 | dset = dset.shuffle(50000) 105 | dset = dset.batch(params['batch_size'], drop_remainder=True) 106 | dset = dset.prefetch(tf.data.experimental.AUTOTUNE) 107 | return dset 108 | 109 | def train_input_fn(self, params): 110 | return self._train_input_fn(params, one_pass=False) 111 | 112 | def train_one_pass_input_fn(self, params): 113 | return self._train_input_fn(params, one_pass=True) 114 | 115 | def eval_input_fn(self, params): 116 | return None 117 | 118 | # Parse individual image from a tfrecords file into TensorFlow expression. 119 | def _parse_tfrecord_tf(self, record): 120 | features = tf.parse_single_example(record, features={ 121 | 'shape': tf.FixedLenFeature([3], tf.int64), 122 | 'data': tf.FixedLenFeature([], tf.string)}) 123 | data = tf.decode_raw(features['data'], tf.uint8) 124 | data = tf.cast(data, tf.int32) 125 | data = tf.reshape(data, features['shape']) 126 | data = tf.transpose(data, [1, 2, 0]) # CHW -> HWC 127 | data.set_shape(self.image_shape) 128 | return pack(image=data, label=tf.constant(0, dtype=tf.int32)) 129 | 130 | 131 | DATASETS = { 132 | "cifar10": functools.partial(SimpleDataset, name="cifar10"), 133 | "celebahq256": functools.partial(SimpleDataset, name="celebahq256"), 134 | "lsun": LsunDataset, 135 | } 136 | 137 | 138 | def get_dataset(name, *, tfds_data_dir=None, tfr_file=None, seed=547): 139 | """Instantiates a data set and sets the random seed.""" 140 | if name not in DATASETS: 141 | raise ValueError("Dataset %s is not available." % name) 142 | kwargs = {} 143 | 144 | if name == 'lsun': 145 | # LsunDataset takes the path to the tf record, not a directory 146 | assert tfr_file is not None 147 | kwargs['tfr_file'] = tfr_file 148 | else: 149 | kwargs['tfds_data_dir'] = tfds_data_dir 150 | 151 | if name not in ['lsun', *SimpleDataset.DATASET_NAMES]: 152 | kwargs['seed'] = seed 153 | 154 | return DATASETS[name](**kwargs) 155 | -------------------------------------------------------------------------------- /diffusion_tf/tpu_utils/simple_eval_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | "One-shot" evaluation worker (i.e. run something once, not in a loop over the course of training) 3 | 4 | - Computes log prob 5 | - Generates samples progressively 6 | """ 7 | 8 | import os 9 | import pickle 10 | import time 11 | 12 | import numpy as np 13 | import tensorflow.compat.v1 as tf 14 | from tqdm import trange 15 | 16 | from .tpu_utils import Model, make_ema, distributed, normalize_data 17 | from .. import utils 18 | 19 | 20 | def _make_ds_iterator(strategy, ds): 21 | return strategy.experimental_distribute_dataset(ds).make_initializable_iterator() 22 | 23 | 24 | class SimpleEvalWorker: 25 | def __init__(self, tpu_name, model_constructor, total_bs, dataset): 26 | tf.logging.set_verbosity(tf.logging.INFO) 27 | 28 | self.resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name) 29 | tf.tpu.experimental.initialize_tpu_system(self.resolver) 30 | self.strategy = tf.distribute.experimental.TPUStrategy(self.resolver) 31 | 32 | self.num_cores = self.strategy.num_replicas_in_sync 33 | assert total_bs % self.num_cores == 0 34 | self.total_bs = total_bs 35 | self.local_bs = total_bs // self.num_cores 36 | print('num cores: {}'.format(self.num_cores)) 37 | print('total batch size: {}'.format(self.total_bs)) 38 | print('local batch size: {}'.format(self.local_bs)) 39 | self.dataset = dataset 40 | 41 | # TPU context 42 | with self.strategy.scope(): 43 | # Dataset iterators 44 | self.train_ds_iterator = _make_ds_iterator( 45 | self.strategy, dataset.train_one_pass_input_fn(params={'batch_size': total_bs})) 46 | self.eval_ds_iterator = _make_ds_iterator( 47 | self.strategy, dataset.eval_input_fn(params={'batch_size': total_bs})) 48 | 49 | img_batch_shape = self.train_ds_iterator.output_shapes['image'].as_list() 50 | assert img_batch_shape[0] == self.local_bs 51 | 52 | # Model 53 | self.model = model_constructor() 54 | assert isinstance(self.model, Model) 55 | 56 | # Eval/samples graphs 57 | print('===== SAMPLES =====') 58 | self.samples_outputs = self._make_progressive_sampling_graph(img_shape=img_batch_shape[1:]) 59 | 60 | # Model with EMA parameters 61 | print('===== EMA =====') 62 | self.global_step = tf.train.get_or_create_global_step() 63 | ema, _ = make_ema(global_step=self.global_step, ema_decay=1e-10, trainable_variables=tf.trainable_variables()) 64 | 65 | # EMA versions of the above 66 | with utils.ema_scope(ema): 67 | print('===== EMA SAMPLES =====') 68 | self.ema_samples_outputs = self._make_progressive_sampling_graph(img_shape=img_batch_shape[1:]) 69 | print('===== EMA BPD =====') 70 | self.bpd_train = self._make_bpd_graph(self.train_ds_iterator) 71 | self.bpd_eval = self._make_bpd_graph(self.eval_ds_iterator) 72 | 73 | def _make_progressive_sampling_graph(self, img_shape): 74 | return distributed( 75 | lambda x_: self.model.progressive_samples_fn( 76 | x_, tf.random_uniform([self.local_bs], 0, self.dataset.num_classes, dtype=tf.int32)), 77 | args=(tf.fill([self.local_bs, *img_shape], value=np.nan),), 78 | reduction='concat', strategy=self.strategy) 79 | 80 | def _make_bpd_graph(self, ds_iterator): 81 | return distributed( 82 | lambda x_: self.model.bpd_fn(normalize_data(tf.cast(x_['image'], tf.float32)), x_['label']), 83 | args=(next(ds_iterator),), reduction='concat', strategy=self.strategy) 84 | 85 | def init_all_iterators(self, sess): 86 | sess.run([self.train_ds_iterator.initializer, self.eval_ds_iterator.initializer]) 87 | 88 | def dump_progressive_samples(self, sess, curr_step, samples_dir, ema: bool, num_samples=50000, batches_per_flush=20): 89 | if not tf.gfile.IsDirectory(samples_dir): 90 | tf.gfile.MakeDirs(samples_dir) 91 | 92 | batch_cache, num_flushes_so_far = [], 0 93 | 94 | def _write_batch_cache(): 95 | nonlocal batch_cache, num_flushes_so_far 96 | # concat all the batches 97 | assert all(set(b.keys()) == set(self.samples_outputs.keys()) for b in batch_cache) 98 | concatenated = { 99 | k: np.concatenate([b[k].astype(np.float32) for b in batch_cache], axis=0) 100 | for k in self.samples_outputs.keys() 101 | } 102 | assert len(set(len(v) for v in concatenated.values())) == 1 103 | # write the file 104 | filename = os.path.join( 105 | samples_dir, 'samples_xstartpred_ema{}_step{:09d}_part{:06d}.pkl'.format( 106 | int(ema), curr_step, num_flushes_so_far)) 107 | assert not tf.io.gfile.exists(filename), 'samples file already exists: {}'.format(filename) 108 | print('writing samples batch to:', filename) 109 | with tf.io.gfile.GFile(filename, 'wb') as f: 110 | f.write(pickle.dumps(concatenated, protocol=pickle.HIGHEST_PROTOCOL)) 111 | print('done writing samples batch') 112 | num_flushes_so_far += 1 113 | batch_cache = [] 114 | 115 | num_gen_batches = int(np.ceil(num_samples / self.total_bs)) 116 | print('generating {} samples ({} batches)...'.format(num_samples, num_gen_batches)) 117 | self.init_all_iterators(sess) 118 | for i_batch in trange(num_gen_batches, desc='sampling'): 119 | batch_cache.append(sess.run(self.ema_samples_outputs if ema else self.samples_outputs)) 120 | if i_batch != 0 and i_batch % batches_per_flush == 0: 121 | _write_batch_cache() 122 | if batch_cache: 123 | _write_batch_cache() 124 | 125 | def dump_bpd(self, sess, curr_step, output_dir, train: bool, ema: bool): 126 | assert ema 127 | if not tf.gfile.IsDirectory(output_dir): 128 | tf.gfile.MakeDirs(output_dir) 129 | filename = os.path.join( 130 | output_dir, 'bpd_{}_ema{}_step{:09d}.pkl'.format('train' if train else 'eval', int(ema), curr_step)) 131 | assert not tf.io.gfile.exists(filename), 'bpd file already exists: {}'.format(filename) 132 | print('will write bpd data to:', filename) 133 | 134 | batches = [] 135 | tf_op = self.bpd_train if train else self.bpd_eval 136 | self.init_all_iterators(sess) 137 | last_print_time = time.time() 138 | while True: 139 | try: 140 | batches.append(sess.run(tf_op)) 141 | if time.time() - last_print_time > 30: 142 | print('num batches so far: {} ({:.2f} sec)'.format(len(batches), time.time() - last_print_time)) 143 | last_print_time = time.time() 144 | except tf.errors.OutOfRangeError: 145 | break 146 | 147 | assert all(set(b.keys()) == set(tf_op.keys()) for b in batches) 148 | concatenated = { 149 | k: np.concatenate([b[k].astype(np.float32) for b in batches], axis=0) 150 | for k in tf_op.keys() 151 | } 152 | num_samples = len(list(concatenated.values())[0]) 153 | assert all(len(v) == num_samples for v in concatenated.values()) 154 | print('evaluated on {} examples'.format(num_samples)) 155 | 156 | print('writing results to:', filename) 157 | with tf.io.gfile.GFile(filename, 'wb') as f: 158 | f.write(pickle.dumps(concatenated, protocol=pickle.HIGHEST_PROTOCOL)) 159 | print('done writing results') 160 | 161 | def run(self, mode: str, logdir: str, load_ckpt: str): 162 | """ 163 | Main entry point. 164 | 165 | :param mode: what to do 166 | :param logdir: model directory for the checkpoint to load 167 | :param load_ckpt: the name of the checkpoint, e.g. "model.ckpt-1000000" 168 | """ 169 | 170 | # Input checkpoint: load_ckpt should be of the form: model.ckpt-1000000 171 | ckpt = os.path.join(logdir, load_ckpt) 172 | assert tf.io.gfile.exists(ckpt + '.index') 173 | 174 | # Output dir 175 | output_dir = os.path.join(logdir, 'simple_eval') 176 | print('Writing output to: {}'.format(output_dir)) 177 | 178 | # Make the session 179 | config = tf.ConfigProto() 180 | config.allow_soft_placement = True 181 | cluster_spec = self.resolver.cluster_spec() 182 | if cluster_spec: 183 | config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 184 | print('making session...') 185 | with tf.Session(target=self.resolver.master(), config=config) as sess: 186 | 187 | print('initializing global variables') 188 | sess.run(tf.global_variables_initializer()) 189 | 190 | # Checkpoint loading 191 | print('making saver') 192 | saver = tf.train.Saver() 193 | saver.restore(sess, ckpt) 194 | global_step_val = sess.run(self.global_step) 195 | print('restored global step: {}'.format(global_step_val)) 196 | 197 | if mode in ['bpd_train', 'bpd_eval']: 198 | self.dump_bpd( 199 | sess, curr_step=global_step_val, output_dir=os.path.join(output_dir, 'bpd'), ema=True, 200 | train=mode == 'bpd_train') 201 | elif mode == 'progressive_samples': 202 | return self.dump_progressive_samples( 203 | sess, curr_step=global_step_val, samples_dir=os.path.join(output_dir, 'progressive_samples'), ema=True) 204 | else: 205 | raise NotImplementedError(mode) -------------------------------------------------------------------------------- /diffusion_tf/tpu_utils/tpu_summaries.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | 7 | from absl import logging 8 | import tensorflow as tf 9 | 10 | 11 | summary = tf.contrib.summary # TensorFlow Summary API v2. 12 | 13 | 14 | TpuSummaryEntry = collections.namedtuple( 15 | "TpuSummaryEntry", "summary_fn name tensor reduce_fn") 16 | 17 | 18 | class TpuSummaries(object): 19 | """Class to simplify TF summaries on TPU. 20 | 21 | An instance of the class provides simple methods for writing summaries in the 22 | similar way to tf.summary. The difference is that each summary entry must 23 | provide a reduction function that is used to reduce the summary values from 24 | all the TPU cores. 25 | """ 26 | 27 | def __init__(self, log_dir, save_summary_steps=250): 28 | self._log_dir = log_dir 29 | self._entries = [] 30 | # While False no summary entries will be added. On TPU we unroll the graph 31 | # and don't want to add multiple summaries per step. 32 | self.record = True 33 | self._save_summary_steps = save_summary_steps 34 | 35 | def image(self, name, tensor, reduce_fn): 36 | """Add a summary for images. Tensor must be of 4-D tensor.""" 37 | if not self.record: 38 | return 39 | self._entries.append( 40 | TpuSummaryEntry(summary.image, name, tensor, reduce_fn)) 41 | 42 | def scalar(self, name, tensor, reduce_fn=tf.math.reduce_mean): 43 | """Add a summary for a scalar tensor.""" 44 | if not self.record: 45 | return 46 | tensor = tf.convert_to_tensor(tensor) 47 | if tensor.shape.ndims == 0: 48 | tensor = tf.expand_dims(tensor, 0) 49 | self._entries.append( 50 | TpuSummaryEntry(summary.scalar, name, tensor, reduce_fn)) 51 | 52 | def get_host_call(self): 53 | """Returns the tuple (host_call_fn, host_call_args) for TPUEstimatorSpec.""" 54 | # All host_call_args must be tensors with batch dimension. 55 | # All tensors are streamed to the host machine (mind the band width). 56 | global_step = tf.train.get_or_create_global_step() 57 | host_call_args = [tf.expand_dims(global_step, 0)] 58 | host_call_args.extend([e.tensor for e in self._entries]) 59 | logging.info("host_call_args: %s", host_call_args) 60 | return (self._host_call_fn, host_call_args) 61 | 62 | def _host_call_fn(self, step, *args): 63 | """Function that will run on the host machine.""" 64 | # Host call receives values from all tensor cores (concatenate on the 65 | # batch dimension). Step is the same for all cores. 66 | step = step[0] 67 | logging.info("host_call_fn: args=%s", args) 68 | with summary.create_file_writer(self._log_dir).as_default(): 69 | with summary.record_summaries_every_n_global_steps( 70 | self._save_summary_steps, step): 71 | for i, e in enumerate(self._entries): 72 | value = e.reduce_fn(args[i]) 73 | e.summary_fn(e.name, value, step=step) 74 | return summary.all_summary_ops() 75 | -------------------------------------------------------------------------------- /diffusion_tf/tpu_utils/tpu_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import time 5 | from datetime import datetime 6 | 7 | import numpy as np 8 | import tensorflow.compat.v1 as tf 9 | import tensorflow_gan as tfgan 10 | from tensorflow.contrib.tpu.python.ops import tpu_ops 11 | from tensorflow.python.tpu import tpu_function 12 | from tqdm import trange 13 | 14 | from . import classifier_metrics_numpy 15 | from .tpu_summaries import TpuSummaries 16 | from .. import utils 17 | 18 | 19 | # ========== TPU utilities ========== 20 | 21 | def num_tpu_replicas(): 22 | return tpu_function.get_tpu_context().number_of_shards 23 | 24 | 25 | def get_tpu_replica_id(): 26 | with tf.control_dependencies(None): 27 | return tpu_ops.tpu_replicated_input(list(range(num_tpu_replicas()))) 28 | 29 | 30 | def distributed(fn, *, args, reduction, strategy): 31 | """ 32 | Sharded computation followed by concat/mean for TPUStrategy. 33 | """ 34 | out = strategy.experimental_run_v2(fn, args=args) 35 | if reduction == 'mean': 36 | return tf.nest.map_structure(lambda x: tf.reduce_mean(strategy.reduce('mean', x)), out) 37 | elif reduction == 'concat': 38 | return tf.nest.map_structure(lambda x: tf.concat(strategy.experimental_local_results(x), axis=0), out) 39 | else: 40 | raise NotImplementedError(reduction) 41 | 42 | 43 | # ========== Inception utilities ========== 44 | 45 | INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05_v4.tar.gz' 46 | INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score_tpu.pb' 47 | INCEPTION_GRAPH_DEF = tfgan.eval.get_graph_def_from_url_tarball( 48 | INCEPTION_URL, INCEPTION_FROZEN_GRAPH, os.path.basename(INCEPTION_URL)) 49 | 50 | 51 | def run_inception(images): 52 | assert images.dtype == tf.float32 # images should be in [-1, 1] 53 | out = tfgan.eval.run_inception( 54 | images, 55 | graph_def=INCEPTION_GRAPH_DEF, 56 | default_graph_def_fn=None, 57 | output_tensor=['pool_3:0', 'logits:0'] 58 | ) 59 | return {'pool_3': out[0], 'logits': out[1]} 60 | 61 | 62 | # ========== Training ========== 63 | 64 | normalize_data = lambda x_: x_ / 127.5 - 1. 65 | unnormalize_data = lambda x_: (x_ + 1.) * 127.5 66 | 67 | 68 | class Model: 69 | # All images (inputs and outputs) should be normalized to [-1, 1] 70 | def train_fn(self, x, y) -> dict: 71 | raise NotImplementedError 72 | 73 | def samples_fn(self, dummy_x, y) -> dict: 74 | raise NotImplementedError 75 | 76 | def sample_and_run_inception(self, dummy_x, y, clip_samples=True): 77 | samples_dict = self.samples_fn(dummy_x, y) 78 | assert isinstance(samples_dict, dict) 79 | return { 80 | k: run_inception(tfgan.eval.preprocess_image(unnormalize_data( 81 | tf.clip_by_value(v, -1., 1.) if clip_samples else v))) 82 | for (k, v) in samples_dict.items() 83 | } 84 | 85 | def bpd_fn(self, x, y) -> dict: 86 | return None 87 | 88 | 89 | def make_ema(global_step, ema_decay, trainable_variables): 90 | ema = tf.train.ExponentialMovingAverage(decay=tf.where(tf.less(global_step, 1), 1e-10, ema_decay)) 91 | ema_op = ema.apply(trainable_variables) 92 | return ema, ema_op 93 | 94 | 95 | def load_train_kwargs(model_dir): 96 | with tf.io.gfile.GFile(os.path.join(model_dir, 'kwargs.json'), 'r') as f: 97 | kwargs = json.loads(f.read()) 98 | return kwargs 99 | 100 | 101 | def run_training( 102 | *, model_constructor, train_input_fn, total_bs, 103 | optimizer, lr, warmup, grad_clip, ema_decay=0.9999, 104 | tpu=None, zone=None, project=None, 105 | log_dir, exp_name, dump_kwargs=None, 106 | date_str=None, iterations_per_loop=1000, keep_checkpoint_max=2, max_steps=int(1e10), 107 | warm_start_from=None 108 | ): 109 | tf.logging.set_verbosity(tf.logging.INFO) 110 | 111 | # Create checkpoint directory 112 | model_dir = os.path.join( 113 | log_dir, 114 | datetime.now().strftime('%Y-%m-%d') if date_str is None else date_str, 115 | exp_name 116 | ) 117 | print('model dir:', model_dir) 118 | if tf.io.gfile.exists(model_dir): 119 | print('model dir already exists: {}'.format(model_dir)) 120 | if input('continue training? [y/n] ') != 'y': 121 | print('aborting') 122 | return 123 | 124 | # Save kwargs in json format 125 | if dump_kwargs is not None: 126 | with tf.io.gfile.GFile(os.path.join(model_dir, 'kwargs.json'), 'w') as f: 127 | f.write(json.dumps(dump_kwargs, indent=2, sort_keys=True) + '\n') 128 | 129 | # model_fn for TPUEstimator 130 | def model_fn(features, params, mode): 131 | local_bs = params['batch_size'] 132 | print('Global batch size: {}, local batch size: {}'.format(total_bs, local_bs)) 133 | assert total_bs == num_tpu_replicas() * local_bs 134 | 135 | assert mode == tf.estimator.ModeKeys.TRAIN, 'only TRAIN mode supported' 136 | assert features['image'].shape[0] == local_bs 137 | assert features['label'].shape == [local_bs] and features['label'].dtype == tf.int32 138 | # assert labels.dtype == features['label'].dtype and labels.shape == features['label'].shape 139 | 140 | del params 141 | 142 | ########## 143 | 144 | # create model 145 | model = model_constructor() 146 | assert isinstance(model, Model) 147 | 148 | # training loss 149 | train_info_dict = model.train_fn(normalize_data(tf.cast(features['image'], tf.float32)), features['label']) 150 | loss = train_info_dict['loss'] 151 | assert loss.shape == [] 152 | 153 | # train op 154 | trainable_variables = tf.trainable_variables() 155 | print('num params: {:,}'.format(sum(int(np.prod(p.shape.as_list())) for p in trainable_variables))) 156 | global_step = tf.train.get_or_create_global_step() 157 | warmed_up_lr = utils.get_warmed_up_lr(max_lr=lr, warmup=warmup, global_step=global_step) 158 | train_op, gnorm = utils.make_optimizer( 159 | loss=loss, 160 | trainable_variables=trainable_variables, 161 | global_step=global_step, 162 | lr=warmed_up_lr, 163 | optimizer=optimizer, 164 | grad_clip=grad_clip / float(num_tpu_replicas()), 165 | tpu=True 166 | ) 167 | 168 | # ema 169 | ema, ema_op = make_ema(global_step=global_step, ema_decay=ema_decay, trainable_variables=trainable_variables) 170 | with tf.control_dependencies([train_op]): 171 | train_op = tf.group(ema_op) 172 | 173 | # summary 174 | tpu_summary = TpuSummaries(model_dir, save_summary_steps=100) 175 | tpu_summary.scalar('train/loss', loss) 176 | tpu_summary.scalar('train/gnorm', gnorm) 177 | tpu_summary.scalar('train/pnorm', utils.rms(trainable_variables)) 178 | tpu_summary.scalar('train/lr', warmed_up_lr) 179 | return tf.estimator.tpu.TPUEstimatorSpec( 180 | mode=mode, host_call=tpu_summary.get_host_call(), loss=loss, train_op=train_op) 181 | 182 | # Set up Estimator and train 183 | print("warm_start_from:", warm_start_from) 184 | estimator = tf.estimator.tpu.TPUEstimator( 185 | model_fn=model_fn, 186 | use_tpu=True, 187 | train_batch_size=total_bs, 188 | eval_batch_size=total_bs, 189 | config=tf.estimator.tpu.RunConfig( 190 | cluster=tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu, zone=zone, project=project), 191 | model_dir=model_dir, 192 | session_config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True), 193 | tpu_config=tf.estimator.tpu.TPUConfig( 194 | iterations_per_loop=iterations_per_loop, 195 | num_shards=None, 196 | per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2 197 | ), 198 | save_checkpoints_secs=1600, # 30 minutes 199 | keep_checkpoint_max=keep_checkpoint_max 200 | ), 201 | warm_start_from=warm_start_from 202 | ) 203 | estimator.train(input_fn=train_input_fn, max_steps=max_steps) 204 | 205 | 206 | # ========== Evaluation / sampling ========== 207 | 208 | 209 | class InceptionFeatures: 210 | """ 211 | Compute and store Inception features for a dataset 212 | """ 213 | 214 | def __init__(self, dataset, strategy, limit_dataset_size=0): 215 | # distributed dataset iterator 216 | if limit_dataset_size > 0: 217 | dataset = dataset.take(limit_dataset_size) 218 | self.ds_iterator = strategy.experimental_distribute_dataset(dataset).make_initializable_iterator() 219 | 220 | # inception network on the dataset 221 | self.inception_real = distributed( 222 | lambda x_: run_inception(tfgan.eval.preprocess_image(x_['image'])), 223 | args=(next(self.ds_iterator),), reduction='concat', strategy=strategy) 224 | 225 | self.cached_inception_real = None # cached inception features 226 | self.real_inception_score = None # saved inception scores for the dataset 227 | 228 | def get(self, sess): 229 | # On the first invocation, compute Inception activations for the eval dataset 230 | if self.cached_inception_real is None: 231 | print('computing inception features on the eval set...') 232 | sess.run(self.ds_iterator.initializer) # reset the eval dataset iterator 233 | inception_real_batches, tstart = [], time.time() 234 | while True: 235 | try: 236 | inception_real_batches.append(sess.run(self.inception_real)) 237 | except tf.errors.OutOfRangeError: 238 | break 239 | self.cached_inception_real = { 240 | feat_key: np.concatenate([batch[feat_key] for batch in inception_real_batches], axis=0).astype(np.float64) 241 | for feat_key in ['pool_3', 'logits'] 242 | } 243 | print('cached eval inception tensors: logits: {}, pool_3: {} (time: {})'.format( 244 | self.cached_inception_real['logits'].shape, self.cached_inception_real['pool_3'].shape, 245 | time.time() - tstart)) 246 | 247 | self.real_inception_score = float( 248 | classifier_metrics_numpy.classifier_score_from_logits(self.cached_inception_real['logits'])) 249 | del self.cached_inception_real['logits'] # save memory 250 | print('real inception score', self.real_inception_score) 251 | 252 | return self.cached_inception_real, self.real_inception_score 253 | 254 | 255 | class EvalWorker: 256 | def __init__(self, tpu_name, model_constructor, total_bs, dataset, inception_bs=8, num_inception_samples=1024, limit_dataset_size=0): 257 | 258 | self.resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name) 259 | tf.tpu.experimental.initialize_tpu_system(self.resolver) 260 | self.strategy = tf.distribute.experimental.TPUStrategy(self.resolver) 261 | 262 | self.num_cores = self.strategy.num_replicas_in_sync 263 | assert total_bs % self.num_cores == 0 264 | self.total_bs = total_bs 265 | self.local_bs = total_bs // self.num_cores 266 | print('num cores: {}'.format(self.num_cores)) 267 | print('total batch size: {}'.format(self.total_bs)) 268 | print('local batch size: {}'.format(self.local_bs)) 269 | self.num_inception_samples = num_inception_samples 270 | assert inception_bs % self.num_cores == 0 271 | self.inception_bs = inception_bs 272 | self.inception_local_bs = inception_bs // self.num_cores 273 | self.dataset = dataset 274 | assert dataset.num_classes == 1, 'not supported' 275 | 276 | # TPU context 277 | with self.strategy.scope(): 278 | # Inception network on real data 279 | print('===== INCEPTION =====') 280 | # Eval dataset iterator (this is the training set without repeat & shuffling) 281 | self.inception_real_train = InceptionFeatures( 282 | dataset=dataset.train_one_pass_input_fn(params={'batch_size': total_bs}), strategy=self.strategy, limit_dataset_size=limit_dataset_size // total_bs) 283 | # Val dataset, if it exists 284 | val_ds = dataset.eval_input_fn(params={'batch_size': total_bs}) 285 | self.inception_real_val = None if val_ds is None else InceptionFeatures(dataset=val_ds, strategy=self.strategy, limit_dataset_size=limit_dataset_size // total_bs) 286 | 287 | img_batch_shape = self.inception_real_train.ds_iterator.output_shapes['image'].as_list() 288 | assert img_batch_shape[0] == self.local_bs 289 | 290 | # Model 291 | self.model = model_constructor() 292 | assert isinstance(self.model, Model) 293 | 294 | # Eval/samples graphs 295 | print('===== SAMPLES =====') 296 | self.samples_outputs, self.samples_inception = self._make_sampling_graph( 297 | img_shape=img_batch_shape[1:], with_inception=True) 298 | 299 | # Model with EMA parameters 300 | self.global_step = tf.train.get_or_create_global_step() 301 | print('===== EMA =====') 302 | ema, _ = make_ema(global_step=self.global_step, ema_decay=1e-10, trainable_variables=tf.trainable_variables()) 303 | 304 | # EMA versions of the above 305 | with utils.ema_scope(ema): 306 | print('===== EMA SAMPLES =====') 307 | self.ema_samples_outputs, self.ema_samples_inception = self._make_sampling_graph( 308 | img_shape=img_batch_shape[1:], with_inception=True) 309 | 310 | def _make_sampling_graph(self, img_shape, with_inception): 311 | 312 | def _make_inputs(total_bs, local_bs): 313 | # Dummy inputs to feed to samplers 314 | input_x = tf.fill([local_bs, *img_shape], value=np.nan) 315 | input_y = tf.random_uniform([local_bs], 0, self.dataset.num_classes, dtype=tf.int32) 316 | return input_x, input_y 317 | 318 | # Samples 319 | samples_outputs = distributed( 320 | self.model.samples_fn, 321 | args=_make_inputs(self.total_bs, self.local_bs), 322 | reduction='concat', strategy=self.strategy) 323 | if not with_inception: 324 | return samples_outputs 325 | 326 | # Inception activations of samples 327 | samples_inception = distributed( 328 | self.model.sample_and_run_inception, 329 | args=_make_inputs(self.inception_bs, self.inception_local_bs), 330 | reduction='concat', strategy=self.strategy) 331 | return samples_outputs, samples_inception 332 | 333 | def _run_sampling(self, sess, ema: bool): 334 | out = {} 335 | print('sampling...') 336 | tstart = time.time() 337 | samples = sess.run(self.ema_samples_outputs if ema else self.samples_outputs) 338 | print('sampling done in {} sec'.format(time.time() - tstart)) 339 | for k, v in samples.items(): 340 | out['samples/{}'.format(k)] = v 341 | return out 342 | 343 | def _run_metrics(self, sess, ema: bool): 344 | print('computing sample quality metrics...') 345 | metrics = {} 346 | 347 | # Get Inception activations on the real dataset 348 | cached_inception_real_train, metrics['real_inception_score_train'] = self.inception_real_train.get(sess) 349 | if self.inception_real_val is not None: 350 | cached_inception_real_val, metrics['real_inception_score'] = self.inception_real_val.get(sess) 351 | else: 352 | cached_inception_real_val = None 353 | 354 | # Generate lots of samples 355 | num_inception_gen_batches = int(np.ceil(self.num_inception_samples / self.inception_bs)) 356 | print('generating {} samples and inception features ({} batches)...'.format( 357 | self.num_inception_samples, num_inception_gen_batches)) 358 | inception_gen_batches = [ 359 | sess.run(self.ema_samples_inception if ema else self.samples_inception) 360 | for _ in trange(num_inception_gen_batches, desc='sampling inception batch') 361 | ] 362 | 363 | # Compute FID and Inception score 364 | assert set(self.samples_outputs.keys()) == set(inception_gen_batches[0].keys()) 365 | for samples_key in self.samples_outputs.keys(): 366 | # concat features from all batches into a single array 367 | inception_gen = { 368 | feat_key: np.concatenate( 369 | [batch[samples_key][feat_key] for batch in inception_gen_batches], axis=0 370 | )[:self.num_inception_samples].astype(np.float64) 371 | for feat_key in ['pool_3', 'logits'] 372 | } 373 | assert all(v.shape[0] == self.num_inception_samples for v in inception_gen.values()) 374 | 375 | # Inception score 376 | metrics['{}/inception{}'.format(samples_key, self.num_inception_samples)] = float( 377 | classifier_metrics_numpy.classifier_score_from_logits(inception_gen['logits'])) 378 | 379 | # FID vs training set 380 | metrics['{}/trainfid{}'.format(samples_key, self.num_inception_samples)] = float( 381 | classifier_metrics_numpy.frechet_classifier_distance_from_activations( 382 | cached_inception_real_train['pool_3'], inception_gen['pool_3'])) 383 | 384 | # FID vs val set 385 | if cached_inception_real_val is not None: 386 | metrics['{}/fid{}'.format(samples_key, self.num_inception_samples)] = float( 387 | classifier_metrics_numpy.frechet_classifier_distance_from_activations( 388 | cached_inception_real_val['pool_3'], inception_gen['pool_3'])) 389 | 390 | return metrics 391 | 392 | def _write_eval_and_samples(self, sess, log: utils.SummaryWriter, curr_step, prefix, ema: bool): 393 | # Samples 394 | for k, v in self._run_sampling(sess, ema=ema).items(): 395 | assert len(v.shape) == 4 and v.shape[0] == self.total_bs 396 | log.images('{}/{}'.format(prefix, k), np.clip(unnormalize_data(v), 0, 255).astype('uint8'), step=curr_step) 397 | log.flush() 398 | 399 | # Metrics 400 | metrics = self._run_metrics(sess, ema=ema) 401 | print('metrics:', json.dumps(metrics, indent=2, sort_keys=True)) 402 | for k, v in metrics.items(): 403 | log.scalar('{}/{}'.format(prefix, k), v, step=curr_step) 404 | log.flush() 405 | 406 | def _dump_samples(self, sess, curr_step, samples_dir, ema: bool, num_samples=50000): 407 | print('will dump samples to', samples_dir) 408 | if not tf.gfile.IsDirectory(samples_dir): 409 | tf.gfile.MakeDirs(samples_dir) 410 | filename = os.path.join( 411 | samples_dir, 'samples_ema{}_step{:09d}.pkl'.format(int(ema), curr_step)) 412 | assert not tf.io.gfile.exists(filename), 'samples file already exists: {}'.format(filename) 413 | 414 | num_gen_batches = int(np.ceil(num_samples / self.total_bs)) 415 | print('generating {} samples ({} batches)...'.format(num_samples, num_gen_batches)) 416 | 417 | # gen_batches = [ 418 | # sess.run(self.ema_samples_outputs if ema else self.samples_outputs) 419 | # for _ in trange(num_gen_batches, desc='sampling') 420 | # ] 421 | # assert all(set(b.keys()) == set(self.samples_outputs.keys()) for b in gen_batches) 422 | # concatenated = { 423 | # k: np.concatenate([b[k].astype(np.float32) for b in gen_batches], axis=0)[:num_samples] 424 | # for k in self.samples_outputs.keys() 425 | # } 426 | # assert all(len(v) == num_samples for v in concatenated.values()) 427 | # 428 | # print('writing samples to:', filename) 429 | # with tf.io.gfile.GFile(filename, 'wb') as f: 430 | # f.write(pickle.dumps(concatenated, protocol=pickle.HIGHEST_PROTOCOL)) 431 | 432 | for i in trange(num_gen_batches, desc='sampling'): 433 | b = sess.run(self.ema_samples_outputs if ema else self.samples_outputs) 434 | assert set(b.keys()) == set(self.samples_outputs.keys()) 435 | b = { 436 | k: b[k].astype(np.float32) for k in self.samples_outputs.keys() 437 | } 438 | #assert all(len(v) == num_samples for v in concatenated.values()) 439 | 440 | filename_i = "{}.batch{:05d}".format(filename, i) 441 | print('writing samples for batch', i, 'to:', filename_i) 442 | with tf.io.gfile.GFile(filename_i, 'wb') as f: 443 | f.write(pickle.dumps(b, protocol=pickle.HIGHEST_PROTOCOL)) 444 | print('done writing samples') 445 | 446 | def run(self, logdir, once: bool, skip_non_ema_pass=True, dump_samples_only=False, load_ckpt=None, samples_dir=None, seed=0): 447 | """Runs the eval/sampling worker loop. 448 | Args: 449 | logdir: directory to read checkpoints from 450 | once: if True, writes results to a temporary directory (not to logdir), 451 | and exits after evaluating one checkpoint. 452 | """ 453 | tf.logging.set_verbosity(tf.logging.INFO) 454 | 455 | # Are we evaluating a single checkpoint or looping on the latest? 456 | if load_ckpt is not None: 457 | # load_ckpt should be of the form: model.ckpt-1000000 458 | assert tf.io.gfile.exists(os.path.join(logdir, load_ckpt) + '.index') 459 | ckpt_iterator = [os.path.join(logdir, load_ckpt)] # load this one checkpoint only 460 | else: 461 | ckpt_iterator = tf.train.checkpoints_iterator(logdir) # wait for checkpoints to come in 462 | assert tf.io.gfile.isdir(logdir), 'expected {} to be a directory'.format(logdir) 463 | 464 | # Set up eval SummaryWriter 465 | if once: 466 | eval_logdir = os.path.join(logdir, 'eval_once_{}'.format(time.time())) 467 | else: 468 | eval_logdir = os.path.join(logdir, 'eval') 469 | print('Writing eval data to: {}'.format(eval_logdir)) 470 | eval_log = utils.SummaryWriter(eval_logdir, write_graph=False) 471 | 472 | # Make the session 473 | config = tf.ConfigProto() 474 | config.allow_soft_placement = True 475 | cluster_spec = self.resolver.cluster_spec() 476 | if cluster_spec: 477 | config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 478 | print('making session...') 479 | with tf.Session(target=self.resolver.master(), config=config) as sess: 480 | 481 | print('initializing global variables') 482 | sess.run(tf.global_variables_initializer()) 483 | 484 | # Checkpoint loading 485 | print('making saver') 486 | saver = tf.train.Saver() 487 | 488 | for ckpt in ckpt_iterator: 489 | # Restore params 490 | saver.restore(sess, ckpt) 491 | global_step_val = sess.run(self.global_step) 492 | print('restored global step: {}'.format(global_step_val)) 493 | 494 | print('seeding') 495 | utils.seed_all(seed) 496 | 497 | print('ema pass') 498 | if dump_samples_only: 499 | if not samples_dir: 500 | samples_dir = os.path.join(eval_logdir, '{}_samples{}'.format(type(self.dataset).__name__, global_step_val)) 501 | self._dump_samples( 502 | sess, curr_step=global_step_val, samples_dir=samples_dir, ema=True) 503 | else: 504 | self._write_eval_and_samples(sess, log=eval_log, curr_step=global_step_val, prefix='eval_ema', ema=True) 505 | 506 | if not skip_non_ema_pass: 507 | print('non-ema pass') 508 | if dump_samples_only: 509 | self._dump_samples( 510 | sess, curr_step=global_step_val, samples_dir=os.path.join(eval_logdir, 'samples'), ema=False) 511 | else: 512 | self._write_eval_and_samples(sess, log=eval_log, curr_step=global_step_val, prefix='eval', ema=False) 513 | 514 | if once: 515 | break 516 | -------------------------------------------------------------------------------- /diffusion_tf/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import io 3 | import random 4 | import time 5 | 6 | import numpy as np 7 | import tensorflow.compat.v1 as tf 8 | from PIL import Image 9 | from tensorflow.compat.v1 import gfile 10 | from tensorflow.core.framework.summary_pb2 import Summary 11 | from tensorflow.core.util.event_pb2 import Event 12 | 13 | 14 | class SummaryWriter: 15 | """Tensorflow summary writer inspired by Jaxboard. 16 | This version doesn't try to avoid Tensorflow dependencies, because this 17 | project uses Tensorflow. 18 | """ 19 | 20 | def __init__(self, dir, write_graph=True): 21 | if not gfile.IsDirectory(dir): 22 | gfile.MakeDirs(dir) 23 | self.writer = tf.summary.FileWriter( 24 | dir, graph=tf.get_default_graph() if write_graph else None) 25 | 26 | def flush(self): 27 | self.writer.flush() 28 | 29 | def close(self): 30 | self.writer.close() 31 | 32 | def _write_event(self, summary_value, step): 33 | self.writer.add_event( 34 | Event( 35 | wall_time=round(time.time()), 36 | step=step, 37 | summary=Summary(value=[summary_value]))) 38 | 39 | def scalar(self, tag, value, step): 40 | self._write_event(Summary.Value(tag=tag, simple_value=float(value)), step) 41 | 42 | def image(self, tag, image, step): 43 | image = np.asarray(image) 44 | if image.ndim == 2: 45 | image = image[:, :, None] 46 | if image.shape[-1] == 1: 47 | image = np.repeat(image, 3, axis=-1) 48 | 49 | bytesio = io.BytesIO() 50 | Image.fromarray(image).save(bytesio, 'PNG') 51 | image_summary = Summary.Image( 52 | encoded_image_string=bytesio.getvalue(), 53 | colorspace=3, 54 | height=image.shape[0], 55 | width=image.shape[1]) 56 | self._write_event(Summary.Value(tag=tag, image=image_summary), step) 57 | 58 | def images(self, tag, images, step): 59 | self.image(tag, tile_imgs(images), step=step) 60 | 61 | 62 | def seed_all(seed): 63 | random.seed(seed) 64 | np.random.seed(seed) 65 | tf.set_random_seed(seed) 66 | 67 | 68 | def tile_imgs(imgs, *, pad_pixels=1, pad_val=255, num_col=0): 69 | assert pad_pixels >= 0 and 0 <= pad_val <= 255 70 | 71 | imgs = np.asarray(imgs) 72 | assert imgs.dtype == np.uint8 73 | if imgs.ndim == 3: 74 | imgs = imgs[..., None] 75 | n, h, w, c = imgs.shape 76 | assert c == 1 or c == 3, 'Expected 1 or 3 channels' 77 | 78 | if num_col <= 0: 79 | # Make a square 80 | ceil_sqrt_n = int(np.ceil(np.sqrt(float(n)))) 81 | num_row = ceil_sqrt_n 82 | num_col = ceil_sqrt_n 83 | else: 84 | # Make a B/num_per_row x num_per_row grid 85 | assert n % num_col == 0 86 | num_row = int(np.ceil(n / num_col)) 87 | 88 | imgs = np.pad( 89 | imgs, 90 | pad_width=((0, num_row * num_col - n), (pad_pixels, pad_pixels), (pad_pixels, pad_pixels), (0, 0)), 91 | mode='constant', 92 | constant_values=pad_val 93 | ) 94 | h, w = h + 2 * pad_pixels, w + 2 * pad_pixels 95 | imgs = imgs.reshape(num_row, num_col, h, w, c) 96 | imgs = imgs.transpose(0, 2, 1, 3, 4) 97 | imgs = imgs.reshape(num_row * h, num_col * w, c) 98 | 99 | if pad_pixels > 0: 100 | imgs = imgs[pad_pixels:-pad_pixels, pad_pixels:-pad_pixels, :] 101 | if c == 1: 102 | imgs = imgs[..., 0] 103 | return imgs 104 | 105 | 106 | def save_tiled_imgs(filename, imgs, pad_pixels=1, pad_val=255, num_col=0): 107 | Image.fromarray(tile_imgs(imgs, pad_pixels=pad_pixels, pad_val=pad_val, num_col=num_col)).save(filename) 108 | 109 | 110 | # === 111 | 112 | def approx_standard_normal_cdf(x): 113 | return 0.5 * (1.0 + tf.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))) 114 | 115 | 116 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 117 | # Assumes data is integers [0, 255] rescaled to [-1, 1] 118 | assert x.shape == means.shape == log_scales.shape 119 | centered_x = x - means 120 | inv_stdv = tf.exp(-log_scales) 121 | plus_in = inv_stdv * (centered_x + 1. / 255.) 122 | cdf_plus = approx_standard_normal_cdf(plus_in) 123 | min_in = inv_stdv * (centered_x - 1. / 255.) 124 | cdf_min = approx_standard_normal_cdf(min_in) 125 | log_cdf_plus = tf.log(tf.maximum(cdf_plus, 1e-12)) 126 | log_one_minus_cdf_min = tf.log(tf.maximum(1. - cdf_min, 1e-12)) 127 | cdf_delta = cdf_plus - cdf_min 128 | log_probs = tf.where( 129 | x < -0.999, log_cdf_plus, 130 | tf.where(x > 0.999, log_one_minus_cdf_min, 131 | tf.log(tf.maximum(cdf_delta, 1e-12)))) 132 | assert log_probs.shape == x.shape 133 | return log_probs 134 | 135 | 136 | # === 137 | 138 | 139 | def rms(variables): 140 | return tf.sqrt( 141 | sum([tf.reduce_sum(tf.square(v)) for v in variables]) / 142 | sum(int(np.prod(v.shape.as_list())) for v in variables)) 143 | 144 | 145 | def get_warmed_up_lr(max_lr, warmup, global_step): 146 | if warmup == 0: 147 | return max_lr 148 | return max_lr * tf.minimum(tf.cast(global_step, tf.float32) / float(warmup), 1.0) 149 | 150 | 151 | def make_optimizer( 152 | *, 153 | loss, trainable_variables, global_step, tpu: bool, 154 | optimizer: str, lr: float, grad_clip: float, 155 | rmsprop_decay=0.95, rmsprop_momentum=0.9, epsilon=1e-8 156 | ): 157 | if optimizer == 'adam': 158 | optimizer = tf.train.AdamOptimizer( 159 | learning_rate=lr, epsilon=epsilon) 160 | elif optimizer == 'rmsprop': 161 | optimizer = tf.train.RMSPropOptimizer( 162 | learning_rate=lr, decay=rmsprop_decay, momentum=rmsprop_momentum, epsilon=epsilon) 163 | else: 164 | raise NotImplementedError(optimizer) 165 | 166 | if tpu: 167 | optimizer = tf.tpu.CrossShardOptimizer(optimizer) 168 | 169 | # compute gradient 170 | grads_and_vars = optimizer.compute_gradients(loss, var_list=trainable_variables) 171 | 172 | # clip gradient 173 | clipped_grads, gnorm = tf.clip_by_global_norm([g for (g, _) in grads_and_vars], grad_clip) 174 | grads_and_vars = [(g, v) for g, (_, v) in zip(clipped_grads, grads_and_vars)] 175 | 176 | # train 177 | train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) 178 | return train_op, gnorm 179 | 180 | 181 | @contextlib.contextmanager 182 | def ema_scope(orig_model_ema): 183 | def _ema_getter(getter, name, *args, **kwargs): 184 | v = getter(name, *args, **kwargs) 185 | v = orig_model_ema.average(v) 186 | if v is None: 187 | raise RuntimeError('Variable {} has no EMA counterpart'.format(name)) 188 | return v 189 | 190 | with tf.variable_scope(tf.get_variable_scope(), custom_getter=_ema_getter, reuse=True): 191 | with tf.name_scope('ema_scope'): 192 | yield 193 | 194 | 195 | def get_gcp_region(): 196 | # https://stackoverflow.com/a/31689692 197 | import requests 198 | metadata_server = "http://metadata/computeMetadata/v1/instance/" 199 | metadata_flavor = {'Metadata-Flavor': 'Google'} 200 | zone = requests.get(metadata_server + 'zone', headers=metadata_flavor).text 201 | zone = zone.split('/')[-1] 202 | region = '-'.join(zone.split('-')[:-1]) 203 | return region 204 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | appdirs==1.4.3 3 | astor==0.8.1 4 | attrs==19.3.0 5 | backcall==0.1.0 6 | bleach==3.1.0 7 | cachetools==4.0.0 8 | certifi==2019.11.28 9 | chardet==3.0.4 10 | cloud-tpu-client==0.5 11 | cloud-tpu-profiler==1.15.0rc1 12 | cloudpickle==1.1.1 13 | cycler==0.10.0 14 | decorator==4.4.1 15 | defusedxml==0.6.0 16 | dill==0.3.1.1 17 | distlib==0.3.0 18 | distro==1.0.1 19 | entrypoints==0.3 20 | filelock==3.0.12 21 | fire==0.2.1 22 | future==0.18.2 23 | gast==0.2.2 24 | google-api-python-client==1.7.11 25 | google-auth==1.11.2 26 | google-auth-httplib2==0.0.3 27 | google-compute-engine==20191210.0 28 | google-pasta==0.1.8 29 | googleapis-common-protos==1.51.0 30 | grpcio==1.27.2 31 | h5py==2.10.0 32 | httplib2==0.17.0 33 | idna==2.8 34 | importlib-metadata==1.5.0 35 | importlib-resources==1.0.2 36 | ipykernel==5.1.4 37 | ipython==7.9.0 38 | ipython-genutils==0.2.0 39 | ipywidgets==7.5.1 40 | jedi==0.16.0 41 | Jinja2==2.11.1 42 | jsonschema==3.2.0 43 | jupyter==1.0.0 44 | jupyter-client==5.3.4 45 | jupyter-console==6.1.0 46 | jupyter-core==4.6.3 47 | Keras-Applications==1.0.8 48 | Keras-Preprocessing==1.1.0 49 | kiwisolver==1.1.0 50 | Markdown==3.2.1 51 | MarkupSafe==1.1.1 52 | matplotlib==3.0.3 53 | mistune==0.8.4 54 | nbconvert==5.6.1 55 | nbformat==5.0.4 56 | notebook==6.0.3 57 | numpy==1.18.1 58 | oauth2client==4.1.3 59 | opt-einsum==3.1.0 60 | pandas==0.25.3 61 | pandocfilters==1.4.2 62 | parso==0.6.1 63 | pexpect==4.8.0 64 | pickleshare==0.7.5 65 | Pillow==7.0.0 66 | prometheus-client==0.7.1 67 | promise==2.3 68 | prompt-toolkit==2.0.10 69 | protobuf==3.11.3 70 | psutil==5.7.0 71 | ptyprocess==0.6.0 72 | pyasn1==0.4.8 73 | pyasn1-modules==0.2.8 74 | pycurl==7.43.0 75 | Pygments==2.5.2 76 | pygobject==3.22.0 77 | pyparsing==2.4.6 78 | pyrsistent==0.15.7 79 | python-apt==1.4.1 80 | python-dateutil==2.8.1 81 | pytz==2019.3 82 | PyYAML==5.3 83 | pyzmq==18.1.1 84 | qtconsole==4.6.0 85 | requests==2.22.0 86 | rsa==4.0 87 | scipy==1.4.1 88 | seaborn==0.9.1 89 | Send2Trash==1.5.0 90 | six==1.14.0 91 | tensorboard==1.15.0 92 | tensorflow==1.15.0 93 | tensorflow-datasets==2.1.0 94 | tensorflow-estimator==1.15.1 95 | tensorflow-gan==0.0.0.dev0 96 | tensorflow-hub==0.7.0 97 | tensorflow-metadata==0.21.1 98 | tensorflow-probability==0.8.0 99 | tensorflow-serving-api==1.14.0 100 | termcolor==1.1.0 101 | terminado==0.8.3 102 | testpath==0.4.4 103 | tornado==6.0.3 104 | tqdm==4.42.1 105 | traitlets==4.3.3 106 | unattended-upgrades==0.1 107 | uritemplate==3.0.1 108 | urllib3==1.25.7 109 | virtualenv==20.0.4 110 | wcwidth==0.1.8 111 | webencodings==0.5.1 112 | Werkzeug==1.0.0 113 | widgetsnbextension==3.5.1 114 | wrapt==1.12.0 115 | zipp==1.2.0 116 | -------------------------------------------------------------------------------- /resources/samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hojonathanho/diffusion/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/resources/samples.png -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hojonathanho/diffusion/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/run_celebahq.py: -------------------------------------------------------------------------------- 1 | """ 2 | CelebaHQ 256x256 3 | 4 | python3 scripts/run_celebahq.py train --bucket_name_prefix $BUCKET_PREFIX --exp_name $EXPERIMENT_NAME --tpu_name $TPU_NAME 5 | python3 scripts/run_celebahq.py evaluation --bucket_name_prefix $BUCKET_PREFIX --tpu_name $EVAL_TPU_NAME --model_dir $MODEL_DIR 6 | """ 7 | 8 | import functools 9 | 10 | import fire 11 | import numpy as np 12 | import tensorflow.compat.v1 as tf 13 | 14 | from diffusion_tf import utils 15 | from diffusion_tf.diffusion_utils import get_beta_schedule, GaussianDiffusion 16 | from diffusion_tf.models import unet 17 | from diffusion_tf.tpu_utils import tpu_utils, datasets 18 | 19 | 20 | class Model(tpu_utils.Model): 21 | def __init__(self, *, model_name, betas: np.ndarray, loss_type: str, num_classes: int, 22 | dropout: float, randflip, block_size: int): 23 | self.model_name = model_name 24 | self.diffusion = GaussianDiffusion(betas=betas, loss_type=loss_type) 25 | self.num_classes = num_classes 26 | self.dropout = dropout 27 | self.randflip = randflip 28 | self.block_size = block_size 29 | 30 | def _denoise(self, x, t, y, dropout): 31 | B, H, W, C = x.shape.as_list() 32 | assert x.dtype == tf.float32 33 | assert t.shape == [B] and t.dtype in [tf.int32, tf.int64] 34 | assert y.shape == [B] and y.dtype in [tf.int32, tf.int64] 35 | orig_out_ch = out_ch = C 36 | 37 | if self.block_size != 1: # this can be used to reduce memory consumption 38 | x = tf.nn.space_to_depth(x, self.block_size) 39 | out_ch *= self.block_size ** 2 40 | 41 | y = None 42 | if self.model_name == 'unet2d16b2c112244': # 114M for block_size=1 43 | out = unet.model( 44 | x, t=t, y=y, name='model', ch=128, ch_mult=(1, 1, 2, 2, 4, 4), num_res_blocks=2, attn_resolutions=(16,), 45 | out_ch=out_ch, num_classes=self.num_classes, dropout=dropout 46 | ) 47 | else: 48 | raise NotImplementedError(self.model_name) 49 | 50 | if self.block_size != 1: 51 | out = tf.nn.depth_to_space(out, self.block_size) 52 | assert out.shape == [B, H, W, orig_out_ch] 53 | return out 54 | 55 | def train_fn(self, x, y): 56 | B, H, W, C = x.shape 57 | if self.randflip: 58 | x = tf.image.random_flip_left_right(x) 59 | assert x.shape == [B, H, W, C] 60 | t = tf.random_uniform([B], 0, self.diffusion.num_timesteps, dtype=tf.int32) 61 | losses = self.diffusion.p_losses( 62 | denoise_fn=functools.partial(self._denoise, y=y, dropout=self.dropout), x_start=x, t=t) 63 | assert losses.shape == t.shape == [B] 64 | return {'loss': tf.reduce_mean(losses)} 65 | 66 | def samples_fn(self, dummy_noise, y): 67 | return { 68 | 'samples': self.diffusion.p_sample_loop( 69 | denoise_fn=functools.partial(self._denoise, y=y, dropout=0), 70 | shape=dummy_noise.shape.as_list(), 71 | noise_fn=tf.random_normal 72 | ) 73 | } 74 | 75 | def samples_fn_denoising_trajectory(self, dummy_noise, y, repeat_noise_steps=0): 76 | times, imgs = self.diffusion.p_sample_loop_trajectory( 77 | denoise_fn=functools.partial(self._denoise, y=y, dropout=0), 78 | shape=dummy_noise.shape.as_list(), 79 | noise_fn=tf.random_normal, 80 | repeat_noise_steps=repeat_noise_steps 81 | ) 82 | return { 83 | 'samples': imgs[-1], 84 | 'denoising_trajectory_times': times, 85 | 'denoising_trajectory_images': imgs 86 | } 87 | 88 | def interpolate_fn(self, dummy_noise, y): 89 | x1, x2, lam, x_interp, t = self.diffusion.interpolate( 90 | denoise_fn=functools.partial(self._denoise, y=y, dropout=0), 91 | shape=dummy_noise.shape.as_list(), 92 | noise_fn=tf.random_normal, 93 | ) 94 | return { 95 | 'x1': x1, # placeholder 96 | 'x2': x2, # placeholder 97 | 'lam': lam, # placeholder 98 | 't': t, # placeholder 99 | 'x_interp': x_interp 100 | } 101 | 102 | 103 | def evaluation( 104 | model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128, 105 | tfds_data_dir='tensorflow_datasets', 106 | ): 107 | region = utils.get_gcp_region() 108 | tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) 109 | kwargs = tpu_utils.load_train_kwargs(model_dir) 110 | print('loaded kwargs:', kwargs) 111 | ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) 112 | worker = tpu_utils.EvalWorker( 113 | tpu_name=tpu_name, 114 | model_constructor=lambda: Model( 115 | model_name=kwargs['model_name'], 116 | betas=get_beta_schedule( 117 | kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], 118 | num_diffusion_timesteps=kwargs['num_diffusion_timesteps'] 119 | ), 120 | loss_type=kwargs['loss_type'], 121 | num_classes=ds.num_classes, 122 | dropout=kwargs['dropout'], 123 | randflip=kwargs['randflip'], 124 | block_size=kwargs['block_size'] 125 | ), 126 | total_bs=total_bs, inception_bs=total_bs, num_inception_samples=2048, 127 | dataset=ds, 128 | ) 129 | worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only) 130 | 131 | 132 | def train( 133 | exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256', 134 | optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000, 135 | num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred', 136 | dropout=0.0, randflip=1, block_size=1, 137 | tfds_data_dir='tensorflow_datasets', log_dir='logs' 138 | ): 139 | region = utils.get_gcp_region() 140 | tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) 141 | log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir) 142 | kwargs = dict(locals()) 143 | ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir) 144 | tpu_utils.run_training( 145 | date_str='9999-99-99', 146 | exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format( 147 | **kwargs), 148 | model_constructor=lambda: Model( 149 | model_name=model_name, 150 | betas=get_beta_schedule( 151 | beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps 152 | ), 153 | loss_type=loss_type, 154 | num_classes=ds.num_classes, 155 | dropout=dropout, 156 | randflip=randflip, 157 | block_size=block_size 158 | ), 159 | optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, 160 | train_input_fn=ds.train_input_fn, 161 | tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs 162 | ) 163 | 164 | 165 | if __name__ == '__main__': 166 | fire.Fire() 167 | -------------------------------------------------------------------------------- /scripts/run_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unconditional CIFAR10 3 | 4 | python3 scripts/run_cifar.py train --bucket_name_prefix $BUCKET_PREFIX --exp_name $EXPERIMENT_NAME --tpu_name $TPU_NAME 5 | python3 scripts/run_cifar.py evaluation --bucket_name_prefix $BUCKET_PREFIX --tpu_name $EVAL_TPU_NAME --model_dir $MODEL_DIR 6 | """ 7 | 8 | import functools 9 | 10 | import fire 11 | import numpy as np 12 | import tensorflow.compat.v1 as tf 13 | 14 | from diffusion_tf import utils 15 | from diffusion_tf.diffusion_utils_2 import get_beta_schedule, GaussianDiffusion2 16 | from diffusion_tf.models import unet 17 | from diffusion_tf.tpu_utils import tpu_utils, datasets, simple_eval_worker 18 | 19 | 20 | class Model(tpu_utils.Model): 21 | def __init__(self, *, model_name, betas: np.ndarray, model_mean_type: str, model_var_type: str, loss_type: str, 22 | num_classes: int, dropout: float, randflip): 23 | self.model_name = model_name 24 | self.diffusion = GaussianDiffusion2( 25 | betas=betas, model_mean_type=model_mean_type, model_var_type=model_var_type, loss_type=loss_type) 26 | self.num_classes = num_classes 27 | self.dropout = dropout 28 | self.randflip = randflip 29 | 30 | def _denoise(self, x, t, y, dropout): 31 | B, H, W, C = x.shape.as_list() 32 | assert x.dtype == tf.float32 33 | assert t.shape == [B] and t.dtype in [tf.int32, tf.int64] 34 | assert y.shape == [B] and y.dtype in [tf.int32, tf.int64] 35 | out_ch = (C * 2) if self.diffusion.model_var_type == 'learned' else C 36 | y = None 37 | if self.model_name == 'unet2d16b2': # 35.7M 38 | return unet.model( 39 | x, t=t, y=y, name='model', ch=128, ch_mult=(1, 2, 2, 2), num_res_blocks=2, attn_resolutions=(16,), 40 | out_ch=out_ch, num_classes=self.num_classes, dropout=dropout 41 | ) 42 | raise NotImplementedError(self.model_name) 43 | 44 | def train_fn(self, x, y): 45 | B, H, W, C = x.shape 46 | if self.randflip: 47 | x = tf.image.random_flip_left_right(x) 48 | assert x.shape == [B, H, W, C] 49 | t = tf.random_uniform([B], 0, self.diffusion.num_timesteps, dtype=tf.int32) 50 | losses = self.diffusion.training_losses( 51 | denoise_fn=functools.partial(self._denoise, y=y, dropout=self.dropout), x_start=x, t=t) 52 | assert losses.shape == t.shape == [B] 53 | return {'loss': tf.reduce_mean(losses)} 54 | 55 | def samples_fn(self, dummy_noise, y): 56 | return { 57 | 'samples': self.diffusion.p_sample_loop( 58 | denoise_fn=functools.partial(self._denoise, y=y, dropout=0), 59 | shape=dummy_noise.shape.as_list(), 60 | noise_fn=tf.random_normal 61 | ) 62 | } 63 | 64 | def progressive_samples_fn(self, dummy_noise, y): 65 | samples, progressive_samples = self.diffusion.p_sample_loop_progressive( 66 | denoise_fn=functools.partial(self._denoise, y=y, dropout=0), 67 | shape=dummy_noise.shape.as_list(), 68 | noise_fn=tf.random_normal 69 | ) 70 | return {'samples': samples, 'progressive_samples': progressive_samples} 71 | 72 | def bpd_fn(self, x, y): 73 | total_bpd_b, terms_bpd_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop( 74 | denoise_fn=functools.partial(self._denoise, y=y, dropout=0), 75 | x_start=x 76 | ) 77 | return { 78 | 'total_bpd': total_bpd_b, 79 | 'terms_bpd': terms_bpd_bt, 80 | 'prior_bpd': prior_bpd_b, 81 | 'mse': mse_bt 82 | } 83 | 84 | 85 | def _load_model(kwargs, ds): 86 | return Model( 87 | model_name=kwargs['model_name'], 88 | betas=get_beta_schedule( 89 | kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], 90 | num_diffusion_timesteps=kwargs['num_diffusion_timesteps'] 91 | ), 92 | model_mean_type=kwargs['model_mean_type'], 93 | model_var_type=kwargs['model_var_type'], 94 | loss_type=kwargs['loss_type'], 95 | num_classes=ds.num_classes, 96 | dropout=kwargs['dropout'], 97 | randflip=kwargs['randflip'] 98 | ) 99 | 100 | 101 | def simple_eval(model_dir, tpu_name, bucket_name_prefix, mode, load_ckpt, total_bs=256, tfds_data_dir='tensorflow_datasets'): 102 | region = utils.get_gcp_region() 103 | tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) 104 | kwargs = tpu_utils.load_train_kwargs(model_dir) 105 | print('loaded kwargs:', kwargs) 106 | ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) 107 | worker = simple_eval_worker.SimpleEvalWorker( 108 | tpu_name=tpu_name, model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds), 109 | total_bs=total_bs, dataset=ds) 110 | worker.run(mode=mode, logdir=model_dir, load_ckpt=load_ckpt) 111 | 112 | 113 | def evaluation( # evaluation loop for use during training 114 | model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=256, 115 | tfds_data_dir='tensorflow_datasets', load_ckpt=None 116 | ): 117 | region = utils.get_gcp_region() 118 | tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) 119 | kwargs = tpu_utils.load_train_kwargs(model_dir) 120 | print('loaded kwargs:', kwargs) 121 | ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) 122 | worker = tpu_utils.EvalWorker( 123 | tpu_name=tpu_name, 124 | model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds), 125 | total_bs=total_bs, inception_bs=total_bs, num_inception_samples=50000, 126 | dataset=ds, 127 | ) 128 | worker.run( 129 | logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only, load_ckpt=load_ckpt) 130 | 131 | 132 | def train( 133 | exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2', dataset='cifar10', 134 | optimizer='adam', total_bs=128, grad_clip=1., lr=2e-4, warmup=5000, 135 | num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', 136 | model_mean_type='eps', model_var_type='fixedlarge', loss_type='mse', 137 | dropout=0.1, randflip=1, 138 | tfds_data_dir='tensorflow_datasets', log_dir='logs', keep_checkpoint_max=2 139 | ): 140 | region = utils.get_gcp_region() 141 | tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) 142 | log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir) 143 | kwargs = dict(locals()) 144 | ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir) 145 | tpu_utils.run_training( 146 | date_str='9999-99-99', 147 | exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{model_mean_type}-{model_var_type}-{loss_type}_dropout{dropout}_randflip{randflip}'.format( 148 | **kwargs), 149 | model_constructor=lambda: Model( 150 | model_name=model_name, 151 | betas=get_beta_schedule( 152 | beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps 153 | ), 154 | model_mean_type=model_mean_type, 155 | model_var_type=model_var_type, 156 | loss_type=loss_type, 157 | num_classes=ds.num_classes, 158 | dropout=dropout, 159 | randflip=randflip 160 | ), 161 | optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, 162 | train_input_fn=ds.train_input_fn, 163 | tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs, iterations_per_loop=2000, keep_checkpoint_max=keep_checkpoint_max 164 | ) 165 | 166 | 167 | if __name__ == '__main__': 168 | fire.Fire() 169 | -------------------------------------------------------------------------------- /scripts/run_lsun.py: -------------------------------------------------------------------------------- 1 | """ 2 | LSUN church, bedroom and cat 256x256 3 | 4 | # LSUN church 5 | python3 scripts/run_lsun.py train --bucket_name_prefix $BUCKET_PREFIX --tpu_name $TPU_NAME --exp_name $EXPERIMENT_NAME --tpu_name $TPU_NAME --tfr_file 'tensorflow_datasets/lsun/church/church-r08.tfrecords' 6 | python3 scripts/run_lsun.py evaluation --tpu_name $EVAL_TPU_NAME --model_dir $MODEL_DIR --tfr_file 'tensorflow_datasets/lsun/church/church-r08.tfrecords' 7 | 8 | # LSUN bedroom 9 | python3 scripts/run_lsun.py train --bucket_name_prefix $BUCKET_PREFIX --exp_name $EXPERIMENT_NAME --tfr_file 'tensorflow_datasets/lsun/bedroom-full/bedroom-full-r08.tfrecords' 10 | python3 scripts/run_lsun.py evaluation --bucket_name_prefix $BUCKET_PREFIX --tpu_name $EVAL_TPU_NAME --model_dir $MODEL_DIR --tfr_file 'tensorflow_datasets/lsun/bedroom-full/bedroom-full-r08.tfrecords' 11 | 12 | # LSUN cat 13 | python3 scripts/run_lsun.py train --bucket_name_prefix $BUCKET_PREFIX --exp_name $EXPERIMENT_NAME --tpu_name $TPU_NAME --tfr_file 'tensorflow_datasets/lsun/cat/cat-r08.tfrecords' --randflip 0 14 | python3 scripts/run_lsun.py evaluation --bucket_name_prefix $BUCKET_PREFIX --tpu_name $EVAL_TPU_NAME --model_dir $MODEL_DIR --tfr_file 'tensorflow_datasets/lsun/cat/cat-r08.tfrecords' 15 | """ 16 | 17 | import functools 18 | 19 | import fire 20 | import numpy as np 21 | import tensorflow.compat.v1 as tf 22 | 23 | from diffusion_tf import utils 24 | from diffusion_tf.diffusion_utils import get_beta_schedule, GaussianDiffusion 25 | from diffusion_tf.models import unet 26 | from diffusion_tf.tpu_utils import tpu_utils, datasets 27 | 28 | 29 | class Model(tpu_utils.Model): 30 | def __init__(self, *, model_name, betas: np.ndarray, loss_type: str, num_classes: int, 31 | dropout: float, randflip, block_size: int): 32 | self.model_name = model_name 33 | self.diffusion = GaussianDiffusion(betas=betas, loss_type=loss_type) 34 | self.num_classes = num_classes 35 | self.dropout = dropout 36 | self.randflip = randflip 37 | self.block_size = block_size 38 | 39 | def _denoise(self, x, t, y, dropout): 40 | B, H, W, C = x.shape.as_list() 41 | assert x.dtype == tf.float32 42 | assert t.shape == [B] and t.dtype in [tf.int32, tf.int64] 43 | assert y.shape == [B] and y.dtype in [tf.int32, tf.int64] 44 | orig_out_ch = out_ch = C 45 | 46 | if self.block_size != 1: 47 | x = tf.nn.space_to_depth(x, self.block_size) 48 | out_ch *= self.block_size ** 2 49 | 50 | y = None 51 | if self.model_name == 'unet2d16b2c112244': # 114M for block_size=1 52 | out = unet.model( 53 | x, t=t, y=y, name='model', ch=128, ch_mult=(1, 1, 2, 2, 4, 4), num_res_blocks=2, attn_resolutions=(16,), 54 | out_ch=out_ch, num_classes=self.num_classes, dropout=dropout 55 | ) 56 | else: 57 | raise NotImplementedError(self.model_name) 58 | 59 | if self.block_size != 1: 60 | out = tf.nn.depth_to_space(out, self.block_size) 61 | assert out.shape == [B, H, W, orig_out_ch] 62 | return out 63 | 64 | def train_fn(self, x, y): 65 | B, H, W, C = x.shape 66 | if self.randflip: 67 | x = tf.image.random_flip_left_right(x) 68 | assert x.shape == [B, H, W, C] 69 | t = tf.random_uniform([B], 0, self.diffusion.num_timesteps, dtype=tf.int32) 70 | losses = self.diffusion.p_losses( 71 | denoise_fn=functools.partial(self._denoise, y=y, dropout=self.dropout), x_start=x, t=t) 72 | assert losses.shape == t.shape == [B] 73 | return {'loss': tf.reduce_mean(losses)} 74 | 75 | def samples_fn(self, dummy_noise, y): 76 | return { 77 | 'samples': self.diffusion.p_sample_loop( 78 | denoise_fn=functools.partial(self._denoise, y=y, dropout=0), 79 | shape=dummy_noise.shape.as_list(), 80 | noise_fn=tf.random_normal 81 | ) 82 | } 83 | 84 | 85 | def evaluation( 86 | model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128, 87 | tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048, 88 | ): 89 | region = utils.get_gcp_region() 90 | tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file) 91 | kwargs = tpu_utils.load_train_kwargs(model_dir) 92 | print('loaded kwargs:', kwargs) 93 | ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file) 94 | worker = tpu_utils.EvalWorker( 95 | tpu_name=tpu_name, 96 | model_constructor=lambda: Model( 97 | model_name=kwargs['model_name'], 98 | betas=get_beta_schedule( 99 | kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], 100 | num_diffusion_timesteps=kwargs['num_diffusion_timesteps'] 101 | ), 102 | loss_type=kwargs['loss_type'], 103 | num_classes=ds.num_classes, 104 | dropout=kwargs['dropout'], 105 | randflip=kwargs['randflip'], 106 | block_size=kwargs['block_size'] 107 | ), 108 | total_bs=total_bs, inception_bs=total_bs, num_inception_samples=num_inception_samples, 109 | dataset=ds, 110 | limit_dataset_size=30000 # limit size of dataset for computing Inception features, for memory reasons 111 | ) 112 | worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only, 113 | samples_dir=samples_dir) 114 | 115 | 116 | def train( 117 | exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='lsun', 118 | optimizer='adam', total_bs=64, grad_clip=1., lr=2e-5, warmup=5000, 119 | num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred', 120 | dropout=0.0, randflip=1, block_size=1, 121 | tfr_file='tensorflow_datasets/lsun/church/church-r08.tfrecords', log_dir='logs', 122 | warm_start_model_dir=None 123 | ): 124 | region = utils.get_gcp_region() 125 | tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file) 126 | log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir) 127 | print("tfr_file:", tfr_file) 128 | print("log_dir:", log_dir) 129 | kwargs = dict(locals()) 130 | ds = datasets.get_dataset(dataset, tfr_file=tfr_file) 131 | tpu_utils.run_training( 132 | date_str='9999-99-99', 133 | exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format( 134 | **kwargs), 135 | model_constructor=lambda: Model( 136 | model_name=model_name, 137 | betas=get_beta_schedule( 138 | beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps 139 | ), 140 | loss_type=loss_type, 141 | num_classes=ds.num_classes, 142 | dropout=dropout, 143 | randflip=randflip, 144 | block_size=block_size 145 | ), 146 | optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, 147 | train_input_fn=ds.train_input_fn, 148 | tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs, 149 | warm_start_from=tf.estimator.WarmStartSettings( 150 | ckpt_to_initialize_from=tf.train.latest_checkpoint(warm_start_model_dir), 151 | vars_to_warm_start=[".*"] 152 | ) if warm_start_model_dir else None 153 | ) 154 | 155 | 156 | if __name__ == '__main__': 157 | fire.Fire() 158 | --------------------------------------------------------------------------------