├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── cifar10.msgpack ├── consistency_models ├── __init__.py ├── consistency.py ├── model.py └── utils.py ├── jax_local_cluster.py ├── mnist.msgpack ├── requirements.txt └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.msgpack filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv* 2 | __pycache__ 3 | .ipynb_checkpoints 4 | *.egg-info 5 | .vscode 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Katherine Crowson 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # consistency-models 2 | 3 | `consistency-models` is a [JAX](https://jax.readthedocs.io/en/latest/) implementation of the continuous time formulation of [Consistency Models](https://arxiv.org/abs/2303.01469), which allows distillation of a [diffusion model](https://arxiv.org/abs/2006.11239) into a single-step generative model. 4 | 5 | **This code is a WORK IN PROGRESS, it is not done, it does not produce high quality results yet, I am releasing it due to general interest in consistency model implementations.** 6 | 7 | ## Requirements 8 | 9 | ```bash 10 | pip install git+https://github.com/crowsonkb/jax-wavelets 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Notes 15 | 16 | `train.py` trains a diffusion model and a consistency model at the same time, and uses L_CD to continuously distill the EMA diffusion model into the consistency model. The consistency model is then used to generate samples in one step. This seems to work better than training the consistency model directly with L_CT. 17 | -------------------------------------------------------------------------------- /cifar10.msgpack: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:65a6942e659f584f486f7e6fac9c0d3479d3e438badd8337dbc81933fd657a66 3 | size 184560111 4 | -------------------------------------------------------------------------------- /consistency_models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import model, utils 2 | from .consistency import ( 3 | l2_metric, 4 | uniform_weight, 5 | cosine_weight, 6 | consistency_loss, 7 | score_matching_loss, 8 | ) 9 | -------------------------------------------------------------------------------- /consistency_models/consistency.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.tree_util import Partial 4 | 5 | from .utils import rb 6 | 7 | 8 | def l2_metric(x, y): 9 | return jnp.sum((x - y) ** 2) 10 | 11 | 12 | # TODO: Rescale based on tmin and tmax 13 | def uniform_weight(t): 14 | return jnp.ones_like(t) 15 | 16 | 17 | # TODO: Rescale based on tmin and tmax 18 | def cosine_weight(t): 19 | return (t**2 + 1) * jnp.pi / 2 20 | 21 | 22 | batchmean = jax.vmap(jnp.mean) 23 | 24 | 25 | def hvp(fun, x, v): 26 | return jax.jvp(jax.grad(fun), (x,), (v,))[1] 27 | 28 | 29 | def consistency_loss( 30 | params, 31 | x0, 32 | t, 33 | noise, 34 | model_fun, 35 | weight_fun=uniform_weight, 36 | metric_fun=l2_metric, 37 | teacher_fun=None, 38 | stopgrad=False, 39 | ): 40 | xt = x0 + noise * rb(t, x0) 41 | if teacher_fun is None: 42 | eps = noise 43 | else: 44 | eps = (xt - teacher_fun(xt, t)) / rb(t, x0) 45 | out, out_jvp = jax.jvp(Partial(model_fun, params), (xt, t), (eps, jnp.ones_like(t))) 46 | if stopgrad: 47 | out_hvp = jax.lax.stop_gradient(hvp(lambda y: metric_fun(y, out), out, out_jvp)) 48 | return jnp.mean(weight_fun(t) * batchmean(out * out_hvp)) 49 | out_hvp = hvp(lambda y: metric_fun(out, y), out, out_jvp) 50 | return jnp.mean(weight_fun(t) * batchmean(out_jvp * out_hvp)) / 2 51 | 52 | 53 | def score_matching_loss(params, x0, t, noise, model_fun, weight_fun=uniform_weight): 54 | xt = x0 + noise * rb(t, x0) 55 | eps = (xt - model_fun(params, xt, t)) / rb(t, x0) 56 | return jnp.mean(weight_fun(t) * batchmean((eps - noise) ** 2)) 57 | -------------------------------------------------------------------------------- /consistency_models/model.py: -------------------------------------------------------------------------------- 1 | from einshape import jax_einshape as einshape 2 | import flax 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | import jax_wavelets as jw 7 | 8 | from .utils import rb 9 | 10 | 11 | class FourierFeatures(nn.Module): 12 | features: int 13 | std: float = 1.0 14 | 15 | @nn.compact 16 | def __call__(self, x): 17 | assert self.features % 2 == 0 18 | kernel = self.param( 19 | "kernel", 20 | nn.initializers.normal(self.std), 21 | (x.shape[-1], self.features // 2), 22 | ) 23 | x = 2 * jnp.pi * x @ kernel 24 | return jnp.concatenate([jnp.cos(x), jnp.sin(x)], axis=-1) 25 | 26 | 27 | class TransformerBlock(nn.Module): 28 | dropout_rate: float = 0.0 29 | dtype: jnp.dtype = jnp.float32 30 | 31 | @nn.compact 32 | def __call__(self, x, z, deterministic=True): 33 | init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") 34 | init_out = nn.initializers.zeros 35 | 36 | # Self attention 37 | x_skip = x 38 | x = nn.LayerNorm(name="self_attn_norm")(x) 39 | q = nn.Dense(x.shape[-1], kernel_init=init, dtype=self.dtype, name="query")(x) 40 | k = nn.Dense(x.shape[-1], kernel_init=init, dtype=self.dtype, name="key")(x) 41 | v = nn.Dense(x.shape[-1], kernel_init=init, dtype=self.dtype, name="value")(x) 42 | q = einshape("ns(hd)->nshd", q, d=64) 43 | k = einshape("ns(hd)->nshd", k, d=64) 44 | v = einshape("ns(hd)->nshd", v, d=64) 45 | q = nn.LayerNorm(feature_axes=(-2, -1), name="self_attn_query_norm")(q) 46 | k = nn.LayerNorm(feature_axes=(-2, -1), name="self_attn_key_norm")(k) 47 | attn_weights = jnp.einsum("...qhd,...khd->...hqk", q, k) / jnp.sqrt(q.shape[-1]) 48 | attn_weights = nn.Dropout(self.dropout_rate)(attn_weights, deterministic) 49 | attn_weights = jax.nn.softmax(attn_weights) 50 | out = jnp.einsum("...hqk,...khd->...qhd", attn_weights, v) 51 | out = einshape("nshd->ns(hd)", out) 52 | x = nn.Dense( 53 | x_skip.shape[-1], kernel_init=init_out, dtype=self.dtype, name="out" 54 | )(out) 55 | x = nn.Dropout(self.dropout_rate)(x, deterministic) 56 | x = x_skip + x 57 | 58 | # Feedforward 59 | x_skip = x 60 | x = nn.LayerNorm(name="ff_norm")(x) 61 | x1 = nn.Dense( 62 | x.shape[-1] * 4, kernel_init=init, dtype=self.dtype, name="ff_1_1" 63 | )(x) 64 | x2 = nn.Dense( 65 | x.shape[-1] * 4, kernel_init=init, dtype=self.dtype, name="ff_1_2" 66 | )(x) 67 | x = x1 * nn.gelu(x2) 68 | x = nn.Dropout(self.dropout_rate)(x, deterministic) 69 | x = nn.Dense( 70 | x_skip.shape[-1], kernel_init=init_out, dtype=self.dtype, name="ff_2" 71 | )(x) 72 | x = nn.Dropout(self.dropout_rate)(x, deterministic) 73 | x = x_skip + x 74 | 75 | return x 76 | 77 | 78 | class TransformerModel(nn.Module): 79 | width: int 80 | depth: int 81 | wavelet_levels: int 82 | wavelet: str = "bior4.4" 83 | dropout_rate: float = 0.0 84 | dtype: jnp.dtype = jnp.float32 85 | 86 | @nn.compact 87 | def __call__(self, xt, t, deterministic=True): 88 | n, h, w, c = xt.shape 89 | 90 | # Precompute wavelet transform kernels 91 | if not self.has_variable("kernels", "kernel"): 92 | filt = jw.make_kernels(jw.get_filter_bank(self.wavelet), c) 93 | filt = self.variable("kernels", "kernel", lambda: filt).value 94 | else: 95 | filt = self.variable("kernels", "kernel").value 96 | 97 | # Karras preconditioner 98 | c_in = 1 / jnp.sqrt(t**2 + 1) 99 | c_out = t / jnp.sqrt(t**2 + 1) 100 | c_skip = 1 / (t**2 + 1) 101 | x = xt * rb(c_in, xt) 102 | 103 | # Input patching 104 | x = jw.wavelet_dec(x, filt[0], self.wavelet_levels, mode="reflect") 105 | _, h2, w2, c2 = x.shape 106 | x = einshape("nhwc->n(hw)c", x) 107 | x = nn.Dense(self.width, dtype=self.dtype, name="proj_in")(x) 108 | 109 | # Timestep embedding 110 | z = FourierFeatures(self.width, std=1.0, name="timestep_embed")( 111 | jnp.log(t)[:, None] 112 | ) 113 | z = nn.Dense(self.width, dtype=self.dtype, name="timestep_embed_in")(z)[:, None] 114 | 115 | # Positional embedding 116 | n_image_toks = x.shape[1] 117 | x = jnp.concatenate([x, z], axis=1) 118 | x = x + self.param( 119 | "pos_emb", nn.initializers.normal(1.0), (x.shape[1], self.width) 120 | ) 121 | 122 | # Transformer 123 | x = nn.LayerNorm(name="norm_in")(x) 124 | for i in range(self.depth): 125 | x = TransformerBlock( 126 | dropout_rate=self.dropout_rate, 127 | dtype=self.dtype, 128 | name=f"transformer_{i}", 129 | )(x, z, deterministic=deterministic) 130 | x = x[:, :n_image_toks] 131 | x = nn.LayerNorm(name="norm_out")(x) 132 | 133 | # Output unpatching 134 | x = nn.Dense(c2, kernel_init=nn.initializers.zeros, name="proj_out")(x) 135 | x = einshape("n(hw)c->nhwc", x, h=h2, w=w2) 136 | x = jw.wavelet_rec(x, filt[1], self.wavelet_levels, mode="reflect") 137 | 138 | # Karras preconditioner, output 139 | x = x * rb(c_out, x) + xt * rb(c_skip, xt) 140 | 141 | return x 142 | 143 | @staticmethod 144 | def weight_decay_mask(params): 145 | return flax.core.FrozenDict( 146 | flax.traverse_util.path_aware_map( 147 | lambda p, v: len(p) >= 3 and p[-1] == "kernel", params 148 | ) 149 | ) 150 | -------------------------------------------------------------------------------- /consistency_models/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import jax 4 | 5 | 6 | def ema_update(tree_ema, tree, decay): 7 | """Update the exponential moving average of a tree. 8 | 9 | Parameters 10 | ---------- 11 | tree_ema : Any 12 | The current value of the exponential moving average. 13 | tree : Any 14 | The new value to update the exponential moving average with. 15 | decay : float 16 | The decay factor of the exponential moving average. 17 | 18 | Returns 19 | ------- 20 | Any 21 | The updated exponential moving average. 22 | """ 23 | return jax.tree_map( 24 | lambda x_ema, x: x_ema * decay + x * (1 - decay), tree_ema, tree 25 | ) 26 | 27 | 28 | class PerfCounter: 29 | """Tracks an exponential moving average of the time between events. 30 | 31 | Parameters 32 | ---------- 33 | decay : float 34 | The decay factor of the exponential moving average. 35 | """ 36 | 37 | def __init__(self, decay=0.99): 38 | self.decay = decay 39 | self.count = 0 40 | self.value = 0.0 41 | self.decay_accum = 1.0 42 | self.last_time = None 43 | self.pause_time = None 44 | 45 | def get(self): 46 | """Get the current average time between events. 47 | 48 | Returns 49 | ------- 50 | float 51 | The current average time between events. If fewer than two events have been 52 | recorded, returns NaN. 53 | """ 54 | try: 55 | return self.value / (1 - self.decay_accum) 56 | except ZeroDivisionError: 57 | return float("nan") 58 | 59 | def get_count(self): 60 | """Get the number of events that have been recorded. 61 | 62 | Returns 63 | ------- 64 | int 65 | The number of events that have been recorded. 66 | """ 67 | return self.count 68 | 69 | def update(self, time_value=None): 70 | """Record the occurrence of an event. 71 | 72 | Parameters 73 | ---------- 74 | time_value : float 75 | The time of the event. If None, the current time is used. 76 | 77 | Returns 78 | ------- 79 | float 80 | The current average time between events. 81 | """ 82 | time_value = time_value or time.time() 83 | self.resume(time_value) 84 | if self.last_time is not None: 85 | self.decay_accum *= self.decay 86 | self.value *= self.decay 87 | self.value += (time_value - self.last_time) * (1 - self.decay) 88 | self.last_time = time_value 89 | self.count += 1 90 | return self.get() 91 | 92 | def pause(self, time_value=None): 93 | """Pause the timer. If the timer is already paused, this does nothing. 94 | 95 | Parameters 96 | ---------- 97 | time_value : float 98 | The time to pause the timer at. If None, the current time is used. 99 | """ 100 | time_value = time_value or time.time() 101 | if self.pause_time is None and self.last_time is not None: 102 | self.pause_time = time_value 103 | 104 | def resume(self, time_value=None): 105 | """Resume the timer. 106 | 107 | Parameters 108 | ---------- 109 | time_value : float 110 | The time to resume the timer at. If None, the current time is used. 111 | 112 | Returns 113 | ------- 114 | float 115 | The duration of the pause, or 0.0 if the timer was not paused. 116 | """ 117 | time_value = time_value or time.time() 118 | if self.pause_time is not None: 119 | pause_duration = time_value - self.pause_time 120 | self.last_time += pause_duration 121 | self.pause_time = None 122 | return pause_duration 123 | return 0.0 124 | 125 | 126 | def rb(x, y): 127 | """Prepare x for right broadcasting against y by inserting trailing axes. 128 | 129 | Ordinary JAX broadcasting is left broadcasting: if x has fewer axes than y, axes 130 | are inserted on the left (leading axes) to match the number of dimensions of y. This 131 | function inserts axes on the right (trailing axes) instead. 132 | 133 | Parameters 134 | ---------- 135 | x : jax.Array 136 | The array to insert trailing axes into. 137 | y : jax.Array 138 | The array to prepare x to be right broadcast against. 139 | 140 | Returns 141 | ------- 142 | jax.Array 143 | x, with trailing axes inserted to match the number of dimensions of y. 144 | 145 | Examples 146 | -------- 147 | >>> x = jnp.zeros((32,)) 148 | >>> y = jnp.zeros((32, 224, 224, 3)) 149 | >>> rb(x, y).shape 150 | (32, 1, 1, 1) 151 | """ 152 | axes_to_insert = y.ndim - x.ndim 153 | if axes_to_insert < 0: 154 | raise ValueError(f"x has {x.ndim} dims but y has {y.ndim}, which is fewer") 155 | return x[(...,) + (None,) * axes_to_insert] 156 | 157 | 158 | def split_by_process(key): 159 | """Splits a PRNG key, returning a different key in each JAX process.""" 160 | return jax.random.split(key, jax.process_count())[jax.process_index()] 161 | 162 | 163 | def tree_size(tree): 164 | """Return the number of elements in a tree.""" 165 | return sum(x.size for x in jax.tree_util.tree_leaves(tree)) 166 | -------------------------------------------------------------------------------- /jax_local_cluster.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """A simple JAX process launcher for multiple devices on a single host. 4 | 5 | You must import jax_local_cluster somewhere inside the script you are launching. 6 | """ 7 | 8 | import argparse 9 | from functools import partial 10 | import os 11 | import signal 12 | import socketserver 13 | from subprocess import Popen, TimeoutExpired 14 | import sys 15 | 16 | import jax 17 | import jax._src as _src 18 | 19 | error = partial(print, file=sys.stderr) 20 | 21 | 22 | class LocalCluster(_src.clusters.ClusterEnv): 23 | @classmethod 24 | def is_env_present(cls): 25 | return "JAX_COORDINATOR_ADDRESS" in os.environ 26 | 27 | @classmethod 28 | def get_coordinator_address(cls): 29 | return os.environ["JAX_COORDINATOR_ADDRESS"] 30 | 31 | @classmethod 32 | def get_process_count(cls): 33 | return int(os.environ["JAX_PROCESS_COUNT"]) 34 | 35 | @classmethod 36 | def get_process_id(cls): 37 | return int(os.environ["JAX_PROCESS_ID"]) 38 | 39 | @classmethod 40 | def get_local_process_id(cls): 41 | return int(os.environ["JAX_LOCAL_PROCESS_ID"]) 42 | 43 | 44 | def get_free_port(): 45 | with socketserver.TCPServer(("127.0.0.1", 0), None) as s: 46 | return s.server_address[1] 47 | 48 | 49 | def signal_and_wait(signum, procs, ctx, timeout=None): 50 | for proc in procs: 51 | proc.send_signal(signum) 52 | for i, proc in enumerate(procs): 53 | ctx["i"] = i 54 | proc.wait(timeout) 55 | ctx["i"] = None 56 | 57 | 58 | def interactive_shutdown(procs): 59 | ctx = {"i": None} 60 | try: 61 | signal_and_wait(signal.SIGINT, procs, ctx) 62 | except KeyboardInterrupt: 63 | try: 64 | error( 65 | f"Process {ctx['i']} (pid {procs[ctx['i']].pid}) did not exit on SIGINT, trying SIGTERM" 66 | ) 67 | signal_and_wait(signal.SIGTERM, procs, ctx, timeout=1) 68 | except (KeyboardInterrupt, TimeoutExpired): 69 | error( 70 | f"Process {ctx['i']} (pid {procs[ctx['i']].pid}) did not exit on SIGTERM, trying SIGKILL" 71 | ) 72 | for proc in procs: 73 | proc.kill() 74 | 75 | 76 | class TerminationHandler: 77 | def __init__(self, procs, verbose): 78 | self.procs = procs 79 | self.verbose = verbose 80 | self.was_called = False 81 | 82 | def __call__(self, signum, frame): 83 | self.was_called = True 84 | if self.verbose: 85 | error("SIGTERM received, shutting down") 86 | try: 87 | signal_and_wait(signal.SIGTERM, self.procs, {}, timeout=1) 88 | except TimeoutExpired: 89 | if self.verbose: 90 | error("SIGTERM timed out, sending SIGKILL") 91 | for proc in self.procs: 92 | proc.kill() 93 | raise KeyboardInterrupt 94 | 95 | 96 | def main(): 97 | p = argparse.ArgumentParser(description=__doc__) 98 | p.add_argument( 99 | "-n", 100 | type=int, 101 | default=0, 102 | help="Number of processes to launch (default: one per local device)", 103 | ) 104 | p.add_argument( 105 | "--port", 106 | type=int, 107 | default=0, 108 | help="Port to use for the coordinator (default: a free port)", 109 | ) 110 | p.add_argument("-v", "--verbose", action="store_true", help="Verbose output") 111 | p.add_argument("command", type=str, nargs=argparse.REMAINDER, help="Command to run") 112 | args = p.parse_args() 113 | 114 | if not args.command: 115 | p.print_help() 116 | sys.exit(1) 117 | 118 | n = args.n if args.n else jax.local_device_count() 119 | if args.verbose: 120 | error(f"Launching {n} processes") 121 | port = args.port if args.port else get_free_port() 122 | if args.verbose: 123 | error(f"Using port {port} for coordinator") 124 | 125 | procs = [] 126 | sigterm_handler = TerminationHandler(procs, args.verbose) 127 | signal.signal(signal.SIGTERM, sigterm_handler) 128 | 129 | try: 130 | for i in range(n): 131 | env = os.environ.copy() 132 | env.pop("OMPI_MCA_orte_hnp_uri", None) 133 | env.pop("SLURM_JOB_ID", None) 134 | env["JAX_COORDINATOR_ADDRESS"] = f"127.0.0.1:{port}" 135 | env["JAX_PROCESS_COUNT"] = str(n) 136 | env["JAX_PROCESS_ID"] = str(i) 137 | env["JAX_LOCAL_PROCESS_ID"] = str(i) 138 | proc = Popen(args.command, env=env) 139 | if args.verbose: 140 | error(f"Launched process {i} (pid {proc.pid})") 141 | procs.append(proc) 142 | for proc in procs: 143 | proc.wait() 144 | except KeyboardInterrupt: 145 | pass 146 | finally: 147 | if not sigterm_handler.was_called: 148 | interactive_shutdown(procs) 149 | if args.verbose: 150 | error("All processes terminated") 151 | 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /mnist.msgpack: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7b337a37a61eda5c4020dde549606b329203b7305b6ed6d0f7a1aded33e2e3f0 3 | size 55160111 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einshape 2 | flax 3 | jax 4 | git+https://github.com/crowsonkb/jax-wavelets 5 | numpy 6 | optax 7 | Pillow 8 | rich 9 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Trains a consistency model.""" 4 | 5 | import argparse 6 | import math 7 | from typing import Any 8 | 9 | from einshape import jax_einshape as einshape 10 | import flax 11 | import jax 12 | from jax.experimental.pjit import pjit 13 | from jax.experimental import mesh_utils, multihost_utils 14 | import jax.numpy as jnp 15 | from jax.tree_util import Partial 16 | import numpy as np 17 | import optax 18 | from PIL import Image 19 | from rich import print 20 | from rich.traceback import install 21 | 22 | import consistency_models as cm 23 | import jax_local_cluster 24 | 25 | 26 | class Normalize(flax.struct.PyTreeNode): 27 | mean: jax.Array 28 | std: jax.Array 29 | 30 | def forward(self, x): 31 | return (x - self.mean) / self.std 32 | 33 | def inverse(self, x): 34 | return x * self.std + self.mean 35 | 36 | def logdet_forward(self, x): 37 | return -jnp.sum(jnp.broadcast_arrays(x, jnp.log(self.std))[1]) 38 | 39 | def logdet_inverse(self, x): 40 | return jnp.sum(jnp.broadcast_arrays(x, jnp.log(self.std))[1]) 41 | 42 | 43 | class ModelState(flax.struct.PyTreeNode): 44 | params: Any 45 | params_ema: Any 46 | state: Any 47 | opt_state: Any 48 | 49 | 50 | class TrainState(flax.struct.PyTreeNode): 51 | step: jax.Array 52 | key: jax.Array 53 | state_s: ModelState 54 | state_c: ModelState 55 | 56 | 57 | def main(): 58 | install() 59 | parser = argparse.ArgumentParser(description=__doc__) 60 | parser.add_argument("--name", type=str, default="run", help="the run name") 61 | parser.add_argument("--seed", type=int, default=43292, help="the random seed") 62 | args = parser.parse_args() 63 | 64 | try: 65 | jax.distributed.initialize() 66 | except ValueError: 67 | pass 68 | 69 | key = jax.random.PRNGKey(args.seed) 70 | 71 | batch_size_per_device = 64 72 | batch_size_per_process = batch_size_per_device * jax.local_device_count() 73 | batch_size = batch_size_per_device * jax.device_count() 74 | 75 | if jax.process_index() == 0: 76 | print("Processes:", jax.process_count()) 77 | print("Devices:", jax.device_count()) 78 | print("Batch size per device:", batch_size_per_device) 79 | print("Batch size per process:", batch_size_per_process) 80 | print("Batch size:", batch_size, flush=True) 81 | 82 | devices = mesh_utils.create_device_mesh((jax.device_count(),)) 83 | mesh = jax.sharding.Mesh(devices, axis_names=("n",)) 84 | pspec_image = jax.sharding.PartitionSpec("n", None, None, None) 85 | 86 | target = {"train": (None, None), "val": (None, None)} 87 | d_ = flax.serialization.from_bytes(target, open("mnist.msgpack", "rb").read()) 88 | d_ = jnp.array(d_["train"][0]) 89 | dataset = d_ / 255 90 | size = dataset.shape[1:3] 91 | ch = dataset.shape[3] 92 | del d_ 93 | 94 | model_s = cm.model.TransformerModel( 95 | width=512, depth=8, wavelet_levels=2, dtype=jnp.bfloat16 96 | ) 97 | key, subkey = jax.random.split(key) 98 | variables = model_s.init(subkey, jnp.zeros([1, *size, ch]), jnp.zeros((1,))) 99 | state, params = variables.pop("params") 100 | mask = model_s.weight_decay_mask(params) 101 | opt_s = optax.adamw(2e-4, b2=0.99, weight_decay=1e-2, mask=mask) 102 | opt_state = opt_s.init(params) 103 | state_s = ModelState(params, jax.tree_map(jnp.copy, params), state, opt_state) 104 | if jax.process_index() == 0: 105 | print("Score model parameters:", cm.utils.tree_size(params)) 106 | del variables, state, params, opt_state 107 | 108 | model_c = cm.model.TransformerModel( 109 | width=512, depth=8, wavelet_levels=2, dtype=jnp.bfloat16 110 | ) 111 | key, subkey = jax.random.split(key) 112 | variables = model_c.init(subkey, jnp.zeros([1, *size, ch]), jnp.zeros((1,))) 113 | state, params = variables.pop("params") 114 | mask = model_c.weight_decay_mask(params) 115 | opt_c = optax.adamw(2e-4, b2=0.99, weight_decay=1e-2, mask=mask) 116 | opt_state = opt_c.init(params) 117 | state_c = ModelState(params, jax.tree_map(jnp.copy, params), state, opt_state) 118 | if jax.process_index() == 0: 119 | print("Consistency model parameters:", cm.utils.tree_size(params)) 120 | del variables, state, params, opt_state 121 | 122 | key, subkey = jax.random.split(key) 123 | state = TrainState( 124 | step=jnp.array(0, jnp.int32), 125 | key=subkey, 126 | state_s=state_s, 127 | state_c=state_c, 128 | ) 129 | del state_s, state_c 130 | 131 | # TODO: use decode() and encode() instead 132 | normalize = Normalize(jnp.array(0.5), jnp.array(0.25)) 133 | 134 | def ema_decay(step): 135 | return jnp.minimum(0.9999, 1 - (step + 1) ** -(2 / 3)) 136 | 137 | @Partial(pjit, in_axis_resources=(None, pspec_image), donate_argnums=0) 138 | def update(state, x): 139 | # Prepare data 140 | x = normalize.forward(x) 141 | 142 | # Sample timesteps and noise 143 | key, *keys = jax.random.split(state.key, 5) 144 | u = jax.random.uniform(keys[0], x.shape[:1]) 145 | # TODO: make this configurable 146 | t = jnp.tan((u * 0.998 + 0.001) * jnp.pi / 2) 147 | weight_fun = cm.cosine_weight 148 | noise = jax.random.normal(keys[1], x.shape) 149 | 150 | # Update score model 151 | def model_fun_s(params, xt, t): 152 | return model_s.apply( 153 | {"params": params, **state.state_s.state}, 154 | xt, 155 | t, 156 | deterministic=False, 157 | rngs={"dropout": keys[2]}, 158 | ) 159 | 160 | loss_s, grad_s = jax.value_and_grad(cm.score_matching_loss)( 161 | state.state_s.params, 162 | x, 163 | t, 164 | noise, 165 | model_fun=model_fun_s, 166 | weight_fun=weight_fun, 167 | ) 168 | updates, opt_state_s = opt_s.update( 169 | grad_s, state.state_s.opt_state, state.state_s.params 170 | ) 171 | params_s = optax.apply_updates(state.state_s.params, updates) 172 | params_s_ema = cm.utils.ema_update( 173 | state.state_s.params_ema, params_s, ema_decay(state.step) 174 | ) 175 | state_s = state.state_s.replace( 176 | params=params_s, params_ema=params_s_ema, opt_state=opt_state_s 177 | ) 178 | 179 | # Update consistency model 180 | def model_fun_c(params, xt, t): 181 | return model_c.apply( 182 | {"params": params, **state.state_c.state}, 183 | xt, 184 | t, 185 | deterministic=False, 186 | rngs={"dropout": keys[3]}, 187 | ) 188 | 189 | loss_c, grad_c = jax.value_and_grad(cm.consistency_loss)( 190 | state.state_c.params, 191 | x, 192 | t, 193 | noise, 194 | model_fun=model_fun_c, 195 | weight_fun=weight_fun, 196 | metric_fun=cm.l2_metric, 197 | teacher_fun=Partial(model_fun_s, params_s_ema), 198 | stopgrad=False, 199 | ) 200 | updates, opt_state_c = opt_c.update( 201 | grad_c, state.state_c.opt_state, state.state_c.params 202 | ) 203 | params_c = optax.apply_updates(state.state_c.params, updates) 204 | params_c_ema = cm.utils.ema_update( 205 | state.state_c.params_ema, params_c, ema_decay(state.step) 206 | ) 207 | state_c = state.state_c.replace( 208 | params=params_c, params_ema=params_c_ema, opt_state=opt_state_c 209 | ) 210 | 211 | # Assemble new training state 212 | state = state.replace( 213 | step=state.step + 1, key=key, state_s=state_s, state_c=state_c 214 | ) 215 | aux = {"loss_s": loss_s, "loss_c": loss_c} 216 | return state, aux 217 | 218 | @Partial(pjit, out_shardings=jax.sharding.PartitionSpec(None), static_argnums=2) 219 | def sample(state, key, n): 220 | tmax = 160.0 221 | 222 | key, subkey = jax.random.split(key, 2) 223 | xt = jax.random.normal(subkey, (n, *size, ch)) * tmax 224 | t = jnp.full((n,), tmax) 225 | x0 = model_c.apply( 226 | {"params": state.state_c.params_ema, **state.state_c.state}, xt, t 227 | ) 228 | 229 | x0 = normalize.inverse(x0) 230 | return x0 231 | 232 | def demo(state, key): 233 | rows = 10 234 | cols = 10 235 | n = rows * cols 236 | n_adj = math.ceil(n / jax.device_count()) * jax.device_count() 237 | x0 = sample(state, key, n_adj) 238 | with jax.spmd_mode("allow_all"): 239 | x0 = x0[:n] 240 | grid = einshape("(ab)hwc->(ah)(bw)c", x0, a=rows, b=cols) 241 | grid = np.array(jnp.round(jnp.clip(grid * 255, 0, 255)).astype(jnp.uint8)) 242 | if jax.process_index() == 0: 243 | if ch == 1: 244 | grid = grid[..., 0] 245 | Image.fromarray(grid).save(f"{args.name}_demo_{step:08}.png") 246 | print("📸 Output demo grid!", flush=True) 247 | 248 | step = state.step.item() 249 | perf_ctr = cm.utils.PerfCounter() 250 | 251 | @Partial(pjit, out_axis_resources=pspec_image) 252 | def select_from_dataset(key, dataset): 253 | idx = jax.random.choice(key, len(dataset), [batch_size]) 254 | return dataset[idx] 255 | 256 | try: 257 | while True: 258 | key, *keys = jax.random.split(key, 3) 259 | if step % 500 == 0: 260 | with mesh: 261 | demo(state, keys[0]) 262 | # TODO: implement saving 263 | # if step > 0 and step % 20000 == 0: 264 | # if jax.process_index() == 0: 265 | # save(train_state, key) 266 | with mesh: 267 | x = select_from_dataset(keys[1], dataset) 268 | state, aux = update(state, x) 269 | average_time = perf_ctr.update() 270 | if step % 25 == 0: 271 | if jax.process_index() == 0: 272 | print( 273 | f'step: {step}, loss_s: {aux["loss_s"].item():g}, loss_c: {aux["loss_c"].item():g}, {1 / average_time:g} it/s', 274 | flush=True, 275 | ) 276 | step += 1 277 | except KeyboardInterrupt: 278 | pass 279 | 280 | 281 | if __name__ == "__main__": 282 | main() 283 | --------------------------------------------------------------------------------