├── .gitignore ├── .markdownlint.json ├── README.md ├── architecture.py ├── args.py ├── batch.py ├── dataset.py ├── dataset.sh ├── docs ├── POSTMORTEM.md └── ROADMAP.md ├── inference_jax.py ├── inference_jax.sh ├── loss.py ├── main.py ├── monitoring.py ├── optimizer.py ├── repository.py ├── repository.sh ├── requirements.txt ├── training.sh ├── training_loop.py ├── training_step.py ├── validation.py └── validation.sh /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .cache 3 | dataset-cache 4 | wandb 5 | Pipfile.lock 6 | Pipfile 7 | .DS_Store 8 | -------------------------------------------------------------------------------- /.markdownlint.json: -------------------------------------------------------------------------------- 1 | { 2 | "no-inline-html": { 3 | "allowed_elements": [ "h1", "img" ] 4 | }, 5 | "line-length": { 6 | "line_length": 2048 7 | } 8 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Markdown Monster icon

4 | 5 | # CHARacter-awaRE Diffusion: Multilingual Character-Aware Encoders for Font-Aware Diffusers That Can Actually Spell 6 | 7 | Tired of text-to-image models that can't spell or deal with fonts and typography correctly ? [The secret seems to be in the use of multilingual, tokenization-free, character-aware transformer encoders](https://arxiv.org/abs/2212.10562) such as [ByT5](https://arxiv.org/abs/2105.13626) and [CANINE-c](https://arxiv.org/abs/2103.06874). 8 | 9 | ## Replace CLIP with ByT5 in HF's `text-to-image` Pipeline 10 | 11 | AS part of the [Hugging Face JAX Diffuser Sprint](https://github.com/huggingface/community-events/tree/main/jax-controlnet-sprint), we will replace [CLIP](https://arxiv.org/abs/2103.00020)'s tokenizer and encoder with [ByT5](https://arxiv.org/abs/2105.13626)'s in the [HF's JAX/FLAX text-to-image pre-training code](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py) and run it on the sponsored TPU ressources provided by Google for the event. 12 | 13 | More specifically, here are the main tasks we will try to accomplish during the sprint: 14 | 15 | - Pre-training dataset preparation: we are NOT going to train on `lambdalabs/pokemon-blip-captions`. So what is it going to be, what are the options? [Anything in here](https://analyticsindiamag.com/top-used-datasets-for-text-to-image-synthesis-models/) or [here](https://github.com/Yutong-Zhou-cv/Awesome-Text-to-Image#head3) takes your fancy? Or maybe [DiffusionDB](https://poloclub.github.io/diffusiondb/)? Or a savant mix of many datasets? We probably will need to combine many datasets as we are looking to cover these requirements: 16 | - We need samples for which there is text in the scene that is explicitely specified in the caption and the priority is to do that in full scene photos. If we can't find enough, we will integrate more specialized datasets for OCR; 17 | - Approximately the same language distribution as ByT5, but also include indonesian (not in ByT5) to see how character-awareness works when text in the prompt is specified in a language. We need to build testing facilities around the languages that are spoken by team members and friends: indonesian, japanese, french, amharic, arabic, norwegian, swedish, hindi, urdu and english. 18 | 19 | We shoud use the [Hugging Face Datasets library](https://huggingface.co/docs/datasets) as much as possible since it [supports JAX out of the box](https://huggingface.co/docs/datasets/en/use_with_jax). For simplicity's sake we will limit us to [concatenated](https://huggingface.co/docs/datasets/en/process#concatenate) Hugging Face datasets such as LAION2B [EN](https://huggingface.co/datasets/laion/laion2B-en), [MULTI](https://huggingface.co/datasets/laion/laion2B-multi) and [NOLANG](https://huggingface.co/datasets/laion/laion1B-nolang). We shall, however [pre-load](https://huggingface.co/docs/datasets/en/loading), [pre-process](https://huggingface.co/docs/datasets/en/loading) and [cache](https://huggingface.co/docs/datasets/en/cache) the dataset on disk before training on it. 20 | - Improvements to the [original code](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py): 21 | - ~~Make sure we can run the original code as-is on the TPU VM.~~ 22 | - Audit and optimize the code for the Google Cloud TPU v4-8 VM: [`jnp`](https://jax.readthedocs.io/en/latest/jax.numpy.html) (instead of np) [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html), [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html), [`pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) everywhere! And we should make sure we do not miss any [optimization made in the sprint code](https://github.com/huggingface/community-events/blob/main/jax-controlnet-sprint/training_scripts/train_controlnet_flax.py ) either. 23 | - Instrumentation for TPU remote monitoring with [Open Telemetry](https://opentelemetry.io/docs/instrumentation/python/), [TensorBoard](https://www.tensorflow.org/tensorboard/), [Perfetto](https://perfetto.dev), [Weights & Biases](https://wandb.ai) and [JAX's own profiler](https://jax.readthedocs.io/en/latest/profiling.html). 24 | - Implement checkpoint milestone snapshot uploading to cloud storage: we need to be able to download the model for local inference benchmarking to make sure we are on the right track. There seems to be [rudimentary checkpoint support in the original code](https://huggingface.co/docs/diffusers/training/text2image#save-and-load-checkpoints). 25 | - ~~No time for politics. NSFW filtering will be turned off. So we get `FlaxStableDiffusionSafetyChecker` out of the way.~~ 26 | - Replace CLIP with ByT5 in [original code](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py): 27 | - ~~Replacing `CLIPTokenizer` with `ByT5Tokenizer`. Since this will run on the CPUs, there is no need for JAX/FLAX unless there is hope for huge performance improvements. This should be trivial.~~ Merged. Needs testing. 28 | - ~~Replacing `FlaxCLIPTextModel` with `FlaxT5EncoderModel`. This *might* be almost as easy as replacing the tokenizer.~~ Merged. Needs testing. 29 | - ~~Rewrite `CLIPImageProcessor` for ByT5. This is still under investigation. It's unclear how hard it will be.~~ Done. Needs testing. 30 | - ~~RAdapt `FlaxAutoencoderKL` and `FlaxUNet2DConditionModel` for ByT5 if necessary.~~ Done. Needs testing. 31 | - ~~Break down the main pretraining loop into many functions in different source files for readability and easier maintenance.~~ 32 | 33 | ## Introducing a Calligraphic & Typographic ControlNet 34 | 35 | Secondly, we will integrate to the above a [Hugging-Face JAX/FLAX ControlNet implementation](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) for better typographic control over the generated images. To the orthographically-enanced SD above and as per [Peter von Platen](https://github.com/patrickvonplaten)'s suggestion, we also introduce the idea a typographic [ControlNet](https://arxiv.org/abs/2302.05543) trained on an synthetic dataset of images paired with multilingual specifications of the textual content, font taxonomy, weight, kerning, leading, slant and any other typographic attribute supported by the [CSS3](https://www.w3.org/Style/CSS/) [Text](https://www.w3.org/TR/css-text-3/), [Fonts](https://www.w3.org/TR/css-fonts-3) and [Writing Modes](https://www.w3.org/TR/css-writing-modes-3/) modules, as implemented by the latest version of [Chromium](https://www.chromium.org/Home/). 36 | -------------------------------------------------------------------------------- /architecture.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import jax.numpy as jnp 4 | 5 | from transformers import FlaxT5ForConditionalGeneration, set_seed 6 | 7 | from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel 8 | 9 | 10 | def setup_model( 11 | seed, 12 | load_pretrained, 13 | output_dir, 14 | training_from_scratch_rng_params, 15 | ): 16 | set_seed(seed) 17 | 18 | # Load models and create wrapper for stable diffusion 19 | 20 | language_model = FlaxT5ForConditionalGeneration.from_pretrained( 21 | "/data/byt5-base", 22 | dtype=jnp.float32, 23 | ) 24 | 25 | vae, vae_params = FlaxAutoencoderKL.from_pretrained( 26 | "/data/stable-diffusion-2-1-vae", 27 | dtype=jnp.float32, 28 | ) 29 | 30 | if load_pretrained: 31 | if os.path.isdir(output_dir): 32 | # find latest epoch output 33 | pretrained_dir = [ 34 | dir 35 | for dir in os.listdir(output_dir).sort(reverse=True) 36 | if os.path.isdir(os.path.join(output_dir, dir)) 37 | ][0] 38 | else: 39 | pretrained_dir = output_dir 40 | 41 | unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( 42 | pretrained_dir, 43 | dtype=jnp.float32, 44 | ) 45 | 46 | print("loaded unet from pre-trained...") 47 | else: 48 | unet = FlaxUNet2DConditionModel.from_config( 49 | config={ 50 | "_diffusers_version": "0.16.0", 51 | "attention_head_dim": [5, 10, 20, 20], 52 | "block_out_channels": [320, 640, 1280, 1280], 53 | "cross_attention_dim": 1536, 54 | "down_block_types": [ 55 | "CrossAttnDownBlock2D", 56 | "CrossAttnDownBlock2D", 57 | "CrossAttnDownBlock2D", 58 | "DownBlock2D", 59 | ], 60 | "dropout": 0.0, 61 | "flip_sin_to_cos": True, 62 | "freq_shift": 0, 63 | "in_channels": 4, 64 | "layers_per_block": 2, 65 | "only_cross_attention": False, 66 | "out_channels": 4, 67 | "sample_size": 64, 68 | "up_block_types": [ 69 | "UpBlock2D", 70 | "CrossAttnUpBlock2D", 71 | "CrossAttnUpBlock2D", 72 | "CrossAttnUpBlock2D", 73 | ], 74 | "use_linear_projection": True, 75 | }, 76 | dtype=jnp.float32, 77 | ) 78 | unet_params = unet.init_weights(rng=training_from_scratch_rng_params) 79 | print("training unet from scratch...") 80 | 81 | return ( 82 | language_model.encode, 83 | language_model.params, 84 | vae, 85 | vae_params, 86 | unet, 87 | unet_params, 88 | ) 89 | 90 | 91 | if __name__ == "__main__": 92 | language_model = FlaxT5ForConditionalGeneration.from_pretrained( 93 | "google/byt5-base", 94 | dtype=jnp.float32, 95 | ) 96 | 97 | language_model.save_pretrained("/data/byt5-base") 98 | 99 | vae, vae_params = FlaxAutoencoderKL.from_pretrained( 100 | "flax/stable-diffusion-2-1", 101 | subfolder="vae", 102 | dtype=jnp.float32, 103 | ) 104 | 105 | vae.save_pretrained("/data/stable-diffusion-2-1-vae", params=vae_params) 106 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 6 | parser.add_argument( 7 | "--output_dir", 8 | type=str, 9 | default="/data/output", 10 | help="The output directory where the model predictions and checkpoints will be written.", 11 | ) 12 | parser.add_argument( 13 | "--cache_dir", 14 | type=str, 15 | default="/data/dataset/cache", 16 | help="The directory where the downloaded models and datasets will be stored.", 17 | ) 18 | parser.add_argument( 19 | "--seed", type=int, default=0, help="A seed for reproducible training." 20 | ) 21 | parser.add_argument( 22 | "--resolution", 23 | type=int, 24 | default=1024, 25 | help=( 26 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 27 | " resolution" 28 | ), 29 | ) 30 | parser.add_argument( 31 | "--train_batch_size", 32 | type=int, 33 | default=4, 34 | help="Batch size (per device) for the training dataloader.", 35 | ) 36 | parser.add_argument("--num_train_epochs", type=int, default=1_000_000) 37 | parser.add_argument( 38 | "--max_train_steps", 39 | type=int, 40 | default=1_048_577, # 33_554_432, 67_108_864, 134_217_728, 1_073_741_824, 2_147_483_648, 17_179_869_184 41 | help="Total number of training steps per epoch to perform.", 42 | ) 43 | parser.add_argument( 44 | "--learning_rate", 45 | type=float, 46 | default=1e-4, 47 | help="Initial learning rate (after the potential warmup period) to use.", 48 | ) 49 | parser.add_argument( 50 | "--adam_beta1", 51 | type=float, 52 | default=0.9, 53 | help="The beta1 parameter for the Adam optimizer.", 54 | ) 55 | parser.add_argument( 56 | "--adam_beta2", 57 | type=float, 58 | default=0.999, 59 | help="The beta2 parameter for the Adam optimizer.", 60 | ) 61 | parser.add_argument( 62 | "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." 63 | ) 64 | parser.add_argument( 65 | "--adam_epsilon", 66 | type=float, 67 | default=1e-08, 68 | help="Epsilon value for the Adam optimizer", 69 | ) 70 | parser.add_argument( 71 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." 72 | ) 73 | parser.add_argument( 74 | "--log_wandb", 75 | type=bool, 76 | default=True, 77 | choices=[True, False], 78 | help=("Whether to use WandB to log the metrics or not"), 79 | ) 80 | 81 | args = parser.parse_args() 82 | 83 | return args 84 | -------------------------------------------------------------------------------- /batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def setup_dataloader(dataset, batch_size): 5 | def _collate(samples): 6 | # TODO: replace torch.stack with https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.stack.html 7 | 8 | pixel_values = ( 9 | torch.stack([sample["pixel_values"] for sample in samples]) 10 | .to(memory_format=torch.contiguous_format) 11 | .float() 12 | .numpy() 13 | ) 14 | 15 | input_ids = ( 16 | torch.stack([sample["input_ids"] for sample in samples]) 17 | .to(memory_format=torch.contiguous_format) 18 | .numpy() 19 | ) 20 | 21 | return { 22 | "pixel_values": pixel_values, 23 | "input_ids": input_ids, 24 | } 25 | 26 | return torch.utils.data.DataLoader( 27 | dataset, 28 | shuffle=True, 29 | collate_fn=_collate, 30 | batch_size=batch_size, 31 | num_workers=4, 32 | drop_last=True, 33 | ) 34 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import os 3 | from PIL import Image 4 | import requests 5 | 6 | from torchvision import transforms 7 | 8 | from transformers import ByT5Tokenizer 9 | 10 | 11 | def _prefilter(sample): 12 | image_url = sample["URL"] 13 | caption = sample["TEXT"] 14 | watermark_probability = sample["pwatermark"] 15 | unsafe_probability = sample["punsafe"] 16 | hash = sample["hash"] 17 | 18 | return ( 19 | caption is not None 20 | and isinstance(caption, str) 21 | and image_url is not None 22 | and isinstance(image_url, str) 23 | and watermark_probability is not None 24 | and watermark_probability < 0.6 25 | and unsafe_probability is not None 26 | and unsafe_probability < 1.0 27 | and hash is not None 28 | ) 29 | 30 | 31 | def _download_image(sample): 32 | is_ok = False 33 | 34 | image_url = sample["URL"] 35 | 36 | cached_image_image_file_path = os.path.join( 37 | "/data/image-cache", "%s.jpg" % hex(sample["hash"]) 38 | ) 39 | 40 | if os.path.isfile(cached_image_image_file_path): 41 | pass 42 | else: 43 | try: 44 | # get image data from url 45 | image_bytes = requests.get(image_url, stream=True, timeout=5).raw 46 | 47 | if image_bytes is not None: 48 | pil_image = Image.open(image_bytes) 49 | 50 | if pil_image.mode == "RGB": 51 | pil_rgb_image = pil_image 52 | 53 | else: 54 | # Deal with non RGB images 55 | if pil_image.mode == "RGBA": 56 | pil_rgba_image = pil_rgb_image 57 | else: 58 | pil_rgba_image = pil_rgb_image.convert("RGBA") 59 | 60 | pil_rgb_image = Image.alpha_composite( 61 | Image.new("RGBA", pil_image.size, (255, 255, 255)), 62 | pil_rgba_image, 63 | ).convert("RGB") 64 | 65 | is_ok = True 66 | 67 | pil_rgb_image.save(cached_image_image_file_path) 68 | 69 | except: 70 | with open(cached_image_image_file_path, mode="a"): 71 | pass 72 | 73 | # save image to disk but do not catch exception. this has to fail because otherwise the mapper will run forever 74 | if is_ok: 75 | pil_rgb_image.save(cached_image_image_file_path) 76 | 77 | return is_ok 78 | 79 | 80 | def _filter_out_unprocessed(sample): 81 | cached_image_image_file_path = os.path.join( 82 | "/data/image-cache", "%s.jpg" % hex(sample["hash"]) 83 | ) 84 | 85 | if ( 86 | os.path.isfile(cached_image_image_file_path) 87 | and os.stat(cached_image_image_file_path).st_size > 0 88 | ): 89 | try: 90 | Image.open(cached_image_image_file_path) 91 | 92 | return True 93 | 94 | except: 95 | pass 96 | 97 | return False 98 | 99 | 100 | def get_compute_intermediate_values_lambda(): 101 | tokenizer = ByT5Tokenizer() 102 | 103 | image_transforms = transforms.Compose( 104 | [ 105 | transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS), 106 | transforms.CenterCrop(512), 107 | transforms.ToTensor(), 108 | ] 109 | ) 110 | 111 | def __get_pixel_values(image_hash): 112 | # compute file name 113 | cached_image_image_file_path = os.path.join( 114 | "/data/image-cache", "%s.jpg" % hex(image_hash) 115 | ) 116 | 117 | # get image data from cache 118 | pil_rgb_image = Image.open(cached_image_image_file_path) 119 | 120 | transformed_image = image_transforms(pil_rgb_image) 121 | 122 | return transformed_image 123 | 124 | def __compute_intermediate_values_lambda(samples): 125 | samples["input_ids"] = tokenizer( 126 | text=samples["TEXT"], 127 | max_length=1024, 128 | padding="max_length", 129 | truncation=True, 130 | return_tensors="pt", 131 | ).input_ids 132 | 133 | samples["pixel_values"] = [ 134 | __get_pixel_values(image_hash) for image_hash in samples["hash"] 135 | ] 136 | 137 | return samples 138 | 139 | return __compute_intermediate_values_lambda 140 | 141 | 142 | def setup_dataset(n): 143 | # loading the dataset 144 | dataset = ( 145 | load_dataset( 146 | "parquet", 147 | data_files={ 148 | # "train": "/data/laion-high-resolution-filtered-shuffled.snappy.parquet", 149 | # "train": "/data/laion-high-resolution-filtered-shuffled-processed-split.zstd.parquet", 150 | # "train": "/data/laion-high-resolution-filtered-shuffled-processed-split-byt5-vae.zstd.parquet", 151 | # "train": "/data/laion-high-resolution-filtered-shuffled-validated-10k.zstd.parquet", 152 | "train": "/data/laion-high-resolution-1M.zstd.parquet", 153 | }, 154 | split="train[:%d]" % n, 155 | cache_dir="/data/cache", 156 | num_proc=32, 157 | ) 158 | .with_format("torch") 159 | .map( 160 | get_compute_intermediate_values_lambda(), 161 | batched=True, 162 | batch_size=16, 163 | num_proc=32, 164 | ) 165 | .select_columns(["input_ids", "pixel_values"]) 166 | ) 167 | 168 | return dataset 169 | 170 | 171 | def prepare_1m_dataset(): 172 | # Gives 1267072 samples to be exact 173 | ( 174 | load_dataset( 175 | "laion/laion-high-resolution", 176 | split="train", 177 | cache_dir="/data/cache", 178 | ) 179 | .with_format("torch") 180 | .select_columns(["TEXT", "hash"]) 181 | .filter( 182 | function=_filter_out_unprocessed, 183 | num_proc=96, 184 | ) 185 | .to_parquet( 186 | "/data/laion-high-resolution-1M.zstd.parquet", 187 | batch_size=128, 188 | compression="ZSTD", 189 | ) 190 | ) 191 | 192 | 193 | if __name__ == "__main__": 194 | prepare_1m_dataset() 195 | # dataset = setup_dataset(64) 196 | 197 | # dataloader = setup_dataloader(dataset, 16) 198 | # for batch in dataloader: 199 | # print(batch["pixel_values"].shape) 200 | -------------------------------------------------------------------------------- /dataset.sh: -------------------------------------------------------------------------------- 1 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=34359738368 2 | #export JAX_PLATFORMS="" 3 | export TF_CPP_MIN_LOG_LEVEL=2 4 | python3 ./dataset.py -------------------------------------------------------------------------------- /docs/POSTMORTEM.md: -------------------------------------------------------------------------------- 1 | Thanks to Hugging Face and Google for providing for 21 days free value of more than 6000 USD and GPU T4. 2 | 3 | What we did manage to do. 4 | 5 | We did not manage to... nor to get in-training validation or post-training inference code to work. -------------------------------------------------------------------------------- /docs/ROADMAP.md: -------------------------------------------------------------------------------- 1 | # CHARRED Roadmap 2 | 3 | ## Use cases 4 | 5 | Artistic rendition of missing, loss, or stolen textual artifacts (text-to-image), OCR (image-to-text), statistical reconstitution of damaged textual artifacts (image-to-image, inpainting to predict the missing characters, ex: MARI). 6 | 7 | Low-resource languages, low-resource domains. (ex: perfumery) 8 | 9 | ## Training 10 | 11 | Are we generating the input embeddings correctly? 12 | 13 | FlaxT5PreTrainedModel.encode(): input_ids=jnp.array(input_ids, dtype="i4") ? uint32 Unicode char or uint8 UTF-8 byte ? 14 | array = np.frombuffer(string.encode("UTF-8", errors="ignore"), dtype=np.uint8) + 3 15 | string = (ndarray - 3).tobytes().decode("UTF-8", errors="ignore") 16 | 17 | Use T5x lib instead of "transformers" 18 | 19 | Should the shape of the latents in the VAE/UNet be bigger to accomodate for more tokens ? 20 | 21 | Would it be possible to have a vocab-less, raw UTF-8 byte character aware decoder-only language model ? 22 | 23 | Write the tests first: https://github.com/deepmind/chex 24 | 25 | VAE/U-Net hyperparameters to accommodate byt5's character-awareness better 26 | 27 | Try to run original code 28 | 29 | 1. DONE: Implement JAX/FLAX SD 2.1 training pipeline with ByT5-Base instead of CLIP: https://github.com/patil-suraj/stable-diffusion-jax https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py https://huggingface.co/google/byt5-base https://huggingface.co/blog/stable_diffusion_jax 30 | 2. DONE: WandB monitoring 31 | 3. DONE: Implement Mini-SNR loss rebalancing: https://arxiv.org/abs/2303.09556 32 | 4. DONE: Implement on-the-fly validation: https://huggingface.co/docs/diffusers/en/conceptual/evaluation 33 | 5. DONE: save checkpoints to disk 34 | 6. Get rid of PytorchDataloader Flax Linen and HF libraries (transformers, diffusers, datasets), use JAX's new Array code, and write pure functions 35 | 7. Make the code independent from device topology (might be hardcoded to 8xTPUv4 at the moment) 36 | 8. Implement streaming from the network (instead of from the disk), mini-batching and gradient accumulation with image aspect ratio and tokenized caption size bucketing. Freezed models caption text embeddings (ByT5) and image embeddings (VAE) caching with bfloat16 half-precision (ByT5 and VAE) and explore using ByT5 XXL float32 (51.6GB), XXL bfloat16 (26GB), or XL float32 (15GB) and discard anything unnecessary from the freezed models (eg: ByT5 decoder weights) to lower the memory requirements: https://github.com/google/jax/issues/1408 https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html https://huggingface.co/google/byt5-xxl https://github.com/google-research/t5x/blob/main/docs/models.md#byt5-checkpoints https://github.com/google-research/t5x/blob/main/t5x/scripts/convert_tf_checkpoint.py https://optax.readthedocs.io/en/latest/gradient_accumulation.html https://optax.readthedocs.io/en/latest/api.html#optax.MultiSteps 37 | 9. Better strategy to load and save checkpoints using JAX-native methods: https://flax.readthedocs.io/en/latest/api_reference/flax.training.html https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#save-checkpoints https://arxiv.org/abs/1604.06174 38 | 39 | ## Inference 40 | 41 | 1. DONE: Implement JAX/FLAX text-to-image inference pipeline and Gradio demo with ByT5-Base instead of CLIP: https://huggingface.co/docs/diffusers/training/text2image https://github.com/patil-suraj/stable-diffusion-jax 42 | 2. Implement AUTOMATIC1111 and Gradio UIs: https://github.com/AUTOMATIC1111/stable-diffusion-webui 43 | 3. Load checkpoints using JAX-native methods https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html 44 | 4. Implement OCR and Document understanging inference pipeline with ByT5 text decoder 45 | 5. Implement text encoding CPU offloading with int8 precision and GPU-accelerated U-Net prediction and VAE decoding with int8 precision https://jax.readthedocs.io/en/latest/multi_process.html https://github.com/TimDettmers/bitsandbytes https://huggingface.co/blog/hf-bitsandbytes-integration 46 | 47 | ## MLOps 48 | 49 | ### XLA, IREE, HLO, MLIR 50 | 51 | https://medium.com/@shivvidhyut/a-brief-introduction-to-distributed-training-with-gradient-descent-a4ba9faefcea 52 | https://www.kaggle.com/code/grez911/tutorial-efficient-gradient-descent-with-jax/notebook 53 | https://github.com/kingoflolz/mesh-transformer-jax 54 | https://github.com/kingoflolz/swarm-jax 55 | https://github.com/openxla/iree https://openxla.github.io/iree/ 56 | https://github.com/openxla/xla/blob/main/xla/xla.proto 57 | https://github.com/openxla/xla/blob/main/xla/xla_data.proto 58 | https://github.com/openxla/xla/blob/main/xla/service/hlo.proto 59 | https://github.com/openxla/xla/tree/main/third_party/tsl/tsl/protobuf 60 | https://github.com/openxla/xla/blob/main/xla/pjrt/distributed/protocol.proto 61 | https://github.com/apache/tvm/ 62 | https://github.com/openai/triton 63 | https://github.com/onnx/onnx-mlir 64 | https://github.com/plaidml/plaidml 65 | https://github.com/llvm/torch-mlir 66 | https://github.com/pytorch/pytorch/tree/main/torch/_dynamo 67 | https://github.com/pytorch/pytorch/tree/main/torch/_inductor 68 | https://github.com/llvm/llvm-project/tree/main/mlir/ https://mlir.llvm.org/ 69 | https://github.com/tensorflow/mlir-hlo 70 | https://github.com/openxla/stablehlo 71 | https://github.com/llvm/torch-mlir 72 | https://research.google/pubs/pub48035/ 73 | https://iq.opengenus.org/mlir-compiler-infrastructure/ 74 | https://www.youtube.com/watch?v=Z8knnMYRPx0 https://mlir.llvm.org/OpenMeetings/2023-03-23-Nelli.pdf 75 | https://mlir.llvm.org/docs/Tutorials/Toy/ 76 | https://mlir.llvm.org/getting_started/ 77 | Production AOT with IREE over Java JNI/JNA/Panama https://github.com/openxla/iree https://github.com/iree-org/iree-jax https://jax.readthedocs.io/en/latest/aot.html https://jax.readthedocs.io/en/latest/_autosummary/jax.make_jaxpr.html https://jax.readthedocs.io/en/latest/_autosummary/jax.xla_computation.html https://github.com/openxla/stablehlo https://github.com/openxla/xla https://github.com/openxla/openxla-pjrt-plugin https://github.com/iml130/iree-template-cpp 78 | hlo/mlir compiler/interpreter/simulator/emulator in java: https://github.com/oracle/graal/tree/master/sulong https://github.com/oracle/graal/tree/master/visualizer https://github.com/oracle/graal/tree/master/truffle https://github.com/graalvm/simplelanguage https://github.com/graalvm/simpletools https://openjdk.org/jeps/442 https://openjdk.org/jeps/448 79 | 80 | ### Java pipeline 81 | 82 | https://vertx.io/docs/vertx-web-api-service/java/ 83 | https://github.com/vert-x3/vertx-infinispan https://github.com/infinispan/infinispan 84 | https://github.com/eclipse-vertx/vertx-grpc 85 | https://github.com/vert-x3/vertx-service-discovery 86 | https://github.com/vert-x3/vertx-service-proxy 87 | 88 | https://kafka.apache.org/ https://github.com/provectus/kafka-ui 89 | https://github.com/vert-x3/vertx-kafka-client 90 | https://github.com/vert-x3/vertx-stomp https://github.com/stomp-js/stompjs https://activemq.apache.org/ https://developers.cloudflare.com/queues/ 91 | https://github.com/apache/pulsar 92 | https://github.com/datastax/kafka-examples 93 | https://github.com/datastax/kafka-sink 94 | https://github.com/datastax/starlight-for-kafka 95 | 96 | https://github.com/apache/arrow/tree/main/java 97 | https://github.com/apache/thrift 98 | https://github.com/apache/avro 99 | https://github.com/apache/orc 100 | https://github.com/apache/parquet-mr 101 | https://github.com/msgpack/msgpack-java 102 | https://github.com/irmen/pickle 103 | https://github.com/jamesmudd/jhdf 104 | 105 | https://github.com/apache/iceberg 106 | https://github.com/eclipse/jnosql 107 | https://github.com/trinodb/trino 108 | https://github.com/apache/druid/ 109 | https://github.com/apache/hudi 110 | https://github.com/delta-io/delta 111 | https://github.com/apache/pinot 112 | 113 | ### HA & Telemetry 114 | 115 | OpenTelemetry/Graphana monitoring instead of WandB, Perfetto or Tensorbord, attach JAX profiler artifacts 116 | https://github.com/resilience4j/resilience4j https://resilience4j.readme.io/docs/micrometer https://vertx.io/docs/vertx-micrometer-metrics/java/ 117 | https://github.com/open-telemetry/opentelemetry-java-instrumentation 118 | https://github.com/grafana/JPProf https://jax.readthedocs.io/en/latest/device_memory_profiling.html https://jax.readthedocs.io/en/latest/_autosummary/jax.profiler.device_memory_profile.html https://github.com/google/pprof/tree/main/proto https://github.com/grafana/pyroscope https://github.com/grafana/otel-profiling-go https://github.com/grafana/metrictank https://github.com/open-telemetry/opentelemetry-collector-contrib/tree/main/extension/pprofextension 119 | https://jax.readthedocs.io/en/latest/profiling.html https://github.com/google/perfetto/tree/master/protos 120 | https://vertx.io/docs/vertx-opentelemetry/java/ 121 | https://vertx.io/docs/vertx-micrometer-metrics/java/ 122 | 123 | ## Dataset preprocessing 124 | 125 | Make the most of cheap Kubernetes clusters: https://github.com/murphye/cheap-gke-cluster 126 | 127 | 1. Synthetic data generation: HTML/SVG/CSS graphic/layout/typography 128 | 2. Dataset merging: synthetic data, LAION-HR: https://huggingface.co/datasets/laion/laion-high-resolution, WiT dataset: https://huggingface.co/datasets/google/wit https://huggingface.co/datasets/wikimedia/wit_base, handwritten and printed documents scans, graphemes-in-the-wild, etc. with language re-sampling to match ByT5's C4 training distribution so as not to loose the multilingual balance: https://huggingface.co/datasets/allenai/c4 https://huggingface.co/datasets/mc4. More image datasets: https://huggingface.co/datasets/facebook/winoground https://huggingface.co/datasets/huggan/wikiart https://huggingface.co/datasets/kakaobrain/coyo-700m https://github.com/unsplash/datasets https://huggingface.co/datasets/red_caps https://huggingface.co/datasets/gigant/oldbookillustrations 129 | 3. Complementary downloads from dataset URLs (mostly images) and JPEG XL archiving (see IIIF) 130 | 4. Deduplication of images with fingerprinting and of captions with sentence embeddings (all the sentence-transformers disappeared on May 8 2023) 131 | 5. Scene segmentation, document layout understanding and caption NER. Because NER is to text what segmentation is to a scene and what layout understanding is to a document, we need to annotate all of these to be able to detect captions-within-a-caption (captions that spell out text within the image, for instance) and also score captions based on how exhaustive is the "coverage" of the scene or document they describe: https://medium.com/nlplanet/a-full-guide-to-finetuning-t5-for-text2text-and-building-a-demo-with-streamlit-c72009631887 https://huggingface.co/docs/transformers/model_doc/flan-ul2 https://pytorch.org/text/main/tutorials/t5_demo.html https://towardsdatascience.com/guide-to-fine-tuning-text-generation-models-gpt-2-gpt-neo-and-t5-dc5de6b3bc5e https://programming-review.com/machine-learning/t5/ https://colab.research.google.com/drive/1syXmhEQ5s7C59zU8RtHVru0wAvMXTSQ8 https://github.com/ttengwang/Caption-Anything https://github.com/facebookresearch/segment-anything https://github.com/facebookresearch/detectron2 https://huggingface.co/datasets/joelito/lextreme ttps://registry.opendata.aws/lei/ https://huggingface.co/datasets/jfrenz/legalglue https://huggingface.co/datasets/super_glue https://huggingface.co/datasets/klue https://huggingface.co/datasets/NbAiLab/norne https://huggingface.co/datasets/indic_glue https://huggingface.co/datasets/acronym_identification https://huggingface.co/datasets/wikicorpus https://huggingface.co/datasets/multi_woz_v22 https://huggingface.co/datasets/wnut_17 https://huggingface.co/datasets/msra_ner https://huggingface.co/datasets/conll2012_ontonotesv5 https://huggingface.co/datasets/conllpp 132 | 6. Image aesthetics and caption exhaustiveness (based on #5) meaningfulness (CoLa) evaluation and filtering: https://github.com/google-research/google-research/tree/master/vila https://github.com/google-research/google-research/tree/master/musiq https://github.com/christophschuhmann/improved-aesthetic-predictor https://www.mdpi.com/2313-433X/9/2/30 https://paperswithcode.com/dataset/aesthetic-visual-analysis https://www.tandfonline.com/doi/full/10.1080/09540091.2022.2147902 https://github.com/bcmi/Awesome-Aesthetic-Evaluation-and-Cropping https://github.com/rmokady/CLIP_prefix_caption https://github.com/google-research-datasets/Image-Caption-Quality-Dataset https://github.com/gchhablani/multilingual-image-captioning https://ai.googleblog.com/2022/10/crossmodal-3600-multilingual-reference.html https://www.cl.uni-heidelberg.de/statnlpgroup/wikicaps/ https://huggingface.co/docs/transformers/main/tasks/image_captioning https://www.mdpi.com/2076-3417/13/4/2446 https://arxiv.org/abs/2201.12723 https://laion.ai/blog/laion-aesthetics/ https://github.com/JD-P/simulacra-aesthetic-captions 133 | 7. Bucketing and batching (similar caption lengths for padding and truncating, similar image ratio for up/downsampling): https://github.com/NovelAI/novelai-aspect-ratio-bucketing 134 | 8. Images preprocessing with JAX-native methods: https://jax.readthedocs.io/en/latest/jax.image.html https://dm-pix.readthedocs.io/ https://github.com/4rtemi5/imax https://github.com/rolandgvc/flaxvision 135 | 136 | ## CharT5 (ByT5 v2) 137 | 138 | Pretrain a better ByT5 with innovations from the T5 family and other character-aware language transformer models: 139 | 140 | - Early character-aware language models: https://arxiv.org/abs/1508.06615 https://arxiv.org/abs/2011.01513 141 | - CANINE-C encoder only character-aware language model (https://arxiv.org/abs/2103.06874, https://github.com/google-research/language/tree/master/language/canine, https://huggingface.co/google/canine-c, https://huggingface.co/vicl/canine-c-finetuned-cola) 142 | - Switch/MOE https://arxiv.org/abs/2101.03961 https://github.com/google-research/t5x/tree/main/t5x/contrib/moe https://towardsdatascience.com/understanding-googles-switch-transformer-904b8bf29f66, https://huggingface.co/google/switch-c-2048, https://towardsdatascience.com/the-switch-transformer-59f3854c7050 https://arxiv.org/abs/2208.02813 143 | - FLAN/PALM/PALM-E/PALM 2/UL2 https://arxiv.org/abs/2210.11416 https://arxiv.org/abs/2301.13688 https://arxiv.org/abs/2109.01652 https://github.com/lucidrains/PaLM-jax https://github.com/conceptofmind/PaLM-flax https://github.com/google-research/t5x/tree/main/t5x/examples/decoder_only https://huggingface.co/docs/transformers/main/model_doc/flan-t5 https://huggingface.co/google/flan-ul2 https://arxiv.org/abs/2205.05131v3 https://github.com/google-research/google-research/tree/master/ul2 https://ai.googleblog.com/2022/10/ul2-20b-open-source-unified-language.html 144 | - T5 v1.1 https://arxiv.org/abs/2002.05202 145 | - CALM https://github.com/google-research/t5x/tree/main/t5x/contrib/calm https://arxiv.org/abs/2207.07061 146 | - FlashAttention https://github.com/HazyResearch/flash-attention https://arxiv.org/abs/2205.14135 147 | - T5x & Seqio https://arxiv.org/abs/2203.17189 148 | - LongT5 https://github.com/google-research/longt5 149 | - WT5 https://github.com/google-research/google-research/tree/master/wt5 150 | - NanoT5 https://github.com/PiotrNawrot/nanoT5 151 | - Tensor Considered Harmful https://nlp.seas.harvard.edu/NamedTensor https://github.com/stanford-crfm/levanter https://github.com/harvardnlp/NamedTensor 152 | - FasterTransformers https://github.com/NVIDIA/FasterTransformer 153 | - FlaxFormers https://github.com/google/flaxformer 154 | 155 | # Beyond SD 2.1 156 | 157 | Integrate and port to JAX as much improvements and ideas from Imagen, SDXL, Deep Floyd, Big Vision, Vision Transformer, etc. as possible : https://github.com/lucidrains/imagen-pytorch https://github.com/deep-floyd/IF https://stable-diffusion-art.com/sdxl-beta/ https://huggingface.co/docs/diffusers/api/pipelines/if https://huggingface.co/spaces/DeepFloyd/IF https://huggingface.co/DeepFloyd/IF-I-XL-v1.0 https://huggingface.co/DeepFloyd/IF-II-L-v1.0 https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler https://huggingface.co/DeepFloyd/IF-notebooks/tree/main https://huggingface.co/blog/if https://huggingface.co/docs/diffusers/main/en/api/pipelines/if https://stability.ai/blog/deepfloyd-if-text-to-image-model https://deepfloyd.ai/ https://www.assemblyai.com/blog/how-imagen-actually-works/ https://www.youtube.com/watch?v=af6WPqvzjjk https://www.youtube.com/watch?v=xqDeAz0U-R4 https://www.assemblyai.com/blog/minimagen-build-your-own-imagen-text-to-image-model/ https://github.com/google-research/big_vision https://github.com/google-research/vision_transformer https://github.com/microsoft/unilm/tree/master/beit https://arxiv.org/abs/2106.04803 https://arxiv.org/abs/2210.01820 https://arxiv.org/abs/2103.15808 https://arxiv.org/abs/2201.10271 https://arxiv.org/abs/2209.15159 https://arxiv.org/abs/2303.14189 https://arxiv.org/abs/2010.11929 https://arxiv.org/abs/2208.10442 https://arxiv.org/abs/2012.12877 https://arxiv.org/abs/2111.06377v3 https://arxiv.org/abs/2107.06263 https://arxiv.org/abs/1906.00446 https://arxiv.org/abs/2110.04627 https://arxiv.org/abs/2208.06366 https://arxiv.org/abs/2302.00902 https://arxiv.org/abs/2212.03185 https://arxiv.org/abs/2212.07372 https://arxiv.org/abs/2209.09002 https://arxiv.org/abs/2301.00704 https://arxiv.org/abs/2211.09117 https://arxiv.org/abs/2302.05917 158 | -------------------------------------------------------------------------------- /inference_jax.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax 3 | import jax.numpy as jnp 4 | from monitoring import wandb_close, wandb_inference_init, wandb_inference_log 5 | 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from diffusers import ( 10 | FlaxAutoencoderKL, 11 | FlaxDPMSolverMultistepScheduler, 12 | FlaxUNet2DConditionModel, 13 | ) 14 | from transformers import ByT5Tokenizer, FlaxT5ForConditionalGeneration 15 | 16 | 17 | def get_inference_lambda(seed): 18 | tokenizer = ByT5Tokenizer() 19 | 20 | language_model = FlaxT5ForConditionalGeneration.from_pretrained( 21 | "google/byt5-base", 22 | dtype=jnp.float32, 23 | ) 24 | text_encoder = language_model.encode 25 | text_encoder_params = language_model.params 26 | max_length = 1024 27 | tokenized_negative_prompt = tokenizer( 28 | "", padding="max_length", max_length=max_length, return_tensors="np" 29 | ).input_ids 30 | negative_prompt_text_encoder_hidden_states = text_encoder( 31 | tokenized_negative_prompt, 32 | params=text_encoder_params, 33 | train=False, 34 | )[0] 35 | 36 | scheduler = FlaxDPMSolverMultistepScheduler.from_config( 37 | config={ 38 | "_diffusers_version": "0.16.0", 39 | "beta_end": 0.012, 40 | "beta_schedule": "scaled_linear", 41 | "beta_start": 0.00085, 42 | "clip_sample": False, 43 | "num_train_timesteps": 1000, 44 | "prediction_type": "v_prediction", 45 | "set_alpha_to_one": False, 46 | "skip_prk_steps": True, 47 | "steps_offset": 1, 48 | "trained_betas": None, 49 | } 50 | ) 51 | timesteps = 20 52 | guidance_scale = jnp.array([7.5], dtype=jnp.bfloat16) 53 | 54 | unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( 55 | "character-aware-diffusion/charred", 56 | dtype=jnp.bfloat16, 57 | ) 58 | 59 | vae, vae_params = FlaxAutoencoderKL.from_pretrained( 60 | "flax/stable-diffusion-2-1", 61 | subfolder="vae", 62 | dtype=jnp.bfloat16, 63 | ) 64 | vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) 65 | 66 | image_width = image_height = 256 67 | 68 | # Generating latent shape 69 | latent_shape = ( 70 | negative_prompt_text_encoder_hidden_states.shape[ 71 | 0 72 | ], # TODO: if is this for the whole context (positive + negative prompts), we should multiply by two 73 | unet.in_channels, 74 | image_width // vae_scale_factor, 75 | image_height // vae_scale_factor, 76 | ) 77 | 78 | def __tokenize_prompt(prompt: str): 79 | return tokenizer( 80 | text=prompt, 81 | max_length=1024, 82 | padding="max_length", 83 | truncation=True, 84 | return_tensors="jax", 85 | ).input_ids 86 | 87 | def __convert_image(image): 88 | # create PIL image from JAX tensor converted to numpy 89 | return Image.fromarray(np.asarray(image), mode="RGB") 90 | 91 | def __get_context(tokenized_prompt: jnp.array): 92 | # Get the text embedding 93 | text_encoder_hidden_states = text_encoder( 94 | tokenized_prompt, 95 | params=text_encoder_params, 96 | train=False, 97 | )[0] 98 | 99 | # context = empty negative prompt embedding + prompt embedding 100 | return jnp.concatenate( 101 | [negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states] 102 | ) 103 | 104 | def __predict_image(context: jnp.array): 105 | def ___timestep(step, step_args): 106 | latents, scheduler_state = step_args 107 | 108 | t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] 109 | 110 | # For classifier-free guidance, we need to do two forward passes. 111 | # Here we concatenate the unconditional and text embeddings into a single batch 112 | # to avoid doing two forward passes 113 | latent_input = jnp.concatenate([latents] * 2) 114 | 115 | timestep = jnp.broadcast_to(t, latent_input.shape[0]) 116 | 117 | scaled_latent_input = scheduler.scale_model_input( 118 | scheduler_state, latent_input, t 119 | ) 120 | 121 | # predict the noise residual 122 | unet_prediction_sample = unet.apply( 123 | {"params": unet_params}, 124 | jnp.array(scaled_latent_input), 125 | jnp.array(timestep, dtype=jnp.int32), 126 | context, 127 | ).sample 128 | 129 | # perform guidance 130 | unet_prediction_sample_uncond, unet_prediction_text = jnp.split( 131 | unet_prediction_sample, 2, axis=0 132 | ) 133 | guided_unet_prediction_sample = ( 134 | unet_prediction_sample_uncond 135 | + guidance_scale 136 | * (unet_prediction_text - unet_prediction_sample_uncond) 137 | ) 138 | 139 | # compute the previous noisy sample x_t -> x_t-1 140 | latents, scheduler_state = scheduler.step( 141 | scheduler_state, guided_unet_prediction_sample, t, latents 142 | ).to_tuple() 143 | 144 | return latents, scheduler_state 145 | 146 | # initialize scheduler state 147 | initial_scheduler_state = scheduler.set_timesteps( 148 | scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape 149 | ) 150 | 151 | # initialize latents 152 | initial_latents = ( 153 | jax.random.normal( 154 | jax.random.PRNGKey(seed), shape=latent_shape, dtype=jnp.bfloat16 155 | ) 156 | * initial_scheduler_state.init_noise_sigma 157 | ) 158 | 159 | final_latents, _ = jax.lax.fori_loop( 160 | 0, timesteps, ___timestep, (initial_latents, initial_scheduler_state) 161 | ) 162 | 163 | vae_output = vae.apply( 164 | {"params": vae_params}, 165 | 1 / vae.config.scaling_factor * final_latents, 166 | method=vae.decode, 167 | ).sample 168 | 169 | # return 8 bit RGB image (width, height, rgb) 170 | return ( 171 | ((vae_output / 2 + 0.5).transpose(0, 2, 3, 1).clip(0, 1) * 255) 172 | .round() 173 | .astype(jnp.uint8)[0] 174 | ) 175 | 176 | jax_jit_compiled_accel_predict_image = jax.jit(__predict_image) 177 | 178 | jax_jit_compiled_cpu_get_context = jax.jit( 179 | __get_context, device=jax.devices(backend="cpu")[0] 180 | ) 181 | 182 | return lambda prompt: __convert_image( 183 | jax_jit_compiled_accel_predict_image( 184 | jax_jit_compiled_cpu_get_context(__tokenize_prompt(prompt)) 185 | ) 186 | ) 187 | 188 | 189 | if __name__ == "__main__": 190 | wandb_inference_init() 191 | 192 | generate_image_for_prompt = get_inference_lambda(87) 193 | 194 | prompts = [ 195 | "a white car", 196 | "a running shoe", 197 | "a forest", 198 | "two people", 199 | "a happy cartoon cat", 200 | ] 201 | 202 | log = [] 203 | 204 | for prompt in prompts: 205 | log.append({"prompt": prompt, "image": generate_image_for_prompt(prompt)}) 206 | 207 | wandb_inference_log(log) 208 | 209 | wandb_close() 210 | -------------------------------------------------------------------------------- /inference_jax.sh: -------------------------------------------------------------------------------- 1 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=34359738368 2 | #export JAX_PLATFORMS="" 3 | #export JAX_PLATFORMS="cpu" 4 | export TF_CPP_MIN_LOG_LEVEL=2 5 | python3 ./inference_jax.py -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax 3 | 4 | from diffusers import FlaxDDPMScheduler 5 | 6 | # Min-SNR 7 | snr_gamma = 5.0 # SNR weighting gamma to be used when rebalancing the loss with Min-SNR. Recommended value is 5.0. 8 | 9 | 10 | def compute_snr_loss_weights(noise_scheduler_state, timesteps): 11 | """ 12 | Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 13 | """ 14 | alphas_cumprod = noise_scheduler_state.common.alphas_cumprod 15 | sqrt_alphas_cumprod = alphas_cumprod ** 0.5 16 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 17 | 18 | alpha = sqrt_alphas_cumprod[timesteps] 19 | sigma = sqrt_one_minus_alphas_cumprod[timesteps] 20 | 21 | # Compute SNR. 22 | snr = jnp.array((alpha / sigma) ** 2) 23 | 24 | # Compute SNR loss weights 25 | return jnp.where(snr < snr_gamma, snr, jnp.ones_like(snr) * snr_gamma) / snr 26 | 27 | 28 | def get_vae_latent_distribution_samples( 29 | latents, 30 | sample_rng, 31 | noise_scheduler, 32 | noise_scheduler_state, 33 | ): 34 | 35 | # Sample noise that we'll add to the latents 36 | noise_rng, timestep_rng = jax.random.split(sample_rng) 37 | noise = jax.random.normal(noise_rng, latents.shape) 38 | 39 | # Sample a random timestep for each image 40 | timesteps = jax.random.randint( 41 | key=timestep_rng, 42 | shape=(latents.shape[0],), 43 | minval=0, 44 | maxval=noise_scheduler.config.num_train_timesteps, 45 | dtype=jnp.int32, 46 | ) 47 | 48 | # Add noise to the latents according to the noise magnitude at each timestep 49 | # (this is the forward diffusion process) 50 | noisy_latents = noise_scheduler.add_noise( 51 | state=noise_scheduler_state, 52 | original_samples=latents, 53 | noise=noise, 54 | timesteps=timesteps, 55 | ) 56 | 57 | return noisy_latents, timesteps, noise 58 | 59 | 60 | def get_cacheable_samples( 61 | text_encoder, text_encoder_params, input_ids, vae, vae_params, pixel_values, rng 62 | ): 63 | 64 | # Get the text embedding 65 | # TODO: Cache this 66 | # TODO: use t5x library 67 | text_encoder_hidden_states = text_encoder( 68 | input_ids, 69 | params=text_encoder_params, 70 | train=False, 71 | )[0] 72 | 73 | # Get the image embedding 74 | vae_outputs = vae.apply( 75 | {"params": vae_params}, 76 | sample=pixel_values, 77 | deterministic=True, 78 | method=vae.encode, 79 | ) 80 | 81 | # Sample the image embedding 82 | # TODO: Cache this 83 | image_latent_distribution_sampling = ( 84 | # (NHWC) -> (NCHW) 85 | jnp.transpose(vae_outputs.latent_dist.sample(rng), (0, 3, 1, 2)) 86 | * vae.config.scaling_factor 87 | ) 88 | 89 | return text_encoder_hidden_states, image_latent_distribution_sampling 90 | 91 | 92 | def get_compute_losses_lambda( 93 | text_encoder, # <-- TODO: take this out of here 94 | text_encoder_params, # <-- TODO: take this out of here 95 | vae, # <-- TODO: take this out of here 96 | vae_params, # <-- TODO: take this out of here 97 | unet, 98 | ): 99 | 100 | # Instanciate training noise scheduler 101 | # TODO: write pure function 102 | noise_scheduler = FlaxDDPMScheduler( 103 | beta_start=0.00085, 104 | beta_end=0.012, 105 | beta_schedule="scaled_linear", 106 | prediction_type="epsilon", 107 | num_train_timesteps=1000, 108 | ) 109 | 110 | def __compute_losses_lambda( 111 | state_params, 112 | batch, 113 | sample_rng, 114 | ): 115 | 116 | # TODO: take this out of here 117 | ( 118 | text_encoder_hidden_states, 119 | image_latent_distribution_sampling, 120 | ) = get_cacheable_samples( 121 | text_encoder, 122 | text_encoder_params, 123 | batch["input_ids"], 124 | vae, 125 | vae_params, 126 | batch["pixel_values"], 127 | sample_rng, 128 | ) 129 | 130 | # initialize scheduler state 131 | noise_scheduler_state = noise_scheduler.create_state() 132 | 133 | # Get the vae latent distribution samples 134 | ( 135 | image_sampling_noisy_input, 136 | image_sampling_timesteps, 137 | noise, 138 | ) = get_vae_latent_distribution_samples( 139 | image_latent_distribution_sampling, 140 | sample_rng, 141 | noise_scheduler, 142 | noise_scheduler_state, 143 | ) 144 | 145 | # Predict the noise residual and compute loss 146 | # TODO: write pure function 147 | unet_predictions = unet.apply( 148 | {"params": state_params}, 149 | sample=image_sampling_noisy_input, 150 | timesteps=image_sampling_timesteps, 151 | encoder_hidden_states=text_encoder_hidden_states, 152 | train=True, 153 | ).sample 154 | 155 | # Compute each batch sample's loss from noisy target 156 | loss_tensors = (noise - unet_predictions) ** 2 157 | 158 | # Compute Min-SNR loss weights 159 | snr_loss_weights = compute_snr_loss_weights( 160 | noise_scheduler_state, 161 | image_sampling_timesteps, 162 | ) 163 | 164 | # Get one loss scalar per batch sample 165 | losses = ( 166 | loss_tensors.mean( 167 | axis=tuple(range(1, loss_tensors.ndim)), 168 | ) 169 | * snr_loss_weights 170 | ) # Balance losses with Min-SNR 171 | 172 | # This must be an averaged scalar, otherwise, you get this: 173 | # TypeError: Gradient only defined for scalar-output functions. Output had shape: (8,). 174 | return losses.mean(axis=0) 175 | 176 | return __compute_losses_lambda 177 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # jax/flax 4 | import jax 5 | from flax.core.frozen_dict import unfreeze 6 | from flax.training import train_state 7 | 8 | from architecture import setup_model 9 | 10 | # internal code 11 | from args import parse_args 12 | from optimizer import setup_optimizer 13 | from training_loop import training_loop 14 | from monitoring import wandb_close, wandb_init 15 | 16 | 17 | def main(): 18 | args = parse_args() 19 | 20 | # number of splits/partitions/devices/shards 21 | num_devices = jax.local_device_count() 22 | 23 | output_dir = args.output_dir 24 | 25 | load_pretrained = os.path.exists(output_dir) and os.path.isdir(output_dir) 26 | 27 | # Setup WandB for logging & tracking 28 | log_wandb = args.log_wandb 29 | if log_wandb: 30 | wandb_init(args, num_devices) 31 | 32 | # init random number generator 33 | seed = args.seed 34 | seed_rng = jax.random.PRNGKey(seed) 35 | rng, training_from_scratch_rng_params = jax.random.split(seed_rng) 36 | print("random generator setup...") 37 | 38 | # Pretrained/freezed and training model setup 39 | text_encoder, text_encoder_params, vae, vae_params, unet, unet_params = setup_model( 40 | seed, 41 | load_pretrained, 42 | output_dir, 43 | training_from_scratch_rng_params, 44 | ) 45 | print("models setup...") 46 | 47 | # Optimization & scheduling setup 48 | optimizer = setup_optimizer( 49 | args.learning_rate, 50 | args.adam_beta1, 51 | args.adam_beta2, 52 | args.adam_epsilon, 53 | args.adam_weight_decay, 54 | args.max_grad_norm, 55 | ) 56 | print("optimizer setup...") 57 | 58 | # Training state setup 59 | unet_training_state = train_state.TrainState.create( 60 | apply_fn=unet, 61 | params=unfreeze(unet_params), 62 | tx=optimizer, 63 | ) 64 | print("training state initialized...") 65 | 66 | if log_wandb: 67 | get_validation_predictions = None # TODO: put validation here 68 | else: 69 | get_validation_predictions = None 70 | 71 | # JAX device data replication 72 | # replicated_state = replicate(unet_training_state) 73 | # NOTE: # These can't be replicated here, otherwise, you get this whenever they are used: flax.errors.ScopeParamShapeError: Initializer expected to generate shape (4, 384, 1536) but got shape (384, 1536) instead for parameter "embedding" in "/shared". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError) 74 | # replicated_text_encoder_params = jax_utils.replicate(text_encoder_params) 75 | # replicated_vae_params = jax_utils.replicate(vae_params) 76 | # print("states & params replicated to TPUs...") 77 | 78 | # Train! 79 | print("Training loop init...") 80 | training_loop( 81 | text_encoder, 82 | text_encoder_params, 83 | vae, 84 | vae_params, 85 | unet, 86 | unet_training_state, 87 | rng, 88 | args.max_train_steps, 89 | args.num_train_epochs, 90 | args.train_batch_size, 91 | output_dir, 92 | log_wandb, 93 | get_validation_predictions, 94 | num_devices, 95 | ) 96 | print("Training loop done...") 97 | 98 | if log_wandb: 99 | wandb_close() 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /monitoring.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | 4 | 5 | def wandb_inference_init(): 6 | wandb.init( 7 | entity="charred", 8 | project="charred-inference", 9 | job_type="inference", 10 | ) 11 | 12 | print("WandB inference init...") 13 | 14 | 15 | def wandb_inference_log(log: list): 16 | wandb_log = [] 17 | 18 | for entry in log: 19 | wandb_log.append(wandb.Image(entry["image"], caption=entry["prompt"])) 20 | 21 | wandb.log({"inference": wandb_log}) 22 | 23 | print("WandB inference log...") 24 | 25 | 26 | def wandb_init(args, num_devices): 27 | wandb.init( 28 | entity="charred", 29 | project="charred", 30 | job_type="train", 31 | config=args, 32 | ) 33 | wandb.config.update( 34 | { 35 | "num_devices": num_devices, 36 | } 37 | ) 38 | wandb.define_metric("*", step_metric="step") 39 | wandb.define_metric("step", step_metric="walltime") 40 | 41 | print("WandB setup...") 42 | 43 | 44 | def wandb_close(): 45 | wandb.finish() 46 | 47 | print("WandB closed...") 48 | 49 | 50 | def get_wandb_log_batch_lambda( 51 | get_predictions, 52 | ): 53 | def __wandb_log_batch( 54 | global_walltime, 55 | global_training_steps, 56 | delta_time, 57 | epoch, 58 | loss, 59 | unet_params, 60 | is_milestone, 61 | ): 62 | 63 | log_data = { 64 | "walltime": global_walltime, 65 | "step": global_training_steps, 66 | "batch_delta_time": delta_time, 67 | "epoch": epoch, 68 | "loss": loss.mean(), 69 | } 70 | 71 | if is_milestone and get_predictions is not None: 72 | log_data["validation"] = [ 73 | wandb.Image(image, caption=prompt) 74 | for prompt, image in get_predictions(unet_params) 75 | ] 76 | 77 | wandb.log( 78 | data=log_data, 79 | commit=True, 80 | ) 81 | 82 | return __wandb_log_batch 83 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | import optax 2 | 3 | 4 | def setup_optimizer( 5 | learning_rate, 6 | adam_beta1, 7 | adam_beta2, 8 | adam_epsilon, 9 | adam_weight_decay, 10 | max_grad_norm, 11 | ): 12 | constant_scheduler = optax.constant_schedule(learning_rate) 13 | 14 | adamw = optax.adamw( 15 | learning_rate=constant_scheduler, 16 | b1=adam_beta1, 17 | b2=adam_beta2, 18 | eps=adam_epsilon, 19 | weight_decay=adam_weight_decay, 20 | ) 21 | 22 | return optax.chain( 23 | optax.clip_by_global_norm(max_grad_norm), 24 | adamw, 25 | ) 26 | -------------------------------------------------------------------------------- /repository.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from threading import Thread 4 | 5 | # hugging face 6 | from huggingface_hub import upload_folder 7 | 8 | 9 | def create_output_dir(output_dir): 10 | if output_dir is not None: 11 | os.makedirs(output_dir, exist_ok=True) 12 | 13 | 14 | def save_to_local_directory( 15 | output_dir, 16 | unet, 17 | unet_params, 18 | ): 19 | print("saving trained weights...") 20 | unet.save_pretrained( 21 | save_directory=output_dir, 22 | params=unet_params, 23 | ) 24 | print("trained weights saved...") 25 | 26 | 27 | def save_to_repository( 28 | output_dir, 29 | unet, 30 | unet_params, 31 | repo_id, 32 | ): 33 | print("saving trained weights...") 34 | unet.save_pretrained( 35 | save_directory=output_dir, 36 | params=unet_params, 37 | ) 38 | print("trained weights saved...") 39 | 40 | Thread( 41 | target=lambda: upload_to_repository( 42 | repo_id, 43 | output_dir, 44 | "End of training epoch.", 45 | ) 46 | ).start() 47 | 48 | 49 | def upload_to_repository( 50 | output_dir, 51 | repo_id, 52 | commit_message, 53 | ): 54 | upload_folder( 55 | repo_id=repo_id, 56 | folder_path=output_dir, 57 | commit_message=commit_message, 58 | ) 59 | 60 | 61 | if __name__ == "__main__": 62 | upload_to_repository( 63 | "/data/output.bak/000920", 64 | "character-aware-diffusion/charred", 65 | "Latest training epoch version as of Apr 28 11:03PM UST.", 66 | ) 67 | -------------------------------------------------------------------------------- /repository.sh: -------------------------------------------------------------------------------- 1 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=34359738368 2 | export TF_CPP_MIN_LOG_LEVEL=2 3 | python3 ./repository.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax[cuda12_local] 2 | flax 3 | optax 4 | diffusers 5 | transformers 6 | datasets 7 | --extra-index-url https://download.pytorch.org/whl/cpu 8 | torch 9 | torchvision 10 | wandb 11 | Pillow -------------------------------------------------------------------------------- /training.sh: -------------------------------------------------------------------------------- 1 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=34359738368 2 | #export JAX_PLATFORMS='' 3 | python3 main.py -------------------------------------------------------------------------------- /training_loop.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from jax import pmap 4 | from jax.random import split 5 | from flax.training.common_utils import shard 6 | from flax.jax_utils import replicate, unreplicate 7 | from jax.profiler import ( 8 | start_trace, 9 | stop_trace, 10 | save_device_memory_profile, 11 | device_memory_profile, 12 | ) 13 | from jaxlib.xla_extension import XlaRuntimeError 14 | 15 | from monitoring import get_wandb_log_batch_lambda 16 | from batch import setup_dataloader 17 | from dataset import setup_dataset 18 | from repository import save_to_local_directory 19 | from training_step import get_training_step_lambda 20 | 21 | 22 | def training_loop( 23 | text_encoder, 24 | text_encoder_params, 25 | vae, 26 | vae_params, 27 | unet, 28 | unet_training_state, 29 | rng, 30 | max_train_steps, 31 | num_train_epochs, 32 | train_batch_size, 33 | output_dir, 34 | log_wandb, 35 | get_validation_predictions, 36 | num_devices, 37 | ): 38 | 39 | # replication setup 40 | unet_training_state = replicate(unet_training_state) 41 | rng = split(rng, num_devices) 42 | 43 | # dataset setup 44 | train_dataset = setup_dataset(max_train_steps) 45 | print("dataset loaded...") 46 | 47 | # batch setup 48 | total_train_batch_size = train_batch_size * num_devices 49 | train_dataloader = setup_dataloader(train_dataset, total_train_batch_size) 50 | print("dataloader setup...") 51 | 52 | # milestone setup 53 | milestone_step_count = min(100_000, max_train_steps) 54 | print(f"milestone step count: {milestone_step_count}") 55 | 56 | # wandb setup 57 | if log_wandb: 58 | wandb_log_batch = get_wandb_log_batch_lambda( 59 | get_validation_predictions, 60 | ) 61 | print("wand log batch setup...") 62 | 63 | # Create parallel version of the train step 64 | # TODO: Should we try "axis_size=num_devices" or "axis_size=total_train_batch_size" ? 65 | jax_pmapped_training_step = pmap( 66 | # cannot send these as static broadcasted arguments because they are not hashable 67 | # TODO: rewrite text_encoder, vae and unet as pure 68 | fun=get_training_step_lambda( 69 | text_encoder, text_encoder_params, vae, vae_params, unet 70 | ), 71 | axis_name="batch", 72 | in_axes=(0, 0, 0), 73 | out_axes=(0, 0, 0), 74 | static_broadcasted_argnums=(), 75 | # We cannot donate the "batch" argument. Otherwise, we get this: 76 | # /site-packages/jax/_src/interpreters/mlir.py:711: UserWarning: Some donated buffers were not usable: ShapedArray(int32[8,1024]), ShapedArray(float32[8,3,512,512]). 77 | # See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation. 78 | # warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}") 79 | # donating rng and training state 80 | donate_argnums=( 81 | 0, 82 | 1, 83 | ), 84 | ) 85 | 86 | # Epoch setup 87 | t0 = time.monotonic() 88 | global_training_steps = 0 89 | global_walltime = time.monotonic() 90 | is_compilation_step = True 91 | is_first_compiled_step = False 92 | loss = None 93 | for epoch in range(num_train_epochs): 94 | 95 | for batch in train_dataloader: 96 | 97 | # getting batch start time 98 | batch_walltime = time.monotonic() 99 | 100 | if is_compilation_step: 101 | print("computing compilation batch...") 102 | # TODO: fix this: 2023-05-05 16:34:23.937383: E external/xla/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace 103 | device_memory_profile() 104 | start_trace( 105 | log_dir="./profiling/compilation_step", 106 | create_perfetto_link=False, 107 | create_perfetto_trace=True, 108 | ) 109 | elif is_first_compiled_step: 110 | print("computing first compiled batch...") 111 | device_memory_profile() 112 | start_trace( 113 | log_dir="./profiling/first_compiled_step", 114 | create_perfetto_link=False, 115 | create_perfetto_trace=True, 116 | ) 117 | 118 | # training step 119 | # TODO: Fix this jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program: Attempting to allocate 1.28G. That was not possible. There are 785.61M free.; (0x0x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well). 120 | try: 121 | 122 | unet_training_state, rng, loss = jax_pmapped_training_step( 123 | unet_training_state, 124 | rng, 125 | shard(batch), 126 | ) 127 | 128 | except XlaRuntimeError as e: 129 | 130 | if is_compilation_step: 131 | stop_trace() 132 | save_device_memory_profile( 133 | filename="./profiling/compilation_step/pprof_memory_profile_error.pb" 134 | ) 135 | print("compilation batch error...") 136 | elif is_first_compiled_step: 137 | stop_trace() 138 | save_device_memory_profile( 139 | filename="./profiling/first_compiled_step/pprof_memory_profile_error.pb" 140 | ) 141 | print("first compiled batch error...") 142 | 143 | raise (e) 144 | 145 | # block until train step has completed 146 | loss.block_until_ready() 147 | 148 | if is_compilation_step: 149 | stop_trace() 150 | save_device_memory_profile( 151 | filename="./profiling/compilation_step/pprof_memory_profile.pb" 152 | ) 153 | print("computed compilation batch...") 154 | elif is_first_compiled_step: 155 | stop_trace() 156 | save_device_memory_profile( 157 | filename="./profiling/first_compiled_step/pprof_memory_profile.pb" 158 | ) 159 | print("computed first compiled batch...") 160 | 161 | global_training_steps += num_devices 162 | 163 | # checking if current batch is a milestone 164 | is_milestone = ( 165 | True if global_training_steps % milestone_step_count == 0 else False 166 | ) 167 | 168 | if log_wandb: 169 | # TODO: is this correct? was only unreplicated before, with no averaging 170 | global_walltime = time.monotonic() - t0 171 | delta_time = time.monotonic() - batch_walltime 172 | wandb_log_batch( 173 | global_walltime, 174 | global_training_steps, 175 | delta_time, 176 | epoch, 177 | loss, 178 | unet_training_state.params, 179 | is_milestone, 180 | ) 181 | 182 | if is_milestone: 183 | save_to_local_directory( 184 | f"{ output_dir }/{ str(global_training_steps).zfill(12) }", 185 | unet, 186 | # TODO: is this ok? 187 | # was: jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) 188 | # then: jax.device_get(flax.jax_utils.unreplicate(state.params)) 189 | # and then, also: jax.device_get(state.params) 190 | # and then, again: unreplicate(state.params) 191 | # Finally found a way to average along the splits/device/partition/shard axis: jax.tree_util.tree_map(f=lambda x: x.mean(axis=0), tree=unet_training_state.params), 192 | unreplicate(tree=unet_training_state.params), 193 | ) 194 | 195 | if is_compilation_step: 196 | is_compilation_step = False 197 | is_first_compiled_step = True 198 | elif is_first_compiled_step: 199 | is_first_compiled_step = False 200 | -------------------------------------------------------------------------------- /training_step.py: -------------------------------------------------------------------------------- 1 | import jax 2 | 3 | from loss import get_compute_losses_lambda 4 | 5 | 6 | def get_training_step_lambda(text_encoder, text_encoder_params, vae, vae_params, unet): 7 | 8 | # Get loss function lambda 9 | # TODO: Are we copying all this static data on every batch, here? 10 | # TODO: Solution #1: avoid copying the static data at every batch 11 | # TODO: Solution #2: offload freezed model computing to CPU, at lease for the text encoding 12 | # Compile loss function. 13 | # NOTE: Can't have this compiled higher up because jax.value_and_grad-compiled functions require real numbers (floating point) dtypes as arguments 14 | jax_loss_value_and_gradient = jax.value_and_grad( 15 | fun=get_compute_losses_lambda( 16 | text_encoder, 17 | text_encoder_params, 18 | vae, 19 | vae_params, 20 | unet, 21 | ), 22 | argnums=0, 23 | ) 24 | 25 | def __training_step_lambda( 26 | state, 27 | rng, 28 | batch, 29 | ): 30 | 31 | # Split RNGs 32 | sample_rng, new_rng = jax.random.split(rng, 2) 33 | 34 | # Compute loss and gradients 35 | # TODO: why are we doing this here instead of in "value_and_grad" with "reduce_axes"? 36 | loss, grad = jax.lax.pmean( 37 | jax_loss_value_and_gradient( 38 | state.params, 39 | batch, 40 | sample_rng, 41 | ), 42 | axis_name="batch", 43 | ) 44 | 45 | # Apply gradients to training state 46 | new_state = state.apply_gradients(grads=grad) 47 | 48 | return new_state, new_rng, loss 49 | 50 | return __training_step_lambda 51 | -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import jax 4 | import jax 5 | import jax.numpy as jnp 6 | from flax.jax_utils import replicate 7 | from flax.training.common_utils import shard 8 | 9 | import numpy as np 10 | from PIL import Image 11 | 12 | from diffusers import ( 13 | FlaxAutoencoderKL, 14 | FlaxDPMSolverMultistepScheduler, 15 | FlaxUNet2DConditionModel, 16 | ) 17 | from transformers import ByT5Tokenizer 18 | 19 | from architecture import setup_model 20 | 21 | 22 | # TODO: try half-precision 23 | 24 | tokenized_prompt_max_length = 1024 25 | 26 | 27 | def tokenize_prompts(prompt: list[str]): 28 | return ByT5Tokenizer()( 29 | text=prompt, 30 | max_length=tokenized_prompt_max_length, 31 | padding="max_length", 32 | truncation=True, 33 | return_tensors="jax", 34 | ).input_ids 35 | 36 | 37 | def convert_images(images: jnp.ndarray): 38 | # create PIL image from JAX tensor converted to numpy 39 | return [Image.fromarray(np.asarray(image), mode="RGB") for image in images] 40 | 41 | 42 | def get_validation_predictions_lambda( 43 | vae: FlaxAutoencoderKL, 44 | vae_params, 45 | unet: FlaxUNet2DConditionModel, 46 | ): 47 | 48 | scheduler = FlaxDPMSolverMultistepScheduler.from_config( 49 | config={ 50 | "_diffusers_version": "0.16.0", 51 | "beta_end": 0.012, 52 | "beta_schedule": "scaled_linear", 53 | "beta_start": 0.00085, 54 | "clip_sample": False, 55 | "num_train_timesteps": 1000, 56 | "prediction_type": "v_prediction", 57 | "set_alpha_to_one": False, 58 | "skip_prk_steps": True, 59 | "steps_offset": 1, 60 | "trained_betas": None, 61 | } 62 | ) 63 | timesteps = 20 64 | 65 | vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) 66 | 67 | image_width = image_height = 256 68 | 69 | # Generating latent shape 70 | latent_shape = ( 71 | 1536, 72 | unet.in_channels, 73 | image_width // vae_scale_factor, 74 | image_height // vae_scale_factor, 75 | ) 76 | 77 | def __predict_images(seed, unet_params, encoded_prompts): 78 | def ___timestep(step, step_args): 79 | latents, scheduler_state = step_args 80 | 81 | t = jnp.asarray(scheduler_state.timesteps, dtype=jnp.int32)[step] 82 | 83 | timestep = jnp.array(jnp.broadcast_to(t, latents.shape[0]), dtype=jnp.int32) 84 | 85 | scaled_latent_input = jnp.array( 86 | scheduler.scale_model_input(scheduler_state, latents, t) 87 | ) 88 | 89 | # predict the noise residual 90 | unet_prediction_sample = unet.apply( 91 | {"params": unet_params}, 92 | sample=scaled_latent_input, 93 | timesteps=timestep, 94 | encoder_hidden_states=encoded_prompts, 95 | train=False, 96 | ).sample 97 | 98 | # compute the previous noisy sample x_t -> x_t-1 99 | return scheduler.step( 100 | scheduler_state, unet_prediction_sample, t, latents 101 | ).to_tuple() 102 | 103 | # initialize scheduler state 104 | initial_scheduler_state = scheduler.set_timesteps( 105 | scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape 106 | ) 107 | 108 | # initialize latents 109 | initial_latents = ( 110 | jax.random.normal( 111 | jax.random.PRNGKey(seed), shape=latent_shape, dtype=jnp.float32 112 | ) 113 | * initial_scheduler_state.init_noise_sigma 114 | ) 115 | 116 | # get denoises latents 117 | final_latents, _ = jax.lax.fori_loop( 118 | 0, timesteps, ___timestep, (initial_latents, initial_scheduler_state) 119 | ) 120 | 121 | # scale latents 122 | scaled_final_latents = 1 / vae.config.scaling_factor * final_latents 123 | 124 | # get image from latents 125 | vae_output = vae.apply( 126 | {"params": vae_params}, 127 | latents=scaled_final_latents, 128 | deterministic=True, 129 | method=vae.decode, 130 | ).sample 131 | 132 | # return 8 bit RGB image (width, height, rgb) 133 | return ( 134 | ( 135 | (vae_output / 2 + 0.5) # TODO: find out why this is necessary 136 | .transpose( 137 | 0, 2, 3, 1 138 | ) # (batch, channel, height, width) => (batch, height, width, channel) 139 | .clip(0, 1) 140 | * 255 141 | ) 142 | .round() 143 | .astype(jnp.uint8) 144 | ) 145 | 146 | return lambda seed, unet_params, encoded_prompts: __predict_images( 147 | seed, unet_params, encoded_prompts 148 | ) 149 | 150 | 151 | if __name__ == "__main__": 152 | # Pretrained/freezed and training model setup 153 | text_encoder, text_encoder_params, vae, vae_params, unet, unet_params = setup_model( 154 | 43, # seed 155 | None, # dtype (defaults to float32) 156 | True, # load pre-trained 157 | "character-aware-diffusion/charred", 158 | None, 159 | ) 160 | # validation prompts 161 | validation_prompts = [ 162 | "a white car", 163 | "une voiture blanche", 164 | "a running shoe", 165 | "une chaussure de course", 166 | "a perfumer and his perfume organ", 167 | "un parfumeur et son orgue à parfums", 168 | "two people", 169 | "deux personnes", 170 | "a happy cartoon cat", 171 | "un dessin de chat heureux", 172 | "a city skyline", 173 | "un panorama urbain", 174 | "a Marilyn Monroe portrait", 175 | "un portrait de Marilyn Monroe", 176 | "a rainy day in London", 177 | "Londres sous la pluie", 178 | ] 179 | 180 | tokenized_prompts = tokenize_prompts(validation_prompts) 181 | 182 | encoded_prompts = text_encoder( 183 | tokenized_prompts, 184 | params=text_encoder_params, 185 | train=False, 186 | )[0] 187 | 188 | validation_predictions_lambda = get_validation_predictions_lambda( 189 | vae, 190 | vae_params, 191 | unet, 192 | ) 193 | 194 | get_validation_predictions = jax.pmap( 195 | fun=validation_predictions_lambda, 196 | axis_name="encoded_prompts", 197 | donate_argnums=(), 198 | ) 199 | 200 | image_predictions = get_validation_predictions( 201 | replicate(2), replicate(unet_params), shard(encoded_prompts) 202 | ) 203 | 204 | images = convert_images(image_predictions) 205 | -------------------------------------------------------------------------------- /validation.sh: -------------------------------------------------------------------------------- 1 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=34359738368 2 | #export JAX_PLATFORMS="" 3 | export TF_CPP_MIN_LOG_LEVEL=2 4 | python3 ./validation.py --------------------------------------------------------------------------------