├── README.md ├── assets └── images │ ├── hatsune_miku(Anim-E).png │ ├── hatsune_miku(Craiyon).png │ ├── hatsune_miku.png │ ├── myname-vae-anime.png │ └── myname-vae.png ├── configs └── vqgan │ └── discriminator │ └── config.json ├── demo_vae.ipynb ├── requirements.txt ├── run_finetune_vae.py └── run_finetune_vqgan.py /README.md: -------------------------------------------------------------------------------- 1 | # fine tune models 2 | 3 | This repository contains the code for fine-tuning the following models: 4 | 5 | 1. VQGAN decoder of Craiyon (dalle-mini) 6 | 2. VAE decoder of stable-diffusion 7 | 8 | --- 9 | ## fine-tuning VQGAN decoder of Craiyon (dalle-mini) 10 | 11 | 12 | * Training code: [run_finetune_vqgan.py](run_finetune_vqgan.py) 13 | * Demo: [Anim·E](https://github.com/cccntu/anim_e) 14 | 15 | | Anim-E | Craiyon (formerly DALL-E mini) | 16 | :-------------------------:|:-------------------------: 17 | ![](assets/images/hatsune_miku(Anim-E).png) | ![](assets/images/hatsune_miku(Craiyon).png) 18 | 19 | * note: 20 | - If you try to reconstruct an anime-styled image with the VQGAN used by Craiyon (dalle-mini), you can see that the result is not good. 21 | - So, I fine-tuned the VQGAN decoder of Craiyon (dalle-mini) with the anime images, and the result is much better. See [Anim·E](https://github.com/cccntu/anim_e) 22 | 23 | 24 | 25 | --- 26 | 27 | ## fine-tuning VAE decoder of stable-diffusion 28 | 29 | * Training code: [run_finetune_vae.py](run_finetune_vae.py) 30 | * Demo: [demo_vae.ipynb](demo_vae.ipynb) 31 | 32 | | vae-anime reconstruction | vae reconstruction | 33 | :-------------------------:|:-------------------------: 34 | ![](assets/images/myname-vae-anime.png) | ![](assets/images/myname-vae.png) 35 | 36 | 37 | * note: 38 | - I fine-tuned at 256x256 resolution, although the VAE in stable-diffusion can handle 512x512 resolution. 39 | - I did this because the VAE is very good at reconstructing any images I tried, at 512x512 resolution. Excetp I noticed that the VAE sometimes struggles to reconstruct details of an image. Such as a smaller face in a larger image. (see demo notebook [demo_vae.ipynb](demo_vae.ipynb)). So I think fine-tuning it with images downsampled to 256x256 resolution can help it to reconstruct details better. 40 | - The result is not as impressive as [Anim·E](https://github.com/cccntu/anim_e). But I think it's because the unet diffusion model of stable-diffusion is not trained to generate anime-styled images. So it still struggle to generate the **latent** of anime-styled images in detail. 41 | 42 | ## Future work 43 | 44 | * fine-tune the unet diffusion model of stable-diffusion 45 | 46 | --- 47 | 48 | ## Note: Comparison between VQGAN in dalle-mini and VAE in stable-diffusion 49 | 50 | * dalle-mini: restricted to 256x256 resolution 51 | * stable-diffusion: no restriction on resolution 52 | ### At 256x256 resolution 53 | * dalle-mini: 256 tokens -> (1x16x16), with vocab size of 16384 (14 bits) 54 | * stable-diffusion: (4x32x32), -> 4096 "tokens" in float (32 bits or 16 bits, depending on the precision) 55 | 56 | In conclusion, the vae in stable-diffusion has (4096 * 32) / (256 * 14) >= **36x more information** than the vqgan in dalle-mini to reconstruct the same image at the same resolution. 57 | 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /assets/images/hatsune_miku(Anim-E).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccntu/fine-tune-models/09b059c0e895c7f89fdbc81f97f7f96620e0d889/assets/images/hatsune_miku(Anim-E).png -------------------------------------------------------------------------------- /assets/images/hatsune_miku(Craiyon).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccntu/fine-tune-models/09b059c0e895c7f89fdbc81f97f7f96620e0d889/assets/images/hatsune_miku(Craiyon).png -------------------------------------------------------------------------------- /assets/images/hatsune_miku.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccntu/fine-tune-models/09b059c0e895c7f89fdbc81f97f7f96620e0d889/assets/images/hatsune_miku.png -------------------------------------------------------------------------------- /assets/images/myname-vae-anime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccntu/fine-tune-models/09b059c0e895c7f89fdbc81f97f7f96620e0d889/assets/images/myname-vae-anime.png -------------------------------------------------------------------------------- /assets/images/myname-vae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccntu/fine-tune-models/09b059c0e895c7f89fdbc81f97f7f96620e0d889/assets/images/myname-vae.png -------------------------------------------------------------------------------- /configs/vqgan/discriminator/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_features": 32, 3 | "image_size": [ 4 | 256, 5 | 256 6 | ], 7 | "max_hidden_feature_size": 512, 8 | "mbstd_group_size": 4, 9 | "mbstd_num_features": 1 10 | } 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # My code requires python 3.9 for dictionary union operator 2 | # lpips requires python >= 3.8 3 | stable_diffusion_jax @ git+https://github.com/patil-suraj/stable-diffusion-jax.git@47297f53bb4907f119079654310bfb14134c2714 4 | vqgan-jax @ git+https://github.com/patil-suraj/vqgan-jax.git@10ef240f8ace869e437f3c32d14898f61512db12 5 | vit-vqgan @ git+https://github.com/patil-suraj/vit-vqgan.git@6dce733329541129f0d60cdce2487a340e726abf 6 | lpips-j @ git+https://github.com/pcuenca/lpips-j.git@346edee27d373d4b19265e33cb588ca17a189cb1 7 | stable_diffusion_jax @ git+https://github.com/patil-suraj/stable-diffusion-jax.git@47297f53bb4907f119079654310bfb14134c2714 8 | datasets~=2.4.0 9 | flax~=0.5.3 10 | optax~=0.1.3 11 | Pillow~=9.2.0 12 | wandb 13 | # install these yourself 14 | #torch 15 | #torchvision 16 | #jax==0.3.16 17 | # diffusers # this is required because stable-diffusion-jax imports it 18 | # jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 19 | -------------------------------------------------------------------------------- /run_finetune_vae.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # credits: 3 | # Flax code is adapted from https://github.com/huggingface/transformers/blob/main/examples/flax/vision/run_image_classification.py 4 | # GAN related code are adapted from https://github.com/patil-suraj/vit-vqgan/ 5 | import inspect 6 | import os 7 | from functools import partial 8 | 9 | # os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5' 10 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 11 | # cuda 12 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 13 | 14 | from copy import deepcopy 15 | from pathlib import Path 16 | 17 | import flax.linen as nn 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | import optax 22 | import torch 23 | import torchvision.transforms as T 24 | from datasets import Dataset as HFDataset 25 | from flax import jax_utils 26 | from flax.jax_utils import pad_shard_unpad, unreplicate 27 | from flax.serialization import from_bytes, to_bytes 28 | from flax.training import train_state 29 | from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key 30 | from lpips_j.lpips import LPIPS 31 | from PIL import Image 32 | from stable_diffusion_jax import AutoencoderKL 33 | from torch.utils.data import DataLoader 34 | from tqdm import tqdm 35 | from vit_vqgan import StyleGANDiscriminator, StyleGANDiscriminatorConfig 36 | 37 | import wandb 38 | 39 | # %% 40 | # since we don't fine-tune the encoder: we don't have kl loss 41 | kl_loss = False # changin this have no effect 42 | # It's not clear if sampling from the distribution is better than using the mean 43 | # For simplicity, we use the mean 44 | sample_from_distribution = False # changing this have no effect 45 | # %% 46 | # paths and configs 47 | wandb.init(project="vae") 48 | 49 | learning_rate = 1e-4 50 | gradient_accumulation_steps = 1 51 | warmup_steps = 4000 * gradient_accumulation_steps 52 | log_steps = 1 * gradient_accumulation_steps 53 | eval_steps = 100 * gradient_accumulation_steps 54 | log_steps = 10 * gradient_accumulation_steps 55 | eval_steps = 100 * gradient_accumulation_steps 56 | total_steps = 150_000 * gradient_accumulation_steps 57 | # skip disc loss for the first 1000 steps, because discriminator is not trained yet 58 | disc_loss_skip_steps = 1000 * gradient_accumulation_steps 59 | 60 | # model = VQModel.from_pretrained("dalle-mini/vqgan_imagenet_f16_16384") 61 | data_root = "/disks" 62 | # a huggingface dataset containing columns "path" and optionally "indices" 63 | # path: can be absolute or relative to `data_root` 64 | # indices: VQ indices of the image at `path` 65 | hfds = HFDataset.from_json("danbooru_image_paths_ds.json") 66 | 67 | # this corresponds to a local dir containing the config.json file 68 | # the config.json file is copied from https://github.com/patil-suraj/vit-vqgan/ 69 | disc_config_path = "configs/vqgan/discriminator/config.json" 70 | 71 | output_dir = Path("output-dir-vae") 72 | output_dir.mkdir(exist_ok=True) 73 | 74 | # the empereically observed values from initial runs, we will scale them closer to the scale of l2 loss 75 | scale_l2 = 0.001 76 | scale_lpips = 0.25 77 | # adjust scale to make the loss comparable to l2 loss 78 | cost_l2 = 0.5 79 | cost_lpips = scale_l2 / scale_lpips * 5 80 | cost_gradient_penalty = 100000000 # this follows vit-vqgan repo 81 | cost_disc = 0.005 82 | 83 | # %% 84 | # convert the weight to jax first, see: 85 | # https://github.com/patil-suraj/stable-diffusion-jax/blob/47297f53bb4907f119079654310bfb14134c2714/example.py#L23 86 | fx_path = Path.home() / "models/stable-diffusion-v1-4-jax" 87 | vae, vae_params = AutoencoderKL.from_pretrained(f"{fx_path}/vae", _do_init=False) 88 | # default to float 32, I don't care 89 | 90 | # %% 91 | model = vae 92 | original_params = deepcopy(vae_params) 93 | # %% 94 | vae_params.keys() 95 | # %% 96 | class EncoderImageDataset(torch.utils.data.Dataset): 97 | # this class was originally used to preprocess images into VQ indices 98 | # now we only use its load() method, and preprocess images into VQ indices on the fly 99 | def __init__(self, df, shape=(256, 256)): 100 | self.df = df 101 | self.shape = shape 102 | 103 | def __len__(self): 104 | return len(self.df) 105 | 106 | @staticmethod 107 | def load(path): 108 | img = Image.open(path).convert("RGB").resize((256, 256)) 109 | img = torch.unsqueeze(T.ToTensor()(img), 0) 110 | return img.permute(0, 2, 3, 1).numpy() 111 | 112 | def __getitem__(self, idx): 113 | row = self.df.iloc[idx] 114 | path = row["resized_path"] 115 | return self.load(path) 116 | 117 | 118 | class DecoderImageDataset(torch.utils.data.Dataset): 119 | def __init__(self, hfds, root=None): 120 | """hdfs: HFDataset""" 121 | self.hfds = hfds 122 | self.root = root 123 | 124 | def __len__(self): 125 | return len(self.hfds) 126 | 127 | def __getitem__(self, idx): 128 | example = self.hfds[idx] 129 | # indices = example["indices"] 130 | path = example["path"] 131 | if self.root is not None: 132 | path = os.path.join(self.root, path.lstrip("/")) 133 | orig_arr = EncoderImageDataset.load(path) 134 | return { 135 | # "indices": indices, 136 | "original": orig_arr, 137 | "name": Path(path).name, 138 | } 139 | 140 | @staticmethod 141 | def collate_fn(examples, return_names=False): 142 | res = { 143 | # "indices": [example["indices"] for example in examples], 144 | "original": np.concatenate( 145 | [example["original"] for example in examples], axis=0 146 | ), 147 | } 148 | if return_names: 149 | res["name"] = [example["name"] for example in examples] 150 | return res 151 | 152 | 153 | def try_batch_size(fn, start_batch_size=1): 154 | # try batch size 155 | batch_size = start_batch_size 156 | while True: 157 | try: 158 | print(f"Trying batch size {batch_size}") 159 | fn(batch_size * 2) 160 | batch_size *= 2 161 | except Exception as e: 162 | return batch_size 163 | 164 | 165 | def get_param_counts(params): 166 | param_counts = [k.size for k in jax.tree_util.tree_leaves(params)] 167 | param_counts = sum(param_counts) 168 | return param_counts 169 | 170 | 171 | def get_training_params(): 172 | keys = ["decoder", "post_quant_conv", "quantize"] 173 | decoder_params = {k: v for k, v in original_params.items() if k in keys} 174 | return deepcopy(decoder_params) 175 | 176 | 177 | # %% 178 | for k, v in original_params.items(): 179 | print(k, get_param_counts(v) / 1e6) 180 | # %% 181 | disc_config = StyleGANDiscriminatorConfig.from_pretrained(disc_config_path) 182 | disc_model = StyleGANDiscriminator( 183 | disc_config, 184 | seed=42, 185 | _do_init=True, 186 | ) 187 | lpips_fn = LPIPS() 188 | 189 | 190 | def init_lpips(rng, image_size): 191 | x = jax.random.normal(rng, shape=(1, image_size, image_size, 3)) 192 | return lpips_fn.init(rng, x, x) 193 | 194 | 195 | # %% 196 | 197 | # encoder_params = {k: v for k, v in params.items() if k not in keys} 198 | rng = jax.random.PRNGKey(0) 199 | rng, dropout_rng = jax.random.split(rng) 200 | 201 | lpips_params = init_lpips(rng, image_size=256) 202 | params = get_training_params() 203 | 204 | warmup_fn = optax.linear_schedule( 205 | init_value=0.0, 206 | end_value=learning_rate, 207 | transition_steps=warmup_steps + 1, # ensure not 0 208 | ) 209 | decay_fn = optax.linear_schedule( 210 | init_value=learning_rate, 211 | end_value=0, 212 | transition_steps=total_steps - warmup_steps, 213 | ) 214 | schedule_fn = optax.join_schedules( 215 | schedules=[warmup_fn, decay_fn], 216 | boundaries=[warmup_steps], 217 | ) 218 | 219 | disc_loss_skip_schedule = optax.join_schedules( 220 | schedules=[ 221 | optax.constant_schedule(0), 222 | optax.constant_schedule(1), 223 | ], 224 | boundaries=[disc_loss_skip_steps], 225 | ) 226 | optimizer = optax.adamw(learning_rate=schedule_fn) 227 | # discriminator_optimizer 228 | optimizer_disc = optax.adamw(learning_rate=schedule_fn) 229 | # gradient accumulation for main optimizer 230 | optimizer = optax.MultiSteps(optimizer, gradient_accumulation_steps) 231 | 232 | # Setup train state 233 | class TrainState(train_state.TrainState): 234 | dropout_rng: jnp.ndarray 235 | 236 | def replicate(self): 237 | return jax_utils.replicate(self).replace( 238 | dropout_rng=shard_prng_key(self.dropout_rng) 239 | ) 240 | 241 | 242 | state = TrainState.create( 243 | apply_fn=model.decode_code, 244 | params=jax.device_put(params), 245 | tx=optimizer, 246 | dropout_rng=dropout_rng, 247 | ) 248 | state_disc = TrainState.create( 249 | apply_fn=disc_model, 250 | params=jax.device_put(disc_model.params), 251 | tx=optimizer_disc, 252 | dropout_rng=dropout_rng, 253 | ) 254 | 255 | loss_fn = optax.l2_loss 256 | 257 | # 258 | def reconstruct(params_with_encoder, params_with_decoder, original, train=False): 259 | distribution = vae.encode(original, params=params_with_encoder) 260 | latent = distribution.mode() 261 | reconstruction = model.decode(latent, params_with_decoder, train=train) 262 | return reconstruction 263 | 264 | 265 | def train_step(state, batch, state_disc): 266 | """Returns new_state, metrics, reconstruction""" 267 | dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) 268 | 269 | def compute_loss(params, batch, dropout_rng, train=True): 270 | original = batch["original"] 271 | reconstruction = reconstruct(original_params, params, original, train=train) 272 | loss_l2 = loss_fn(reconstruction, original).mean() 273 | disc_fake_scores = state_disc.apply_fn( 274 | reconstruction, 275 | params=state_disc.params, 276 | dropout_rng=dropout_rng, 277 | train=train, 278 | ) 279 | loss_disc = jnp.mean(nn.softplus(-disc_fake_scores)) 280 | loss_lpips = jnp.mean(lpips_fn.apply(lpips_params, original, reconstruction)) 281 | 282 | loss = ( 283 | loss_l2 * cost_l2 284 | + loss_lpips * cost_lpips 285 | + loss_disc * cost_disc * disc_loss_skip_schedule(state.step) 286 | ) 287 | loss_details = { 288 | "loss_l2": loss_l2 * cost_l2, 289 | "loss_lpips": loss_lpips * cost_lpips, 290 | "loss_disc": loss_disc * cost_disc, 291 | } 292 | return loss, (loss_details, reconstruction) 293 | 294 | grad_fn = jax.value_and_grad(compute_loss, has_aux=True) 295 | (loss, (loss_details, reconstruction)), grad = grad_fn( 296 | state.params, batch, dropout_rng, train=True 297 | ) 298 | # legacy code, I didn't use multi gpu 299 | # grad = jax.lax.pmean(grad, "batch") 300 | 301 | new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) 302 | 303 | metrics = loss_details | {"learning_rate": schedule_fn(state.step)} 304 | # metrics = jax.lax.pmean(metrics, axis_name="batch") 305 | return new_state, metrics, reconstruction 306 | 307 | 308 | # %% 309 | def compute_stylegan_loss( 310 | disc_params, batch, fake_images, dropout_rng, disc_model_fn, train 311 | ): 312 | disc_fake_scores = disc_model_fn( 313 | fake_images, params=disc_params, dropout_rng=dropout_rng, train=train 314 | ) 315 | disc_real_scores = disc_model_fn( 316 | batch, params=disc_params, dropout_rng=dropout_rng, train=train 317 | ) 318 | # -log sigmoid(f(x)) = log (1 + exp(-f(x))) = softplus(-f(x)) 319 | # -log(1-sigmoid(f(x))) = log (1 + exp(f(x))) = softplus(f(x)) 320 | # https://github.com/pfnet-research/sngan_projection/issues/18#issuecomment-392683263 321 | loss_real = nn.softplus(-disc_real_scores) 322 | loss_fake = nn.softplus(disc_fake_scores) 323 | disc_loss_stylegan = jnp.mean(loss_real + loss_fake) 324 | 325 | # gradient penalty r1: https://github.com/NVlabs/stylegan2/blob/bf0fe0baba9fc7039eae0cac575c1778be1ce3e3/training/loss.py#L63-L67 326 | r1_grads = jax.grad( 327 | lambda x: jnp.mean( 328 | disc_model_fn(x, params=disc_params, dropout_rng=dropout_rng, train=train) 329 | ) 330 | )(batch) 331 | # get the squares of gradients 332 | r1_grads = jnp.mean(r1_grads**2) 333 | 334 | disc_loss = disc_loss_stylegan + cost_gradient_penalty * r1_grads 335 | disc_loss_details = { 336 | "pred_p_real": jnp.exp(-loss_real).mean(), # p = 1 -> predict real is real 337 | "pred_p_fake": jnp.exp(-loss_fake).mean(), # p = 1 -> predict fake is fake 338 | "loss_real": loss_real.mean(), 339 | "loss_fake": loss_fake.mean(), 340 | "loss_stylegan": disc_loss_stylegan, 341 | "loss_gradient_penalty": cost_gradient_penalty * r1_grads, 342 | "loss": disc_loss, 343 | } 344 | return disc_loss, disc_loss_details 345 | 346 | 347 | train_compute_stylegan_loss = partial(compute_stylegan_loss, train=True) 348 | grad_stylegan_fn = jax.value_and_grad(train_compute_stylegan_loss, has_aux=True) 349 | 350 | 351 | def train_step_disc(state_disc, batch, fake_images): 352 | dropout_rng, new_dropout_rng = jax.random.split(state_disc.dropout_rng) 353 | # convert fake images to int then back to float, so discriminator can't cheat 354 | dtype = fake_images.dtype 355 | fake_images = (fake_images.clip(0, 1) * 255).astype(jnp.uint8).astype(dtype) / 255 356 | (disc_loss, disc_loss_details), disc_grads = grad_stylegan_fn( 357 | state_disc.params, 358 | batch, 359 | fake_images, 360 | dropout_rng, 361 | disc_model, 362 | ) 363 | new_state = state_disc.apply_gradients( 364 | grads=disc_grads, dropout_rng=new_dropout_rng 365 | ) 366 | metrics = disc_loss_details | {"learning_rate_disc": schedule_fn(state_disc.step)} 367 | # metrics = jax.lax.pmean(metrics, axis_name="batch") 368 | return new_state, metrics 369 | 370 | 371 | # %% 372 | # Take the first 100 images as validation set 373 | train_ds = DecoderImageDataset(hfds.select(range(100, len(hfds))), root=data_root) 374 | test_ds = DecoderImageDataset(hfds.select(range(100)), root=data_root) 375 | # %% 376 | jit_train_step = jax.jit(train_step) 377 | jit_train_step_disc = jax.jit(train_step_disc) 378 | # %% 379 | def try_train_batch_size_fn(batch_size): 380 | example = train_ds[0] 381 | batch = train_ds.collate_fn([example] * batch_size) 382 | new_state, metrics, reconstruction = jit_train_step(state, batch, state_disc) 383 | new_state, metrics = jit_train_step_disc( 384 | state_disc, batch["original"], reconstruction 385 | ) 386 | return 387 | 388 | 389 | # this takes about 20 GB of memory, adjust batch size accordingly for your GPU 390 | train_batch_size = 8 391 | state = jax.device_put(state, jax.devices()[0]) 392 | train_batch_size = try_batch_size( 393 | try_train_batch_size_fn, start_batch_size=train_batch_size 394 | ) 395 | print(f"Training batch size: {train_batch_size}") 396 | # %% 397 | # %% 398 | # %% 399 | # try it again, make sure there is no error 400 | try_train_batch_size_fn(train_batch_size) 401 | print(f"Training batch size: {train_batch_size}") 402 | 403 | # %% 404 | wandb.log({"train_dataset_size": len(train_ds)}) 405 | # %% 406 | 407 | dataloader = DataLoader( 408 | train_ds, 409 | batch_size=train_batch_size, 410 | shuffle=True, 411 | collate_fn=partial(DecoderImageDataset.collate_fn, return_names=False), 412 | num_workers=4, 413 | drop_last=True, 414 | prefetch_factor=4, 415 | persistent_workers=True, 416 | ) 417 | 418 | # %% 419 | # recreate states, because we tried training them before 420 | state = TrainState.create( 421 | apply_fn=model.decode_code, 422 | params=jax.device_put(params), 423 | tx=optimizer, 424 | dropout_rng=dropout_rng, 425 | ) 426 | state_disc = TrainState.create( 427 | apply_fn=disc_model, 428 | params=jax.device_put(disc_model.params), 429 | tx=optimizer_disc, 430 | dropout_rng=dropout_rng, 431 | ) 432 | state = jax.device_put(state, jax.devices()[0]) 433 | state_disc = jax.device_put(state_disc, jax.devices()[0]) 434 | # %% 435 | # data loader without shuffle, so we can see the progress on the same images 436 | train_dl_eval = DataLoader( 437 | train_ds, 438 | batch_size=train_batch_size, 439 | shuffle=False, 440 | collate_fn=partial(DecoderImageDataset.collate_fn, return_names=True), 441 | num_workers=4, 442 | drop_last=True, 443 | prefetch_factor=4, 444 | persistent_workers=True, 445 | ) 446 | test_dl = DataLoader( 447 | test_ds, 448 | batch_size=train_batch_size, 449 | shuffle=False, 450 | collate_fn=partial(DecoderImageDataset.collate_fn, return_names=True), 451 | num_workers=4, 452 | drop_last=True, 453 | prefetch_factor=4, 454 | persistent_workers=True, 455 | ) 456 | # %% 457 | # evaluation functions 458 | @jax.jit 459 | def infer_fn(batch, state): 460 | original = batch["original"] 461 | reconstruction = reconstruct(original_params, state.params, original) 462 | return reconstruction 463 | 464 | 465 | def evaluate(use_tqdm=False, step=None): 466 | losses = [] 467 | iterable = test_dl if not use_tqdm else tqdm(test_dl) 468 | for batch in iterable: 469 | name = batch.pop("name") 470 | reconstruction = infer_fn(batch, state) 471 | losses.append(loss_fn(reconstruction, batch["original"]).mean()) 472 | loss = np.mean(jax.device_get(losses)) 473 | wandb.log({"test_loss": loss, "step": step}) 474 | 475 | 476 | def postpro(decoded_images): 477 | """util function to postprocess images""" 478 | decoded_images = decoded_images.clip(0.0, 1.0) # .reshape((-1, 256, 256, 3)) 479 | return [ 480 | Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8)) 481 | for decoded_img in decoded_images 482 | ] 483 | 484 | 485 | def log_images(dl, num_images=8, suffix="", step=None): 486 | logged_images = 0 487 | 488 | def batch_gen(): 489 | while True: 490 | for batch in dl: 491 | yield batch 492 | 493 | batch_iter = batch_gen() 494 | while logged_images < num_images: 495 | batch = next(batch_iter) 496 | 497 | names = batch.pop("name") 498 | reconstruction = infer_fn(batch, state) 499 | left_right = np.concatenate([batch["original"], reconstruction], axis=2) 500 | 501 | images = postpro(left_right) 502 | for name, image in zip(names, images): 503 | wandb.log( 504 | {f"{name}{suffix}": wandb.Image(image, caption=name), "step": step} 505 | ) 506 | logged_images += len(images) 507 | 508 | 509 | def log_test_images(num_images=8, step=None): 510 | return log_images(dl=test_dl, num_images=num_images, step=step) 511 | 512 | 513 | def log_train_images(num_images=8, step=None): 514 | return log_images( 515 | dl=train_dl_eval, num_images=num_images, suffix="|train", step=step 516 | ) 517 | 518 | 519 | def data_iter(): 520 | while True: 521 | for batch in dataloader: 522 | yield batch 523 | 524 | 525 | # %% 526 | for steps, batch in zip(tqdm(range(total_steps)), data_iter()): 527 | state, metrics, reconstruction = jit_train_step(state, batch, state_disc) 528 | state_disc, metrics_disc = jit_train_step_disc( 529 | state_disc, batch["original"], reconstruction 530 | ) 531 | # metrics = metrics | metrics_disc 532 | metrics["disc_step"] = metrics_disc 533 | metrics["step"] = steps 534 | if steps % log_steps == 1: 535 | wandb.log(metrics) 536 | if steps % eval_steps == 1: 537 | evaluate(step=steps) 538 | log_test_images(step=steps) 539 | log_train_images(step=steps) 540 | with Path(output_dir / "latest_state_disc.msgpack").open("wb") as f: 541 | f.write(to_bytes(jax.device_get(state_disc))) 542 | with Path(output_dir / "latest_state.msgpack").open("wb") as f: 543 | f.write(to_bytes(jax.device_get(state))) 544 | 545 | # how to use the model 546 | 547 | """ 548 | # load the model to stable_diffusion_jax 549 | # https://github.com/patil-suraj/stable-diffusion-jax/tree/main/stable_diffusion_jax 550 | from stable_diffusion_jax.convertkk_diffusers_to_jax import convert_diffusers_to_jax 551 | from stable_diffusion_jax import AutoencoderKL 552 | from pathlib import Path 553 | pt_path = Path.home()/"models/stable-diffusion-v1-4" 554 | fx_path = Path.home()/"models/stable-diffusion-v1-4-jax" 555 | 556 | #convert_diffusers_to_jax(pt_path, fx_path) 557 | # %% 558 | # inference with jax 559 | dtype = jnp.bfloat16 560 | vae, vae_params = AutoencoderKL.from_pretrained(f"{fx_path}/vae", _do_init=False, dtype=dtype) 561 | 562 | # %% 563 | from flax.serialization import msgpack_restore 564 | 565 | weight_dir = Path('.') 566 | path = weight_dir/'latest_state.msgpack' 567 | with open(path, "rb") as f: 568 | state_dict = msgpack_restore(f.read()) 569 | state_dict.keys() 570 | # %% 571 | from copy import deepcopy 572 | 573 | new_params = deepcopy(vae_params) 574 | for k, v in state_dict['params'].items(): 575 | if k in new_params: 576 | new_params[k] = v 577 | vae.save_pretrained(f"{fx_path}/vae-anime", params=new_params) 578 | # after this, you can use the model in stable-diffusion-jax, as: 579 | # vae, vae_params = AutoencoderKL.from_pretrained(f"{fx_path}/vae-anime", _do_init=False, dtype=dtype) 580 | 581 | """ 582 | -------------------------------------------------------------------------------- /run_finetune_vqgan.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # credits: 3 | # Flax code is adapted from https://github.com/huggingface/transformers/blob/main/examples/flax/vision/run_image_classification.py 4 | # GAN related code are adapted from https://github.com/patil-suraj/vit-vqgan/ 5 | import inspect 6 | import os 7 | from functools import partial 8 | 9 | # os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5' 10 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 11 | 12 | from copy import deepcopy 13 | from pathlib import Path 14 | 15 | import flax.linen as nn 16 | import jax 17 | import jax.numpy as jnp 18 | import numpy as np 19 | import optax 20 | import torch 21 | import torchvision.transforms as T 22 | from datasets import Dataset as HFDataset 23 | from flax import jax_utils 24 | from flax.jax_utils import pad_shard_unpad, unreplicate 25 | from flax.serialization import from_bytes, to_bytes 26 | from flax.training import train_state 27 | from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key 28 | from lpips_j.lpips import LPIPS 29 | from PIL import Image 30 | from torch.utils.data import DataLoader 31 | from tqdm import tqdm 32 | from vit_vqgan import StyleGANDiscriminator, StyleGANDiscriminatorConfig 33 | from vqgan_jax.modeling_flax_vqgan import Decoder, VQModel 34 | 35 | import wandb 36 | 37 | # %% 38 | # paths and configs 39 | wandb.init(project="vqgan") 40 | 41 | learning_rate = 1e-4 42 | gradient_accumulation_steps = 1 43 | warmup_steps = 4000 * gradient_accumulation_steps 44 | log_steps = 1 * gradient_accumulation_steps 45 | eval_steps = 100 * gradient_accumulation_steps 46 | log_steps = 10 * gradient_accumulation_steps 47 | eval_steps = 100 * gradient_accumulation_steps 48 | total_steps = 150_000 * gradient_accumulation_steps 49 | 50 | model = VQModel.from_pretrained("dalle-mini/vqgan_imagenet_f16_16384") 51 | data_root = "/disks" 52 | # a huggingface dataset containing columns "path" and optionally "indices" 53 | # path: can be absolute or relative to `data_root` 54 | # indices: VQ indices of the image at `path` 55 | hfds = HFDataset.from_json("danbooru_image_paths_ds.json") 56 | 57 | # this corresponds to a local dir containing the config.json file 58 | # the config.json file is copied from https://github.com/patil-suraj/vit-vqgan/ 59 | disc_config_path = "configs/vqgan/discriminator/config.json" 60 | 61 | output_dir = Path("output-dir") 62 | output_dir.mkdir(exist_ok=True) 63 | 64 | # the empereically observed values from initial runs, we will scale them closer to the scale of l2 loss 65 | scale_l2 = 0.001 66 | scale_lpips = 0.25 67 | # adjust scale to make the loss comparable to l2 loss 68 | cost_l2 = 0.5 69 | cost_lpips = scale_l2 / scale_lpips * 5 70 | cost_gradient_penalty = 100000000 # this follows vit-vqgan repo 71 | cost_disc = 0.005 72 | 73 | # %% 74 | class EncoderImageDataset(torch.utils.data.Dataset): 75 | # this class was originally used to preprocess images into VQ indices 76 | # now we only use its load() method, and preprocess images into VQ indices on the fly 77 | def __init__(self, df, shape=(256, 256)): 78 | self.df = df 79 | self.shape = shape 80 | 81 | def __len__(self): 82 | return len(self.df) 83 | 84 | @staticmethod 85 | def load(path): 86 | img = Image.open(path).convert("RGB").resize((256, 256)) 87 | img = torch.unsqueeze(T.ToTensor()(img), 0) 88 | return img.permute(0, 2, 3, 1).numpy() 89 | 90 | def __getitem__(self, idx): 91 | row = self.df.iloc[idx] 92 | path = row["resized_path"] 93 | return self.load(path) 94 | 95 | 96 | class DecoderImageDataset(torch.utils.data.Dataset): 97 | def __init__(self, hfds, root=None): 98 | """hdfs: HFDataset""" 99 | self.hfds = hfds 100 | self.root = root 101 | 102 | def __len__(self): 103 | return len(self.hfds) 104 | 105 | def __getitem__(self, idx): 106 | example = self.hfds[idx] 107 | # indices = example["indices"] 108 | path = example["path"] 109 | if self.root is not None: 110 | path = os.path.join(self.root, path.lstrip("/")) 111 | orig_arr = EncoderImageDataset.load(path) 112 | return { 113 | # "indices": indices, 114 | "original": orig_arr, 115 | "name": Path(path).name, 116 | } 117 | 118 | @staticmethod 119 | def collate_fn(examples, return_names=False): 120 | res = { 121 | # "indices": [example["indices"] for example in examples], 122 | "original": np.concatenate( 123 | [example["original"] for example in examples], axis=0 124 | ), 125 | } 126 | if return_names: 127 | res["name"] = [example["name"] for example in examples] 128 | return res 129 | 130 | 131 | def try_batch_size(fn, start_batch_size=1): 132 | # try batch size 133 | batch_size = start_batch_size 134 | while True: 135 | try: 136 | print(f"Trying batch size {batch_size}") 137 | fn(batch_size * 2) 138 | batch_size *= 2 139 | except Exception as e: 140 | return batch_size 141 | 142 | 143 | def get_param_counts(params): 144 | param_counts = [k.size for k in jax.tree_util.tree_leaves(params)] 145 | param_counts = sum(param_counts) 146 | return param_counts 147 | 148 | 149 | def get_training_params(): 150 | keys = ["decoder", "post_quant_conv", "quantize"] 151 | decoder_params = {k: v for k, v in original_vqparams.items() if k in keys} 152 | return deepcopy(decoder_params) 153 | 154 | 155 | # %% 156 | original_vqparams = deepcopy(model.params) 157 | for k, v in original_vqparams.items(): 158 | print(k, get_param_counts(v) / 1e6) 159 | # %% 160 | disc_config = StyleGANDiscriminatorConfig.from_pretrained(disc_config_path) 161 | disc_model = StyleGANDiscriminator( 162 | disc_config, 163 | seed=42, 164 | _do_init=True, 165 | ) 166 | lpips_fn = LPIPS() 167 | 168 | 169 | def init_lpips(rng, image_size): 170 | x = jax.random.normal(rng, shape=(1, image_size, image_size, 3)) 171 | return lpips_fn.init(rng, x, x) 172 | 173 | 174 | # %% 175 | 176 | # encoder_params = {k: v for k, v in params.items() if k not in keys} 177 | rng = jax.random.PRNGKey(0) 178 | rng, dropout_rng = jax.random.split(rng) 179 | 180 | lpips_params = init_lpips(rng, image_size=256) 181 | params = get_training_params() 182 | 183 | warmup_fn = optax.linear_schedule( 184 | init_value=0.0, 185 | end_value=learning_rate, 186 | transition_steps=warmup_steps + 1, # ensure not 0 187 | ) 188 | decay_fn = optax.linear_schedule( 189 | init_value=learning_rate, 190 | end_value=0, 191 | transition_steps=total_steps - warmup_steps, 192 | ) 193 | schedule_fn = optax.join_schedules( 194 | schedules=[warmup_fn, decay_fn], 195 | boundaries=[warmup_steps], 196 | ) 197 | 198 | optimizer = optax.adamw(learning_rate=schedule_fn) 199 | # discriminator_optimizer 200 | optimizer_disc = optax.adamw(learning_rate=schedule_fn) 201 | # gradient accumulation for main optimizer 202 | optimizer = optax.MultiSteps(optimizer, gradient_accumulation_steps) 203 | 204 | # Setup train state 205 | class TrainState(train_state.TrainState): 206 | dropout_rng: jnp.ndarray 207 | 208 | def replicate(self): 209 | return jax_utils.replicate(self).replace( 210 | dropout_rng=shard_prng_key(self.dropout_rng) 211 | ) 212 | 213 | 214 | state = TrainState.create( 215 | apply_fn=model.decode_code, 216 | params=jax.device_put(params), 217 | tx=optimizer, 218 | dropout_rng=dropout_rng, 219 | ) 220 | state_disc = TrainState.create( 221 | apply_fn=disc_model, 222 | params=jax.device_put(disc_model.params), 223 | tx=optimizer_disc, 224 | dropout_rng=dropout_rng, 225 | ) 226 | 227 | loss_fn = optax.l2_loss 228 | 229 | 230 | def train_step(state, batch, state_disc): 231 | """Returns new_state, metrics, reconstruction""" 232 | dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) 233 | 234 | def compute_loss(params, batch, dropout_rng, train=True): 235 | original = batch["original"] 236 | if "indices" in batch: 237 | indices = batch["indices"] 238 | else: 239 | quant_states, indices = model.encode(original) 240 | reconstruction = state.apply_fn(indices, params=params) 241 | loss_l2 = loss_fn(reconstruction, original).mean() 242 | disc_fake_scores = state_disc.apply_fn( 243 | reconstruction, 244 | params=state_disc.params, 245 | dropout_rng=dropout_rng, 246 | train=train, 247 | ) 248 | loss_disc = jnp.mean(nn.softplus(-disc_fake_scores)) 249 | loss_lpips = jnp.mean(lpips_fn.apply(lpips_params, original, reconstruction)) 250 | 251 | loss = loss_l2 * cost_l2 + loss_lpips * cost_lpips + loss_disc * cost_disc 252 | loss_details = { 253 | "loss_l2": loss_l2 * cost_l2, 254 | "loss_lpips": loss_lpips * cost_lpips, 255 | "loss_disc": loss_disc * cost_disc, 256 | } 257 | return loss, (loss_details, reconstruction) 258 | 259 | grad_fn = jax.value_and_grad(compute_loss, has_aux=True) 260 | (loss, (loss_details, reconstruction)), grad = grad_fn( 261 | state.params, batch, dropout_rng, train=True 262 | ) 263 | # legacy code, I didn't use multi gpu 264 | # grad = jax.lax.pmean(grad, "batch") 265 | 266 | new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) 267 | 268 | metrics = loss_details | {"learning_rate": schedule_fn(state.step)} 269 | # metrics = jax.lax.pmean(metrics, axis_name="batch") 270 | return new_state, metrics, reconstruction 271 | 272 | 273 | # %% 274 | def compute_stylegan_loss( 275 | disc_params, batch, fake_images, dropout_rng, disc_model_fn, train 276 | ): 277 | disc_fake_scores = disc_model_fn( 278 | fake_images, params=disc_params, dropout_rng=dropout_rng, train=train 279 | ) 280 | disc_real_scores = disc_model_fn( 281 | batch, params=disc_params, dropout_rng=dropout_rng, train=train 282 | ) 283 | # -log sigmoid(f(x)) = log (1 + exp(-f(x))) = softplus(-f(x)) 284 | # -log(1-sigmoid(f(x))) = log (1 + exp(f(x))) = softplus(f(x)) 285 | # https://github.com/pfnet-research/sngan_projection/issues/18#issuecomment-392683263 286 | loss_real = nn.softplus(-disc_real_scores) 287 | loss_fake = nn.softplus(disc_fake_scores) 288 | disc_loss_stylegan = jnp.mean(loss_real + loss_fake) 289 | 290 | # gradient penalty r1: https://github.com/NVlabs/stylegan2/blob/bf0fe0baba9fc7039eae0cac575c1778be1ce3e3/training/loss.py#L63-L67 291 | r1_grads = jax.grad( 292 | lambda x: jnp.mean( 293 | disc_model_fn(x, params=disc_params, dropout_rng=dropout_rng, train=train) 294 | ) 295 | )(batch) 296 | # get the squares of gradients 297 | r1_grads = jnp.mean(r1_grads**2) 298 | 299 | disc_loss = disc_loss_stylegan + cost_gradient_penalty * r1_grads 300 | disc_loss_details = { 301 | "pred_p_real": jnp.exp(-loss_real).mean(), # p = 1 -> predict real is real 302 | "pred_p_fake": jnp.exp(-loss_fake).mean(), # p = 1 -> predict fake is fake 303 | "loss_real": loss_real.mean(), 304 | "loss_fake": loss_fake.mean(), 305 | "loss_stylegan": disc_loss_stylegan, 306 | "loss_gradient_penalty": cost_gradient_penalty * r1_grads, 307 | "loss": disc_loss, 308 | } 309 | return disc_loss, disc_loss_details 310 | 311 | 312 | train_compute_stylegan_loss = partial(compute_stylegan_loss, train=True) 313 | grad_stylegan_fn = jax.value_and_grad(train_compute_stylegan_loss, has_aux=True) 314 | 315 | 316 | def train_step_disc(state_disc, batch, fake_images): 317 | dropout_rng, new_dropout_rng = jax.random.split(state_disc.dropout_rng) 318 | # convert fake images to int then back to float, so discriminator can't cheat 319 | dtype = fake_images.dtype 320 | fake_images = (fake_images.clip(0, 1) * 255).astype(jnp.uint8).astype(dtype) / 255 321 | (disc_loss, disc_loss_details), disc_grads = grad_stylegan_fn( 322 | state_disc.params, 323 | batch, 324 | fake_images, 325 | dropout_rng, 326 | disc_model, 327 | ) 328 | new_state = state_disc.apply_gradients( 329 | grads=disc_grads, dropout_rng=new_dropout_rng 330 | ) 331 | metrics = disc_loss_details | {"learning_rate_disc": schedule_fn(state_disc.step)} 332 | # metrics = jax.lax.pmean(metrics, axis_name="batch") 333 | return new_state, metrics 334 | 335 | 336 | # %% 337 | # Take the first 100 images as validation set 338 | train_ds = DecoderImageDataset(hfds.select(range(100, len(hfds))), root=data_root) 339 | test_ds = DecoderImageDataset(hfds.select(range(100)), root=data_root) 340 | # %% 341 | jit_train_step = jax.jit(train_step) 342 | jit_train_step_disc = jax.jit(train_step_disc) 343 | # %% 344 | def try_train_batch_size_fn(batch_size): 345 | example = train_ds[0] 346 | batch = train_ds.collate_fn([example] * batch_size) 347 | new_state, metrics, reconstruction = jit_train_step(state, batch, state_disc) 348 | new_state, metrics = jit_train_step_disc( 349 | state_disc, batch["original"], reconstruction 350 | ) 351 | return 352 | 353 | 354 | # this takes about 20 GB of memory, adjust batch size accordingly for your GPU 355 | train_batch_size = 8 356 | train_batch_size = try_batch_size( 357 | try_train_batch_size_fn, start_batch_size=train_batch_size 358 | ) 359 | print(f"Training batch size: {train_batch_size}") 360 | # %% 361 | # try it again, make sure there is no error 362 | try_train_batch_size_fn(train_batch_size) 363 | print(f"Training batch size: {train_batch_size}") 364 | 365 | # %% 366 | def reconstruct_from_original(batch): 367 | original = batch["original"] 368 | if "indices" in batch: 369 | indices = batch["indices"] 370 | else: 371 | quant_states, indices = model.encode(original) 372 | reconstruction = state.apply_fn(indices) 373 | return reconstruction 374 | 375 | 376 | # %% 377 | wandb.log({"train_dataset_size": len(train_ds)}) 378 | # %% 379 | 380 | dataloader = DataLoader( 381 | train_ds, 382 | batch_size=train_batch_size, 383 | shuffle=True, 384 | collate_fn=partial(DecoderImageDataset.collate_fn, return_names=False), 385 | num_workers=4, 386 | drop_last=True, 387 | prefetch_factor=4, 388 | persistent_workers=True, 389 | ) 390 | 391 | # %% 392 | state = TrainState.create( 393 | apply_fn=model.decode_code, 394 | params=jax.device_put(params), 395 | tx=optimizer, 396 | dropout_rng=dropout_rng, 397 | ) 398 | state_disc = TrainState.create( 399 | apply_fn=disc_model, 400 | params=jax.device_put(disc_model.params), 401 | tx=optimizer_disc, 402 | dropout_rng=dropout_rng, 403 | ) 404 | # %% 405 | # data loader without shuffle, so we can see the progress on the same images 406 | train_dl_eval = DataLoader( 407 | train_ds, 408 | batch_size=train_batch_size, 409 | shuffle=False, 410 | collate_fn=partial(DecoderImageDataset.collate_fn, return_names=True), 411 | num_workers=4, 412 | drop_last=True, 413 | prefetch_factor=4, 414 | persistent_workers=True, 415 | ) 416 | test_dl = DataLoader( 417 | test_ds, 418 | batch_size=train_batch_size, 419 | shuffle=False, 420 | collate_fn=partial(DecoderImageDataset.collate_fn, return_names=True), 421 | num_workers=4, 422 | drop_last=True, 423 | prefetch_factor=4, 424 | persistent_workers=True, 425 | ) 426 | # %% 427 | # evaluation functions 428 | @jax.jit 429 | def infer_fn(batch, state): 430 | original = batch["original"] 431 | if "indices" in batch: 432 | indices = batch["indices"] 433 | else: 434 | quant_states, indices = model.encode(original) 435 | reconstruction = state.apply_fn(indices, params=state.params) 436 | return reconstruction 437 | 438 | 439 | def evaluate(use_tqdm=False, step=None): 440 | losses = [] 441 | iterable = test_dl if not use_tqdm else tqdm(test_dl) 442 | for batch in iterable: 443 | name = batch.pop("name") 444 | reconstruction = infer_fn(batch, state) 445 | losses.append(loss_fn(reconstruction, batch["original"]).mean()) 446 | loss = np.mean(jax.device_get(losses)) 447 | wandb.log({"test_loss": loss, "step": step}) 448 | 449 | 450 | def postpro(decoded_images): 451 | """util function to postprocess images""" 452 | decoded_images = decoded_images.clip(0.0, 1.0) # .reshape((-1, 256, 256, 3)) 453 | return [ 454 | Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8)) 455 | for decoded_img in decoded_images 456 | ] 457 | 458 | 459 | def log_images(dl, num_images=8, suffix="", step=None): 460 | logged_images = 0 461 | 462 | def batch_gen(): 463 | while True: 464 | for batch in dl: 465 | yield batch 466 | 467 | batch_iter = batch_gen() 468 | while logged_images < num_images: 469 | batch = next(batch_iter) 470 | 471 | names = batch.pop("name") 472 | reconstruction = infer_fn(batch, state) 473 | left_right = np.concatenate([batch["original"], reconstruction], axis=2) 474 | 475 | images = postpro(left_right) 476 | for name, image in zip(names, images): 477 | wandb.log( 478 | {f"{name}{suffix}": wandb.Image(image, caption=name), "step": step} 479 | ) 480 | logged_images += len(images) 481 | 482 | 483 | def log_test_images(num_images=8, step=None): 484 | return log_images(dl=test_dl, num_images=num_images, step=step) 485 | 486 | 487 | def log_train_images(num_images=8, step=None): 488 | return log_images( 489 | dl=train_dl_eval, num_images=num_images, suffix="|train", step=step 490 | ) 491 | 492 | 493 | def data_iter(): 494 | while True: 495 | for batch in dataloader: 496 | yield batch 497 | 498 | 499 | # %% 500 | for steps, batch in zip(tqdm(range(total_steps)), data_iter()): 501 | state, metrics, reconstruction = jit_train_step(state, batch, state_disc) 502 | state_disc, metrics_disc = jit_train_step_disc( 503 | state_disc, batch["original"], reconstruction 504 | ) 505 | # metrics = metrics | metrics_disc 506 | metrics["disc_step"] = metrics_disc 507 | metrics["step"] = steps 508 | if steps % log_steps == 1: 509 | wandb.log(metrics) 510 | if steps % eval_steps == 1: 511 | evaluate(step=steps) 512 | log_test_images(step=steps) 513 | log_train_images(step=steps) 514 | with Path(output_dir / "latest_state_disc.msgpack").open("wb") as f: 515 | f.write(to_bytes(jax.device_get(state_disc))) 516 | with Path(output_dir / "latest_state.msgpack").open("wb") as f: 517 | f.write(to_bytes(jax.device_get(state))) 518 | 519 | # how to use the model 520 | 521 | ## you can load .msgpack to dictionary with: 522 | """ 523 | from flax.serialization import msgpack_restore 524 | with open("{path}.msgpack", "rb") as f: 525 | x = f.read() 526 | state_dict = msgpack_restore(x) 527 | """ 528 | ## you can then replace the original decoder with the trained one, and save it 529 | """ 530 | for k,v in state_dict['params'].items(): 531 | vqgan_params[k] = v 532 | vqgan.save_pretrained("vqgan", params=vqgan_params) 533 | """ 534 | ## you can convert it to min-dalle (pytorch) format with: 535 | # https://github.com/kuprel/min-dalle-flax/blob/main/min_dalle_flax/load_params.py 536 | # it's mentioned in the readme of the repo https://github.com/kuprel/min-dalle 537 | --------------------------------------------------------------------------------