├── .gitignore
├── .markdownlint.json
├── README.md
├── architecture.py
├── args.py
├── batch.py
├── dataset.py
├── dataset.sh
├── docs
├── POSTMORTEM.md
└── ROADMAP.md
├── inference_jax.py
├── inference_jax.sh
├── loss.py
├── main.py
├── monitoring.py
├── optimizer.py
├── repository.py
├── repository.sh
├── requirements.txt
├── training.sh
├── training_loop.py
├── training_step.py
├── validation.py
└── validation.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .cache
3 | dataset-cache
4 | wandb
5 | Pipfile.lock
6 | Pipfile
7 | .DS_Store
8 |
--------------------------------------------------------------------------------
/.markdownlint.json:
--------------------------------------------------------------------------------
1 | {
2 | "no-inline-html": {
3 | "allowed_elements": [ "h1", "img" ]
4 | },
5 | "line-length": {
6 | "line_length": 2048
7 | }
8 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

4 |
5 | # CHARacter-awaRE Diffusion: Multilingual Character-Aware Encoders for Font-Aware Diffusers That Can Actually Spell
6 |
7 | Tired of text-to-image models that can't spell or deal with fonts and typography correctly ? [The secret seems to be in the use of multilingual, tokenization-free, character-aware transformer encoders](https://arxiv.org/abs/2212.10562) such as [ByT5](https://arxiv.org/abs/2105.13626) and [CANINE-c](https://arxiv.org/abs/2103.06874).
8 |
9 | ## Replace CLIP with ByT5 in HF's `text-to-image` Pipeline
10 |
11 | AS part of the [Hugging Face JAX Diffuser Sprint](https://github.com/huggingface/community-events/tree/main/jax-controlnet-sprint), we will replace [CLIP](https://arxiv.org/abs/2103.00020)'s tokenizer and encoder with [ByT5](https://arxiv.org/abs/2105.13626)'s in the [HF's JAX/FLAX text-to-image pre-training code](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py) and run it on the sponsored TPU ressources provided by Google for the event.
12 |
13 | More specifically, here are the main tasks we will try to accomplish during the sprint:
14 |
15 | - Pre-training dataset preparation: we are NOT going to train on `lambdalabs/pokemon-blip-captions`. So what is it going to be, what are the options? [Anything in here](https://analyticsindiamag.com/top-used-datasets-for-text-to-image-synthesis-models/) or [here](https://github.com/Yutong-Zhou-cv/Awesome-Text-to-Image#head3) takes your fancy? Or maybe [DiffusionDB](https://poloclub.github.io/diffusiondb/)? Or a savant mix of many datasets? We probably will need to combine many datasets as we are looking to cover these requirements:
16 | - We need samples for which there is text in the scene that is explicitely specified in the caption and the priority is to do that in full scene photos. If we can't find enough, we will integrate more specialized datasets for OCR;
17 | - Approximately the same language distribution as ByT5, but also include indonesian (not in ByT5) to see how character-awareness works when text in the prompt is specified in a language. We need to build testing facilities around the languages that are spoken by team members and friends: indonesian, japanese, french, amharic, arabic, norwegian, swedish, hindi, urdu and english.
18 |
19 | We shoud use the [Hugging Face Datasets library](https://huggingface.co/docs/datasets) as much as possible since it [supports JAX out of the box](https://huggingface.co/docs/datasets/en/use_with_jax). For simplicity's sake we will limit us to [concatenated](https://huggingface.co/docs/datasets/en/process#concatenate) Hugging Face datasets such as LAION2B [EN](https://huggingface.co/datasets/laion/laion2B-en), [MULTI](https://huggingface.co/datasets/laion/laion2B-multi) and [NOLANG](https://huggingface.co/datasets/laion/laion1B-nolang). We shall, however [pre-load](https://huggingface.co/docs/datasets/en/loading), [pre-process](https://huggingface.co/docs/datasets/en/loading) and [cache](https://huggingface.co/docs/datasets/en/cache) the dataset on disk before training on it.
20 | - Improvements to the [original code](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py):
21 | - ~~Make sure we can run the original code as-is on the TPU VM.~~
22 | - Audit and optimize the code for the Google Cloud TPU v4-8 VM: [`jnp`](https://jax.readthedocs.io/en/latest/jax.numpy.html) (instead of np) [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html), [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html), [`pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) everywhere! And we should make sure we do not miss any [optimization made in the sprint code](https://github.com/huggingface/community-events/blob/main/jax-controlnet-sprint/training_scripts/train_controlnet_flax.py ) either.
23 | - Instrumentation for TPU remote monitoring with [Open Telemetry](https://opentelemetry.io/docs/instrumentation/python/), [TensorBoard](https://www.tensorflow.org/tensorboard/), [Perfetto](https://perfetto.dev), [Weights & Biases](https://wandb.ai) and [JAX's own profiler](https://jax.readthedocs.io/en/latest/profiling.html).
24 | - Implement checkpoint milestone snapshot uploading to cloud storage: we need to be able to download the model for local inference benchmarking to make sure we are on the right track. There seems to be [rudimentary checkpoint support in the original code](https://huggingface.co/docs/diffusers/training/text2image#save-and-load-checkpoints).
25 | - ~~No time for politics. NSFW filtering will be turned off. So we get `FlaxStableDiffusionSafetyChecker` out of the way.~~
26 | - Replace CLIP with ByT5 in [original code](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py):
27 | - ~~Replacing `CLIPTokenizer` with `ByT5Tokenizer`. Since this will run on the CPUs, there is no need for JAX/FLAX unless there is hope for huge performance improvements. This should be trivial.~~ Merged. Needs testing.
28 | - ~~Replacing `FlaxCLIPTextModel` with `FlaxT5EncoderModel`. This *might* be almost as easy as replacing the tokenizer.~~ Merged. Needs testing.
29 | - ~~Rewrite `CLIPImageProcessor` for ByT5. This is still under investigation. It's unclear how hard it will be.~~ Done. Needs testing.
30 | - ~~RAdapt `FlaxAutoencoderKL` and `FlaxUNet2DConditionModel` for ByT5 if necessary.~~ Done. Needs testing.
31 | - ~~Break down the main pretraining loop into many functions in different source files for readability and easier maintenance.~~
32 |
33 | ## Introducing a Calligraphic & Typographic ControlNet
34 |
35 | Secondly, we will integrate to the above a [Hugging-Face JAX/FLAX ControlNet implementation](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) for better typographic control over the generated images. To the orthographically-enanced SD above and as per [Peter von Platen](https://github.com/patrickvonplaten)'s suggestion, we also introduce the idea a typographic [ControlNet](https://arxiv.org/abs/2302.05543) trained on an synthetic dataset of images paired with multilingual specifications of the textual content, font taxonomy, weight, kerning, leading, slant and any other typographic attribute supported by the [CSS3](https://www.w3.org/Style/CSS/) [Text](https://www.w3.org/TR/css-text-3/), [Fonts](https://www.w3.org/TR/css-fonts-3) and [Writing Modes](https://www.w3.org/TR/css-writing-modes-3/) modules, as implemented by the latest version of [Chromium](https://www.chromium.org/Home/).
36 |
--------------------------------------------------------------------------------
/architecture.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import jax.numpy as jnp
4 |
5 | from transformers import FlaxT5ForConditionalGeneration, set_seed
6 |
7 | from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel
8 |
9 |
10 | def setup_model(
11 | seed,
12 | load_pretrained,
13 | output_dir,
14 | training_from_scratch_rng_params,
15 | ):
16 | set_seed(seed)
17 |
18 | # Load models and create wrapper for stable diffusion
19 |
20 | language_model = FlaxT5ForConditionalGeneration.from_pretrained(
21 | "/data/byt5-base",
22 | dtype=jnp.float32,
23 | )
24 |
25 | vae, vae_params = FlaxAutoencoderKL.from_pretrained(
26 | "/data/stable-diffusion-2-1-vae",
27 | dtype=jnp.float32,
28 | )
29 |
30 | if load_pretrained:
31 | if os.path.isdir(output_dir):
32 | # find latest epoch output
33 | pretrained_dir = [
34 | dir
35 | for dir in os.listdir(output_dir).sort(reverse=True)
36 | if os.path.isdir(os.path.join(output_dir, dir))
37 | ][0]
38 | else:
39 | pretrained_dir = output_dir
40 |
41 | unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
42 | pretrained_dir,
43 | dtype=jnp.float32,
44 | )
45 |
46 | print("loaded unet from pre-trained...")
47 | else:
48 | unet = FlaxUNet2DConditionModel.from_config(
49 | config={
50 | "_diffusers_version": "0.16.0",
51 | "attention_head_dim": [5, 10, 20, 20],
52 | "block_out_channels": [320, 640, 1280, 1280],
53 | "cross_attention_dim": 1536,
54 | "down_block_types": [
55 | "CrossAttnDownBlock2D",
56 | "CrossAttnDownBlock2D",
57 | "CrossAttnDownBlock2D",
58 | "DownBlock2D",
59 | ],
60 | "dropout": 0.0,
61 | "flip_sin_to_cos": True,
62 | "freq_shift": 0,
63 | "in_channels": 4,
64 | "layers_per_block": 2,
65 | "only_cross_attention": False,
66 | "out_channels": 4,
67 | "sample_size": 64,
68 | "up_block_types": [
69 | "UpBlock2D",
70 | "CrossAttnUpBlock2D",
71 | "CrossAttnUpBlock2D",
72 | "CrossAttnUpBlock2D",
73 | ],
74 | "use_linear_projection": True,
75 | },
76 | dtype=jnp.float32,
77 | )
78 | unet_params = unet.init_weights(rng=training_from_scratch_rng_params)
79 | print("training unet from scratch...")
80 |
81 | return (
82 | language_model.encode,
83 | language_model.params,
84 | vae,
85 | vae_params,
86 | unet,
87 | unet_params,
88 | )
89 |
90 |
91 | if __name__ == "__main__":
92 | language_model = FlaxT5ForConditionalGeneration.from_pretrained(
93 | "google/byt5-base",
94 | dtype=jnp.float32,
95 | )
96 |
97 | language_model.save_pretrained("/data/byt5-base")
98 |
99 | vae, vae_params = FlaxAutoencoderKL.from_pretrained(
100 | "flax/stable-diffusion-2-1",
101 | subfolder="vae",
102 | dtype=jnp.float32,
103 | )
104 |
105 | vae.save_pretrained("/data/stable-diffusion-2-1-vae", params=vae_params)
106 |
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def parse_args():
5 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
6 | parser.add_argument(
7 | "--output_dir",
8 | type=str,
9 | default="/data/output",
10 | help="The output directory where the model predictions and checkpoints will be written.",
11 | )
12 | parser.add_argument(
13 | "--cache_dir",
14 | type=str,
15 | default="/data/dataset/cache",
16 | help="The directory where the downloaded models and datasets will be stored.",
17 | )
18 | parser.add_argument(
19 | "--seed", type=int, default=0, help="A seed for reproducible training."
20 | )
21 | parser.add_argument(
22 | "--resolution",
23 | type=int,
24 | default=1024,
25 | help=(
26 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
27 | " resolution"
28 | ),
29 | )
30 | parser.add_argument(
31 | "--train_batch_size",
32 | type=int,
33 | default=4,
34 | help="Batch size (per device) for the training dataloader.",
35 | )
36 | parser.add_argument("--num_train_epochs", type=int, default=1_000_000)
37 | parser.add_argument(
38 | "--max_train_steps",
39 | type=int,
40 | default=1_048_577, # 33_554_432, 67_108_864, 134_217_728, 1_073_741_824, 2_147_483_648, 17_179_869_184
41 | help="Total number of training steps per epoch to perform.",
42 | )
43 | parser.add_argument(
44 | "--learning_rate",
45 | type=float,
46 | default=1e-4,
47 | help="Initial learning rate (after the potential warmup period) to use.",
48 | )
49 | parser.add_argument(
50 | "--adam_beta1",
51 | type=float,
52 | default=0.9,
53 | help="The beta1 parameter for the Adam optimizer.",
54 | )
55 | parser.add_argument(
56 | "--adam_beta2",
57 | type=float,
58 | default=0.999,
59 | help="The beta2 parameter for the Adam optimizer.",
60 | )
61 | parser.add_argument(
62 | "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
63 | )
64 | parser.add_argument(
65 | "--adam_epsilon",
66 | type=float,
67 | default=1e-08,
68 | help="Epsilon value for the Adam optimizer",
69 | )
70 | parser.add_argument(
71 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
72 | )
73 | parser.add_argument(
74 | "--log_wandb",
75 | type=bool,
76 | default=True,
77 | choices=[True, False],
78 | help=("Whether to use WandB to log the metrics or not"),
79 | )
80 |
81 | args = parser.parse_args()
82 |
83 | return args
84 |
--------------------------------------------------------------------------------
/batch.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def setup_dataloader(dataset, batch_size):
5 | def _collate(samples):
6 | # TODO: replace torch.stack with https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.stack.html
7 |
8 | pixel_values = (
9 | torch.stack([sample["pixel_values"] for sample in samples])
10 | .to(memory_format=torch.contiguous_format)
11 | .float()
12 | .numpy()
13 | )
14 |
15 | input_ids = (
16 | torch.stack([sample["input_ids"] for sample in samples])
17 | .to(memory_format=torch.contiguous_format)
18 | .numpy()
19 | )
20 |
21 | return {
22 | "pixel_values": pixel_values,
23 | "input_ids": input_ids,
24 | }
25 |
26 | return torch.utils.data.DataLoader(
27 | dataset,
28 | shuffle=True,
29 | collate_fn=_collate,
30 | batch_size=batch_size,
31 | num_workers=4,
32 | drop_last=True,
33 | )
34 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import os
3 | from PIL import Image
4 | import requests
5 |
6 | from torchvision import transforms
7 |
8 | from transformers import ByT5Tokenizer
9 |
10 |
11 | def _prefilter(sample):
12 | image_url = sample["URL"]
13 | caption = sample["TEXT"]
14 | watermark_probability = sample["pwatermark"]
15 | unsafe_probability = sample["punsafe"]
16 | hash = sample["hash"]
17 |
18 | return (
19 | caption is not None
20 | and isinstance(caption, str)
21 | and image_url is not None
22 | and isinstance(image_url, str)
23 | and watermark_probability is not None
24 | and watermark_probability < 0.6
25 | and unsafe_probability is not None
26 | and unsafe_probability < 1.0
27 | and hash is not None
28 | )
29 |
30 |
31 | def _download_image(sample):
32 | is_ok = False
33 |
34 | image_url = sample["URL"]
35 |
36 | cached_image_image_file_path = os.path.join(
37 | "/data/image-cache", "%s.jpg" % hex(sample["hash"])
38 | )
39 |
40 | if os.path.isfile(cached_image_image_file_path):
41 | pass
42 | else:
43 | try:
44 | # get image data from url
45 | image_bytes = requests.get(image_url, stream=True, timeout=5).raw
46 |
47 | if image_bytes is not None:
48 | pil_image = Image.open(image_bytes)
49 |
50 | if pil_image.mode == "RGB":
51 | pil_rgb_image = pil_image
52 |
53 | else:
54 | # Deal with non RGB images
55 | if pil_image.mode == "RGBA":
56 | pil_rgba_image = pil_rgb_image
57 | else:
58 | pil_rgba_image = pil_rgb_image.convert("RGBA")
59 |
60 | pil_rgb_image = Image.alpha_composite(
61 | Image.new("RGBA", pil_image.size, (255, 255, 255)),
62 | pil_rgba_image,
63 | ).convert("RGB")
64 |
65 | is_ok = True
66 |
67 | pil_rgb_image.save(cached_image_image_file_path)
68 |
69 | except:
70 | with open(cached_image_image_file_path, mode="a"):
71 | pass
72 |
73 | # save image to disk but do not catch exception. this has to fail because otherwise the mapper will run forever
74 | if is_ok:
75 | pil_rgb_image.save(cached_image_image_file_path)
76 |
77 | return is_ok
78 |
79 |
80 | def _filter_out_unprocessed(sample):
81 | cached_image_image_file_path = os.path.join(
82 | "/data/image-cache", "%s.jpg" % hex(sample["hash"])
83 | )
84 |
85 | if (
86 | os.path.isfile(cached_image_image_file_path)
87 | and os.stat(cached_image_image_file_path).st_size > 0
88 | ):
89 | try:
90 | Image.open(cached_image_image_file_path)
91 |
92 | return True
93 |
94 | except:
95 | pass
96 |
97 | return False
98 |
99 |
100 | def get_compute_intermediate_values_lambda():
101 | tokenizer = ByT5Tokenizer()
102 |
103 | image_transforms = transforms.Compose(
104 | [
105 | transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS),
106 | transforms.CenterCrop(512),
107 | transforms.ToTensor(),
108 | ]
109 | )
110 |
111 | def __get_pixel_values(image_hash):
112 | # compute file name
113 | cached_image_image_file_path = os.path.join(
114 | "/data/image-cache", "%s.jpg" % hex(image_hash)
115 | )
116 |
117 | # get image data from cache
118 | pil_rgb_image = Image.open(cached_image_image_file_path)
119 |
120 | transformed_image = image_transforms(pil_rgb_image)
121 |
122 | return transformed_image
123 |
124 | def __compute_intermediate_values_lambda(samples):
125 | samples["input_ids"] = tokenizer(
126 | text=samples["TEXT"],
127 | max_length=1024,
128 | padding="max_length",
129 | truncation=True,
130 | return_tensors="pt",
131 | ).input_ids
132 |
133 | samples["pixel_values"] = [
134 | __get_pixel_values(image_hash) for image_hash in samples["hash"]
135 | ]
136 |
137 | return samples
138 |
139 | return __compute_intermediate_values_lambda
140 |
141 |
142 | def setup_dataset(n):
143 | # loading the dataset
144 | dataset = (
145 | load_dataset(
146 | "parquet",
147 | data_files={
148 | # "train": "/data/laion-high-resolution-filtered-shuffled.snappy.parquet",
149 | # "train": "/data/laion-high-resolution-filtered-shuffled-processed-split.zstd.parquet",
150 | # "train": "/data/laion-high-resolution-filtered-shuffled-processed-split-byt5-vae.zstd.parquet",
151 | # "train": "/data/laion-high-resolution-filtered-shuffled-validated-10k.zstd.parquet",
152 | "train": "/data/laion-high-resolution-1M.zstd.parquet",
153 | },
154 | split="train[:%d]" % n,
155 | cache_dir="/data/cache",
156 | num_proc=32,
157 | )
158 | .with_format("torch")
159 | .map(
160 | get_compute_intermediate_values_lambda(),
161 | batched=True,
162 | batch_size=16,
163 | num_proc=32,
164 | )
165 | .select_columns(["input_ids", "pixel_values"])
166 | )
167 |
168 | return dataset
169 |
170 |
171 | def prepare_1m_dataset():
172 | # Gives 1267072 samples to be exact
173 | (
174 | load_dataset(
175 | "laion/laion-high-resolution",
176 | split="train",
177 | cache_dir="/data/cache",
178 | )
179 | .with_format("torch")
180 | .select_columns(["TEXT", "hash"])
181 | .filter(
182 | function=_filter_out_unprocessed,
183 | num_proc=96,
184 | )
185 | .to_parquet(
186 | "/data/laion-high-resolution-1M.zstd.parquet",
187 | batch_size=128,
188 | compression="ZSTD",
189 | )
190 | )
191 |
192 |
193 | if __name__ == "__main__":
194 | prepare_1m_dataset()
195 | # dataset = setup_dataset(64)
196 |
197 | # dataloader = setup_dataloader(dataset, 16)
198 | # for batch in dataloader:
199 | # print(batch["pixel_values"].shape)
200 |
--------------------------------------------------------------------------------
/dataset.sh:
--------------------------------------------------------------------------------
1 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=34359738368
2 | #export JAX_PLATFORMS=""
3 | export TF_CPP_MIN_LOG_LEVEL=2
4 | python3 ./dataset.py
--------------------------------------------------------------------------------
/docs/POSTMORTEM.md:
--------------------------------------------------------------------------------
1 | Thanks to Hugging Face and Google for providing for 21 days free value of more than 6000 USD and GPU T4.
2 |
3 | What we did manage to do.
4 |
5 | We did not manage to... nor to get in-training validation or post-training inference code to work.
--------------------------------------------------------------------------------
/docs/ROADMAP.md:
--------------------------------------------------------------------------------
1 | # CHARRED Roadmap
2 |
3 | ## Use cases
4 |
5 | Artistic rendition of missing, loss, or stolen textual artifacts (text-to-image), OCR (image-to-text), statistical reconstitution of damaged textual artifacts (image-to-image, inpainting to predict the missing characters, ex: MARI).
6 |
7 | Low-resource languages, low-resource domains. (ex: perfumery)
8 |
9 | ## Training
10 |
11 | Are we generating the input embeddings correctly?
12 |
13 | FlaxT5PreTrainedModel.encode(): input_ids=jnp.array(input_ids, dtype="i4") ? uint32 Unicode char or uint8 UTF-8 byte ?
14 | array = np.frombuffer(string.encode("UTF-8", errors="ignore"), dtype=np.uint8) + 3
15 | string = (ndarray - 3).tobytes().decode("UTF-8", errors="ignore")
16 |
17 | Use T5x lib instead of "transformers"
18 |
19 | Should the shape of the latents in the VAE/UNet be bigger to accomodate for more tokens ?
20 |
21 | Would it be possible to have a vocab-less, raw UTF-8 byte character aware decoder-only language model ?
22 |
23 | Write the tests first: https://github.com/deepmind/chex
24 |
25 | VAE/U-Net hyperparameters to accommodate byt5's character-awareness better
26 |
27 | Try to run original code
28 |
29 | 1. DONE: Implement JAX/FLAX SD 2.1 training pipeline with ByT5-Base instead of CLIP: https://github.com/patil-suraj/stable-diffusion-jax https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py https://huggingface.co/google/byt5-base https://huggingface.co/blog/stable_diffusion_jax
30 | 2. DONE: WandB monitoring
31 | 3. DONE: Implement Mini-SNR loss rebalancing: https://arxiv.org/abs/2303.09556
32 | 4. DONE: Implement on-the-fly validation: https://huggingface.co/docs/diffusers/en/conceptual/evaluation
33 | 5. DONE: save checkpoints to disk
34 | 6. Get rid of PytorchDataloader Flax Linen and HF libraries (transformers, diffusers, datasets), use JAX's new Array code, and write pure functions
35 | 7. Make the code independent from device topology (might be hardcoded to 8xTPUv4 at the moment)
36 | 8. Implement streaming from the network (instead of from the disk), mini-batching and gradient accumulation with image aspect ratio and tokenized caption size bucketing. Freezed models caption text embeddings (ByT5) and image embeddings (VAE) caching with bfloat16 half-precision (ByT5 and VAE) and explore using ByT5 XXL float32 (51.6GB), XXL bfloat16 (26GB), or XL float32 (15GB) and discard anything unnecessary from the freezed models (eg: ByT5 decoder weights) to lower the memory requirements: https://github.com/google/jax/issues/1408 https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html https://huggingface.co/google/byt5-xxl https://github.com/google-research/t5x/blob/main/docs/models.md#byt5-checkpoints https://github.com/google-research/t5x/blob/main/t5x/scripts/convert_tf_checkpoint.py https://optax.readthedocs.io/en/latest/gradient_accumulation.html https://optax.readthedocs.io/en/latest/api.html#optax.MultiSteps
37 | 9. Better strategy to load and save checkpoints using JAX-native methods: https://flax.readthedocs.io/en/latest/api_reference/flax.training.html https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#save-checkpoints https://arxiv.org/abs/1604.06174
38 |
39 | ## Inference
40 |
41 | 1. DONE: Implement JAX/FLAX text-to-image inference pipeline and Gradio demo with ByT5-Base instead of CLIP: https://huggingface.co/docs/diffusers/training/text2image https://github.com/patil-suraj/stable-diffusion-jax
42 | 2. Implement AUTOMATIC1111 and Gradio UIs: https://github.com/AUTOMATIC1111/stable-diffusion-webui
43 | 3. Load checkpoints using JAX-native methods https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html
44 | 4. Implement OCR and Document understanging inference pipeline with ByT5 text decoder
45 | 5. Implement text encoding CPU offloading with int8 precision and GPU-accelerated U-Net prediction and VAE decoding with int8 precision https://jax.readthedocs.io/en/latest/multi_process.html https://github.com/TimDettmers/bitsandbytes https://huggingface.co/blog/hf-bitsandbytes-integration
46 |
47 | ## MLOps
48 |
49 | ### XLA, IREE, HLO, MLIR
50 |
51 | https://medium.com/@shivvidhyut/a-brief-introduction-to-distributed-training-with-gradient-descent-a4ba9faefcea
52 | https://www.kaggle.com/code/grez911/tutorial-efficient-gradient-descent-with-jax/notebook
53 | https://github.com/kingoflolz/mesh-transformer-jax
54 | https://github.com/kingoflolz/swarm-jax
55 | https://github.com/openxla/iree https://openxla.github.io/iree/
56 | https://github.com/openxla/xla/blob/main/xla/xla.proto
57 | https://github.com/openxla/xla/blob/main/xla/xla_data.proto
58 | https://github.com/openxla/xla/blob/main/xla/service/hlo.proto
59 | https://github.com/openxla/xla/tree/main/third_party/tsl/tsl/protobuf
60 | https://github.com/openxla/xla/blob/main/xla/pjrt/distributed/protocol.proto
61 | https://github.com/apache/tvm/
62 | https://github.com/openai/triton
63 | https://github.com/onnx/onnx-mlir
64 | https://github.com/plaidml/plaidml
65 | https://github.com/llvm/torch-mlir
66 | https://github.com/pytorch/pytorch/tree/main/torch/_dynamo
67 | https://github.com/pytorch/pytorch/tree/main/torch/_inductor
68 | https://github.com/llvm/llvm-project/tree/main/mlir/ https://mlir.llvm.org/
69 | https://github.com/tensorflow/mlir-hlo
70 | https://github.com/openxla/stablehlo
71 | https://github.com/llvm/torch-mlir
72 | https://research.google/pubs/pub48035/
73 | https://iq.opengenus.org/mlir-compiler-infrastructure/
74 | https://www.youtube.com/watch?v=Z8knnMYRPx0 https://mlir.llvm.org/OpenMeetings/2023-03-23-Nelli.pdf
75 | https://mlir.llvm.org/docs/Tutorials/Toy/
76 | https://mlir.llvm.org/getting_started/
77 | Production AOT with IREE over Java JNI/JNA/Panama https://github.com/openxla/iree https://github.com/iree-org/iree-jax https://jax.readthedocs.io/en/latest/aot.html https://jax.readthedocs.io/en/latest/_autosummary/jax.make_jaxpr.html https://jax.readthedocs.io/en/latest/_autosummary/jax.xla_computation.html https://github.com/openxla/stablehlo https://github.com/openxla/xla https://github.com/openxla/openxla-pjrt-plugin https://github.com/iml130/iree-template-cpp
78 | hlo/mlir compiler/interpreter/simulator/emulator in java: https://github.com/oracle/graal/tree/master/sulong https://github.com/oracle/graal/tree/master/visualizer https://github.com/oracle/graal/tree/master/truffle https://github.com/graalvm/simplelanguage https://github.com/graalvm/simpletools https://openjdk.org/jeps/442 https://openjdk.org/jeps/448
79 |
80 | ### Java pipeline
81 |
82 | https://vertx.io/docs/vertx-web-api-service/java/
83 | https://github.com/vert-x3/vertx-infinispan https://github.com/infinispan/infinispan
84 | https://github.com/eclipse-vertx/vertx-grpc
85 | https://github.com/vert-x3/vertx-service-discovery
86 | https://github.com/vert-x3/vertx-service-proxy
87 |
88 | https://kafka.apache.org/ https://github.com/provectus/kafka-ui
89 | https://github.com/vert-x3/vertx-kafka-client
90 | https://github.com/vert-x3/vertx-stomp https://github.com/stomp-js/stompjs https://activemq.apache.org/ https://developers.cloudflare.com/queues/
91 | https://github.com/apache/pulsar
92 | https://github.com/datastax/kafka-examples
93 | https://github.com/datastax/kafka-sink
94 | https://github.com/datastax/starlight-for-kafka
95 |
96 | https://github.com/apache/arrow/tree/main/java
97 | https://github.com/apache/thrift
98 | https://github.com/apache/avro
99 | https://github.com/apache/orc
100 | https://github.com/apache/parquet-mr
101 | https://github.com/msgpack/msgpack-java
102 | https://github.com/irmen/pickle
103 | https://github.com/jamesmudd/jhdf
104 |
105 | https://github.com/apache/iceberg
106 | https://github.com/eclipse/jnosql
107 | https://github.com/trinodb/trino
108 | https://github.com/apache/druid/
109 | https://github.com/apache/hudi
110 | https://github.com/delta-io/delta
111 | https://github.com/apache/pinot
112 |
113 | ### HA & Telemetry
114 |
115 | OpenTelemetry/Graphana monitoring instead of WandB, Perfetto or Tensorbord, attach JAX profiler artifacts
116 | https://github.com/resilience4j/resilience4j https://resilience4j.readme.io/docs/micrometer https://vertx.io/docs/vertx-micrometer-metrics/java/
117 | https://github.com/open-telemetry/opentelemetry-java-instrumentation
118 | https://github.com/grafana/JPProf https://jax.readthedocs.io/en/latest/device_memory_profiling.html https://jax.readthedocs.io/en/latest/_autosummary/jax.profiler.device_memory_profile.html https://github.com/google/pprof/tree/main/proto https://github.com/grafana/pyroscope https://github.com/grafana/otel-profiling-go https://github.com/grafana/metrictank https://github.com/open-telemetry/opentelemetry-collector-contrib/tree/main/extension/pprofextension
119 | https://jax.readthedocs.io/en/latest/profiling.html https://github.com/google/perfetto/tree/master/protos
120 | https://vertx.io/docs/vertx-opentelemetry/java/
121 | https://vertx.io/docs/vertx-micrometer-metrics/java/
122 |
123 | ## Dataset preprocessing
124 |
125 | Make the most of cheap Kubernetes clusters: https://github.com/murphye/cheap-gke-cluster
126 |
127 | 1. Synthetic data generation: HTML/SVG/CSS graphic/layout/typography
128 | 2. Dataset merging: synthetic data, LAION-HR: https://huggingface.co/datasets/laion/laion-high-resolution, WiT dataset: https://huggingface.co/datasets/google/wit https://huggingface.co/datasets/wikimedia/wit_base, handwritten and printed documents scans, graphemes-in-the-wild, etc. with language re-sampling to match ByT5's C4 training distribution so as not to loose the multilingual balance: https://huggingface.co/datasets/allenai/c4 https://huggingface.co/datasets/mc4. More image datasets: https://huggingface.co/datasets/facebook/winoground https://huggingface.co/datasets/huggan/wikiart https://huggingface.co/datasets/kakaobrain/coyo-700m https://github.com/unsplash/datasets https://huggingface.co/datasets/red_caps https://huggingface.co/datasets/gigant/oldbookillustrations
129 | 3. Complementary downloads from dataset URLs (mostly images) and JPEG XL archiving (see IIIF)
130 | 4. Deduplication of images with fingerprinting and of captions with sentence embeddings (all the sentence-transformers disappeared on May 8 2023)
131 | 5. Scene segmentation, document layout understanding and caption NER. Because NER is to text what segmentation is to a scene and what layout understanding is to a document, we need to annotate all of these to be able to detect captions-within-a-caption (captions that spell out text within the image, for instance) and also score captions based on how exhaustive is the "coverage" of the scene or document they describe: https://medium.com/nlplanet/a-full-guide-to-finetuning-t5-for-text2text-and-building-a-demo-with-streamlit-c72009631887 https://huggingface.co/docs/transformers/model_doc/flan-ul2 https://pytorch.org/text/main/tutorials/t5_demo.html https://towardsdatascience.com/guide-to-fine-tuning-text-generation-models-gpt-2-gpt-neo-and-t5-dc5de6b3bc5e https://programming-review.com/machine-learning/t5/ https://colab.research.google.com/drive/1syXmhEQ5s7C59zU8RtHVru0wAvMXTSQ8 https://github.com/ttengwang/Caption-Anything https://github.com/facebookresearch/segment-anything https://github.com/facebookresearch/detectron2 https://huggingface.co/datasets/joelito/lextreme ttps://registry.opendata.aws/lei/ https://huggingface.co/datasets/jfrenz/legalglue https://huggingface.co/datasets/super_glue https://huggingface.co/datasets/klue https://huggingface.co/datasets/NbAiLab/norne https://huggingface.co/datasets/indic_glue https://huggingface.co/datasets/acronym_identification https://huggingface.co/datasets/wikicorpus https://huggingface.co/datasets/multi_woz_v22 https://huggingface.co/datasets/wnut_17 https://huggingface.co/datasets/msra_ner https://huggingface.co/datasets/conll2012_ontonotesv5 https://huggingface.co/datasets/conllpp
132 | 6. Image aesthetics and caption exhaustiveness (based on #5) meaningfulness (CoLa) evaluation and filtering: https://github.com/google-research/google-research/tree/master/vila https://github.com/google-research/google-research/tree/master/musiq https://github.com/christophschuhmann/improved-aesthetic-predictor https://www.mdpi.com/2313-433X/9/2/30 https://paperswithcode.com/dataset/aesthetic-visual-analysis https://www.tandfonline.com/doi/full/10.1080/09540091.2022.2147902 https://github.com/bcmi/Awesome-Aesthetic-Evaluation-and-Cropping https://github.com/rmokady/CLIP_prefix_caption https://github.com/google-research-datasets/Image-Caption-Quality-Dataset https://github.com/gchhablani/multilingual-image-captioning https://ai.googleblog.com/2022/10/crossmodal-3600-multilingual-reference.html https://www.cl.uni-heidelberg.de/statnlpgroup/wikicaps/ https://huggingface.co/docs/transformers/main/tasks/image_captioning https://www.mdpi.com/2076-3417/13/4/2446 https://arxiv.org/abs/2201.12723 https://laion.ai/blog/laion-aesthetics/ https://github.com/JD-P/simulacra-aesthetic-captions
133 | 7. Bucketing and batching (similar caption lengths for padding and truncating, similar image ratio for up/downsampling): https://github.com/NovelAI/novelai-aspect-ratio-bucketing
134 | 8. Images preprocessing with JAX-native methods: https://jax.readthedocs.io/en/latest/jax.image.html https://dm-pix.readthedocs.io/ https://github.com/4rtemi5/imax https://github.com/rolandgvc/flaxvision
135 |
136 | ## CharT5 (ByT5 v2)
137 |
138 | Pretrain a better ByT5 with innovations from the T5 family and other character-aware language transformer models:
139 |
140 | - Early character-aware language models: https://arxiv.org/abs/1508.06615 https://arxiv.org/abs/2011.01513
141 | - CANINE-C encoder only character-aware language model (https://arxiv.org/abs/2103.06874, https://github.com/google-research/language/tree/master/language/canine, https://huggingface.co/google/canine-c, https://huggingface.co/vicl/canine-c-finetuned-cola)
142 | - Switch/MOE https://arxiv.org/abs/2101.03961 https://github.com/google-research/t5x/tree/main/t5x/contrib/moe https://towardsdatascience.com/understanding-googles-switch-transformer-904b8bf29f66, https://huggingface.co/google/switch-c-2048, https://towardsdatascience.com/the-switch-transformer-59f3854c7050 https://arxiv.org/abs/2208.02813
143 | - FLAN/PALM/PALM-E/PALM 2/UL2 https://arxiv.org/abs/2210.11416 https://arxiv.org/abs/2301.13688 https://arxiv.org/abs/2109.01652 https://github.com/lucidrains/PaLM-jax https://github.com/conceptofmind/PaLM-flax https://github.com/google-research/t5x/tree/main/t5x/examples/decoder_only https://huggingface.co/docs/transformers/main/model_doc/flan-t5 https://huggingface.co/google/flan-ul2 https://arxiv.org/abs/2205.05131v3 https://github.com/google-research/google-research/tree/master/ul2 https://ai.googleblog.com/2022/10/ul2-20b-open-source-unified-language.html
144 | - T5 v1.1 https://arxiv.org/abs/2002.05202
145 | - CALM https://github.com/google-research/t5x/tree/main/t5x/contrib/calm https://arxiv.org/abs/2207.07061
146 | - FlashAttention https://github.com/HazyResearch/flash-attention https://arxiv.org/abs/2205.14135
147 | - T5x & Seqio https://arxiv.org/abs/2203.17189
148 | - LongT5 https://github.com/google-research/longt5
149 | - WT5 https://github.com/google-research/google-research/tree/master/wt5
150 | - NanoT5 https://github.com/PiotrNawrot/nanoT5
151 | - Tensor Considered Harmful https://nlp.seas.harvard.edu/NamedTensor https://github.com/stanford-crfm/levanter https://github.com/harvardnlp/NamedTensor
152 | - FasterTransformers https://github.com/NVIDIA/FasterTransformer
153 | - FlaxFormers https://github.com/google/flaxformer
154 |
155 | # Beyond SD 2.1
156 |
157 | Integrate and port to JAX as much improvements and ideas from Imagen, SDXL, Deep Floyd, Big Vision, Vision Transformer, etc. as possible : https://github.com/lucidrains/imagen-pytorch https://github.com/deep-floyd/IF https://stable-diffusion-art.com/sdxl-beta/ https://huggingface.co/docs/diffusers/api/pipelines/if https://huggingface.co/spaces/DeepFloyd/IF https://huggingface.co/DeepFloyd/IF-I-XL-v1.0 https://huggingface.co/DeepFloyd/IF-II-L-v1.0 https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler https://huggingface.co/DeepFloyd/IF-notebooks/tree/main https://huggingface.co/blog/if https://huggingface.co/docs/diffusers/main/en/api/pipelines/if https://stability.ai/blog/deepfloyd-if-text-to-image-model https://deepfloyd.ai/ https://www.assemblyai.com/blog/how-imagen-actually-works/ https://www.youtube.com/watch?v=af6WPqvzjjk https://www.youtube.com/watch?v=xqDeAz0U-R4 https://www.assemblyai.com/blog/minimagen-build-your-own-imagen-text-to-image-model/ https://github.com/google-research/big_vision https://github.com/google-research/vision_transformer https://github.com/microsoft/unilm/tree/master/beit https://arxiv.org/abs/2106.04803 https://arxiv.org/abs/2210.01820 https://arxiv.org/abs/2103.15808 https://arxiv.org/abs/2201.10271 https://arxiv.org/abs/2209.15159 https://arxiv.org/abs/2303.14189 https://arxiv.org/abs/2010.11929 https://arxiv.org/abs/2208.10442 https://arxiv.org/abs/2012.12877 https://arxiv.org/abs/2111.06377v3 https://arxiv.org/abs/2107.06263 https://arxiv.org/abs/1906.00446 https://arxiv.org/abs/2110.04627 https://arxiv.org/abs/2208.06366 https://arxiv.org/abs/2302.00902 https://arxiv.org/abs/2212.03185 https://arxiv.org/abs/2212.07372 https://arxiv.org/abs/2209.09002 https://arxiv.org/abs/2301.00704 https://arxiv.org/abs/2211.09117 https://arxiv.org/abs/2302.05917
158 |
--------------------------------------------------------------------------------
/inference_jax.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax
3 | import jax.numpy as jnp
4 | from monitoring import wandb_close, wandb_inference_init, wandb_inference_log
5 |
6 | import numpy as np
7 | from PIL import Image
8 |
9 | from diffusers import (
10 | FlaxAutoencoderKL,
11 | FlaxDPMSolverMultistepScheduler,
12 | FlaxUNet2DConditionModel,
13 | )
14 | from transformers import ByT5Tokenizer, FlaxT5ForConditionalGeneration
15 |
16 |
17 | def get_inference_lambda(seed):
18 | tokenizer = ByT5Tokenizer()
19 |
20 | language_model = FlaxT5ForConditionalGeneration.from_pretrained(
21 | "google/byt5-base",
22 | dtype=jnp.float32,
23 | )
24 | text_encoder = language_model.encode
25 | text_encoder_params = language_model.params
26 | max_length = 1024
27 | tokenized_negative_prompt = tokenizer(
28 | "", padding="max_length", max_length=max_length, return_tensors="np"
29 | ).input_ids
30 | negative_prompt_text_encoder_hidden_states = text_encoder(
31 | tokenized_negative_prompt,
32 | params=text_encoder_params,
33 | train=False,
34 | )[0]
35 |
36 | scheduler = FlaxDPMSolverMultistepScheduler.from_config(
37 | config={
38 | "_diffusers_version": "0.16.0",
39 | "beta_end": 0.012,
40 | "beta_schedule": "scaled_linear",
41 | "beta_start": 0.00085,
42 | "clip_sample": False,
43 | "num_train_timesteps": 1000,
44 | "prediction_type": "v_prediction",
45 | "set_alpha_to_one": False,
46 | "skip_prk_steps": True,
47 | "steps_offset": 1,
48 | "trained_betas": None,
49 | }
50 | )
51 | timesteps = 20
52 | guidance_scale = jnp.array([7.5], dtype=jnp.bfloat16)
53 |
54 | unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
55 | "character-aware-diffusion/charred",
56 | dtype=jnp.bfloat16,
57 | )
58 |
59 | vae, vae_params = FlaxAutoencoderKL.from_pretrained(
60 | "flax/stable-diffusion-2-1",
61 | subfolder="vae",
62 | dtype=jnp.bfloat16,
63 | )
64 | vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
65 |
66 | image_width = image_height = 256
67 |
68 | # Generating latent shape
69 | latent_shape = (
70 | negative_prompt_text_encoder_hidden_states.shape[
71 | 0
72 | ], # TODO: if is this for the whole context (positive + negative prompts), we should multiply by two
73 | unet.in_channels,
74 | image_width // vae_scale_factor,
75 | image_height // vae_scale_factor,
76 | )
77 |
78 | def __tokenize_prompt(prompt: str):
79 | return tokenizer(
80 | text=prompt,
81 | max_length=1024,
82 | padding="max_length",
83 | truncation=True,
84 | return_tensors="jax",
85 | ).input_ids
86 |
87 | def __convert_image(image):
88 | # create PIL image from JAX tensor converted to numpy
89 | return Image.fromarray(np.asarray(image), mode="RGB")
90 |
91 | def __get_context(tokenized_prompt: jnp.array):
92 | # Get the text embedding
93 | text_encoder_hidden_states = text_encoder(
94 | tokenized_prompt,
95 | params=text_encoder_params,
96 | train=False,
97 | )[0]
98 |
99 | # context = empty negative prompt embedding + prompt embedding
100 | return jnp.concatenate(
101 | [negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
102 | )
103 |
104 | def __predict_image(context: jnp.array):
105 | def ___timestep(step, step_args):
106 | latents, scheduler_state = step_args
107 |
108 | t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
109 |
110 | # For classifier-free guidance, we need to do two forward passes.
111 | # Here we concatenate the unconditional and text embeddings into a single batch
112 | # to avoid doing two forward passes
113 | latent_input = jnp.concatenate([latents] * 2)
114 |
115 | timestep = jnp.broadcast_to(t, latent_input.shape[0])
116 |
117 | scaled_latent_input = scheduler.scale_model_input(
118 | scheduler_state, latent_input, t
119 | )
120 |
121 | # predict the noise residual
122 | unet_prediction_sample = unet.apply(
123 | {"params": unet_params},
124 | jnp.array(scaled_latent_input),
125 | jnp.array(timestep, dtype=jnp.int32),
126 | context,
127 | ).sample
128 |
129 | # perform guidance
130 | unet_prediction_sample_uncond, unet_prediction_text = jnp.split(
131 | unet_prediction_sample, 2, axis=0
132 | )
133 | guided_unet_prediction_sample = (
134 | unet_prediction_sample_uncond
135 | + guidance_scale
136 | * (unet_prediction_text - unet_prediction_sample_uncond)
137 | )
138 |
139 | # compute the previous noisy sample x_t -> x_t-1
140 | latents, scheduler_state = scheduler.step(
141 | scheduler_state, guided_unet_prediction_sample, t, latents
142 | ).to_tuple()
143 |
144 | return latents, scheduler_state
145 |
146 | # initialize scheduler state
147 | initial_scheduler_state = scheduler.set_timesteps(
148 | scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
149 | )
150 |
151 | # initialize latents
152 | initial_latents = (
153 | jax.random.normal(
154 | jax.random.PRNGKey(seed), shape=latent_shape, dtype=jnp.bfloat16
155 | )
156 | * initial_scheduler_state.init_noise_sigma
157 | )
158 |
159 | final_latents, _ = jax.lax.fori_loop(
160 | 0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
161 | )
162 |
163 | vae_output = vae.apply(
164 | {"params": vae_params},
165 | 1 / vae.config.scaling_factor * final_latents,
166 | method=vae.decode,
167 | ).sample
168 |
169 | # return 8 bit RGB image (width, height, rgb)
170 | return (
171 | ((vae_output / 2 + 0.5).transpose(0, 2, 3, 1).clip(0, 1) * 255)
172 | .round()
173 | .astype(jnp.uint8)[0]
174 | )
175 |
176 | jax_jit_compiled_accel_predict_image = jax.jit(__predict_image)
177 |
178 | jax_jit_compiled_cpu_get_context = jax.jit(
179 | __get_context, device=jax.devices(backend="cpu")[0]
180 | )
181 |
182 | return lambda prompt: __convert_image(
183 | jax_jit_compiled_accel_predict_image(
184 | jax_jit_compiled_cpu_get_context(__tokenize_prompt(prompt))
185 | )
186 | )
187 |
188 |
189 | if __name__ == "__main__":
190 | wandb_inference_init()
191 |
192 | generate_image_for_prompt = get_inference_lambda(87)
193 |
194 | prompts = [
195 | "a white car",
196 | "a running shoe",
197 | "a forest",
198 | "two people",
199 | "a happy cartoon cat",
200 | ]
201 |
202 | log = []
203 |
204 | for prompt in prompts:
205 | log.append({"prompt": prompt, "image": generate_image_for_prompt(prompt)})
206 |
207 | wandb_inference_log(log)
208 |
209 | wandb_close()
210 |
--------------------------------------------------------------------------------
/inference_jax.sh:
--------------------------------------------------------------------------------
1 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=34359738368
2 | #export JAX_PLATFORMS=""
3 | #export JAX_PLATFORMS="cpu"
4 | export TF_CPP_MIN_LOG_LEVEL=2
5 | python3 ./inference_jax.py
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import jax
3 |
4 | from diffusers import FlaxDDPMScheduler
5 |
6 | # Min-SNR
7 | snr_gamma = 5.0 # SNR weighting gamma to be used when rebalancing the loss with Min-SNR. Recommended value is 5.0.
8 |
9 |
10 | def compute_snr_loss_weights(noise_scheduler_state, timesteps):
11 | """
12 | Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
13 | """
14 | alphas_cumprod = noise_scheduler_state.common.alphas_cumprod
15 | sqrt_alphas_cumprod = alphas_cumprod ** 0.5
16 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
17 |
18 | alpha = sqrt_alphas_cumprod[timesteps]
19 | sigma = sqrt_one_minus_alphas_cumprod[timesteps]
20 |
21 | # Compute SNR.
22 | snr = jnp.array((alpha / sigma) ** 2)
23 |
24 | # Compute SNR loss weights
25 | return jnp.where(snr < snr_gamma, snr, jnp.ones_like(snr) * snr_gamma) / snr
26 |
27 |
28 | def get_vae_latent_distribution_samples(
29 | latents,
30 | sample_rng,
31 | noise_scheduler,
32 | noise_scheduler_state,
33 | ):
34 |
35 | # Sample noise that we'll add to the latents
36 | noise_rng, timestep_rng = jax.random.split(sample_rng)
37 | noise = jax.random.normal(noise_rng, latents.shape)
38 |
39 | # Sample a random timestep for each image
40 | timesteps = jax.random.randint(
41 | key=timestep_rng,
42 | shape=(latents.shape[0],),
43 | minval=0,
44 | maxval=noise_scheduler.config.num_train_timesteps,
45 | dtype=jnp.int32,
46 | )
47 |
48 | # Add noise to the latents according to the noise magnitude at each timestep
49 | # (this is the forward diffusion process)
50 | noisy_latents = noise_scheduler.add_noise(
51 | state=noise_scheduler_state,
52 | original_samples=latents,
53 | noise=noise,
54 | timesteps=timesteps,
55 | )
56 |
57 | return noisy_latents, timesteps, noise
58 |
59 |
60 | def get_cacheable_samples(
61 | text_encoder, text_encoder_params, input_ids, vae, vae_params, pixel_values, rng
62 | ):
63 |
64 | # Get the text embedding
65 | # TODO: Cache this
66 | # TODO: use t5x library
67 | text_encoder_hidden_states = text_encoder(
68 | input_ids,
69 | params=text_encoder_params,
70 | train=False,
71 | )[0]
72 |
73 | # Get the image embedding
74 | vae_outputs = vae.apply(
75 | {"params": vae_params},
76 | sample=pixel_values,
77 | deterministic=True,
78 | method=vae.encode,
79 | )
80 |
81 | # Sample the image embedding
82 | # TODO: Cache this
83 | image_latent_distribution_sampling = (
84 | # (NHWC) -> (NCHW)
85 | jnp.transpose(vae_outputs.latent_dist.sample(rng), (0, 3, 1, 2))
86 | * vae.config.scaling_factor
87 | )
88 |
89 | return text_encoder_hidden_states, image_latent_distribution_sampling
90 |
91 |
92 | def get_compute_losses_lambda(
93 | text_encoder, # <-- TODO: take this out of here
94 | text_encoder_params, # <-- TODO: take this out of here
95 | vae, # <-- TODO: take this out of here
96 | vae_params, # <-- TODO: take this out of here
97 | unet,
98 | ):
99 |
100 | # Instanciate training noise scheduler
101 | # TODO: write pure function
102 | noise_scheduler = FlaxDDPMScheduler(
103 | beta_start=0.00085,
104 | beta_end=0.012,
105 | beta_schedule="scaled_linear",
106 | prediction_type="epsilon",
107 | num_train_timesteps=1000,
108 | )
109 |
110 | def __compute_losses_lambda(
111 | state_params,
112 | batch,
113 | sample_rng,
114 | ):
115 |
116 | # TODO: take this out of here
117 | (
118 | text_encoder_hidden_states,
119 | image_latent_distribution_sampling,
120 | ) = get_cacheable_samples(
121 | text_encoder,
122 | text_encoder_params,
123 | batch["input_ids"],
124 | vae,
125 | vae_params,
126 | batch["pixel_values"],
127 | sample_rng,
128 | )
129 |
130 | # initialize scheduler state
131 | noise_scheduler_state = noise_scheduler.create_state()
132 |
133 | # Get the vae latent distribution samples
134 | (
135 | image_sampling_noisy_input,
136 | image_sampling_timesteps,
137 | noise,
138 | ) = get_vae_latent_distribution_samples(
139 | image_latent_distribution_sampling,
140 | sample_rng,
141 | noise_scheduler,
142 | noise_scheduler_state,
143 | )
144 |
145 | # Predict the noise residual and compute loss
146 | # TODO: write pure function
147 | unet_predictions = unet.apply(
148 | {"params": state_params},
149 | sample=image_sampling_noisy_input,
150 | timesteps=image_sampling_timesteps,
151 | encoder_hidden_states=text_encoder_hidden_states,
152 | train=True,
153 | ).sample
154 |
155 | # Compute each batch sample's loss from noisy target
156 | loss_tensors = (noise - unet_predictions) ** 2
157 |
158 | # Compute Min-SNR loss weights
159 | snr_loss_weights = compute_snr_loss_weights(
160 | noise_scheduler_state,
161 | image_sampling_timesteps,
162 | )
163 |
164 | # Get one loss scalar per batch sample
165 | losses = (
166 | loss_tensors.mean(
167 | axis=tuple(range(1, loss_tensors.ndim)),
168 | )
169 | * snr_loss_weights
170 | ) # Balance losses with Min-SNR
171 |
172 | # This must be an averaged scalar, otherwise, you get this:
173 | # TypeError: Gradient only defined for scalar-output functions. Output had shape: (8,).
174 | return losses.mean(axis=0)
175 |
176 | return __compute_losses_lambda
177 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # jax/flax
4 | import jax
5 | from flax.core.frozen_dict import unfreeze
6 | from flax.training import train_state
7 |
8 | from architecture import setup_model
9 |
10 | # internal code
11 | from args import parse_args
12 | from optimizer import setup_optimizer
13 | from training_loop import training_loop
14 | from monitoring import wandb_close, wandb_init
15 |
16 |
17 | def main():
18 | args = parse_args()
19 |
20 | # number of splits/partitions/devices/shards
21 | num_devices = jax.local_device_count()
22 |
23 | output_dir = args.output_dir
24 |
25 | load_pretrained = os.path.exists(output_dir) and os.path.isdir(output_dir)
26 |
27 | # Setup WandB for logging & tracking
28 | log_wandb = args.log_wandb
29 | if log_wandb:
30 | wandb_init(args, num_devices)
31 |
32 | # init random number generator
33 | seed = args.seed
34 | seed_rng = jax.random.PRNGKey(seed)
35 | rng, training_from_scratch_rng_params = jax.random.split(seed_rng)
36 | print("random generator setup...")
37 |
38 | # Pretrained/freezed and training model setup
39 | text_encoder, text_encoder_params, vae, vae_params, unet, unet_params = setup_model(
40 | seed,
41 | load_pretrained,
42 | output_dir,
43 | training_from_scratch_rng_params,
44 | )
45 | print("models setup...")
46 |
47 | # Optimization & scheduling setup
48 | optimizer = setup_optimizer(
49 | args.learning_rate,
50 | args.adam_beta1,
51 | args.adam_beta2,
52 | args.adam_epsilon,
53 | args.adam_weight_decay,
54 | args.max_grad_norm,
55 | )
56 | print("optimizer setup...")
57 |
58 | # Training state setup
59 | unet_training_state = train_state.TrainState.create(
60 | apply_fn=unet,
61 | params=unfreeze(unet_params),
62 | tx=optimizer,
63 | )
64 | print("training state initialized...")
65 |
66 | if log_wandb:
67 | get_validation_predictions = None # TODO: put validation here
68 | else:
69 | get_validation_predictions = None
70 |
71 | # JAX device data replication
72 | # replicated_state = replicate(unet_training_state)
73 | # NOTE: # These can't be replicated here, otherwise, you get this whenever they are used: flax.errors.ScopeParamShapeError: Initializer expected to generate shape (4, 384, 1536) but got shape (384, 1536) instead for parameter "embedding" in "/shared". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)
74 | # replicated_text_encoder_params = jax_utils.replicate(text_encoder_params)
75 | # replicated_vae_params = jax_utils.replicate(vae_params)
76 | # print("states & params replicated to TPUs...")
77 |
78 | # Train!
79 | print("Training loop init...")
80 | training_loop(
81 | text_encoder,
82 | text_encoder_params,
83 | vae,
84 | vae_params,
85 | unet,
86 | unet_training_state,
87 | rng,
88 | args.max_train_steps,
89 | args.num_train_epochs,
90 | args.train_batch_size,
91 | output_dir,
92 | log_wandb,
93 | get_validation_predictions,
94 | num_devices,
95 | )
96 | print("Training loop done...")
97 |
98 | if log_wandb:
99 | wandb_close()
100 |
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
--------------------------------------------------------------------------------
/monitoring.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | import numpy as np
3 |
4 |
5 | def wandb_inference_init():
6 | wandb.init(
7 | entity="charred",
8 | project="charred-inference",
9 | job_type="inference",
10 | )
11 |
12 | print("WandB inference init...")
13 |
14 |
15 | def wandb_inference_log(log: list):
16 | wandb_log = []
17 |
18 | for entry in log:
19 | wandb_log.append(wandb.Image(entry["image"], caption=entry["prompt"]))
20 |
21 | wandb.log({"inference": wandb_log})
22 |
23 | print("WandB inference log...")
24 |
25 |
26 | def wandb_init(args, num_devices):
27 | wandb.init(
28 | entity="charred",
29 | project="charred",
30 | job_type="train",
31 | config=args,
32 | )
33 | wandb.config.update(
34 | {
35 | "num_devices": num_devices,
36 | }
37 | )
38 | wandb.define_metric("*", step_metric="step")
39 | wandb.define_metric("step", step_metric="walltime")
40 |
41 | print("WandB setup...")
42 |
43 |
44 | def wandb_close():
45 | wandb.finish()
46 |
47 | print("WandB closed...")
48 |
49 |
50 | def get_wandb_log_batch_lambda(
51 | get_predictions,
52 | ):
53 | def __wandb_log_batch(
54 | global_walltime,
55 | global_training_steps,
56 | delta_time,
57 | epoch,
58 | loss,
59 | unet_params,
60 | is_milestone,
61 | ):
62 |
63 | log_data = {
64 | "walltime": global_walltime,
65 | "step": global_training_steps,
66 | "batch_delta_time": delta_time,
67 | "epoch": epoch,
68 | "loss": loss.mean(),
69 | }
70 |
71 | if is_milestone and get_predictions is not None:
72 | log_data["validation"] = [
73 | wandb.Image(image, caption=prompt)
74 | for prompt, image in get_predictions(unet_params)
75 | ]
76 |
77 | wandb.log(
78 | data=log_data,
79 | commit=True,
80 | )
81 |
82 | return __wandb_log_batch
83 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | import optax
2 |
3 |
4 | def setup_optimizer(
5 | learning_rate,
6 | adam_beta1,
7 | adam_beta2,
8 | adam_epsilon,
9 | adam_weight_decay,
10 | max_grad_norm,
11 | ):
12 | constant_scheduler = optax.constant_schedule(learning_rate)
13 |
14 | adamw = optax.adamw(
15 | learning_rate=constant_scheduler,
16 | b1=adam_beta1,
17 | b2=adam_beta2,
18 | eps=adam_epsilon,
19 | weight_decay=adam_weight_decay,
20 | )
21 |
22 | return optax.chain(
23 | optax.clip_by_global_norm(max_grad_norm),
24 | adamw,
25 | )
26 |
--------------------------------------------------------------------------------
/repository.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from threading import Thread
4 |
5 | # hugging face
6 | from huggingface_hub import upload_folder
7 |
8 |
9 | def create_output_dir(output_dir):
10 | if output_dir is not None:
11 | os.makedirs(output_dir, exist_ok=True)
12 |
13 |
14 | def save_to_local_directory(
15 | output_dir,
16 | unet,
17 | unet_params,
18 | ):
19 | print("saving trained weights...")
20 | unet.save_pretrained(
21 | save_directory=output_dir,
22 | params=unet_params,
23 | )
24 | print("trained weights saved...")
25 |
26 |
27 | def save_to_repository(
28 | output_dir,
29 | unet,
30 | unet_params,
31 | repo_id,
32 | ):
33 | print("saving trained weights...")
34 | unet.save_pretrained(
35 | save_directory=output_dir,
36 | params=unet_params,
37 | )
38 | print("trained weights saved...")
39 |
40 | Thread(
41 | target=lambda: upload_to_repository(
42 | repo_id,
43 | output_dir,
44 | "End of training epoch.",
45 | )
46 | ).start()
47 |
48 |
49 | def upload_to_repository(
50 | output_dir,
51 | repo_id,
52 | commit_message,
53 | ):
54 | upload_folder(
55 | repo_id=repo_id,
56 | folder_path=output_dir,
57 | commit_message=commit_message,
58 | )
59 |
60 |
61 | if __name__ == "__main__":
62 | upload_to_repository(
63 | "/data/output.bak/000920",
64 | "character-aware-diffusion/charred",
65 | "Latest training epoch version as of Apr 28 11:03PM UST.",
66 | )
67 |
--------------------------------------------------------------------------------
/repository.sh:
--------------------------------------------------------------------------------
1 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=34359738368
2 | export TF_CPP_MIN_LOG_LEVEL=2
3 | python3 ./repository.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jax[cuda12_local]
2 | flax
3 | optax
4 | diffusers
5 | transformers
6 | datasets
7 | --extra-index-url https://download.pytorch.org/whl/cpu
8 | torch
9 | torchvision
10 | wandb
11 | Pillow
--------------------------------------------------------------------------------
/training.sh:
--------------------------------------------------------------------------------
1 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=34359738368
2 | #export JAX_PLATFORMS=''
3 | python3 main.py
--------------------------------------------------------------------------------
/training_loop.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from jax import pmap
4 | from jax.random import split
5 | from flax.training.common_utils import shard
6 | from flax.jax_utils import replicate, unreplicate
7 | from jax.profiler import (
8 | start_trace,
9 | stop_trace,
10 | save_device_memory_profile,
11 | device_memory_profile,
12 | )
13 | from jaxlib.xla_extension import XlaRuntimeError
14 |
15 | from monitoring import get_wandb_log_batch_lambda
16 | from batch import setup_dataloader
17 | from dataset import setup_dataset
18 | from repository import save_to_local_directory
19 | from training_step import get_training_step_lambda
20 |
21 |
22 | def training_loop(
23 | text_encoder,
24 | text_encoder_params,
25 | vae,
26 | vae_params,
27 | unet,
28 | unet_training_state,
29 | rng,
30 | max_train_steps,
31 | num_train_epochs,
32 | train_batch_size,
33 | output_dir,
34 | log_wandb,
35 | get_validation_predictions,
36 | num_devices,
37 | ):
38 |
39 | # replication setup
40 | unet_training_state = replicate(unet_training_state)
41 | rng = split(rng, num_devices)
42 |
43 | # dataset setup
44 | train_dataset = setup_dataset(max_train_steps)
45 | print("dataset loaded...")
46 |
47 | # batch setup
48 | total_train_batch_size = train_batch_size * num_devices
49 | train_dataloader = setup_dataloader(train_dataset, total_train_batch_size)
50 | print("dataloader setup...")
51 |
52 | # milestone setup
53 | milestone_step_count = min(100_000, max_train_steps)
54 | print(f"milestone step count: {milestone_step_count}")
55 |
56 | # wandb setup
57 | if log_wandb:
58 | wandb_log_batch = get_wandb_log_batch_lambda(
59 | get_validation_predictions,
60 | )
61 | print("wand log batch setup...")
62 |
63 | # Create parallel version of the train step
64 | # TODO: Should we try "axis_size=num_devices" or "axis_size=total_train_batch_size" ?
65 | jax_pmapped_training_step = pmap(
66 | # cannot send these as static broadcasted arguments because they are not hashable
67 | # TODO: rewrite text_encoder, vae and unet as pure
68 | fun=get_training_step_lambda(
69 | text_encoder, text_encoder_params, vae, vae_params, unet
70 | ),
71 | axis_name="batch",
72 | in_axes=(0, 0, 0),
73 | out_axes=(0, 0, 0),
74 | static_broadcasted_argnums=(),
75 | # We cannot donate the "batch" argument. Otherwise, we get this:
76 | # /site-packages/jax/_src/interpreters/mlir.py:711: UserWarning: Some donated buffers were not usable: ShapedArray(int32[8,1024]), ShapedArray(float32[8,3,512,512]).
77 | # See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
78 | # warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
79 | # donating rng and training state
80 | donate_argnums=(
81 | 0,
82 | 1,
83 | ),
84 | )
85 |
86 | # Epoch setup
87 | t0 = time.monotonic()
88 | global_training_steps = 0
89 | global_walltime = time.monotonic()
90 | is_compilation_step = True
91 | is_first_compiled_step = False
92 | loss = None
93 | for epoch in range(num_train_epochs):
94 |
95 | for batch in train_dataloader:
96 |
97 | # getting batch start time
98 | batch_walltime = time.monotonic()
99 |
100 | if is_compilation_step:
101 | print("computing compilation batch...")
102 | # TODO: fix this: 2023-05-05 16:34:23.937383: E external/xla/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace
103 | device_memory_profile()
104 | start_trace(
105 | log_dir="./profiling/compilation_step",
106 | create_perfetto_link=False,
107 | create_perfetto_trace=True,
108 | )
109 | elif is_first_compiled_step:
110 | print("computing first compiled batch...")
111 | device_memory_profile()
112 | start_trace(
113 | log_dir="./profiling/first_compiled_step",
114 | create_perfetto_link=False,
115 | create_perfetto_trace=True,
116 | )
117 |
118 | # training step
119 | # TODO: Fix this jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program: Attempting to allocate 1.28G. That was not possible. There are 785.61M free.; (0x0x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
120 | try:
121 |
122 | unet_training_state, rng, loss = jax_pmapped_training_step(
123 | unet_training_state,
124 | rng,
125 | shard(batch),
126 | )
127 |
128 | except XlaRuntimeError as e:
129 |
130 | if is_compilation_step:
131 | stop_trace()
132 | save_device_memory_profile(
133 | filename="./profiling/compilation_step/pprof_memory_profile_error.pb"
134 | )
135 | print("compilation batch error...")
136 | elif is_first_compiled_step:
137 | stop_trace()
138 | save_device_memory_profile(
139 | filename="./profiling/first_compiled_step/pprof_memory_profile_error.pb"
140 | )
141 | print("first compiled batch error...")
142 |
143 | raise (e)
144 |
145 | # block until train step has completed
146 | loss.block_until_ready()
147 |
148 | if is_compilation_step:
149 | stop_trace()
150 | save_device_memory_profile(
151 | filename="./profiling/compilation_step/pprof_memory_profile.pb"
152 | )
153 | print("computed compilation batch...")
154 | elif is_first_compiled_step:
155 | stop_trace()
156 | save_device_memory_profile(
157 | filename="./profiling/first_compiled_step/pprof_memory_profile.pb"
158 | )
159 | print("computed first compiled batch...")
160 |
161 | global_training_steps += num_devices
162 |
163 | # checking if current batch is a milestone
164 | is_milestone = (
165 | True if global_training_steps % milestone_step_count == 0 else False
166 | )
167 |
168 | if log_wandb:
169 | # TODO: is this correct? was only unreplicated before, with no averaging
170 | global_walltime = time.monotonic() - t0
171 | delta_time = time.monotonic() - batch_walltime
172 | wandb_log_batch(
173 | global_walltime,
174 | global_training_steps,
175 | delta_time,
176 | epoch,
177 | loss,
178 | unet_training_state.params,
179 | is_milestone,
180 | )
181 |
182 | if is_milestone:
183 | save_to_local_directory(
184 | f"{ output_dir }/{ str(global_training_steps).zfill(12) }",
185 | unet,
186 | # TODO: is this ok?
187 | # was: jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
188 | # then: jax.device_get(flax.jax_utils.unreplicate(state.params))
189 | # and then, also: jax.device_get(state.params)
190 | # and then, again: unreplicate(state.params)
191 | # Finally found a way to average along the splits/device/partition/shard axis: jax.tree_util.tree_map(f=lambda x: x.mean(axis=0), tree=unet_training_state.params),
192 | unreplicate(tree=unet_training_state.params),
193 | )
194 |
195 | if is_compilation_step:
196 | is_compilation_step = False
197 | is_first_compiled_step = True
198 | elif is_first_compiled_step:
199 | is_first_compiled_step = False
200 |
--------------------------------------------------------------------------------
/training_step.py:
--------------------------------------------------------------------------------
1 | import jax
2 |
3 | from loss import get_compute_losses_lambda
4 |
5 |
6 | def get_training_step_lambda(text_encoder, text_encoder_params, vae, vae_params, unet):
7 |
8 | # Get loss function lambda
9 | # TODO: Are we copying all this static data on every batch, here?
10 | # TODO: Solution #1: avoid copying the static data at every batch
11 | # TODO: Solution #2: offload freezed model computing to CPU, at lease for the text encoding
12 | # Compile loss function.
13 | # NOTE: Can't have this compiled higher up because jax.value_and_grad-compiled functions require real numbers (floating point) dtypes as arguments
14 | jax_loss_value_and_gradient = jax.value_and_grad(
15 | fun=get_compute_losses_lambda(
16 | text_encoder,
17 | text_encoder_params,
18 | vae,
19 | vae_params,
20 | unet,
21 | ),
22 | argnums=0,
23 | )
24 |
25 | def __training_step_lambda(
26 | state,
27 | rng,
28 | batch,
29 | ):
30 |
31 | # Split RNGs
32 | sample_rng, new_rng = jax.random.split(rng, 2)
33 |
34 | # Compute loss and gradients
35 | # TODO: why are we doing this here instead of in "value_and_grad" with "reduce_axes"?
36 | loss, grad = jax.lax.pmean(
37 | jax_loss_value_and_gradient(
38 | state.params,
39 | batch,
40 | sample_rng,
41 | ),
42 | axis_name="batch",
43 | )
44 |
45 | # Apply gradients to training state
46 | new_state = state.apply_gradients(grads=grad)
47 |
48 | return new_state, new_rng, loss
49 |
50 | return __training_step_lambda
51 |
--------------------------------------------------------------------------------
/validation.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import jax
4 | import jax
5 | import jax.numpy as jnp
6 | from flax.jax_utils import replicate
7 | from flax.training.common_utils import shard
8 |
9 | import numpy as np
10 | from PIL import Image
11 |
12 | from diffusers import (
13 | FlaxAutoencoderKL,
14 | FlaxDPMSolverMultistepScheduler,
15 | FlaxUNet2DConditionModel,
16 | )
17 | from transformers import ByT5Tokenizer
18 |
19 | from architecture import setup_model
20 |
21 |
22 | # TODO: try half-precision
23 |
24 | tokenized_prompt_max_length = 1024
25 |
26 |
27 | def tokenize_prompts(prompt: list[str]):
28 | return ByT5Tokenizer()(
29 | text=prompt,
30 | max_length=tokenized_prompt_max_length,
31 | padding="max_length",
32 | truncation=True,
33 | return_tensors="jax",
34 | ).input_ids
35 |
36 |
37 | def convert_images(images: jnp.ndarray):
38 | # create PIL image from JAX tensor converted to numpy
39 | return [Image.fromarray(np.asarray(image), mode="RGB") for image in images]
40 |
41 |
42 | def get_validation_predictions_lambda(
43 | vae: FlaxAutoencoderKL,
44 | vae_params,
45 | unet: FlaxUNet2DConditionModel,
46 | ):
47 |
48 | scheduler = FlaxDPMSolverMultistepScheduler.from_config(
49 | config={
50 | "_diffusers_version": "0.16.0",
51 | "beta_end": 0.012,
52 | "beta_schedule": "scaled_linear",
53 | "beta_start": 0.00085,
54 | "clip_sample": False,
55 | "num_train_timesteps": 1000,
56 | "prediction_type": "v_prediction",
57 | "set_alpha_to_one": False,
58 | "skip_prk_steps": True,
59 | "steps_offset": 1,
60 | "trained_betas": None,
61 | }
62 | )
63 | timesteps = 20
64 |
65 | vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
66 |
67 | image_width = image_height = 256
68 |
69 | # Generating latent shape
70 | latent_shape = (
71 | 1536,
72 | unet.in_channels,
73 | image_width // vae_scale_factor,
74 | image_height // vae_scale_factor,
75 | )
76 |
77 | def __predict_images(seed, unet_params, encoded_prompts):
78 | def ___timestep(step, step_args):
79 | latents, scheduler_state = step_args
80 |
81 | t = jnp.asarray(scheduler_state.timesteps, dtype=jnp.int32)[step]
82 |
83 | timestep = jnp.array(jnp.broadcast_to(t, latents.shape[0]), dtype=jnp.int32)
84 |
85 | scaled_latent_input = jnp.array(
86 | scheduler.scale_model_input(scheduler_state, latents, t)
87 | )
88 |
89 | # predict the noise residual
90 | unet_prediction_sample = unet.apply(
91 | {"params": unet_params},
92 | sample=scaled_latent_input,
93 | timesteps=timestep,
94 | encoder_hidden_states=encoded_prompts,
95 | train=False,
96 | ).sample
97 |
98 | # compute the previous noisy sample x_t -> x_t-1
99 | return scheduler.step(
100 | scheduler_state, unet_prediction_sample, t, latents
101 | ).to_tuple()
102 |
103 | # initialize scheduler state
104 | initial_scheduler_state = scheduler.set_timesteps(
105 | scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
106 | )
107 |
108 | # initialize latents
109 | initial_latents = (
110 | jax.random.normal(
111 | jax.random.PRNGKey(seed), shape=latent_shape, dtype=jnp.float32
112 | )
113 | * initial_scheduler_state.init_noise_sigma
114 | )
115 |
116 | # get denoises latents
117 | final_latents, _ = jax.lax.fori_loop(
118 | 0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
119 | )
120 |
121 | # scale latents
122 | scaled_final_latents = 1 / vae.config.scaling_factor * final_latents
123 |
124 | # get image from latents
125 | vae_output = vae.apply(
126 | {"params": vae_params},
127 | latents=scaled_final_latents,
128 | deterministic=True,
129 | method=vae.decode,
130 | ).sample
131 |
132 | # return 8 bit RGB image (width, height, rgb)
133 | return (
134 | (
135 | (vae_output / 2 + 0.5) # TODO: find out why this is necessary
136 | .transpose(
137 | 0, 2, 3, 1
138 | ) # (batch, channel, height, width) => (batch, height, width, channel)
139 | .clip(0, 1)
140 | * 255
141 | )
142 | .round()
143 | .astype(jnp.uint8)
144 | )
145 |
146 | return lambda seed, unet_params, encoded_prompts: __predict_images(
147 | seed, unet_params, encoded_prompts
148 | )
149 |
150 |
151 | if __name__ == "__main__":
152 | # Pretrained/freezed and training model setup
153 | text_encoder, text_encoder_params, vae, vae_params, unet, unet_params = setup_model(
154 | 43, # seed
155 | None, # dtype (defaults to float32)
156 | True, # load pre-trained
157 | "character-aware-diffusion/charred",
158 | None,
159 | )
160 | # validation prompts
161 | validation_prompts = [
162 | "a white car",
163 | "une voiture blanche",
164 | "a running shoe",
165 | "une chaussure de course",
166 | "a perfumer and his perfume organ",
167 | "un parfumeur et son orgue à parfums",
168 | "two people",
169 | "deux personnes",
170 | "a happy cartoon cat",
171 | "un dessin de chat heureux",
172 | "a city skyline",
173 | "un panorama urbain",
174 | "a Marilyn Monroe portrait",
175 | "un portrait de Marilyn Monroe",
176 | "a rainy day in London",
177 | "Londres sous la pluie",
178 | ]
179 |
180 | tokenized_prompts = tokenize_prompts(validation_prompts)
181 |
182 | encoded_prompts = text_encoder(
183 | tokenized_prompts,
184 | params=text_encoder_params,
185 | train=False,
186 | )[0]
187 |
188 | validation_predictions_lambda = get_validation_predictions_lambda(
189 | vae,
190 | vae_params,
191 | unet,
192 | )
193 |
194 | get_validation_predictions = jax.pmap(
195 | fun=validation_predictions_lambda,
196 | axis_name="encoded_prompts",
197 | donate_argnums=(),
198 | )
199 |
200 | image_predictions = get_validation_predictions(
201 | replicate(2), replicate(unet_params), shard(encoded_prompts)
202 | )
203 |
204 | images = convert_images(image_predictions)
205 |
--------------------------------------------------------------------------------
/validation.sh:
--------------------------------------------------------------------------------
1 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=34359738368
2 | #export JAX_PLATFORMS=""
3 | export TF_CPP_MIN_LOG_LEVEL=2
4 | python3 ./validation.py
--------------------------------------------------------------------------------