├── .gitignore
├── LICENCE.txt
├── README.md
├── assets
├── LBM.jpg
├── depth_normal.jpg
├── object_removal.jpg
├── relight.gif
├── relight.jpg
├── relight_2.gif
├── shadow_control.gif
└── upscaler.jpg
├── examples
├── inference
│ ├── gradio_demo.py
│ ├── inference.py
│ └── utils.py
└── training
│ ├── config
│ └── surface.yaml
│ └── train_lbm_surface.py
├── pyproject.toml
├── requirements.txt
├── src
└── lbm
│ ├── config.py
│ ├── data
│ ├── __init__.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── collation_fn.py
│ │ ├── dataset.py
│ │ └── datasets_config.py
│ ├── filters
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── filter_wrapper.py
│ │ ├── filters.py
│ │ └── filters_config.py
│ └── mappers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── mappers.py
│ │ ├── mappers_config.py
│ │ └── mappers_wrapper.py
│ ├── inference
│ ├── __init__.py
│ ├── inference.py
│ └── utils.py
│ ├── models
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ └── model_config.py
│ ├── embedders
│ │ ├── __init__.py
│ │ ├── base
│ │ │ ├── __init__.py
│ │ │ ├── base_conditioner.py
│ │ │ └── base_conditioner_config.py
│ │ ├── conditioners_wrapper.py
│ │ └── latents_concat
│ │ │ ├── __init__.py
│ │ │ ├── latents_concat_embedder_config.py
│ │ │ └── latents_concat_embedder_model.py
│ ├── lbm
│ │ ├── __init__.py
│ │ ├── lbm_config.py
│ │ └── lbm_model.py
│ ├── unets
│ │ ├── __init__.py
│ │ └── unet.py
│ ├── utils.py
│ └── vae
│ │ ├── __init__.py
│ │ ├── autoencoderKL.py
│ │ └── autoencoderKL_config.py
│ └── trainer
│ ├── __init__.py
│ ├── loggers.py
│ ├── trainer.py
│ ├── training_config.py
│ └── utils.py
└── tests
├── README.md
├── requirements.txt
├── test_dataset
├── test_filters.py
└── test_mappers.py
├── test_lbm
└── test_lbm.py
├── test_unets
└── test_unets_wrappers.py
└── test_vaes
└── test_autoencoder.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.ipynb
2 | *.png
3 | *.pyc
4 | *.gradio
5 | *.safetensors
6 | examples/inference/ckpts/*
7 | examples/inference/examples/*
8 | envs/*
9 | checkpoints/*
10 | *.sh
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Latent Bridge Matching (LBM)
2 |
3 | This repository is the official implementation of the paper [LBM: Latent Bridge Matching for Fast Image-to-Image Translation](http://arxiv.org/abs/2503.07535).
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 | DEMO space
38 |
39 |
40 |
41 |
42 | ## Abstract
43 | In this paper, we introduce Latent Bridge Matching (LBM), a new, versatile and scalable method that relies on Bridge Matching in a latent space to achieve fast image-to-image translation. We show that the method can reach state-of-the-art results for various image-to-image tasks using only a single inference step. In addition to its efficiency, we also demonstrate the versatility of the method across different image translation tasks such as object removal, normal and depth estimation, and object relighting. We also derive a conditional framework of LBM and demonstrate its effectiveness by tackling the tasks of controllable image relighting and shadow generation.
44 |
45 |
46 |
47 |
48 |
49 | ## License
50 | This code is released under the **Creative Commons BY-NC 4.0 license**.
51 |
52 | ## Considered Use-cases
53 | We validate the method on various use-cases such as object relighting, image restoration, object removal, depth and normal maps estimation as well as controllable object relighting and shadow generation.
54 |
55 | Image Relighting 🔦
56 |
57 | For object relighting, the method should translate the encoded source images created by pasting the foreground onto the target background image to the desired target relighted image.
58 |
59 |
60 |
61 |
62 |
63 |
64 | Image Restoration 🧹
65 |
66 | In the context of image restoration, the method shall transport the distribution of the degraded images to the distribution of the clean images.
67 |
68 |
69 |
70 |
71 |
72 |
73 | Object Removal ✂️
74 | For object removal, the model is trained to find a transport map from the masked images to the images without the objects
75 |
76 |
77 |
78 |
79 |
80 | Controllable Image Relighting and Shadow Generation🕹️
81 |
82 | We also derive a conditional framework of LBM and demonstrate its effectiveness by tackling the tasks of controllable image relighting and shadow generation
83 |
84 |
85 |
86 |
87 |
88 |
89 | Normals and Depth Maps Estimation 🗺️
90 |
91 | Finally, we also consider common tasks such as normal and depth estimation where the model should translate an input image into a normal or depth map
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 | ## Setup
101 | To be up and running, you need first to create a virtual env with at least python3.10 installed and activate it
102 |
103 | ### With venv
104 | ```bash
105 | python3.10 -m venv envs/lbm
106 | source envs/lbm/bin/activate
107 | ```
108 |
109 | ### With conda
110 | ```bash
111 | conda create -n lbm python=3.10
112 | conda activate lbm
113 | ```
114 |
115 | Then install the required dependencies and the repo in editable mode
116 |
117 | ```bash
118 | pip install --upgrade pip
119 | pip install -e .
120 | ```
121 |
122 | ## Inference
123 |
124 | We provide in `examples` a simple script to perform depth and normal estimation using the proposed method.
125 |
126 | ```bash
127 | python examples/inference/inference.py \
128 | --model_name [depth|normals|relighting] \
129 | --source_image path_to_your_image.jpg \
130 | --output_path output_images
131 | ```
132 |
133 | See the trained models on the HF Hub 🤗
134 | - [Surface normals Checkpoint](https://huggingface.co/jasperai/LBM_normals)
135 | - [Depth Checkpoint](https://huggingface.co/jasperai/LBM_depth)
136 | - [Relighting Checkpoint](https://huggingface.co/jasperai/LBM_relighting)
137 |
138 | ## Local Gradio Demo
139 | To run the local gradio demo, just run the following command:
140 | ```bash
141 | python examples/inference/gradio_demo.py
142 | ```
143 | It will download the pretrained model from the HF Hub as well as example images.
144 |
145 | ## Training
146 | We provide in `examples\training` an example of a script to train a LBM for surface normal predictions on [`hypersim`](https://github.com/apple/ml-hypersim) see [this](https://github.com/prs-eth/Marigold/blob/main/script/dataset_preprocess/hypersim/README.md) for data processing.
147 |
148 | In `examples\trainig\configs`, you will find the configuration `yaml` associated to the training script. The only thing you need to do is to amend the `SHARDS_PATH_OR_URLS` section of the `yaml` so the model is trained on your own data.
149 |
150 | Please note that this package uses [`webdataset`](https://github.com/webdataset/webdataset) to handle the datastream and so the urls you use should be fomatted according to the [`webdataset format`](https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format). In particular, for this example, each sample in your `.tar` files needs to be composed of a `jpg` file containing the image, a `normal.png` file containing the target normals as well as a `mask.png` containing a mask indicating the valid pixels
151 |
152 | ```
153 | sample = {
154 | "jpg": source_image,
155 | "normal.png": normals # target_image
156 | "mask.png": mask # mask of valid pixels
157 | }
158 | ```
159 |
160 | To train the model, you can use the following command:
161 |
162 | ```bash
163 | python examples/training/train_lbm_surface.py examples/training/config/surface.yaml
164 | ```
165 |
166 | *Note*: Make sure to update the relevant section of the `yaml` file to use your own data and log the results on your own [WandB](https://wandb.ai/site).
167 |
168 | ## Citation
169 | If you find this work useful or use it in your research, please consider citing us
170 | ```bibtex
171 | @article{chadebec2025lbm,
172 | title={LBM: Latent Bridge Matching for Fast Image-to-Image Translation},
173 | author={Clément Chadebec and Onur Tasar and Sanjeev Sreetharan and Benjamin Aubin},
174 | year={2025},
175 | journal = {arXiv preprint arXiv:2503.07535},
176 | }
177 | ```
178 |
--------------------------------------------------------------------------------
/assets/LBM.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/LBM.jpg
--------------------------------------------------------------------------------
/assets/depth_normal.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/depth_normal.jpg
--------------------------------------------------------------------------------
/assets/object_removal.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/object_removal.jpg
--------------------------------------------------------------------------------
/assets/relight.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/relight.gif
--------------------------------------------------------------------------------
/assets/relight.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/relight.jpg
--------------------------------------------------------------------------------
/assets/relight_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/relight_2.gif
--------------------------------------------------------------------------------
/assets/shadow_control.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/shadow_control.gif
--------------------------------------------------------------------------------
/assets/upscaler.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/upscaler.jpg
--------------------------------------------------------------------------------
/examples/inference/gradio_demo.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import logging
3 | import os
4 | from copy import deepcopy
5 |
6 | import gradio as gr
7 | import numpy as np
8 | import PIL
9 | import torch
10 | from huggingface_hub import snapshot_download
11 | from PIL import Image
12 | from torchvision.transforms import ToPILImage, ToTensor
13 | from transformers import AutoModelForImageSegmentation
14 | from utils import extract_object, resize_and_center_crop
15 |
16 | from lbm.inference import get_model
17 |
18 | PATH = os.path.dirname(os.path.abspath(__file__))
19 | os.environ["GRADIO_TEMP_DIR"] = ".gradio"
20 |
21 |
22 | if not os.path.exists(os.path.join(PATH, "ckpts", "relighting")):
23 | logging.info(f"Downloading relighting LBM model from HF hub...")
24 | model = get_model(
25 | f"jasperai/LBM_relighting",
26 | save_dir=os.path.join(PATH, "ckpts", "relighting"),
27 | torch_dtype=torch.bfloat16,
28 | device="cuda",
29 | )
30 | else:
31 | model_dir = os.path.join(PATH, "ckpts", "relighting")
32 | logging.info(f"Loading relighting LBM model from local...")
33 | model = get_model(
34 | os.path.join(PATH, "ckpts", "relighting"),
35 | torch_dtype=torch.bfloat16,
36 | device="cuda",
37 | )
38 |
39 | ASPECT_RATIOS = {
40 | str(512 / 2048): (512, 2048),
41 | str(1024 / 1024): (1024, 1024),
42 | str(2048 / 512): (2048, 512),
43 | str(896 / 1152): (896, 1152),
44 | str(1152 / 896): (1152, 896),
45 | str(512 / 1920): (512, 1920),
46 | str(640 / 1536): (640, 1536),
47 | str(768 / 1280): (768, 1280),
48 | str(1280 / 768): (1280, 768),
49 | str(1536 / 640): (1536, 640),
50 | str(1920 / 512): (1920, 512),
51 | }
52 |
53 | birefnet = AutoModelForImageSegmentation.from_pretrained(
54 | "ZhengPeng7/BiRefNet", trust_remote_code=True
55 | ).cuda()
56 | image_size = (1024, 1024)
57 |
58 | if not os.path.exists(os.path.join(PATH, "examples")):
59 | logging.info(f"Downloading backgrounds from HF hub...")
60 | _ = snapshot_download(
61 | "jasperai/LBM_relighting",
62 | repo_type="space",
63 | allow_patterns="*.jpg",
64 | local_dir=PATH,
65 | )
66 |
67 |
68 | def evaluate(
69 | fg_image: PIL.Image.Image,
70 | bg_image: PIL.Image.Image,
71 | num_sampling_steps: int = 1,
72 | ):
73 |
74 | ori_h_bg, ori_w_bg = fg_image.size
75 | ar_bg = ori_h_bg / ori_w_bg
76 | closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
77 | dimensions_bg = ASPECT_RATIOS[closest_ar_bg]
78 |
79 | _, fg_mask = extract_object(birefnet, deepcopy(fg_image))
80 |
81 | fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1])
82 | fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1])
83 | bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1])
84 |
85 | img_pasted = Image.composite(fg_image, bg_image, fg_mask)
86 |
87 | img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1
88 | batch = {
89 | "source_image": img_pasted_tensor.cuda().to(torch.bfloat16),
90 | }
91 |
92 | z_source = model.vae.encode(batch[model.source_key])
93 |
94 | output_image = model.sample(
95 | z=z_source,
96 | num_steps=num_sampling_steps,
97 | conditioner_inputs=batch,
98 | max_samples=1,
99 | ).clamp(-1, 1)
100 |
101 | output_image = (output_image[0].float().cpu() + 1) / 2
102 | output_image = ToPILImage()(output_image)
103 |
104 | # paste the output image on the background image
105 | output_image = Image.composite(output_image, bg_image, fg_mask)
106 |
107 | output_image.resize((ori_h_bg, ori_w_bg))
108 |
109 | return (np.array(img_pasted), np.array(output_image))
110 |
111 |
112 | with gr.Blocks(title="LBM Object Relighting") as demo:
113 | gr.Markdown(
114 | f"""
115 | # Object Relighting with Latent Bridge Matching
116 | This is an interactive demo of [LBM: Latent Bridge Matching for Fast Image-to-Image Translation](https://arxiv.org/abs/2503.07535) *by Jasper Research*. This demo is based on the [LBM relighting checkpoint](https://huggingface.co/jasperai/LBM_relighting).
117 | """
118 | )
119 | gr.Markdown(
120 | """
121 | If you enjoy the space, please also promote *open-source* by giving a ⭐ to the Github Repo.
122 | """
123 | )
124 |
125 | with gr.Row():
126 | with gr.Column():
127 | with gr.Row():
128 | fg_image = gr.Image(
129 | type="pil",
130 | label="Input Image",
131 | image_mode="RGB",
132 | height=360,
133 | # width=360,
134 | )
135 | bg_image = gr.Image(
136 | type="pil",
137 | label="Target Background",
138 | image_mode="RGB",
139 | height=360,
140 | # width=360,
141 | )
142 |
143 | with gr.Row():
144 | submit_button = gr.Button("Relight", variant="primary")
145 | with gr.Row():
146 | num_inference_steps = gr.Slider(
147 | minimum=1,
148 | maximum=4,
149 | value=1,
150 | step=1,
151 | label="Number of Inference Steps",
152 | )
153 |
154 | bg_gallery = gr.Gallery(
155 | # height=450,
156 | object_fit="contain",
157 | label="Background List",
158 | value=[
159 | path
160 | for path in glob.glob(
161 | os.path.join(PATH, "examples/backgrounds/*.jpg")
162 | )
163 | ],
164 | columns=5,
165 | allow_preview=False,
166 | )
167 |
168 | with gr.Column():
169 | output_slider = gr.ImageSlider(label="Composite vs LBM", type="numpy")
170 | output_slider.upload(
171 | fn=evaluate,
172 | inputs=[fg_image, bg_image, num_inference_steps],
173 | outputs=[output_slider],
174 | )
175 |
176 | submit_button.click(
177 | evaluate,
178 | inputs=[fg_image, bg_image, num_inference_steps],
179 | outputs=[output_slider],
180 | )
181 |
182 | with gr.Row():
183 | gr.Examples(
184 | fn=evaluate,
185 | examples=[
186 | [
187 | os.path.join(PATH, "examples/foregrounds/2.jpg"),
188 | os.path.join(PATH, "examples/backgrounds/14.jpg"),
189 | 1,
190 | ],
191 | [
192 | os.path.join(PATH, "examples/foregrounds/10.jpg"),
193 | os.path.join(PATH, "examples/backgrounds/4.jpg"),
194 | 1,
195 | ],
196 | [
197 | os.path.join(PATH, "examples/foregrounds/11.jpg"),
198 | os.path.join(PATH, "examples/backgrounds/24.jpg"),
199 | 1,
200 | ],
201 | [
202 | os.path.join(PATH, "examples/foregrounds/19.jpg"),
203 | os.path.join(PATH, "examples/backgrounds/3.jpg"),
204 | 1,
205 | ],
206 | [
207 | os.path.join(PATH, "examples/foregrounds/4.jpg"),
208 | os.path.join(PATH, "examples/backgrounds/6.jpg"),
209 | 1,
210 | ],
211 | [
212 | os.path.join(PATH, "examples/foregrounds/14.jpg"),
213 | os.path.join(PATH, "examples/backgrounds/22.jpg"),
214 | 1,
215 | ],
216 | [
217 | os.path.join(PATH, "examples/foregrounds/12.jpg"),
218 | os.path.join(PATH, "examples/backgrounds/1.jpg"),
219 | 1,
220 | ],
221 | ],
222 | inputs=[fg_image, bg_image, num_inference_steps],
223 | outputs=[output_slider],
224 | run_on_click=True,
225 | )
226 |
227 | def bg_gallery_selected(gal, evt: gr.SelectData):
228 | return gal[evt.index][0]
229 |
230 | bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image)
231 |
232 | if __name__ == "__main__":
233 |
234 | demo.launch(share=True)
235 |
--------------------------------------------------------------------------------
/examples/inference/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 |
5 | import torch
6 | from PIL import Image
7 |
8 | from lbm.inference import evaluate, get_model
9 |
10 | PATH = os.path.dirname(os.path.abspath(__file__))
11 |
12 | logging.basicConfig(level=logging.INFO)
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--source_image", type=str, required=True)
16 | parser.add_argument("--output_path", type=str, required=True)
17 | parser.add_argument("--num_inference_steps", type=int, default=1)
18 | parser.add_argument(
19 | "--model_name",
20 | type=str,
21 | default="normals",
22 | choices=["normals", "depth", "relighting"],
23 | )
24 |
25 |
26 | args = parser.parse_args()
27 |
28 |
29 | def main():
30 | # download the weights from HF hub
31 | if not os.path.exists(os.path.join(PATH, "ckpts", f"{args.model_name}")):
32 | logging.info(f"Downloading {args.model_name} LBM model from HF hub...")
33 | model = get_model(
34 | f"jasperai/LBM_{args.model_name}",
35 | save_dir=os.path.join(PATH, "ckpts", f"{args.model_name}"),
36 | torch_dtype=torch.bfloat16,
37 | device="cuda",
38 | )
39 |
40 | else:
41 | model_dir = os.path.join(PATH, "ckpts", f"{args.model_name}")
42 | logging.info(f"Loading {args.model_name} LBM model from local...")
43 | model = get_model(model_dir, torch_dtype=torch.bfloat16, device="cuda")
44 |
45 | source_image = Image.open(args.source_image).convert("RGB")
46 |
47 | output_image = evaluate(model, source_image, args.num_inference_steps)
48 |
49 | os.makedirs(args.output_path, exist_ok=True)
50 |
51 | source_image.save(os.path.join(args.output_path, "source_image.jpg"))
52 | output_image.save(os.path.join(args.output_path, "output_image.jpg"))
53 |
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/examples/inference/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | from torchvision import transforms
4 |
5 |
6 | def extract_object(birefnet, img):
7 | # Data settings
8 | image_size = (1024, 1024)
9 | transform_image = transforms.Compose(
10 | [
11 | transforms.Resize(image_size),
12 | transforms.ToTensor(),
13 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
14 | ]
15 | )
16 |
17 | image = img
18 | input_images = transform_image(image).unsqueeze(0).cuda()
19 |
20 | # Prediction
21 | with torch.no_grad():
22 | preds = birefnet(input_images)[-1].sigmoid().cpu()
23 | pred = preds[0].squeeze()
24 | pred_pil = transforms.ToPILImage()(pred)
25 | mask = pred_pil.resize(image.size)
26 | image = Image.composite(image, Image.new("RGB", image.size, (127, 127, 127)), mask)
27 | return image, mask
28 |
29 |
30 | def resize_and_center_crop(image, target_width, target_height):
31 | original_width, original_height = image.size
32 | scale_factor = max(target_width / original_width, target_height / original_height)
33 | resized_width = int(round(original_width * scale_factor))
34 | resized_height = int(round(original_height * scale_factor))
35 | resized_image = image.resize((resized_width, resized_height), Image.LANCZOS)
36 | left = (resized_width - target_width) / 2
37 | top = (resized_height - target_height) / 2
38 | right = (resized_width + target_width) / 2
39 | bottom = (resized_height + target_height) / 2
40 | cropped_image = resized_image.crop((left, top, right, bottom))
41 | return cropped_image
42 |
--------------------------------------------------------------------------------
/examples/training/config/surface.yaml:
--------------------------------------------------------------------------------
1 | # wandb
2 | wandb_project: lbm-surface-flows
3 | timestep_sampling: custom_timesteps
4 | unet_input_channels: 4
5 | vae_num_channels: 4
6 | selected_timesteps: [250, 500, 750, 1000]
7 | prob: [0.25, 0.25, 0.25, 0.25]
8 | pixel_loss_type: lpips # l1 l2
9 | pixel_loss_weight: 10.0
10 | latent_loss_type: l2 # l1 l2
11 | latent_loss_weight: 1.0
12 | bridge_noise_sigma: 0.005
13 | conditioning_images_keys: []
14 | conditioning_masks_keys: []
15 |
16 | # SHARDS_PATH_OR_URLS
17 | train_shards:
18 | - pipe:cat PATH_TO_TRAIN_TARS
19 |
20 | validation_shards:
21 | - pipe:cat PATH_TO_VAL_TARS
22 |
23 | batch_size: 4
24 | learning_rate: 4e-5
25 | optimizer: AdamW
26 | num_steps: [1, 4]
27 | log_interval: 500
28 | resume_from_checkpoint: true
29 | max_epochs: 50
30 | save_interval: 5000
31 | save_ckpt_path: ./checkpoints
32 |
--------------------------------------------------------------------------------
/examples/training/train_lbm_surface.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import os
4 | import random
5 | import re
6 | import shutil
7 | from typing import List, Optional
8 |
9 | import braceexpand
10 | import fire
11 | import torch
12 | import yaml
13 | from diffusers import FlowMatchEulerDiscreteScheduler, StableDiffusionXLPipeline
14 | from diffusers.models import UNet2DConditionModel
15 | from diffusers.models.attention import BasicTransformerBlock
16 | from diffusers.models.resnet import ResnetBlock2D
17 | from pytorch_lightning import Trainer, loggers
18 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
19 | from pytorch_lightning.strategies import FSDPStrategy
20 | from torch.distributed.fsdp.wrap import ModuleWrapPolicy
21 | from torchvision.transforms import InterpolationMode
22 |
23 | from lbm.data.datasets import DataModule, DataModuleConfig
24 | from lbm.data.filters import KeyFilter, KeyFilterConfig
25 | from lbm.data.mappers import (
26 | KeyRenameMapper,
27 | KeyRenameMapperConfig,
28 | MapperWrapper,
29 | RescaleMapper,
30 | RescaleMapperConfig,
31 | TorchvisionMapper,
32 | TorchvisionMapperConfig,
33 | )
34 | from lbm.models.embedders import (
35 | ConditionerWrapper,
36 | LatentsConcatEmbedder,
37 | LatentsConcatEmbedderConfig,
38 | )
39 | from lbm.models.lbm import LBMConfig, LBMModel
40 | from lbm.models.unets import DiffusersUNet2DCondWrapper
41 | from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig
42 | from lbm.trainer import TrainingConfig, TrainingPipeline
43 | from lbm.trainer.loggers import WandbSampleLogger
44 | from lbm.trainer.utils import StateDictAdapter
45 |
46 |
47 | def get_model(
48 | backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0",
49 | vae_num_channels: int = 4,
50 | unet_input_channels: int = 4,
51 | timestep_sampling: str = "log_normal",
52 | selected_timesteps: Optional[List[float]] = None,
53 | prob: Optional[List[float]] = None,
54 | conditioning_images_keys: Optional[List[str]] = [],
55 | conditioning_masks_keys: Optional[List[str]] = [],
56 | source_key: str = "source_image",
57 | target_key: str = "source_image_paste",
58 | mask_key: str = "mask",
59 | bridge_noise_sigma: float = 0.0,
60 | logit_mean: float = 0.0,
61 | logit_std: float = 1.0,
62 | pixel_loss_type: str = "lpips",
63 | latent_loss_type: str = "l2",
64 | latent_loss_weight: float = 1.0,
65 | pixel_loss_weight: float = 0.0,
66 | ):
67 |
68 | conditioners = []
69 |
70 | # Load pretrained model as base
71 | pipe = StableDiffusionXLPipeline.from_pretrained(
72 | backbone_signature,
73 | torch_dtype=torch.bfloat16,
74 | )
75 |
76 | ### MMMDiT ###
77 | # Get Architecture
78 | denoiser = DiffusersUNet2DCondWrapper(
79 | in_channels=unet_input_channels, # Add downsampled_image
80 | out_channels=vae_num_channels,
81 | center_input_sample=False,
82 | flip_sin_to_cos=True,
83 | freq_shift=0,
84 | down_block_types=[
85 | "DownBlock2D",
86 | "CrossAttnDownBlock2D",
87 | "CrossAttnDownBlock2D",
88 | ],
89 | mid_block_type="UNetMidBlock2DCrossAttn",
90 | up_block_types=["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
91 | only_cross_attention=False,
92 | block_out_channels=[320, 640, 1280],
93 | layers_per_block=2,
94 | downsample_padding=1,
95 | mid_block_scale_factor=1,
96 | dropout=0.0,
97 | act_fn="silu",
98 | norm_num_groups=32,
99 | norm_eps=1e-05,
100 | cross_attention_dim=[320, 640, 1280],
101 | transformer_layers_per_block=[1, 2, 10],
102 | reverse_transformer_layers_per_block=None,
103 | encoder_hid_dim=None,
104 | encoder_hid_dim_type=None,
105 | attention_head_dim=[5, 10, 20],
106 | num_attention_heads=None,
107 | dual_cross_attention=False,
108 | use_linear_projection=True,
109 | class_embed_type=None,
110 | addition_embed_type=None,
111 | addition_time_embed_dim=None,
112 | num_class_embeds=None,
113 | upcast_attention=None,
114 | resnet_time_scale_shift="default",
115 | resnet_skip_time_act=False,
116 | resnet_out_scale_factor=1.0,
117 | time_embedding_type="positional",
118 | time_embedding_dim=None,
119 | time_embedding_act_fn=None,
120 | timestep_post_act=None,
121 | time_cond_proj_dim=None,
122 | conv_in_kernel=3,
123 | conv_out_kernel=3,
124 | projection_class_embeddings_input_dim=None,
125 | attention_type="default",
126 | class_embeddings_concat=False,
127 | mid_block_only_cross_attention=None,
128 | cross_attention_norm=None,
129 | addition_embed_type_num_heads=64,
130 | ).to(torch.bfloat16)
131 |
132 | state_dict = pipe.unet.state_dict()
133 |
134 | del state_dict["add_embedding.linear_1.weight"]
135 | del state_dict["add_embedding.linear_1.bias"]
136 | del state_dict["add_embedding.linear_2.weight"]
137 | del state_dict["add_embedding.linear_2.bias"]
138 |
139 | # Adapt the shapes
140 | state_dict_adapter = StateDictAdapter()
141 | state_dict = state_dict_adapter(
142 | model_state_dict=denoiser.state_dict(),
143 | checkpoint_state_dict=state_dict,
144 | regex_keys=[
145 | r"class_embedding.linear_\d+.(weight|bias)",
146 | r"conv_in.weight",
147 | r"(down_blocks|up_blocks)\.\d+\.attentions\.\d+\.transformer_blocks\.\d+\.attn\d+\.(to_k|to_v)\.weight",
148 | r"mid_block\.attentions\.\d+\.transformer_blocks\.\d+\.attn\d+\.(to_k|to_v)\.weight",
149 | ],
150 | strategy="zeros",
151 | )
152 |
153 | denoiser.load_state_dict(state_dict, strict=True)
154 |
155 | del pipe
156 |
157 | if conditioning_images_keys != [] or conditioning_masks_keys != []:
158 |
159 | latents_concat_embedder_config = LatentsConcatEmbedderConfig(
160 | image_keys=conditioning_images_keys,
161 | mask_keys=conditioning_masks_keys,
162 | )
163 | latent_concat_embedder = LatentsConcatEmbedder(latents_concat_embedder_config)
164 | latent_concat_embedder.freeze()
165 | conditioners.append(latent_concat_embedder)
166 |
167 | # Wrap conditioners and set to device
168 | conditioner = ConditionerWrapper(
169 | conditioners=conditioners,
170 | )
171 |
172 | ## VAE ##
173 | # Get VAE model
174 | vae_config = AutoencoderKLDiffusersConfig(
175 | version=backbone_signature,
176 | subfolder="vae",
177 | tiling_size=(128, 128),
178 | )
179 | vae = AutoencoderKLDiffusers(vae_config)
180 | vae.freeze()
181 | vae.to(torch.bfloat16)
182 |
183 | # LBM Config
184 | config = LBMConfig(
185 | ucg_keys=None,
186 | source_key=source_key,
187 | target_key=target_key,
188 | mask_key=mask_key,
189 | latent_loss_weight=latent_loss_weight,
190 | latent_loss_type=latent_loss_type,
191 | pixel_loss_type=pixel_loss_type,
192 | pixel_loss_weight=pixel_loss_weight,
193 | timestep_sampling=timestep_sampling,
194 | logit_mean=logit_mean,
195 | logit_std=logit_std,
196 | selected_timesteps=selected_timesteps,
197 | prob=prob,
198 | bridge_noise_sigma=bridge_noise_sigma,
199 | )
200 |
201 | training_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
202 | backbone_signature,
203 | subfolder="scheduler",
204 | )
205 | sampling_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
206 | backbone_signature,
207 | subfolder="scheduler",
208 | )
209 |
210 | # LBM Model
211 | model = LBMModel(
212 | config,
213 | denoiser=denoiser,
214 | training_noise_scheduler=training_noise_scheduler,
215 | sampling_noise_scheduler=sampling_noise_scheduler,
216 | vae=vae,
217 | conditioner=conditioner,
218 | ).to(torch.bfloat16)
219 |
220 | return model
221 |
222 |
223 | def get_filter_mappers():
224 | filters_mappers = [
225 | KeyFilter(KeyFilterConfig(keys=["jpg", "normal_aligned.png", "mask.png"])),
226 | MapperWrapper(
227 | [
228 | KeyRenameMapper(
229 | KeyRenameMapperConfig(
230 | key_map={
231 | "jpg": "image",
232 | "normal_aligned.png": "normal",
233 | "mask.png": "mask",
234 | }
235 | )
236 | ),
237 | TorchvisionMapper(
238 | TorchvisionMapperConfig(
239 | key="image",
240 | transforms=["ToTensor", "Resize"],
241 | transforms_kwargs=[
242 | {},
243 | {
244 | "size": (480, 640),
245 | "interpolation": InterpolationMode.NEAREST_EXACT,
246 | },
247 | ],
248 | )
249 | ),
250 | TorchvisionMapper(
251 | TorchvisionMapperConfig(
252 | key="normal",
253 | transforms=["ToTensor", "Resize"],
254 | transforms_kwargs=[
255 | {},
256 | {
257 | "size": (480, 640),
258 | "interpolation": InterpolationMode.NEAREST_EXACT,
259 | },
260 | ],
261 | )
262 | ),
263 | TorchvisionMapper(
264 | TorchvisionMapperConfig(
265 | key="mask",
266 | transforms=["ToTensor", "Resize", "Normalize"],
267 | transforms_kwargs=[
268 | {},
269 | {
270 | "size": (480, 640),
271 | "interpolation": InterpolationMode.NEAREST_EXACT,
272 | },
273 | {"mean": 0.0, "std": 1.0},
274 | ],
275 | )
276 | ),
277 | RescaleMapper(RescaleMapperConfig(key="image")),
278 | RescaleMapper(RescaleMapperConfig(key="normal")),
279 | ],
280 | ),
281 | ]
282 |
283 | return filters_mappers
284 |
285 |
286 | def get_data_module(
287 | train_shards: List[str],
288 | validation_shards: List[str],
289 | batch_size: int,
290 | ):
291 |
292 | # TRAIN
293 | train_filters_mappers = get_filter_mappers()
294 |
295 | # unbrace urls
296 | train_shards_path_or_urls_unbraced = []
297 | for train_shards_path_or_url in train_shards:
298 | train_shards_path_or_urls_unbraced.extend(
299 | braceexpand.braceexpand(train_shards_path_or_url)
300 | )
301 |
302 | # shuffle shards
303 | random.shuffle(train_shards_path_or_urls_unbraced)
304 |
305 | # data config
306 | data_config = DataModuleConfig(
307 | shards_path_or_urls=train_shards_path_or_urls_unbraced,
308 | decoder="pil",
309 | shuffle_before_split_by_node_buffer_size=20,
310 | shuffle_before_split_by_workers_buffer_size=20,
311 | shuffle_before_filter_mappers_buffer_size=20,
312 | shuffle_after_filter_mappers_buffer_size=20,
313 | per_worker_batch_size=batch_size,
314 | num_workers=min(10, len(train_shards_path_or_urls_unbraced)),
315 | )
316 |
317 | train_data_config = data_config
318 |
319 | # VALIDATION
320 | validation_filters_mappers = get_filter_mappers()
321 |
322 | # unbrace urls
323 | validation_shards_path_or_urls_unbraced = []
324 | for validation_shards_path_or_url in validation_shards:
325 | validation_shards_path_or_urls_unbraced.extend(
326 | braceexpand.braceexpand(validation_shards_path_or_url)
327 | )
328 |
329 | data_config = DataModuleConfig(
330 | shards_path_or_urls=validation_shards_path_or_urls_unbraced,
331 | decoder="pil",
332 | shuffle_before_split_by_node_buffer_size=10,
333 | shuffle_before_split_by_workers_buffer_size=10,
334 | shuffle_before_filter_mappers_buffer_size=10,
335 | shuffle_after_filter_mappers_buffer_size=10,
336 | per_worker_batch_size=batch_size,
337 | num_workers=min(10, len(train_shards_path_or_urls_unbraced)),
338 | )
339 |
340 | validation_data_config = data_config
341 |
342 | # data module
343 | data_module = DataModule(
344 | train_config=train_data_config,
345 | train_filters_mappers=train_filters_mappers,
346 | eval_config=validation_data_config,
347 | eval_filters_mappers=validation_filters_mappers,
348 | )
349 |
350 | return data_module
351 |
352 |
353 | def main(
354 | train_shards: List[str] = ["pipe:cat path/to/train/shards"],
355 | validation_shards: List[str] = ["pipe:cat path/to/validation/shards"],
356 | backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0",
357 | vae_num_channels: int = 4,
358 | unet_input_channels: int = 4,
359 | source_key: str = "image",
360 | target_key: str = "normal",
361 | mask_key: str = "mask",
362 | wandb_project: str = "lbm-surface",
363 | batch_size: int = 8,
364 | num_steps: List[int] = [1, 2, 4],
365 | learning_rate: float = 5e-5,
366 | learning_rate_scheduler: str = None,
367 | learning_rate_scheduler_kwargs: dict = {},
368 | optimizer: str = "AdamW",
369 | optimizer_kwargs: dict = {},
370 | timestep_sampling: str = "uniform",
371 | logit_mean: float = 0.0,
372 | logit_std: float = 1.0,
373 | pixel_loss_type: str = "lpips",
374 | latent_loss_type: str = "l2",
375 | latent_loss_weight: float = 1.0,
376 | pixel_loss_weight: float = 0.0,
377 | selected_timesteps: List[float] = None,
378 | prob: List[float] = None,
379 | conditioning_images_keys: Optional[List[str]] = [],
380 | conditioning_masks_keys: Optional[List[str]] = [],
381 | config_yaml: dict = None,
382 | save_ckpt_path: str = "./checkpoints",
383 | log_interval: int = 100,
384 | resume_from_checkpoint: bool = True,
385 | max_epochs: int = 100,
386 | bridge_noise_sigma: float = 0.005,
387 | save_interval: int = 1000,
388 | path_config: str = None,
389 | ):
390 | model = get_model(
391 | backbone_signature=backbone_signature,
392 | vae_num_channels=vae_num_channels,
393 | unet_input_channels=unet_input_channels,
394 | source_key=source_key,
395 | target_key=target_key,
396 | mask_key=mask_key,
397 | timestep_sampling=timestep_sampling,
398 | logit_mean=logit_mean,
399 | logit_std=logit_std,
400 | pixel_loss_type=pixel_loss_type,
401 | latent_loss_type=latent_loss_type,
402 | latent_loss_weight=latent_loss_weight,
403 | pixel_loss_weight=pixel_loss_weight,
404 | selected_timesteps=selected_timesteps,
405 | prob=prob,
406 | conditioning_images_keys=conditioning_images_keys,
407 | conditioning_masks_keys=conditioning_masks_keys,
408 | bridge_noise_sigma=bridge_noise_sigma,
409 | )
410 |
411 | data_module = get_data_module(
412 | train_shards=train_shards,
413 | validation_shards=validation_shards,
414 | batch_size=batch_size,
415 | )
416 |
417 | train_parameters = ["denoiser.*"]
418 |
419 | # Training Config
420 | training_config = TrainingConfig(
421 | learning_rate=learning_rate,
422 | lr_scheduler_name=learning_rate_scheduler,
423 | lr_scheduler_kwargs=learning_rate_scheduler_kwargs,
424 | log_keys=["image", "normal", "mask"],
425 | trainable_params=train_parameters,
426 | optimizer_name=optimizer,
427 | optimizer_kwargs=optimizer_kwargs,
428 | log_samples_model_kwargs={
429 | "input_shape": None,
430 | "num_steps": num_steps,
431 | },
432 | )
433 | if (
434 | os.path.exists(save_ckpt_path)
435 | and resume_from_checkpoint
436 | and "last.ckpt" in os.listdir(save_ckpt_path)
437 | ):
438 | start_ckpt = f"{save_ckpt_path}/last.ckpt"
439 | print(f"Resuming from checkpoint: {start_ckpt}")
440 |
441 | else:
442 | start_ckpt = None
443 |
444 | pipeline = TrainingPipeline(model=model, pipeline_config=training_config)
445 |
446 | pipeline.save_hyperparameters(
447 | {
448 | f"embedder_{i}": embedder.config.to_dict()
449 | for i, embedder in enumerate(model.conditioner.conditioners)
450 | }
451 | )
452 |
453 | pipeline.save_hyperparameters(
454 | {
455 | "denoiser": model.denoiser.config,
456 | "vae": model.vae.config.to_dict(),
457 | "config_yaml": config_yaml,
458 | "training": training_config.to_dict(),
459 | "training_noise_scheduler": model.training_noise_scheduler.config,
460 | "sampling_noise_scheduler": model.sampling_noise_scheduler.config,
461 | }
462 | )
463 |
464 | training_signature = (
465 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
466 | + "-LBM-Surface"
467 | + f"{os.environ['SLURM_JOB_ID']}"
468 | + f"_{os.environ.get('SLURM_ARRAY_TASK_ID', 0)}"
469 | )
470 | dir_path = f"{save_ckpt_path}/logs/{training_signature}"
471 | if os.environ["SLURM_PROCID"] == "0":
472 | os.makedirs(dir_path, exist_ok=True)
473 | if path_config is not None:
474 | shutil.copy(path_config, f"{save_ckpt_path}/config.yaml")
475 | run_name = training_signature
476 |
477 | # Ignore parameters unused during training
478 | ignore_states = []
479 | for name, param in pipeline.model.named_parameters():
480 | ignore = True
481 | for regex in ["denoiser."]:
482 | pattern = re.compile(regex)
483 | if re.match(pattern, name):
484 | ignore = False
485 | if ignore:
486 | ignore_states.append(param)
487 |
488 | # FSDP Strategy
489 | strategy = FSDPStrategy(
490 | auto_wrap_policy=ModuleWrapPolicy(
491 | [
492 | UNet2DConditionModel,
493 | BasicTransformerBlock,
494 | ResnetBlock2D,
495 | torch.nn.Conv2d,
496 | ]
497 | ),
498 | activation_checkpointing_policy=ModuleWrapPolicy(
499 | [
500 | BasicTransformerBlock,
501 | ResnetBlock2D,
502 | ]
503 | ),
504 | sharding_strategy="SHARD_GRAD_OP",
505 | ignored_states=ignore_states,
506 | )
507 |
508 | trainer = Trainer(
509 | accelerator="gpu",
510 | devices=int(os.environ["SLURM_NPROCS"]) // int(os.environ["SLURM_NNODES"]),
511 | num_nodes=int(os.environ["SLURM_NNODES"]),
512 | strategy=strategy,
513 | default_root_dir="logs",
514 | logger=loggers.WandbLogger(
515 | project=wandb_project, offline=False, name=run_name, save_dir=save_ckpt_path
516 | ),
517 | callbacks=[
518 | WandbSampleLogger(log_batch_freq=log_interval),
519 | LearningRateMonitor(logging_interval="step"),
520 | ModelCheckpoint(
521 | dirpath=save_ckpt_path,
522 | every_n_train_steps=save_interval,
523 | save_last=True,
524 | ),
525 | ],
526 | num_sanity_val_steps=0,
527 | precision="bf16-mixed",
528 | limit_val_batches=2,
529 | val_check_interval=1000,
530 | max_epochs=max_epochs,
531 | )
532 |
533 | trainer.fit(pipeline, data_module, ckpt_path=start_ckpt)
534 |
535 |
536 | def main_from_config(path_config: str = None):
537 | with open(path_config, "r") as file:
538 | config = yaml.safe_load(file)
539 | logging.info(
540 | f"Running main with config: {yaml.dump(config, default_flow_style=False)}"
541 | )
542 | main(**config, config_yaml=config, path_config=path_config)
543 |
544 |
545 | if __name__ == "__main__":
546 | fire.Fire(main_from_config)
547 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling", "hatch-requirements-txt"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "lbm"
7 | dynamic = ["dependencies", "optional-dependencies"]
8 | description = "LBM: Latent Bridge Matching for Fast Image-to-Image Translation"
9 | readme = "README.md"
10 | requires-python = ">=3.10"
11 | authors = [
12 | { name = "Clement Chadebec", email = "clement.chadebec@jasper.ai" },
13 | { name = "Benjamin Aubin", email = "benjamin.aubin@jasper.ai" },
14 | ]
15 | maintainers = [
16 | { name = "Clement Chadebec", email = "clement.chadebec@jasper.ai" },
17 | ]
18 | classifiers = [
19 | "Programming Language :: Python :: 3",
20 | "Programming Language :: Python :: 3.10",
21 | "Programming Language :: Python :: 3.11",
22 | "Programming Language :: Python :: 3.12",
23 | "License :: OSI Approved :: Apache Software License",
24 | "Operating System :: OS Independent",
25 | ]
26 | version = "0.1"
27 |
28 | [project.urls]
29 | Homepage = "https://github.com/gojasper/LBM"
30 | Repository = "https://github.com/gojasper/LBM"
31 |
32 | [tool.hatch.metadata]
33 | allow-direct-references = true
34 |
35 | [tool.hatch.metadata.hooks.requirements_txt]
36 | files = ["requirements.txt"]
37 |
38 | [tool.hatch.build.targets.wheel]
39 | packages = ["src/lbm"]
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.4.0
2 | diffusers==0.32.2
3 | torch==2.7.0
4 | torchvision>=0.20.0
5 | black==24.2.0
6 | einops==0.7.0
7 | fire>=0.5.0
8 | gradio==5.29.0
9 | isort==5.13.2
10 | lightning==2.5.0
11 | lpips==0.1.4
12 | opencv-python==4.9.0.80
13 | peft==0.9.0
14 | pydantic>=2.6.1
15 | scipy>=1.12.0
16 | sentencepiece>=0.2.0
17 | timm==0.9.16
18 | tokenizers>=0.15.2
19 | torch-fidelity>=0.3.0
20 | torchmetrics>=1.3.1
21 | transformers==4.42.3
22 | wandb==0.16.2
23 | webdataset>=0.2.86
24 | kornia==0.8.0
--------------------------------------------------------------------------------
/src/lbm/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import warnings
4 | from dataclasses import asdict, field
5 | from typing import Any, Dict, Union
6 |
7 | import yaml
8 | from pydantic import ValidationError
9 | from pydantic.dataclasses import dataclass
10 | from yaml import safe_load
11 |
12 |
13 | @dataclass
14 | class BaseConfig:
15 | """This is the BaseConfig class which defines all the useful loading and saving methods
16 | of the configs"""
17 |
18 | name: str = field(init=False)
19 |
20 | def __post_init__(self):
21 | self.name = self.__class__.__name__
22 |
23 | @classmethod
24 | def from_dict(cls, config_dict: Dict[str, Any]) -> "BaseConfig":
25 | """Creates a BaseConfig instance from a dictionnary
26 |
27 | Args:
28 | config_dict (dict): The Python dictionnary containing all the parameters
29 |
30 | Returns:
31 | :class:`BaseConfig`: The created instance
32 | """
33 | try:
34 | config = cls(**config_dict)
35 | except (ValidationError, TypeError) as e:
36 | raise e
37 | return config
38 |
39 | @classmethod
40 | def _dict_from_json(cls, json_path: Union[str, os.PathLike]) -> Dict[str, Any]:
41 | try:
42 | with open(json_path) as f:
43 | try:
44 | config_dict = json.load(f)
45 | return config_dict
46 |
47 | except (TypeError, json.JSONDecodeError) as e:
48 | raise TypeError(
49 | f"File {json_path} not loadable. Maybe not json ? \n"
50 | f"Catch Exception {type(e)} with message: " + str(e)
51 | ) from e
52 |
53 | except FileNotFoundError:
54 | raise FileNotFoundError(
55 | f"Config file not found. Please check path '{json_path}'"
56 | )
57 |
58 | @classmethod
59 | def from_json(cls, json_path: str) -> "BaseConfig":
60 | """Creates a BaseConfig instance from a JSON config file
61 |
62 | Args:
63 | json_path (str): The path to the json file containing all the parameters
64 |
65 | Returns:
66 | :class:`BaseConfig`: The created instance
67 | """
68 | config_dict = cls._dict_from_json(json_path)
69 |
70 | config_name = config_dict.pop("name")
71 |
72 | if cls.__name__ != config_name:
73 | warnings.warn(
74 | f"You are trying to load a "
75 | f"`{ cls.__name__}` while a "
76 | f"`{config_name}` is given."
77 | )
78 |
79 | return cls.from_dict(config_dict)
80 |
81 | def to_dict(self) -> dict:
82 | """Transforms object into a Python dictionnary
83 |
84 | Returns:
85 | (dict): The dictionnary containing all the parameters"""
86 | return asdict(self)
87 |
88 | def to_json_string(self):
89 | """Transforms object into a JSON string
90 |
91 | Returns:
92 | (str): The JSON str containing all the parameters"""
93 | return json.dumps(self.to_dict())
94 |
95 | def save_json(self, file_path: str):
96 | """Saves a ``.json`` file from the dataclass
97 |
98 | Args:
99 | file_path (str): path to the file
100 | """
101 | with open(os.path.join(file_path), "w", encoding="utf-8") as fp:
102 | fp.write(self.to_json_string())
103 |
104 | def save_yaml(self, file_path: str):
105 | """Saves a ``.yaml`` file from the dataclass
106 |
107 | Args:
108 | file_path (str): path to the file
109 | """
110 | with open(os.path.join(file_path), "w", encoding="utf-8") as fp:
111 | yaml.dump(self.to_dict(), fp)
112 |
113 | @classmethod
114 | def from_yaml(cls, yaml_path: str) -> "BaseConfig":
115 | """Creates a BaseConfig instance from a YAML config file
116 |
117 | Args:
118 | yaml_path (str): The path to the yaml file containing all the parameters
119 |
120 | Returns:
121 | :class:`BaseConfig`: The created instance
122 | """
123 | with open(yaml_path, "r") as f:
124 | try:
125 | config_dict = safe_load(f)
126 | except yaml.YAMLError as e:
127 | raise yaml.YAMLError(
128 | f"File {yaml_path} not loadable. Maybe not yaml ? \n"
129 | f"Catch Exception {type(e)} with message: " + str(e)
130 | ) from e
131 |
132 | config_name = config_dict.pop("name")
133 |
134 | if cls.__name__ != config_name:
135 | warnings.warn(
136 | f"You are trying to load a "
137 | f"`{ cls.__name__}` while a "
138 | f"`{config_name}` is given."
139 | )
140 |
141 | return cls.from_dict(config_dict)
142 |
--------------------------------------------------------------------------------
/src/lbm/data/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains a collection of data related classes and functions to train the :mod:`cr.models`.
3 | In a training loop a batch of data is struvtued as a dictionnary on which the modules :mod:`cr.data.datasets`
4 | and :mod:`cr.data.filters` allow to perform several operations.
5 |
6 |
7 | Examples
8 | ########
9 |
10 | Create a DataModule to train a model
11 |
12 | .. code-block::python
13 |
14 | from cr.data import DataModule, DataModuleConfig
15 | from cr.data.filters import KeyFilter, KeyFilterConfig
16 | from cr.data.mappers import KeyRenameMapper, KeyRenameMapperConfig
17 |
18 | # Create the filters and mappers
19 | filters_mappers = [
20 | KeyFilter(KeyFilterConfig(keys=["image", "txt"])),
21 | KeyRenameMapper(
22 | KeyRenameMapperConfig(key_map={"jpg": "image", "txt": "text"})
23 | )
24 | ]
25 |
26 | # Create the DataModule
27 | data_module = DataModule(
28 | train_config=DataModuleConfig(
29 | shards_path_or_urls="your urls or paths",
30 | decoder="pil",
31 | shuffle_buffer_size=100,
32 | per_worker_batch_size=32,
33 | num_workers=4,
34 | ),
35 | train_filters_mappers=filters_mappers,
36 | eval_config=DataModuleConfig(
37 | shards_path_or_urls="your urls or paths",
38 | decoder="pil",
39 | shuffle_buffer_size=100,
40 | per_worker_batch_size=32,
41 | num_workers=4,
42 | ),
43 | eval_filters_mappers=filters_mappers,
44 | )
45 |
46 | # This can then be passed to a :mod:`pytorch_lightning.Trainer` to train a model
47 |
48 |
49 |
50 |
51 |
52 | The :mod:`cr.data` includes the following submodules:
53 |
54 | - :mod:`cr.data.datasets`: a collection of :mod:`pytorch_lightning.LightningDataModule` used to train the models. In particular,
55 | they can used to create the dataloaders and setup the data pipelines.
56 | - :mod:`cr.data.filters`: a collection of filters used apply filters on a training batch of data/
57 |
58 | """
59 |
60 | from .datasets import DataModule
61 |
62 | __all__ = ["DataModule"]
63 |
--------------------------------------------------------------------------------
/src/lbm/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | A collection of :mod:`pytorch_lightning.LightningDataModule` used to train the models. In particular,
3 | they can be used to create the dataloaders and setup the data pipelines.
4 | """
5 |
6 | from .dataset import DataModule
7 | from .datasets_config import DataModuleConfig
8 |
9 | __all__ = ["DataModule", "DataModuleConfig"]
10 |
--------------------------------------------------------------------------------
/src/lbm/data/datasets/collation_fn.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Union
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def custom_collation_fn(
8 | samples: List[Dict[str, Union[int, float, np.ndarray, torch.Tensor]]],
9 | combine_tensors: bool = True,
10 | combine_scalars: bool = True,
11 | ) -> dict:
12 | """
13 | Collate function for PyTorch DataLoader.
14 |
15 | Args:
16 | samples(List[Dict[str, Union[int, float, np.ndarray, torch.Tensor]]]): List of samples.
17 | combine_tensors (bool): Whether to turn lists of tensors into a single tensor.
18 | combine_scalars (bool): Whether to turn lists of scalars into a single ndarray.
19 | """
20 | keys = set.intersection(*[set(sample.keys()) for sample in samples])
21 | batched = {key: [] for key in keys}
22 | for s in samples:
23 | [batched[key].append(s[key]) for key in batched]
24 |
25 | result = {}
26 | for key in batched:
27 | if isinstance(batched[key][0], (int, float)):
28 | if combine_scalars:
29 | result[key] = np.array(list(batched[key]))
30 | elif isinstance(batched[key][0], torch.Tensor):
31 | if combine_tensors:
32 | result[key] = torch.stack(list(batched[key]))
33 | elif isinstance(batched[key][0], np.ndarray):
34 | if combine_tensors:
35 | result[key] = np.array(list(batched[key]))
36 | else:
37 | result[key] = list(batched[key])
38 |
39 | del samples
40 | del batched
41 | return result
42 |
--------------------------------------------------------------------------------
/src/lbm/data/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Union
2 |
3 | import pytorch_lightning as pl
4 | import webdataset as wds
5 | from webdataset import DataPipeline
6 |
7 | from ..filters import BaseFilter, FilterWrapper
8 | from ..mappers import BaseMapper, MapperWrapper
9 | from .collation_fn import custom_collation_fn
10 | from .datasets_config import DataModuleConfig
11 |
12 |
13 | class DataPipeline:
14 | """
15 | DataPipeline class for creating a dataloader from a single configuration
16 |
17 | Args:
18 |
19 | config (DataModuleConfig):
20 | Configuration for the dataset
21 |
22 | filters_mappers (Union[List[Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]]):
23 | List of filters and mappers for the dataset. These will be sequentially applied.
24 |
25 | batched_filters_mappers (List[Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]]):
26 | List of batched transforms for the dataset. These will be sequentially applied.
27 | """
28 |
29 | def __init__(
30 | self,
31 | config: DataModuleConfig,
32 | filters_mappers: List[
33 | Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]
34 | ],
35 | batched_filters_mappers: List[
36 | Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]
37 | ] = None,
38 | ):
39 | self.config = config
40 | self.shards_path_or_urls = config.shards_path_or_urls
41 | self.filters_mappers = filters_mappers
42 | self.batched_filters_mappers = batched_filters_mappers or []
43 |
44 | if filters_mappers is None:
45 | filters_mappers = []
46 |
47 | # set processing pipeline
48 | self.processing_pipeline = [wds.decode(config.decoder, handler=config.handler)]
49 | self.processing_pipeline.extend(
50 | self._add_filters_mappers(
51 | filters_mappers=filters_mappers,
52 | handler=config.handler,
53 | )
54 | )
55 |
56 | def _add_filters_mappers(
57 | self,
58 | filters_mappers: List[
59 | Union[
60 | FilterWrapper,
61 | MapperWrapper,
62 | ]
63 | ],
64 | handler: Callable = wds.warn_and_continue,
65 | ) -> List[Union[FilterWrapper, MapperWrapper]]:
66 | tmp_pipeline = []
67 | for filter_mapper in filters_mappers:
68 | if isinstance(filter_mapper, FilterWrapper) or isinstance(
69 | filter_mapper, BaseFilter
70 | ):
71 | tmp_pipeline.append(wds.select(filter_mapper))
72 | elif isinstance(filter_mapper, MapperWrapper) or isinstance(
73 | filter_mapper, BaseMapper
74 | ):
75 | tmp_pipeline.append(wds.map(filter_mapper, handler=handler))
76 | elif isinstance(filter_mapper) or isinstance(filter_mapper):
77 | tmp_pipeline.append(wds.map(filter_mapper, handler=handler))
78 | else:
79 | raise ValueError("Unknown type of filter/mapper")
80 | return tmp_pipeline
81 |
82 | def setup(self):
83 | pipeline = [wds.SimpleShardList(self.shards_path_or_urls)]
84 |
85 | # shuffle before split by node
86 | if self.config.shuffle_before_split_by_node_buffer_size is not None:
87 | pipeline.append(
88 | wds.shuffle(
89 | self.config.shuffle_before_split_by_node_buffer_size,
90 | handler=self.config.handler,
91 | )
92 | )
93 | # split by node
94 | pipeline.append(wds.split_by_node)
95 |
96 | # shuffle before split by workers
97 | if self.config.shuffle_before_split_by_workers_buffer_size is not None:
98 | pipeline.append(
99 | wds.shuffle(
100 | self.config.shuffle_before_split_by_workers_buffer_size,
101 | handler=self.config.handler,
102 | )
103 | )
104 | # split by worker
105 | pipeline.extend(
106 | [
107 | wds.split_by_worker,
108 | wds.tarfile_to_samples(
109 | handler=self.config.handler,
110 | rename_files=self.config.rename_files_fn,
111 | ),
112 | ]
113 | )
114 |
115 | # shuffle before filter mappers
116 | if self.config.shuffle_before_filter_mappers_buffer_size is not None:
117 | pipeline.append(
118 | wds.shuffle(
119 | self.config.shuffle_before_filter_mappers_buffer_size,
120 | handler=self.config.handler,
121 | )
122 | )
123 |
124 | # apply filters and mappers
125 | pipeline.extend(self.processing_pipeline)
126 |
127 | # shuffle after filter mappers
128 | if self.config.shuffle_after_filter_mappers_buffer_size is not None:
129 | pipeline.append(
130 | wds.shuffle(
131 | self.config.shuffle_after_filter_mappers_buffer_size,
132 | handler=self.config.handler,
133 | ),
134 | )
135 |
136 | # batching
137 | pipeline.append(
138 | wds.batched(
139 | self.config.per_worker_batch_size,
140 | collation_fn=custom_collation_fn,
141 | )
142 | )
143 |
144 | # apply batched transforms
145 | pipeline.extend(
146 | self._add_filters_mappers(
147 | filters_mappers=self.batched_filters_mappers,
148 | handler=self.config.handler,
149 | )
150 | )
151 |
152 | # create the data pipeline
153 | pipeline = wds.DataPipeline(*pipeline, handler=self.config.handler)
154 |
155 | # set the pipeline
156 | self.pipeline = pipeline
157 |
158 | def dataloader(self):
159 | # return the loader
160 | return wds.WebLoader(
161 | self.pipeline,
162 | batch_size=None,
163 | num_workers=self.config.num_workers,
164 | )
165 |
166 |
167 | class DataModule(pl.LightningDataModule):
168 | """
169 | Main DataModule class for creating data loaders and training/evaluating models
170 |
171 | Args:
172 |
173 | train_config (DataModuleConfig):
174 | Configuration for the training dataset
175 |
176 | train_filters_mappers (Union[List[Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]]):
177 | List of filters and mappers for the training dataset. These will be sequentially applied.
178 |
179 | train_batched_filters_mappers (List[Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]]):
180 | List of batched transforms for the training dataset. These will be sequentially applied.
181 |
182 | eval_config (DataModuleConfig):
183 | Configuration for the evaluation dataset
184 |
185 | eval_filters_mappers (List[Union[FilterWrapper, MapperWrapper]]):
186 | List of filters and mappers for the evaluation dataset.These will be sequentially applied.
187 |
188 | eval_batched_filters_mappers (List[Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]]):
189 | List of batched transforms for the evaluation dataset. These will be sequentially applied.
190 | """
191 |
192 | def __init__(
193 | self,
194 | train_config: DataModuleConfig,
195 | train_filters_mappers: List[
196 | Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]
197 | ] = None,
198 | train_batched_filters_mappers: List[
199 | Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]
200 | ] = None,
201 | eval_config: DataModuleConfig = None,
202 | eval_filters_mappers: List[Union[FilterWrapper, MapperWrapper]] = None,
203 | eval_batched_filters_mappers: List[
204 | Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]
205 | ] = None,
206 | ):
207 | super().__init__()
208 |
209 | self.train_config = train_config
210 | self.train_filters_mappers = train_filters_mappers
211 | self.train_batched_filters_mappers = train_batched_filters_mappers
212 |
213 | self.eval_config = eval_config
214 | self.eval_filters_mappers = eval_filters_mappers
215 | self.eval_batched_filters_mappers = eval_batched_filters_mappers
216 |
217 | def setup(self, stage=None):
218 | """
219 | Setup the data module and create the webdataset processing pipelines
220 | """
221 |
222 | # train pipeline
223 | self.train_pipeline = DataPipeline(
224 | config=self.train_config,
225 | filters_mappers=self.train_filters_mappers,
226 | batched_filters_mappers=self.train_batched_filters_mappers,
227 | )
228 | self.train_pipeline.setup()
229 |
230 | # eval pipeline
231 | if self.eval_config is not None:
232 | self.eval_pipeline = DataPipeline(
233 | config=self.eval_config,
234 | filters_mappers=self.eval_filters_mappers,
235 | batched_filters_mappers=self.eval_batched_filters_mappers,
236 | )
237 | self.eval_pipeline.setup()
238 |
239 | def train_dataloader(self):
240 | return self.train_pipeline.dataloader()
241 |
242 | def val_dataloader(self):
243 | return self.eval_pipeline.dataloader()
244 |
--------------------------------------------------------------------------------
/src/lbm/data/datasets/datasets_config.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Union
2 |
3 | import webdataset as wds
4 | from pydantic.dataclasses import dataclass
5 |
6 | from ...config import BaseConfig
7 |
8 |
9 | @dataclass
10 | class DataModuleConfig(BaseConfig):
11 | """
12 | Configuration for the DataModule
13 |
14 | Args:
15 |
16 | shards_path_or_urls (Union[str, List[str]]): The path or url to the shards. Defaults to None.
17 | per_worker_batch_size (int): The batch size for the dataset. Defaults to 16.
18 | num_workers (int): The number of workers to use. Defaults to 1.
19 | shuffle_before_split_by_node_buffer_size (Optional[int]): The buffer size for the shuffle before split by node. Defaults to 100.
20 | shuffle_before_split_by_workers_buffer_size (Optional[int]): The buffer size for the shuffle before split by workers. Defaults to 100.
21 | shuffle_before_filter_mappers_buffer_size (Optional[int]): The buffer size for the shuffle before filter mappers. Defaults to 1000.
22 | shuffle_after_filter_mappers_buffer_size (Optional[int]): The buffer size for the shuffle after filter mappers. Defaults to 1000.
23 | decoder (str): The decoder to use. Defaults to "pil".
24 | handler (Callable): A callable to handle the warnings. Defaults to wds.warn_and_continue.
25 | rename_files_fn (Optional[Callable[[str], str]]): A callable to rename the files. Defaults to None.
26 | """
27 |
28 | shards_path_or_urls: Union[str, List[str]] = None
29 | per_worker_batch_size: int = 16
30 | num_workers: int = 1
31 | shuffle_before_split_by_node_buffer_size: Optional[int] = 100
32 | shuffle_before_split_by_workers_buffer_size: Optional[int] = 100
33 | shuffle_before_filter_mappers_buffer_size: Optional[int] = 1000
34 | shuffle_after_filter_mappers_buffer_size: Optional[int] = 1000
35 | decoder: str = "pil"
36 | handler: Callable = wds.warn_and_continue
37 | rename_files_fn: Optional[Callable[[str], str]] = None
38 |
39 | def __post_init__(self):
40 | super().__post_init__()
41 | if self.rename_files_fn is not None:
42 | assert callable(self.rename_files_fn), "rename_files must be a callable"
43 |
--------------------------------------------------------------------------------
/src/lbm/data/filters/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseFilter
2 | from .filter_wrapper import FilterWrapper
3 | from .filters import KeyFilter
4 | from .filters_config import BaseFilterConfig, KeyFilterConfig
5 |
6 | __all__ = [
7 | "BaseFilter",
8 | "FilterWrapper",
9 | "KeyFilter",
10 | "BaseFilterConfig",
11 | "KeyFilterConfig",
12 | ]
13 |
--------------------------------------------------------------------------------
/src/lbm/data/filters/base.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | from .filters_config import BaseFilterConfig
4 |
5 |
6 | class BaseFilter:
7 | """
8 | Base class for filters. This class should be subclassed to create a new filter.
9 |
10 | Args:
11 |
12 | config (BaseFilterConfig):
13 | Configuration for the filter
14 | """
15 |
16 | def __init__(self, config: BaseFilterConfig):
17 | self.verbose = config.verbose
18 |
19 | def __call__(self, sample: Dict[str, Any]) -> bool:
20 | """This function should be implemented by the subclass"""
21 | raise NotImplementedError
22 |
--------------------------------------------------------------------------------
/src/lbm/data/filters/filter_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Union
2 |
3 | from .base import BaseFilter
4 |
5 |
6 | class FilterWrapper:
7 | """
8 | Wrapper for multiple filters. This class allows to apply multiple filters to a batch of data.
9 | The filters are applied in the order they are passed to the wrapper.
10 |
11 | Args:
12 |
13 | filters (List[BaseFilter]):
14 | List of filters to apply to the batch of data
15 | """
16 |
17 | def __init__(
18 | self,
19 | filters: Union[List[BaseFilter], None] = None,
20 | ):
21 | self.filters = filters
22 |
23 | def __call__(self, batch: Dict[str, Any]) -> None:
24 | """
25 | Forward pass through all filters
26 |
27 | Args:
28 |
29 | batch: batch of data
30 | """
31 | filter_output = True
32 | for filter in self.filters:
33 | filter_output = filter(batch)
34 | if not filter_output:
35 | return False
36 | return True
37 |
--------------------------------------------------------------------------------
/src/lbm/data/filters/filters.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from .base import BaseFilter
4 | from .filters_config import KeyFilterConfig
5 |
6 | logging.basicConfig(level=logging.INFO)
7 |
8 |
9 | class KeyFilter(BaseFilter):
10 | """
11 | This filter checks if ALL the given keys are present in the sample
12 |
13 | Args:
14 |
15 | config (KeyFilterConfig): configuration for the filter
16 | """
17 |
18 | def __init__(self, config: KeyFilterConfig):
19 | super().__init__(config)
20 | keys = config.keys
21 | if isinstance(keys, str):
22 | keys = [keys]
23 |
24 | self.keys = set(keys)
25 |
26 | def __call__(self, batch: dict) -> bool:
27 | try:
28 | res = self.keys.issubset(set(batch.keys()))
29 | return res
30 | except Exception as e:
31 | if self.verbose:
32 | logging.error(f"Error in KeyFilter: {e}")
33 | return False
34 |
--------------------------------------------------------------------------------
/src/lbm/data/filters/filters_config.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | from pydantic.dataclasses import dataclass
4 |
5 | from ...config import BaseConfig
6 |
7 |
8 | @dataclass
9 | class BaseFilterConfig(BaseConfig):
10 | """
11 | Base configuration for filters
12 |
13 | Args:
14 |
15 | verbose (bool):
16 | If True, print debug information. Defaults to False"""
17 |
18 | verbose: bool = False
19 |
20 |
21 | @dataclass
22 | class KeyFilterConfig(BaseFilterConfig):
23 | """
24 | This filter checks if the keys are present in a sample.
25 |
26 | Args:
27 |
28 | keys (Union[str, List[str]]):
29 | Key or list of keys to check. Defaults to "txt"
30 | """
31 |
32 | keys: Union[str, List[str]] = "txt"
33 |
--------------------------------------------------------------------------------
/src/lbm/data/mappers/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseMapper
2 | from .mappers import KeyRenameMapper, RescaleMapper, TorchvisionMapper
3 | from .mappers_config import (
4 | KeyRenameMapperConfig,
5 | RescaleMapperConfig,
6 | TorchvisionMapperConfig,
7 | )
8 | from .mappers_wrapper import MapperWrapper
9 |
10 | __all__ = [
11 | "BaseMapper",
12 | "KeyRenameMapper",
13 | "RescaleMapper",
14 | "TorchvisionMapper",
15 | "KeyRenameMapperConfig",
16 | "RescaleMapperConfig",
17 | "TorchvisionMapperConfig",
18 | "MapperWrapper",
19 | ]
20 |
--------------------------------------------------------------------------------
/src/lbm/data/mappers/base.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | from .mappers_config import BaseMapperConfig
4 |
5 |
6 | class BaseMapper:
7 | """
8 | Base class for the mappers used to modify the samples in the data pipeline.
9 |
10 | Args:
11 |
12 | config (BaseMapperConfig):
13 | Configuration for the mapper.
14 | """
15 |
16 | def __init__(self, config: BaseMapperConfig):
17 | self.config = config
18 | self.key = config.key
19 |
20 | if config.output_key is None:
21 | self.output_key = config.key
22 | else:
23 | self.output_key = config.output_key
24 |
25 | def map(self, batch: Dict[str, Any], *args, **kwargs) -> Dict[str, Any]:
26 | raise NotImplementedError
27 |
--------------------------------------------------------------------------------
/src/lbm/data/mappers/mappers.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | from torchvision import transforms
4 |
5 | from .base import BaseMapper
6 | from .mappers_config import (
7 | KeyRenameMapperConfig,
8 | RescaleMapperConfig,
9 | TorchvisionMapperConfig,
10 | )
11 |
12 |
13 | class KeyRenameMapper(BaseMapper):
14 | """
15 | Rename keys in a sample according to a key map
16 |
17 | Args:
18 |
19 | config (KeyRenameMapperConfig): Configuration for the mapper
20 |
21 | Examples
22 | ########
23 |
24 | 1. Rename keys in a sample according to a key map
25 |
26 | .. code-block:: python
27 |
28 | from cr.data.mappers import KeyRenameMapper, KeyRenameMapperConfig
29 |
30 | config = KeyRenameMapperConfig(
31 | key_map={"old_key": "new_key"}
32 | )
33 |
34 | mapper = KeyRenameMapper(config)
35 |
36 | sample = {"old_key": 1}
37 | new_sample = mapper(sample)
38 | print(new_sample) # {"new_key": 1}
39 |
40 | 2. Rename keys in a sample according to a key map and a condition key
41 |
42 | .. code-block:: python
43 |
44 | from cr.data.mappers import KeyRenameMapper, KeyRenameMapperConfig
45 |
46 | config = KeyRenameMapperConfig(
47 | key_map={"old_key": "new_key"},
48 | condition_key="condition",
49 | condition_fn=lambda x: x == 1
50 | )
51 |
52 | mapper = KeyRenameMapper(config)
53 |
54 | sample = {"old_key": 1, "condition": 1}
55 | new_sample = mapper(sample)
56 | print(new_sample) # {"new_key": 1}
57 |
58 | sample = {"old_key": 1, "condition": 0}
59 | new_sample = mapper(sample)
60 | print(new_sample) # {"old_key": 1}
61 |
62 | ```
63 | """
64 |
65 | def __init__(self, config: KeyRenameMapperConfig):
66 | super().__init__(config)
67 | self.key_map = config.key_map
68 | self.condition_key = config.condition_key
69 | self.condition_fn = config.condition_fn
70 | self.else_key_map = config.else_key_map
71 |
72 | def __call__(self, batch: Dict[str, Any], *args, **kwrags):
73 | if self.condition_key is not None:
74 | condition_key = batch[self.condition_key]
75 | if self.condition_fn(condition_key):
76 | for old_key, new_key in self.key_map.items():
77 | if old_key in batch:
78 | batch[new_key] = batch.pop(old_key)
79 |
80 | elif self.else_key_map is not None:
81 | for old_key, new_key in self.else_key_map.items():
82 | if old_key in batch:
83 | batch[new_key] = batch.pop(old_key)
84 |
85 | else:
86 | for old_key, new_key in self.key_map.items():
87 | if old_key in batch:
88 | batch[new_key] = batch.pop(old_key)
89 | return batch
90 |
91 |
92 | class TorchvisionMapper(BaseMapper):
93 | """
94 | Apply torchvision transforms to a sample
95 |
96 | Args:
97 |
98 | config (TorchvisionMapperConfig): Configuration for the mapper
99 | """
100 |
101 | def __init__(self, config: TorchvisionMapperConfig):
102 | super().__init__(config)
103 | chained_transforms = []
104 | for transform, kwargs in zip(config.transforms, config.transforms_kwargs):
105 | transform = getattr(transforms, transform)
106 | chained_transforms.append(transform(**kwargs))
107 | self.transforms = transforms.Compose(chained_transforms)
108 |
109 | def __call__(self, batch: Dict[str, Any], *args, **kwrags) -> Dict[str, Any]:
110 | if self.key in batch:
111 | batch[self.output_key] = self.transforms(batch[self.key])
112 | return batch
113 |
114 |
115 | class RescaleMapper(BaseMapper):
116 | """
117 | Rescale a sample from [0, 1] to [-1, 1]
118 |
119 | Args:
120 |
121 | config (RescaleMapperConfig): Configuration for the mapper
122 | """
123 |
124 | def __init__(self, config: RescaleMapperConfig):
125 | super().__init__(config)
126 |
127 | def __call__(self, batch: Dict[str, Any], *args, **kwrags) -> Dict[str, Any]:
128 | if isinstance(batch[self.key], list):
129 | tmp = []
130 | for i, image in enumerate(batch[self.key]):
131 | tmp.append(2 * image - 1)
132 | batch[self.output_key] = tmp
133 | else:
134 | batch[self.output_key] = 2 * batch[self.key] - 1
135 | return batch
136 |
--------------------------------------------------------------------------------
/src/lbm/data/mappers/mappers_config.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, List, Optional
2 |
3 | from pydantic.dataclasses import dataclass
4 |
5 | from ...config import BaseConfig
6 |
7 |
8 | @dataclass
9 | class BaseMapperConfig(BaseConfig):
10 | """
11 | Base configuration for mappers.
12 |
13 | Args:
14 |
15 | verbose (bool):
16 | If True, print debug information. Defaults to False
17 |
18 | key (Optional[str]):
19 | Key to apply the mapper to. Defaults to None
20 |
21 | output_key (Optional[str]):
22 | Key to store the output of the mapper. Defaults to None
23 | """
24 |
25 | verbose: bool = False
26 | key: Optional[str] = None
27 | output_key: Optional[str] = None
28 |
29 |
30 | @dataclass
31 | class KeyRenameMapperConfig(BaseMapperConfig):
32 | """
33 | Rename keys in a sample according to a key map
34 |
35 | Args:
36 |
37 | key_map (Dict[str, str]): Dictionary with the old keys as keys and the new keys as values
38 | condition_key (Optional[str]): Key to use for the condition. Defaults to None
39 | condition_fn (Optional[Callable[[Any], bool]]): Function to use for the condition to be met so
40 | the key map is applied. Defaults to None.
41 | else_key_map (Optional[Dict[str, str]]): Dictionary with the old keys as keys and the new keys as values
42 | if the condition is not met. Defaults to None *i.e.* the original key will be used.
43 | """
44 |
45 | key_map: Dict[str, str] = None
46 | condition_key: Optional[str] = None
47 | condition_fn: Optional[Callable[[Any], bool]] = None
48 | else_key_map: Optional[Dict[str, str]] = None
49 |
50 | def __post_init__(self):
51 | super().__post_init__()
52 | assert self.key_map is not None, "key_map should be provided"
53 | assert all(
54 | isinstance(old_key, str) and isinstance(new_key, str)
55 | for old_key, new_key in self.key_map.items()
56 | ), "key_map should be a dictionary with string keys and values"
57 | if self.condition_key is not None:
58 | assert self.condition_fn is not None, "condition_fn should be provided"
59 | assert callable(self.condition_fn), "condition_fn should be callable"
60 | if self.condition_fn is not None:
61 | assert self.condition_key is not None, "condition_key should be provided"
62 | assert isinstance(
63 | self.condition_key, str
64 | ), "condition_key should be a string"
65 | if self.else_key_map is not None:
66 | assert all(
67 | isinstance(old_key, str) and isinstance(new_key, str)
68 | for old_key, new_key in self.else_key_map.items()
69 | ), "else_key_map should be a dictionary with string keys and values"
70 |
71 |
72 | @dataclass
73 | class TorchvisionMapperConfig(BaseMapperConfig):
74 | """
75 | Apply torchvision transforms to a sample
76 |
77 | Args:
78 |
79 | key (str): Key to apply the transforms to
80 | transforms (torchvision.transforms): List of torchvision transforms to apply
81 | transforms_kwargs (Dict[str, Any]): List of kwargs for the transforms
82 | """
83 |
84 | key: str = "image"
85 | transforms: List[str] = None
86 | transforms_kwargs: List[Dict[str, Any]] = None
87 |
88 | def __post_init__(self):
89 | super().__post_init__()
90 | if self.transforms is None:
91 | self.transforms = []
92 | if self.transforms_kwargs is None:
93 | self.transforms_kwargs = []
94 | assert len(self.transforms) == len(
95 | self.transforms_kwargs
96 | ), "Number of transforms and kwargs should be same"
97 |
98 |
99 | @dataclass
100 | class RescaleMapperConfig(BaseMapperConfig):
101 | """
102 | Rescale a sample from [0, 1] to [-1, 1]
103 |
104 | Args:
105 |
106 | key (str): Key to rescale
107 | """
108 |
109 | key: str = "image"
110 |
--------------------------------------------------------------------------------
/src/lbm/data/mappers/mappers_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Union
2 |
3 | from .base import BaseMapper
4 |
5 |
6 | class MapperWrapper:
7 | """
8 | Wrapper for the mappers to allow iterating over several mappers in one go.
9 |
10 | Args:
11 |
12 | mappers (Union[List[BaseMapper], None]): List of mappers to apply to the batch
13 | """
14 |
15 | def __init__(
16 | self,
17 | mappers: Union[List[BaseMapper], None] = None,
18 | ):
19 | self.mappers = mappers
20 |
21 | def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]:
22 | """
23 | Forward pass through all mappers
24 |
25 | Args:
26 |
27 | batch (Dict[str, Any]): batch of data
28 | """
29 | for mapper in self.mappers:
30 | batch = mapper(batch)
31 | return batch
32 |
--------------------------------------------------------------------------------
/src/lbm/inference/__init__.py:
--------------------------------------------------------------------------------
1 | from .inference import evaluate
2 | from .utils import get_model
3 |
4 | __all__ = ["evaluate", "get_model"]
5 |
--------------------------------------------------------------------------------
/src/lbm/inference/inference.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import PIL
4 | import torch
5 | from torchvision.transforms import ToPILImage, ToTensor
6 |
7 | from lbm.models.lbm import LBMModel
8 |
9 | logging.basicConfig(level=logging.INFO)
10 | logger = logging.getLogger(__name__)
11 |
12 | ASPECT_RATIOS = {
13 | str(512 / 2048): (512, 2048),
14 | str(1024 / 1024): (1024, 1024),
15 | str(2048 / 512): (2048, 512),
16 | str(896 / 1152): (896, 1152),
17 | str(1152 / 896): (1152, 896),
18 | str(512 / 1920): (512, 1920),
19 | str(640 / 1536): (640, 1536),
20 | str(768 / 1280): (768, 1280),
21 | str(1280 / 768): (1280, 768),
22 | str(1536 / 640): (1536, 640),
23 | str(1920 / 512): (1920, 512),
24 | }
25 |
26 |
27 | @torch.no_grad()
28 | def evaluate(
29 | model: LBMModel,
30 | source_image: PIL.Image.Image,
31 | num_sampling_steps: int = 1,
32 | ):
33 | """
34 | Evaluate the model on an image coming from the source distribution and generate a new image from the target distribution.
35 |
36 | Args:
37 | model (LBMModel): The model to evaluate.
38 | source_image (PIL.Image.Image): The source image to evaluate the model on.
39 | num_sampling_steps (int): The number of sampling steps to use for the model.
40 |
41 | Returns:
42 | PIL.Image.Image: The generated image.
43 | """
44 |
45 | ori_h_bg, ori_w_bg = source_image.size
46 | ar_bg = ori_h_bg / ori_w_bg
47 | closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
48 | source_dimensions = ASPECT_RATIOS[closest_ar_bg]
49 |
50 | source_image = source_image.resize(source_dimensions)
51 |
52 | img_pasted_tensor = ToTensor()(source_image).unsqueeze(0) * 2 - 1
53 | batch = {
54 | "source_image": img_pasted_tensor.cuda().to(torch.bfloat16),
55 | }
56 |
57 | z_source = model.vae.encode(batch[model.source_key])
58 |
59 | output_image = model.sample(
60 | z=z_source,
61 | num_steps=num_sampling_steps,
62 | conditioner_inputs=batch,
63 | max_samples=1,
64 | ).clamp(-1, 1)
65 |
66 | output_image = (output_image[0].float().cpu() + 1) / 2
67 | output_image = ToPILImage()(output_image)
68 | output_image.resize((ori_h_bg, ori_w_bg))
69 |
70 | return output_image
71 |
--------------------------------------------------------------------------------
/src/lbm/inference/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from typing import List, Optional
4 |
5 | import torch
6 | import yaml
7 | from diffusers import FlowMatchEulerDiscreteScheduler
8 | from huggingface_hub import snapshot_download
9 | from safetensors.torch import load_file
10 |
11 | from lbm.models.embedders import (
12 | ConditionerWrapper,
13 | LatentsConcatEmbedder,
14 | LatentsConcatEmbedderConfig,
15 | )
16 | from lbm.models.lbm import LBMConfig, LBMModel
17 | from lbm.models.unets import DiffusersUNet2DCondWrapper
18 | from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig
19 |
20 |
21 | def get_model(
22 | model_dir: str,
23 | save_dir: Optional[str] = None,
24 | torch_dtype: torch.dtype = torch.bfloat16,
25 | device: str = "cuda",
26 | ) -> LBMModel:
27 | """Download the model from the model directory using either a local path or a path to HuggingFace Hub
28 |
29 | Args:
30 | model_dir (str): The path to the model directory containing the model weights and config, can be a local path or a path to HuggingFace Hub
31 | save_dir (Optional[str]): The local path to save the model if downloading from HuggingFace Hub. Defaults to None.
32 | torch_dtype (torch.dtype): The torch dtype to use for the model. Defaults to torch.bfloat16.
33 | device (str): The device to use for the model. Defaults to "cuda".
34 |
35 | Returns:
36 | LBMModel: The loaded model
37 | """
38 | if not os.path.exists(model_dir):
39 | local_dir = snapshot_download(
40 | model_dir,
41 | local_dir=save_dir,
42 | )
43 | model_dir = local_dir
44 |
45 | model_files = os.listdir(model_dir)
46 |
47 | # check yaml config file is present
48 | yaml_file = [f for f in model_files if f.endswith(".yaml")]
49 | if len(yaml_file) == 0:
50 | raise ValueError("No yaml file found in the model directory.")
51 |
52 | # check safetensors weights file is present
53 | safetensors_files = sorted([f for f in model_files if f.endswith(".safetensors")])
54 | ckpt_files = sorted([f for f in model_files if f.endswith(".ckpt")])
55 | if len(safetensors_files) == 0 and len(ckpt_files) == 0:
56 | raise ValueError("No safetensors or ckpt file found in the model directory")
57 |
58 | if len(model_files) == 0:
59 | raise ValueError("No model files found in the model directory")
60 |
61 | with open(os.path.join(model_dir, yaml_file[0]), "r") as f:
62 | config = yaml.safe_load(f)
63 |
64 | model = _get_model_from_config(**config, torch_dtype=torch_dtype)
65 |
66 | if len(safetensors_files) > 0:
67 | logging.info(f"Loading safetensors file: {safetensors_files[-1]}")
68 | sd = load_file(os.path.join(model_dir, safetensors_files[-1]))
69 | model.load_state_dict(sd, strict=True)
70 | elif len(ckpt_files) > 0:
71 | logging.info(f"Loading ckpt file: {ckpt_files[-1]}")
72 | sd = torch.load(
73 | os.path.join(model_dir, ckpt_files[-1]),
74 | map_location="cpu",
75 | )["state_dict"]
76 | sd = {k[6:]: v for k, v in sd.items() if k.startswith("model.")}
77 | model.load_state_dict(
78 | sd,
79 | strict=True,
80 | )
81 | model.to(device).to(torch_dtype)
82 |
83 | model.eval()
84 |
85 | return model
86 |
87 |
88 | def _get_model_from_config(
89 | backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0",
90 | vae_num_channels: int = 4,
91 | unet_input_channels: int = 4,
92 | timestep_sampling: str = "log_normal",
93 | selected_timesteps: Optional[List[float]] = None,
94 | prob: Optional[List[float]] = None,
95 | conditioning_images_keys: Optional[List[str]] = [],
96 | conditioning_masks_keys: Optional[List[str]] = [],
97 | source_key: str = "source_image",
98 | target_key: str = "source_image_paste",
99 | bridge_noise_sigma: float = 0.0,
100 | logit_mean: float = 0.0,
101 | logit_std: float = 1.0,
102 | pixel_loss_type: str = "lpips",
103 | latent_loss_type: str = "l2",
104 | latent_loss_weight: float = 1.0,
105 | pixel_loss_weight: float = 0.0,
106 | torch_dtype: torch.dtype = torch.bfloat16,
107 | **kwargs,
108 | ):
109 |
110 | conditioners = []
111 |
112 | denoiser = DiffusersUNet2DCondWrapper(
113 | in_channels=unet_input_channels, # Add downsampled_image
114 | out_channels=vae_num_channels,
115 | center_input_sample=False,
116 | flip_sin_to_cos=True,
117 | freq_shift=0,
118 | down_block_types=[
119 | "DownBlock2D",
120 | "CrossAttnDownBlock2D",
121 | "CrossAttnDownBlock2D",
122 | ],
123 | mid_block_type="UNetMidBlock2DCrossAttn",
124 | up_block_types=["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
125 | only_cross_attention=False,
126 | block_out_channels=[320, 640, 1280],
127 | layers_per_block=2,
128 | downsample_padding=1,
129 | mid_block_scale_factor=1,
130 | dropout=0.0,
131 | act_fn="silu",
132 | norm_num_groups=32,
133 | norm_eps=1e-05,
134 | cross_attention_dim=[320, 640, 1280],
135 | transformer_layers_per_block=[1, 2, 10],
136 | reverse_transformer_layers_per_block=None,
137 | encoder_hid_dim=None,
138 | encoder_hid_dim_type=None,
139 | attention_head_dim=[5, 10, 20],
140 | num_attention_heads=None,
141 | dual_cross_attention=False,
142 | use_linear_projection=True,
143 | class_embed_type=None,
144 | addition_embed_type=None,
145 | addition_time_embed_dim=None,
146 | num_class_embeds=None,
147 | upcast_attention=None,
148 | resnet_time_scale_shift="default",
149 | resnet_skip_time_act=False,
150 | resnet_out_scale_factor=1.0,
151 | time_embedding_type="positional",
152 | time_embedding_dim=None,
153 | time_embedding_act_fn=None,
154 | timestep_post_act=None,
155 | time_cond_proj_dim=None,
156 | conv_in_kernel=3,
157 | conv_out_kernel=3,
158 | projection_class_embeddings_input_dim=None,
159 | attention_type="default",
160 | class_embeddings_concat=False,
161 | mid_block_only_cross_attention=None,
162 | cross_attention_norm=None,
163 | addition_embed_type_num_heads=64,
164 | ).to(torch_dtype)
165 |
166 | if conditioning_images_keys != [] or conditioning_masks_keys != []:
167 |
168 | latents_concat_embedder_config = LatentsConcatEmbedderConfig(
169 | image_keys=conditioning_images_keys,
170 | mask_keys=conditioning_masks_keys,
171 | )
172 | latent_concat_embedder = LatentsConcatEmbedder(latents_concat_embedder_config)
173 | latent_concat_embedder.freeze()
174 | conditioners.append(latent_concat_embedder)
175 |
176 | # Wrap conditioners and set to device
177 | conditioner = ConditionerWrapper(
178 | conditioners=conditioners,
179 | )
180 |
181 | ## VAE ##
182 | # Get VAE model
183 | vae_config = AutoencoderKLDiffusersConfig(
184 | version=backbone_signature,
185 | subfolder="vae",
186 | tiling_size=(128, 128),
187 | )
188 | vae = AutoencoderKLDiffusers(vae_config).to(torch_dtype)
189 | vae.freeze()
190 | vae.to(torch_dtype)
191 |
192 | ## Diffusion Model ##
193 | # Get diffusion model
194 | config = LBMConfig(
195 | source_key=source_key,
196 | target_key=target_key,
197 | latent_loss_weight=latent_loss_weight,
198 | latent_loss_type=latent_loss_type,
199 | pixel_loss_type=pixel_loss_type,
200 | pixel_loss_weight=pixel_loss_weight,
201 | timestep_sampling=timestep_sampling,
202 | logit_mean=logit_mean,
203 | logit_std=logit_std,
204 | selected_timesteps=selected_timesteps,
205 | prob=prob,
206 | bridge_noise_sigma=bridge_noise_sigma,
207 | )
208 |
209 | sampling_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
210 | backbone_signature,
211 | subfolder="scheduler",
212 | )
213 |
214 | model = LBMModel(
215 | config,
216 | denoiser=denoiser,
217 | sampling_noise_scheduler=sampling_noise_scheduler,
218 | vae=vae,
219 | conditioner=conditioner,
220 | ).to(torch_dtype)
221 |
222 | return model
223 |
--------------------------------------------------------------------------------
/src/lbm/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/src/lbm/models/__init__.py
--------------------------------------------------------------------------------
/src/lbm/models/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_model import BaseModel
2 | from .model_config import ModelConfig
3 |
4 | __all__ = ["BaseModel", "ModelConfig"]
5 |
--------------------------------------------------------------------------------
/src/lbm/models/base/base_model.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .model_config import ModelConfig
7 |
8 |
9 | class BaseModel(nn.Module):
10 | def __init__(self, config: ModelConfig):
11 | nn.Module.__init__(self)
12 | self.config = config
13 | self.input_key = config.input_key
14 | self.device = torch.device("cpu")
15 | self.dtype = torch.float32
16 |
17 | def on_fit_start(self, device: torch.device | None = None, *args, **kwargs):
18 | """Called when the training starts
19 |
20 | Args:
21 | device (Optional[torch.device], optional): The device to use. Usefull to set
22 | relevant parameters on the model and embedder to the right device only
23 | once at the start of the training. Defaults to None.
24 | """
25 | if device is not None:
26 | self.device = device
27 | self.to(self.device)
28 |
29 | def forward(self, batch: Dict[str, Any], *args, **kwargs):
30 | raise NotImplementedError("forward method is not implemented")
31 |
32 | def freeze(self):
33 | """Freeze the model"""
34 | self.eval()
35 | for param in self.parameters():
36 | param.requires_grad = False
37 |
38 | def to(self, *args, **kwargs):
39 | device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
40 | self = super().to(
41 | device=device,
42 | dtype=dtype,
43 | non_blocking=non_blocking,
44 | )
45 |
46 | if device is not None:
47 | self.device = device
48 | if dtype is not None:
49 | self.dtype = dtype
50 | return self
51 |
52 | def compute_metrics(self, batch: Dict[str, Any], *args, **kwargs):
53 | """Compute the metrics"""
54 | return {}
55 |
56 | def sample(self, batch: Dict[str, Any], *args, **kwargs):
57 | """Sample from the model"""
58 | return {}
59 |
60 | def log_samples(self, batch: Dict[str, Any], *args, **kwargs):
61 | """Log the samples"""
62 | return None
63 |
64 | def on_train_batch_end(self, batch: Dict[str, Any], *args, **kwargs):
65 | """Update the model an optimization is perforned on a batch."""
66 | pass
67 |
--------------------------------------------------------------------------------
/src/lbm/models/base/model_config.py:
--------------------------------------------------------------------------------
1 | from pydantic.dataclasses import dataclass
2 |
3 | from ...config import BaseConfig
4 |
5 |
6 | @dataclass
7 | class ModelConfig(BaseConfig):
8 | input_key: str = "image"
9 |
--------------------------------------------------------------------------------
/src/lbm/models/embedders/__init__.py:
--------------------------------------------------------------------------------
1 | from .conditioners_wrapper import ConditionerWrapper
2 | from .latents_concat import LatentsConcatEmbedder, LatentsConcatEmbedderConfig
3 |
4 | __all__ = ["LatentsConcatEmbedder", "LatentsConcatEmbedderConfig", "ConditionerWrapper"]
5 |
--------------------------------------------------------------------------------
/src/lbm/models/embedders/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_conditioner import BaseConditioner
2 | from .base_conditioner_config import BaseConditionerConfig
3 |
4 | __all__ = ["BaseConditioner", "BaseConditionerConfig"]
5 |
--------------------------------------------------------------------------------
/src/lbm/models/embedders/base/base_conditioner.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Union
2 |
3 | import torch
4 |
5 | from ...base.base_model import BaseModel
6 | from .base_conditioner_config import BaseConditionerConfig
7 |
8 | DIM2CONDITIONING = {
9 | 2: "vector",
10 | 3: "crossattn",
11 | 4: "concat",
12 | }
13 |
14 |
15 | class BaseConditioner(BaseModel):
16 | """This is the base class for all the conditioners. This absctacts the conditioning process
17 |
18 | Args:
19 |
20 | config (BaseConditionerConfig): The configuration of the conditioner
21 |
22 | Examples
23 | ########
24 |
25 | To use the conditioner, you can import the class and use it as follows:
26 |
27 | .. code-block:: python
28 |
29 | from cr.models.embedders import BaseConditioner, BaseConditionerConfig
30 |
31 | # Create the conditioner config
32 | config = BaseConditionerConfig(
33 | input_key="text", # The key for the input
34 | unconditional_conditioning_rate=0.3, # Drops the conditioning with 30% probability during training
35 | )
36 |
37 | # Create the conditioner
38 | conditioner = BaseConditioner(config)
39 | """
40 |
41 | def __init__(self, config: BaseConditionerConfig):
42 | BaseModel.__init__(self, config)
43 | self.config = config
44 | self.input_key = config.input_key
45 | self.dim2outputkey = DIM2CONDITIONING
46 | self.ucg_rate = config.unconditional_conditioning_rate
47 |
48 | def forward(
49 | self, batch: Dict[str, Any], force_zero_embedding: bool = False, *args, **kwargs
50 | ):
51 | """
52 | Forward pass of the embedder.
53 |
54 | Args:
55 |
56 | batch (Dict[str, Any]): A dictionary containing the input data.
57 | force_zero_embedding (bool): Whether to force zero embedding.
58 | This will return an embedding with all entries set to 0. Defaults to False.
59 | """
60 | raise NotImplementedError("Forward pass must be implemented in child class")
61 |
--------------------------------------------------------------------------------
/src/lbm/models/embedders/base/base_conditioner_config.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from pydantic.dataclasses import dataclass
4 |
5 | from ....config import BaseConfig
6 |
7 |
8 | @dataclass
9 | class BaseConditionerConfig(BaseConfig):
10 | """This is the ClipEmbedderConfig class which defines all the useful parameters to instantiate the model
11 |
12 | Args:
13 |
14 | input_key (str): The key for the input. Defaults to "text".
15 | unconditional_conditioning_rate (float): Drops the conditioning with this probability during training. Defaults to 0.0.
16 | """
17 |
18 | input_key: str = "text"
19 | unconditional_conditioning_rate: float = 0.0
20 |
21 | def __post_init__(self):
22 | super().__post_init__()
23 |
24 | assert (
25 | self.unconditional_conditioning_rate >= 0.0
26 | and self.unconditional_conditioning_rate <= 1.0
27 | ), "Unconditional conditioning rate should be between 0 and 1"
28 |
--------------------------------------------------------------------------------
/src/lbm/models/embedders/conditioners_wrapper.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any, Dict, List, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .base import BaseConditioner
8 |
9 | KEY2CATDIM = {
10 | "vector": 1,
11 | "crossattn": 2,
12 | "concat": 1,
13 | }
14 |
15 |
16 | class ConditionerWrapper(nn.Module):
17 | """
18 | Wrapper for conditioners. This class allows to apply multiple conditioners in a single forward pass.
19 |
20 | Args:
21 |
22 | conditioners (List[BaseConditioner]): List of conditioners to apply in the forward pass.
23 | """
24 |
25 | def __init__(
26 | self,
27 | conditioners: Union[List[BaseConditioner], None] = None,
28 | ):
29 | nn.Module.__init__(self)
30 | self.conditioners = nn.ModuleList(conditioners)
31 | self.device = torch.device("cpu")
32 | self.dtype = torch.float32
33 |
34 | def conditioner_sanity_check(self):
35 | cond_input_keys = []
36 | for conditioner in self.conditioners:
37 | cond_input_keys.append(conditioner.input_key)
38 |
39 | assert all([key in set(cond_input_keys) for key in self.ucg_keys])
40 |
41 | def on_fit_start(self, device: torch.device | None = None, *args, **kwargs):
42 | """Called when the training starts"""
43 | for conditioner in self.conditioners:
44 | conditioner.on_fit_start(device=device, *args, **kwargs)
45 |
46 | def forward(
47 | self,
48 | batch: Dict[str, Any],
49 | ucg_keys: List[str] = None,
50 | set_ucg_rate_zero=False,
51 | *args,
52 | **kwargs,
53 | ):
54 | """
55 | Forward pass through all conditioners
56 |
57 | Args:
58 |
59 | batch: batch of data
60 | ucg_keys: keys to use for ucg. This will force zero conditioning in all the
61 | conditioners that have input_keys in ucg_keys
62 | set_ucg_rate_zero: set the ucg rate to zero for all the conditioners except the ones in ucg_keys
63 |
64 | Returns:
65 |
66 | Dict[str, Any]: The output of the conditioner. The output of the conditioner is a dictionary with the main key "cond" and value
67 | is a dictionary with the keys as the type of conditioning and the value as the conditioning tensor.
68 | """
69 | if ucg_keys is None:
70 | ucg_keys = []
71 | wrapper_outputs = dict(cond={})
72 | for conditioner in self.conditioners:
73 | if conditioner.input_key in ucg_keys:
74 | force_zero_embedding = True
75 | elif conditioner.ucg_rate > 0 and not set_ucg_rate_zero:
76 | force_zero_embedding = bool(torch.rand(1) < conditioner.ucg_rate)
77 | else:
78 | force_zero_embedding = False
79 |
80 | conditioner_output = conditioner.forward(
81 | batch, force_zero_embedding=force_zero_embedding, *args, **kwargs
82 | )
83 | logging.debug(
84 | f"conditioner:{conditioner.__class__.__name__}, input_key:{conditioner.input_key}, force_ucg_zero_embedding:{force_zero_embedding}"
85 | )
86 | for key in conditioner_output:
87 | logging.debug(
88 | f"conditioner_output:{key}:{conditioner_output[key].shape}"
89 | )
90 | if key in wrapper_outputs["cond"]:
91 | wrapper_outputs["cond"][key] = torch.cat(
92 | [wrapper_outputs["cond"][key], conditioner_output[key]],
93 | KEY2CATDIM[key],
94 | )
95 | else:
96 | wrapper_outputs["cond"][key] = conditioner_output[key]
97 |
98 | return wrapper_outputs
99 |
100 | def to(self, *args, **kwargs):
101 | """
102 | Move all conditioners to device and dtype
103 | """
104 | device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
105 | self = super().to(device=device, dtype=dtype, non_blocking=non_blocking)
106 | for conditioner in self.conditioners:
107 | conditioner.to(device=device, dtype=dtype, non_blocking=non_blocking)
108 |
109 | if device is not None:
110 | self.device = device
111 | if dtype is not None:
112 | self.dtype = dtype
113 |
114 | return self
115 |
--------------------------------------------------------------------------------
/src/lbm/models/embedders/latents_concat/__init__.py:
--------------------------------------------------------------------------------
1 | from .latents_concat_embedder_config import LatentsConcatEmbedderConfig
2 | from .latents_concat_embedder_model import LatentsConcatEmbedder
3 |
4 | __all__ = ["LatentsConcatEmbedder", "LatentsConcatEmbedderConfig"]
5 |
--------------------------------------------------------------------------------
/src/lbm/models/embedders/latents_concat/latents_concat_embedder_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import field
2 | from typing import List, Union
3 |
4 | from pydantic.dataclasses import dataclass
5 |
6 | from ..base import BaseConditionerConfig
7 |
8 |
9 | @dataclass
10 | class LatentsConcatEmbedderConfig(BaseConditionerConfig):
11 | """
12 | Configs for the LatentsConcatEmbedder embedder
13 |
14 | Args:
15 | image_keys (Union[List[str], None]): Keys of the images to compute the VAE embeddings
16 | mask_keys (Union[List[str], None]): Keys of the masks to resize
17 | """
18 |
19 | image_keys: Union[List[str], None] = field(default_factory=lambda: ["image"])
20 | mask_keys: Union[List[str], None] = field(default_factory=lambda: ["mask"])
21 |
22 | def __post_init__(self):
23 | super().__post_init__()
24 |
25 | # Make sure that at least one of the image_keys or mask_keys is provided
26 | assert (self.image_keys is not None) or (
27 | self.mask_keys is not None
28 | ), "At least one of the image_keys or mask_keys must be provided."
29 |
30 | self.image_keys = self.image_keys if self.image_keys is not None else []
31 | self.mask_keys = self.mask_keys if self.mask_keys is not None else []
32 |
--------------------------------------------------------------------------------
/src/lbm/models/embedders/latents_concat/latents_concat_embedder_model.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | import torch
4 | import torchvision.transforms.functional as F
5 |
6 | from lbm.models.vae import AutoencoderKLDiffusers
7 |
8 | from ..base import BaseConditioner
9 | from .latents_concat_embedder_config import LatentsConcatEmbedderConfig
10 |
11 |
12 | class LatentsConcatEmbedder(BaseConditioner):
13 | """
14 | Class computing VAE embeddings from given images and resizing the masks.
15 | Then outputs are then concatenated to the noise in the latent space.
16 |
17 | Args:
18 | config (LatentsConcatEmbedderConfig): Configs to create the embedder
19 | """
20 |
21 | def __init__(self, config: LatentsConcatEmbedderConfig):
22 | BaseConditioner.__init__(self, config)
23 |
24 | def forward(
25 | self, batch: Dict[str, Any], vae: AutoencoderKLDiffusers, *args, **kwargs
26 | ) -> dict:
27 | """
28 | Args:
29 | batch (dict): A batch of images to be processed by this embedder. In the batch,
30 | the images must range between [-1, 1] and the masks range between [0, 1].
31 | vae (AutoencoderKLDiffusers): VAE
32 |
33 | Returns:
34 | output (dict): outputs
35 | """
36 |
37 | # Check if image are of the same size
38 | dims_list = []
39 | for image_key in self.config.image_keys:
40 | dims_list.append(batch[image_key].shape[-2:])
41 | for mask_key in self.config.mask_keys:
42 | dims_list.append(batch[mask_key].shape[-2:])
43 | assert all(
44 | dims == dims_list[0] for dims in dims_list
45 | ), "All images and masks must have the same dimensions."
46 |
47 | # Find the latent dimensions
48 | if len(self.config.image_keys) > 0:
49 | latent_dims = (
50 | batch[self.config.image_keys[0]].shape[-2] // vae.downsampling_factor,
51 | batch[self.config.image_keys[0]].shape[-1] // vae.downsampling_factor,
52 | )
53 | else:
54 | latent_dims = (
55 | batch[self.config.mask_keys[0]].shape[-2] // vae.downsampling_factor,
56 | batch[self.config.mask_keys[0]].shape[-1] // vae.downsampling_factor,
57 | )
58 |
59 | outputs = []
60 |
61 | # Resize the masks and concat them
62 | for mask_key in self.config.mask_keys:
63 | curr_latents = F.resize(
64 | batch[mask_key],
65 | size=latent_dims,
66 | interpolation=F.InterpolationMode.BILINEAR,
67 | )
68 | outputs.append(curr_latents)
69 |
70 | # Compute VAE embeddings from the images
71 | for image_key in self.config.image_keys:
72 | vae_embs = vae.encode(batch[image_key])
73 | outputs.append(vae_embs)
74 |
75 | # Concat all the outputs
76 | outputs = torch.concat(outputs, dim=1)
77 |
78 | outputs = {self.dim2outputkey[outputs.dim()]: outputs}
79 |
80 | return outputs
81 |
--------------------------------------------------------------------------------
/src/lbm/models/lbm/__init__.py:
--------------------------------------------------------------------------------
1 | from .lbm_config import LBMConfig
2 | from .lbm_model import LBMModel
3 |
4 | __all__ = ["LBMModel", "LBMConfig"]
5 |
--------------------------------------------------------------------------------
/src/lbm/models/lbm/lbm_config.py:
--------------------------------------------------------------------------------
1 | from typing import List, Literal, Optional, Tuple
2 |
3 | from pydantic.dataclasses import dataclass
4 |
5 | from ..base import ModelConfig
6 |
7 |
8 | @dataclass
9 | class LBMConfig(ModelConfig):
10 | """This is the Config for LBM Model class which defines all the useful parameters to be used in the model.
11 |
12 | Args:
13 |
14 | source_key (str):
15 | Key for the source image. Defaults to "source_image"
16 |
17 | target_key (str):
18 | Key for the target image. Defaults to "target_image"
19 |
20 | mask_key (Optional[str]):
21 | Key for the mask showing the valid pixels. Defaults to None
22 |
23 | latent_loss_type (str):
24 | Loss type to use. Defaults to "l2". Choices are "l2", "l1"
25 |
26 | pixel_loss_type (str):
27 | Pixel loss type to use. Defaults to "l2". Choices are "l2", "l1", "lpips"
28 |
29 | pixel_loss_max_size (int):
30 | Maximum size of the image for pixel loss.
31 | The image will be cropped to this size to reduce decoding computation cost. Defaults to 512
32 |
33 | pixel_loss_weight (float):
34 | Weight of the pixel loss. Defaults to 0.0
35 |
36 | timestep_sampling (str):
37 | Timestep sampling to use. Defaults to "uniform". Choices are "uniform"
38 |
39 | input_key (str):
40 | Key for the input. Defaults to "image"
41 |
42 | controlnet_input_key (str):
43 | Key for the controlnet conditioning. Defaults to "controlnet_conditioning"
44 |
45 | adapter_input_key (str):
46 | Key for the adapter conditioning. Defaults to "adapter_conditioning"
47 |
48 | ucg_keys (Optional[List[str]]):
49 | List of keys for which we enforce zero_conditioning during Classifier-free guidance. Defaults to None
50 |
51 | prediction_type (str):
52 | Type of prediction to use. Defaults to "epsilon". Choices are "epsilon", "v_prediction", "flow
53 |
54 | logit_mean (Optional[float]):
55 | Mean of the logit for the log normal distribution. Defaults to 0.0
56 |
57 | logit_std (Optional[float]):
58 | Standard deviation of the logit for the log normal distribution. Defaults to 1.0
59 |
60 | guidance_scale (Optional[float]):
61 | The guidance scale. Useful for finetunning guidance distilled diffusion models. Defaults to None
62 |
63 | selected_timesteps (Optional[List[float]]):
64 | List of selected timesteps to be sampled from if using `custom_timesteps` timestep sampling. Defaults to None
65 |
66 | prob (Optional[List[float]]):
67 | List of probabilities for the selected timesteps if using `custom_timesteps` timestep sampling. Defaults to None
68 | """
69 |
70 | source_key: str = "source_image"
71 | target_key: str = "target_image"
72 | mask_key: Optional[str] = None
73 | latent_loss_weight: float = 1.0
74 | latent_loss_type: Literal["l2", "l1"] = "l2"
75 | pixel_loss_type: Literal["l2", "l1", "lpips"] = "l2"
76 | pixel_loss_max_size: int = 512
77 | pixel_loss_weight: float = 0.0
78 | timestep_sampling: Literal["uniform", "log_normal", "custom_timesteps"] = "uniform"
79 | logit_mean: Optional[float] = 0.0
80 | logit_std: Optional[float] = 1.0
81 | selected_timesteps: Optional[List[float]] = None
82 | prob: Optional[List[float]] = None
83 | bridge_noise_sigma: float = 0.001
84 |
85 | def __post_init__(self):
86 | super().__post_init__()
87 | if self.timestep_sampling == "log_normal":
88 | assert isinstance(self.logit_mean, float) and isinstance(
89 | self.logit_std, float
90 | ), "logit_mean and logit_std should be float for log_normal timestep sampling"
91 |
92 | if self.timestep_sampling == "custom_timesteps":
93 | assert isinstance(self.selected_timesteps, list) and isinstance(
94 | self.prob, list
95 | ), "timesteps and prob should be list for custom_timesteps timestep sampling"
96 | assert len(self.selected_timesteps) == len(
97 | self.prob
98 | ), "timesteps and prob should be of same length for custom_timesteps timestep sampling"
99 | assert (
100 | sum(self.prob) == 1
101 | ), "prob should sum to 1 for custom_timesteps timestep sampling"
102 |
--------------------------------------------------------------------------------
/src/lbm/models/lbm/lbm_model.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple, Union
2 |
3 | import lpips
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
8 | from tqdm import tqdm
9 |
10 | from ..base.base_model import BaseModel
11 | from ..embedders import ConditionerWrapper
12 | from ..unets import DiffusersUNet2DCondWrapper, DiffusersUNet2DWrapper
13 | from ..vae import AutoencoderKLDiffusers
14 | from .lbm_config import LBMConfig
15 |
16 |
17 | class LBMModel(BaseModel):
18 | """This is the LBM class which defines the model.
19 |
20 | Args:
21 |
22 | config (LBMConfig):
23 | Configuration for the model
24 |
25 | denoiser (Union[DiffusersUNet2DWrapper, DiffusersTransformer2DWrapper]):
26 | Denoiser to use for the diffusion model. Defaults to None
27 |
28 | training_noise_scheduler (EulerDiscreteScheduler):
29 | Noise scheduler to use for training. Defaults to None
30 |
31 | sampling_noise_scheduler (EulerDiscreteScheduler):
32 | Noise scheduler to use for sampling. Defaults to None
33 |
34 | vae (AutoencoderKLDiffusers):
35 | VAE to use for the diffusion model. Defaults to None
36 |
37 | conditioner (ConditionerWrapper):
38 | Conditioner to use for the diffusion model. Defaults to None
39 | """
40 |
41 | @classmethod
42 | def load_from_config(cls, config: LBMConfig):
43 | return cls(config=config)
44 |
45 | def __init__(
46 | self,
47 | config: LBMConfig,
48 | denoiser: Union[
49 | DiffusersUNet2DWrapper,
50 | DiffusersUNet2DCondWrapper,
51 | ] = None,
52 | training_noise_scheduler: FlowMatchEulerDiscreteScheduler = None,
53 | sampling_noise_scheduler: FlowMatchEulerDiscreteScheduler = None,
54 | vae: AutoencoderKLDiffusers = None,
55 | conditioner: ConditionerWrapper = None,
56 | ):
57 | BaseModel.__init__(self, config)
58 |
59 | self.vae = vae
60 | self.denoiser = denoiser
61 | self.conditioner = conditioner
62 | self.sampling_noise_scheduler = sampling_noise_scheduler
63 | self.training_noise_scheduler = training_noise_scheduler
64 | self.timestep_sampling = config.timestep_sampling
65 | self.latent_loss_type = config.latent_loss_type
66 | self.latent_loss_weight = config.latent_loss_weight
67 | self.pixel_loss_type = config.pixel_loss_type
68 | self.pixel_loss_max_size = config.pixel_loss_max_size
69 | self.pixel_loss_weight = config.pixel_loss_weight
70 | self.logit_mean = config.logit_mean
71 | self.logit_std = config.logit_std
72 | self.prob = config.prob
73 | self.selected_timesteps = config.selected_timesteps
74 | self.source_key = config.source_key
75 | self.target_key = config.target_key
76 | self.mask_key = config.mask_key
77 | self.bridge_noise_sigma = config.bridge_noise_sigma
78 |
79 | self.num_iterations = nn.Parameter(
80 | torch.tensor(0, dtype=torch.float32), requires_grad=False
81 | )
82 | if self.pixel_loss_type == "lpips" and self.pixel_loss_weight > 0:
83 | self.lpips_loss = lpips.LPIPS(net="vgg")
84 |
85 | else:
86 | self.lpips_loss = None
87 |
88 | def on_fit_start(self, device: torch.device | None = None, *args, **kwargs):
89 | """Called when the training starts"""
90 | super().on_fit_start(device=device, *args, **kwargs)
91 | if self.vae is not None:
92 | self.vae.on_fit_start(device=device, *args, **kwargs)
93 | if self.conditioner is not None:
94 | self.conditioner.on_fit_start(device=device, *args, **kwargs)
95 |
96 | def forward(self, batch: Dict[str, Any], step=0, batch_idx=0, *args, **kwargs):
97 |
98 | self.num_iterations += 1
99 |
100 | # Get inputs/latents
101 | if self.vae is not None:
102 | vae_inputs = batch[self.target_key]
103 | z = self.vae.encode(vae_inputs)
104 | downsampling_factor = self.vae.downsampling_factor
105 | else:
106 | z = batch[self.target_key]
107 | downsampling_factor = 1
108 |
109 | if self.mask_key in batch:
110 | valid_mask = batch[self.mask_key].bool()[:, 0, :, :].unsqueeze(1)
111 | invalid_mask = ~valid_mask
112 | valid_mask_for_latent = ~torch.max_pool2d(
113 | invalid_mask.float(),
114 | downsampling_factor,
115 | downsampling_factor,
116 | ).bool()
117 | valid_mask_for_latent = valid_mask_for_latent.repeat((1, z.shape[1], 1, 1))
118 |
119 | else:
120 | valid_mask = torch.ones_like(batch[self.target_key]).bool()
121 | valid_mask_for_latent = torch.ones_like(z).bool()
122 |
123 | source_image = batch[self.source_key]
124 | source_image = torch.nn.functional.interpolate(
125 | source_image,
126 | size=batch[self.target_key].shape[-2:],
127 | mode="bilinear",
128 | align_corners=False,
129 | ).to(z.dtype)
130 | if self.vae is not None:
131 | z_source = self.vae.encode(source_image)
132 |
133 | else:
134 | z_source = source_image
135 |
136 | # Get conditionings
137 | conditioning = self._get_conditioning(batch, *args, **kwargs)
138 |
139 | # Sample a timestep
140 | timestep = self._timestep_sampling(n_samples=z.shape[0], device=z.device)
141 | sigmas = None
142 |
143 | # Create interpolant
144 | sigmas = self._get_sigmas(
145 | self.training_noise_scheduler, timestep, n_dim=4, device=z.device
146 | )
147 | noisy_sample = (
148 | sigmas * z_source
149 | + (1.0 - sigmas) * z
150 | + self.bridge_noise_sigma
151 | * (sigmas * (1.0 - sigmas)) ** 0.5
152 | * torch.randn_like(z)
153 | )
154 |
155 | for i, t in enumerate(timestep):
156 | if t.item() == self.training_noise_scheduler.timesteps[0]:
157 | noisy_sample[i] = z_source[i]
158 |
159 | # Predict noise level using denoiser
160 | prediction = self.denoiser(
161 | sample=noisy_sample,
162 | timestep=timestep,
163 | conditioning=conditioning,
164 | *args,
165 | **kwargs,
166 | )
167 |
168 | target = z_source - z
169 | denoised_sample = noisy_sample - prediction * sigmas
170 | target_pixels = batch[self.target_key]
171 |
172 | # Compute loss
173 | if self.latent_loss_weight > 0:
174 | loss = self.latent_loss(prediction, target.detach(), valid_mask_for_latent)
175 | latent_recon_loss = loss.mean()
176 |
177 | else:
178 | loss = torch.zeros(z.shape[0], device=z.device)
179 | latent_recon_loss = torch.zeros_like(loss)
180 |
181 | if self.pixel_loss_weight > 0:
182 | denoised_sample = self._predicted_x_0(
183 | model_output=prediction,
184 | sample=noisy_sample,
185 | sigmas=sigmas,
186 | )
187 | pixel_loss = self.pixel_loss(
188 | denoised_sample, target_pixels.detach(), valid_mask
189 | )
190 | loss += self.pixel_loss_weight * pixel_loss
191 |
192 | else:
193 | pixel_loss = torch.zeros_like(latent_recon_loss)
194 |
195 | return {
196 | "loss": loss.mean(),
197 | "latent_recon_loss": latent_recon_loss,
198 | "pixel_recon_loss": pixel_loss.mean(),
199 | "predicted_hr": denoised_sample,
200 | "noisy_sample": noisy_sample,
201 | }
202 |
203 | def latent_loss(self, prediction, model_input, valid_latent_mask):
204 | if self.latent_loss_type == "l2":
205 | return torch.mean(
206 | (
207 | (prediction * valid_latent_mask - model_input * valid_latent_mask)
208 | ** 2
209 | ).reshape(model_input.shape[0], -1),
210 | 1,
211 | )
212 | elif self.latent_loss_type == "l1":
213 | return torch.mean(
214 | torch.abs(
215 | prediction * valid_latent_mask - model_input * valid_latent_mask
216 | ).reshape(model_input.shape[0], -1),
217 | 1,
218 | )
219 | else:
220 | raise NotImplementedError(
221 | f"Loss type {self.latent_loss_type} not implemented"
222 | )
223 |
224 | def pixel_loss(self, prediction, model_input, valid_mask):
225 |
226 | latent_crop = self.pixel_loss_max_size // self.vae.downsampling_factor
227 | input_crop = self.pixel_loss_max_size
228 |
229 | crop_h = max((prediction.shape[2] - latent_crop), 0)
230 | crop_w = max((prediction.shape[3] - latent_crop), 0)
231 |
232 | input_crop_h = max((model_input.shape[2] - self.pixel_loss_max_size), 0)
233 | input_crop_w = max((model_input.shape[3] - self.pixel_loss_max_size), 0)
234 |
235 | # image random cropping
236 | if crop_h == 0:
237 | offset_h = 0
238 | else:
239 | offset_h = torch.randint(0, crop_h, (1,)).item()
240 |
241 | if crop_w == 0:
242 | offset_w = 0
243 | else:
244 | offset_w = torch.randint(0, crop_w, (1,)).item()
245 | input_offset_h = offset_h * self.vae.downsampling_factor
246 | input_offset_w = offset_w * self.vae.downsampling_factor
247 |
248 | prediction = prediction[
249 | :,
250 | :,
251 | crop_h
252 | - offset_h : min(crop_h - offset_h + latent_crop, prediction.shape[2]),
253 | crop_w
254 | - offset_w : min(crop_w - offset_w + latent_crop, prediction.shape[3]),
255 | ]
256 |
257 | model_input = model_input[
258 | :,
259 | :,
260 | input_crop_h
261 | - input_offset_h : min(
262 | input_crop_h - input_offset_h + input_crop, model_input.shape[2]
263 | ),
264 | input_crop_w
265 | - input_offset_w : min(
266 | input_crop_w - input_offset_w + input_crop, model_input.shape[3]
267 | ),
268 | ]
269 |
270 | valid_mask = valid_mask[
271 | :,
272 | :,
273 | input_crop_h
274 | - input_offset_h : min(
275 | input_crop_h - input_offset_h + input_crop, valid_mask.shape[2]
276 | ),
277 | input_crop_w
278 | - input_offset_w : min(
279 | input_crop_w - input_offset_w + input_crop, valid_mask.shape[3]
280 | ),
281 | ]
282 |
283 | decoded_prediction = self.vae.decode(prediction).clamp(-1, 1)
284 |
285 | if self.pixel_loss_type == "l2":
286 | return torch.mean(
287 | (
288 | (decoded_prediction * valid_mask - model_input * valid_mask) ** 2
289 | ).reshape(model_input.shape[0], -1),
290 | 1,
291 | )
292 |
293 | elif self.pixel_loss_type == "l1":
294 | return torch.mean(
295 | torch.abs(
296 | decoded_prediction * valid_mask - model_input * valid_mask
297 | ).reshape(model_input.shape[0], -1),
298 | 1,
299 | )
300 |
301 | elif self.pixel_loss_type == "lpips":
302 | return self.lpips_loss(
303 | decoded_prediction * valid_mask, model_input * valid_mask
304 | ).mean()
305 |
306 | def _get_conditioning(
307 | self,
308 | batch: Dict[str, Any],
309 | ucg_keys: List[str] = None,
310 | set_ucg_rate_zero=False,
311 | *args,
312 | **kwargs,
313 | ):
314 | """
315 | Get the conditionings
316 | """
317 | if self.conditioner is not None:
318 | return self.conditioner(
319 | batch,
320 | ucg_keys=ucg_keys,
321 | set_ucg_rate_zero=set_ucg_rate_zero,
322 | vae=self.vae,
323 | *args,
324 | **kwargs,
325 | )
326 | else:
327 | return None
328 |
329 | def _timestep_sampling(self, n_samples=1, device="cpu"):
330 | if self.timestep_sampling == "uniform":
331 | idx = torch.randint(
332 | 0,
333 | self.training_noise_scheduler.config.num_train_timesteps,
334 | (n_samples,),
335 | device="cpu",
336 | )
337 | return self.training_noise_scheduler.timesteps[idx].to(device=device)
338 |
339 | elif self.timestep_sampling == "log_normal":
340 | u = torch.normal(
341 | mean=self.logit_mean,
342 | std=self.logit_std,
343 | size=(n_samples,),
344 | device="cpu",
345 | )
346 | u = torch.nn.functional.sigmoid(u)
347 | indices = (
348 | u * self.training_noise_scheduler.config.num_train_timesteps
349 | ).long()
350 | return self.training_noise_scheduler.timesteps[indices].to(device=device)
351 |
352 | elif self.timestep_sampling == "custom_timesteps":
353 | idx = np.random.choice(len(self.selected_timesteps), n_samples, p=self.prob)
354 |
355 | return torch.tensor(
356 | self.selected_timesteps, device=device, dtype=torch.long
357 | )[idx]
358 |
359 | def _predicted_x_0(
360 | self,
361 | model_output,
362 | sample,
363 | sigmas=None,
364 | ):
365 | """
366 | Predict x_0, the orinal denoised sample, using the model output and the timesteps depending on the prediction type.
367 | """
368 | pred_x_0 = sample - model_output * sigmas
369 | return pred_x_0
370 |
371 | def _get_sigmas(
372 | self, scheduler, timesteps, n_dim=4, dtype=torch.float32, device="cpu"
373 | ):
374 | sigmas = scheduler.sigmas.to(device=device, dtype=dtype)
375 | schedule_timesteps = scheduler.timesteps.to(device)
376 | timesteps = timesteps.to(device)
377 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
378 |
379 | sigma = sigmas[step_indices].flatten()
380 | while len(sigma.shape) < n_dim:
381 | sigma = sigma.unsqueeze(-1)
382 | return sigma
383 |
384 | @torch.no_grad()
385 | def sample(
386 | self,
387 | z: torch.Tensor,
388 | num_steps: int = 20,
389 | conditioner_inputs: Optional[Dict[str, Any]] = None,
390 | max_samples: Optional[int] = None,
391 | verbose: bool = False,
392 | ):
393 | self.sampling_noise_scheduler.set_timesteps(
394 | sigmas=np.linspace(1, 1 / num_steps, num_steps)
395 | )
396 |
397 | sample = z
398 |
399 | # Get conditioning
400 | conditioning = self._get_conditioning(
401 | conditioner_inputs, set_ucg_rate_zero=True, device=z.device
402 | )
403 |
404 | # If max_samples parameter is provided, limit the number of samples
405 | if max_samples is not None:
406 | sample = sample[:max_samples]
407 |
408 | if conditioning:
409 | conditioning["cond"] = {
410 | k: v[:max_samples] for k, v in conditioning["cond"].items()
411 | }
412 |
413 | for i, t in tqdm(
414 | enumerate(self.sampling_noise_scheduler.timesteps), disable=not verbose
415 | ):
416 | if hasattr(self.sampling_noise_scheduler, "scale_model_input"):
417 | denoiser_input = self.sampling_noise_scheduler.scale_model_input(
418 | sample, t
419 | )
420 |
421 | else:
422 | denoiser_input = sample
423 |
424 | # Predict noise level using denoiser using conditionings
425 | pred = self.denoiser(
426 | sample=denoiser_input,
427 | timestep=t.to(z.device).repeat(denoiser_input.shape[0]),
428 | conditioning=conditioning,
429 | )
430 |
431 | # Make one step on the reverse diffusion process
432 | sample = self.sampling_noise_scheduler.step(
433 | pred, t, sample, return_dict=False
434 | )[0]
435 | if i < len(self.sampling_noise_scheduler.timesteps) - 1:
436 | timestep = (
437 | self.sampling_noise_scheduler.timesteps[i + 1]
438 | .to(z.device)
439 | .repeat(sample.shape[0])
440 | )
441 | sigmas = self._get_sigmas(
442 | self.sampling_noise_scheduler, timestep, n_dim=4, device=z.device
443 | )
444 | sample = sample + self.bridge_noise_sigma * (
445 | sigmas * (1.0 - sigmas)
446 | ) ** 0.5 * torch.randn_like(sample)
447 | sample = sample.to(z.dtype)
448 |
449 | if self.vae is not None:
450 | decoded_sample = self.vae.decode(sample)
451 |
452 | else:
453 | decoded_sample = sample
454 |
455 | return decoded_sample
456 |
457 | def log_samples(
458 | self,
459 | batch: Dict[str, Any],
460 | input_shape: Optional[Tuple[int, int, int]] = None,
461 | max_samples: Optional[int] = None,
462 | num_steps: Union[int, List[int]] = 20,
463 | ):
464 | if isinstance(num_steps, int):
465 | num_steps = [num_steps]
466 |
467 | logs = {}
468 |
469 | N = max_samples if max_samples is not None else len(batch[self.source_key])
470 |
471 | batch = {k: v[:N] for k, v in batch.items()}
472 |
473 | # infer input shape based on VAE configuration if not passed
474 | if input_shape is None:
475 | if self.vae is not None:
476 | # get input pixel size of the vae
477 | input_shape = batch[self.target_key].shape[2:]
478 | # rescale to latent size
479 | input_shape = (
480 | self.vae.latent_channels,
481 | input_shape[0] // self.vae.downsampling_factor,
482 | input_shape[1] // self.vae.downsampling_factor,
483 | )
484 | else:
485 | raise ValueError(
486 | "input_shape must be passed when no VAE is used in the model"
487 | )
488 |
489 | for num_step in num_steps:
490 | source_image = batch[self.source_key]
491 | source_image = torch.nn.functional.interpolate(
492 | source_image,
493 | size=batch[self.target_key].shape[2:],
494 | mode="bilinear",
495 | align_corners=False,
496 | ).to(dtype=self.dtype)
497 | if self.vae is not None:
498 | z = self.vae.encode(source_image)
499 |
500 | else:
501 | z = source_image
502 |
503 | with torch.autocast(dtype=self.dtype, device_type="cuda"):
504 | logs[f"samples_{num_step}_steps"] = self.sample(
505 | z,
506 | num_steps=num_step,
507 | conditioner_inputs=batch,
508 | max_samples=N,
509 | )
510 |
511 | return logs
512 |
--------------------------------------------------------------------------------
/src/lbm/models/unets/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains a collection of U-Net models.
3 | The :mod:`cr.models.unets` module includes the following classes:
4 |
5 | - :class:`DiffusersUNet2DWrapper`: A 2D U-Net model for diffusers.
6 | - :class:`DiffusersUNet2DCondWrapper`: A 2D U-Net model for diffusers with conditional input.
7 | """
8 |
9 | from .unet import DiffusersUNet2DCondWrapper, DiffusersUNet2DWrapper
10 |
11 | __all__ = [
12 | "DiffusersUNet2DWrapper",
13 | "DiffusersUNet2DCondWrapper",
14 | ]
15 |
--------------------------------------------------------------------------------
/src/lbm/models/unets/unet.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Union
2 |
3 | import torch
4 | from diffusers.models import UNet2DConditionModel, UNet2DModel
5 |
6 |
7 | class DiffusersUNet2DWrapper(UNet2DModel):
8 | """
9 | Wrapper for the UNet2DModel from diffusers
10 |
11 | See diffusers' UNet2DModel for more details
12 | """
13 |
14 | def __init__(self, *args, **kwargs):
15 | UNet2DModel.__init__(self, *args, **kwargs)
16 |
17 | def forward(
18 | self,
19 | sample: torch.Tensor,
20 | timestep: Union[torch.Tensor, float, int],
21 | conditioning: Dict[str, torch.Tensor] = None,
22 | *args,
23 | **kwargs,
24 | ):
25 | """
26 | The forward pass of the model
27 |
28 | Args:
29 |
30 | sample (torch.Tensor): The input sample
31 | timesteps (Union[torch.Tensor, float, int]): The number of timesteps
32 | """
33 | if conditioning is not None:
34 | class_labels = conditioning["cond"].get("vector", None)
35 | concat = conditioning["cond"].get("concat", None)
36 |
37 | else:
38 | class_labels = None
39 | concat = None
40 |
41 | if concat is not None:
42 | sample = torch.cat([sample, concat], dim=1)
43 |
44 | return super().forward(sample, timestep, class_labels).sample
45 |
46 | def freeze(self):
47 | """
48 | Freeze the model
49 | """
50 | self.eval()
51 | for param in self.parameters():
52 | param.requires_grad = False
53 |
54 |
55 | class DiffusersUNet2DCondWrapper(UNet2DConditionModel):
56 | """
57 | Wrapper for the UNet2DConditionModel from diffusers
58 |
59 | See diffusers' Unet2DConditionModel for more details
60 | """
61 |
62 | def __init__(self, *args, **kwargs):
63 | UNet2DConditionModel.__init__(self, *args, **kwargs)
64 | # BaseModel.__init__(self, config=ModelConfig())
65 |
66 | def forward(
67 | self,
68 | sample: torch.Tensor,
69 | timestep: Union[torch.Tensor, float, int],
70 | conditioning: Dict[str, torch.Tensor],
71 | ip_adapter_cond_embedding: Optional[List[torch.Tensor]] = None,
72 | down_block_additional_residuals: torch.Tensor = None,
73 | mid_block_additional_residual: torch.Tensor = None,
74 | down_intrablock_additional_residuals: torch.Tensor = None,
75 | *args,
76 | **kwargs,
77 | ):
78 | """
79 | The forward pass of the model
80 |
81 | Args:
82 |
83 | sample (torch.Tensor): The input sample
84 | timesteps (Union[torch.Tensor, float, int]): The number of timesteps
85 | conditioning (Dict[str, torch.Tensor]): The conditioning data
86 | down_block_additional_residuals (List[torch.Tensor]): Residuals for the down blocks.
87 | These residuals typically are used for the controlnet.
88 | mid_block_additional_residual (List[torch.Tensor]): Residuals for the mid blocks.
89 | These residuals typically are used for the controlnet.
90 | down_intrablock_additional_residuals (List[torch.Tensor]): Residuals for the down intrablocks.
91 | These residuals typically are used for the T2I adapters.middle block outputs. Defaults to False
92 | """
93 |
94 | assert isinstance(conditioning, dict), "conditionings must be a dictionary"
95 | # assert "crossattn" in conditioning["cond"], "crossattn must be in conditionings"
96 |
97 | class_labels = conditioning["cond"].get("vector", None)
98 | crossattn = conditioning["cond"].get("crossattn", None)
99 | concat = conditioning["cond"].get("concat", None)
100 |
101 | # concat conditioning
102 | if concat is not None:
103 | sample = torch.cat([sample, concat], dim=1)
104 |
105 | # down_intrablock_additional_residuals needs to be cloned, since unet will modify it
106 | if down_intrablock_additional_residuals is not None:
107 | down_intrablock_additional_residuals_clone = [
108 | curr_residuals.clone()
109 | for curr_residuals in down_intrablock_additional_residuals
110 | ]
111 | else:
112 | down_intrablock_additional_residuals_clone = None
113 |
114 | # Check diffusers.models.embeddings.py > MultiIPAdapterImageProjectionLayer > forward() for implementation
115 | # Exepected format : List[torch.Tensor] of shape (batch_size, num_image_embeds, embed_dim)
116 | # with length = number of ip_adapters loaded in the ip_adapter_wrapper
117 | if ip_adapter_cond_embedding is not None:
118 | added_cond_kwargs = {
119 | "image_embeds": [
120 | ip_adapter_embedding.unsqueeze(1)
121 | for ip_adapter_embedding in ip_adapter_cond_embedding
122 | ]
123 | }
124 | else:
125 | added_cond_kwargs = None
126 |
127 | return (
128 | super()
129 | .forward(
130 | sample=sample,
131 | timestep=timestep,
132 | encoder_hidden_states=crossattn,
133 | class_labels=class_labels,
134 | added_cond_kwargs=added_cond_kwargs,
135 | down_block_additional_residuals=down_block_additional_residuals,
136 | mid_block_additional_residual=mid_block_additional_residual,
137 | down_intrablock_additional_residuals=down_intrablock_additional_residuals_clone,
138 | )
139 | .sample
140 | )
141 |
142 | def freeze(self):
143 | """
144 | Freeze the model
145 | """
146 | self.eval()
147 | for param in self.parameters():
148 | param.requires_grad = False
149 |
--------------------------------------------------------------------------------
/src/lbm/models/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from copy import deepcopy
4 | from typing import List, Tuple
5 |
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | TILING_METHODS = ["average", "gaussian", "linear"]
10 |
11 |
12 | class Tiler:
13 | def get_tiles(
14 | self,
15 | input: torch.Tensor,
16 | tile_size: tuple,
17 | overlap_size: tuple,
18 | scale: int = 1,
19 | out_channels: int = 3,
20 | ) -> List[List[torch.tensor]]:
21 | """Get tiles
22 | Args:
23 | input (torch.Tensor): input array of shape (batch_size, channels, height, width)
24 | tile_size (tuple): tile size
25 | overlap_size (tuple): overlap size
26 | scale (int): scaling factor of the output wrt input
27 | out_channels (int): number of output channels
28 | Returns:
29 | List[List[torch.Tensor]]: List of tiles
30 | """
31 | # assert isinstance(scale, int)
32 | assert (
33 | overlap_size[0] <= tile_size[0]
34 | ), f"Overlap size {overlap_size} must be smaller than tile size {tile_size}"
35 | assert (
36 | overlap_size[1] <= tile_size[1]
37 | ), f"Overlap size {overlap_size} must be smaller than tile size {tile_size}"
38 |
39 | B, C, H, W = input.shape
40 | tile_size_H, tile_size_W = tile_size
41 |
42 | # sets overlap to 0 if the input is smaller than the tile size (i.e. no overlap)
43 | overlap_H, overlap_W = (
44 | overlap_size[0] if H > tile_size_H else 0,
45 | overlap_size[1] if W > tile_size_W else 0,
46 | )
47 |
48 | self.output_overlap_size = (
49 | int(overlap_H * scale),
50 | int(overlap_W * scale),
51 | )
52 | self.tile_size = tile_size
53 | self.output_tile_size = (
54 | int(tile_size_H * scale),
55 | int(tile_size_W * scale),
56 | )
57 | self.output_shape = (
58 | B,
59 | out_channels,
60 | int(H * scale),
61 | int(W * scale),
62 | )
63 | tiles = []
64 | logging.debug(f"(Tiler) Input shape: {(B, C, H, W)}")
65 | logging.debug(f"(Tiler) Output shape: {self.output_shape}")
66 | logging.debug(f"(Tiler) Tile size: {(tile_size_H, tile_size_W)}")
67 | logging.debug(f"(Tiler) Overlap size: {(overlap_H, overlap_W)}")
68 | # loop over all tiles in the image with overlap
69 | for i in range(0, H, tile_size_H - overlap_H):
70 | row = []
71 | for j in range(0, W, tile_size_W - overlap_W):
72 | tile = deepcopy(
73 | input[
74 | :,
75 | :,
76 | i : i + tile_size_H,
77 | j : j + tile_size_W,
78 | ]
79 | )
80 | row.append(tile)
81 | tiles.append(row)
82 | return tiles
83 |
84 | def merge_tiles(
85 | self, tiles: List[List[torch.tensor]], tiling_method: str = "gaussian"
86 | ) -> torch.tensor:
87 | """Merge tiles by averaging the overlaping regions
88 | Args:
89 | tiles (Dict[str, Tile]): dictionary of processed tiles
90 | tiling_method (str): tiling method. Can be "average", "gaussian" or "linear"
91 | Returns:
92 | torch.tensor: output image
93 | """
94 | if tiling_method == "average":
95 | return self._average_merge_tiles(tiles)
96 | elif tiling_method == "gaussian":
97 | return self._gaussian_merge_tiles(tiles)
98 | elif tiling_method == "linear":
99 | return self._linear_merge_tiles(tiles)
100 | else:
101 | raise ValueError(
102 | f"Unknown tiling method {tiling_method}. Available methods are {TILING_METHODS}"
103 | )
104 |
105 | def _average_merge_tiles(self, tiles: List[List[torch.tensor]]) -> torch.tensor:
106 | """Merge tiles by averaging the overlaping regions
107 | Args:
108 | tiles (Dict[str, Tile]): dictionary of processed tiles
109 | Returns:
110 | torch.tensor: output image
111 | """
112 |
113 | output = torch.zeros(self.output_shape)
114 |
115 | # weights to store multiplicity
116 | weights = torch.zeros(self.output_shape)
117 |
118 | _, _, output_H, output_W = self.output_shape
119 | output_overlap_size_H, output_overlap_size_W = self.output_overlap_size
120 | output_tile_size_H, output_tile_size_W = self.output_tile_size
121 |
122 | for id_i, i in enumerate(
123 | range(
124 | 0,
125 | output_H,
126 | output_tile_size_H - output_overlap_size_H,
127 | )
128 | ):
129 | for id_j, j in enumerate(
130 | range(
131 | 0,
132 | output_W,
133 | output_tile_size_W - output_overlap_size_W,
134 | )
135 | ):
136 | output[
137 | :,
138 | :,
139 | i : i + output_tile_size_H,
140 | j : j + output_tile_size_W,
141 | ] += (
142 | tiles[id_i][id_j] * 1
143 | )
144 | weights[
145 | :,
146 | :,
147 | i : i + output_tile_size_H,
148 | j : j + output_tile_size_W,
149 | ] += 1
150 |
151 | # outputs is summed up with this multiplicity
152 | # so we need to divide by the weights wich is either 1, 2 or 4 depending on the region
153 | output = output / weights
154 | return output
155 |
156 | def _gaussian_weights(
157 | self, tile_width: int, tile_height: int, nbatches: int, channels: int
158 | ):
159 | """Generates a gaussian mask of weights for tile contributions.
160 |
161 | Args:
162 | tile_width (int): width of the tile
163 | tile_height (int): height of the tile
164 | nbatches (int): number of batches
165 | channels (int): number of channels
166 | Returns:
167 | torch.tensor: weights
168 | """
169 | import numpy as np
170 | from numpy import exp, pi, sqrt
171 |
172 | latent_width = tile_width
173 | latent_height = tile_height
174 |
175 | var = 0.01
176 | midpoint = (
177 | latent_width - 1
178 | ) / 2 # -1 because index goes from 0 to latent_width - 1
179 | x_probs = [
180 | exp(
181 | -(x - midpoint)
182 | * (x - midpoint)
183 | / (latent_width * latent_width)
184 | / (2 * var)
185 | )
186 | / sqrt(2 * pi * var)
187 | for x in range(latent_width)
188 | ]
189 | midpoint = latent_height / 2
190 | y_probs = [
191 | exp(
192 | -(y - midpoint)
193 | * (y - midpoint)
194 | / (latent_height * latent_height)
195 | / (2 * var)
196 | )
197 | / sqrt(2 * pi * var)
198 | for y in range(latent_height)
199 | ]
200 |
201 | weights = np.outer(y_probs, x_probs)
202 | return torch.tile(
203 | torch.tensor(weights, device="cpu"), (nbatches, channels, 1, 1)
204 | )
205 |
206 | def _gaussian_merge_tiles(self, tiles: List[List[torch.tensor]]) -> torch.tensor:
207 | """Merge tiles by averaging the overlaping regions
208 | Args:
209 | List[List[torch.tensor]]: List of processed tiles
210 | Returns:
211 | torch.tensor: output image
212 | """
213 | B, output_C, output_H, output_W = self.output_shape
214 | output_overlap_size_H, output_overlap_size_W = self.output_overlap_size
215 | output_tile_size_H, output_tile_size_W = self.output_tile_size
216 |
217 | output = torch.zeros(self.output_shape)
218 | # weights to store multiplicity
219 | weights = torch.zeros(self.output_shape)
220 |
221 | for id_i, i in enumerate(
222 | range(
223 | 0,
224 | output_H,
225 | output_tile_size_H - output_overlap_size_H,
226 | )
227 | ):
228 | for id_j, j in enumerate(
229 | range(
230 | 0,
231 | output_W,
232 | output_tile_size_W - output_overlap_size_W,
233 | )
234 | ):
235 | w = self._gaussian_weights(
236 | tiles[id_i][id_j].shape[3],
237 | tiles[id_i][id_j].shape[2],
238 | B,
239 | output_C,
240 | )
241 | output[
242 | :,
243 | :,
244 | i : i + output_tile_size_H,
245 | j : j + output_tile_size_W,
246 | ] += (
247 | tiles[id_i][id_j] * w
248 | )
249 | weights[
250 | :,
251 | :,
252 | i : i + output_tile_size_H,
253 | j : j + output_tile_size_W,
254 | ] += w
255 |
256 | # outputs is summed up with this multiplicity
257 | output = output / weights
258 | return output
259 |
260 | def _blend_v(
261 | self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
262 | ) -> torch.Tensor:
263 | blend_extent = min(a.shape[2], b.shape[2], blend_extent)
264 | for y in range(blend_extent):
265 | b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[
266 | :, :, y, :
267 | ] * (y / blend_extent)
268 | return b
269 |
270 | def _blend_h(
271 | self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
272 | ) -> torch.Tensor:
273 | blend_extent = min(a.shape[3], b.shape[3], blend_extent)
274 | for x in range(blend_extent):
275 | b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[
276 | :, :, :, x
277 | ] * (x / blend_extent)
278 | return b
279 |
280 | def _linear_merge_tiles(self, tiles: List[List[torch.tensor]]) -> torch.Tensor:
281 | """Merge tiles by blending the overlaping regions
282 | Args:
283 | tiles (List[List[torch.tensor]]): List of processed tiles
284 | Returns:
285 | torch.Tensor: output image
286 | """
287 | output_overlap_size_H, output_overlap_size_W = self.output_overlap_size
288 | output_tile_size_H, output_tile_size_W = self.output_tile_size
289 |
290 | res_rows = []
291 | tiles_copy = deepcopy(tiles)
292 |
293 | # Cut the right and bottom overlap region
294 | limit_i = output_tile_size_H - output_overlap_size_H
295 | limit_j = output_tile_size_W - output_overlap_size_W
296 | for i, tile_row in enumerate(tiles_copy):
297 | res_row = []
298 | for j, tile in enumerate(tile_row):
299 | tile_val = tile
300 | if j > 0:
301 | tile_val = self._blend_h(
302 | tile_row[j - 1], tile, output_overlap_size_W
303 | )
304 | tiles_copy[i][j] = tile_val
305 | if i > 0:
306 | tile_val = self._blend_v(
307 | tiles_copy[i - 1][j], tile_val, output_overlap_size_H
308 | )
309 | tiles_copy[i][j] = tile_val
310 | res_row.append(tile_val[:, :, :limit_i, :limit_j])
311 | res_rows.append(torch.cat(res_row, dim=3))
312 | output = torch.cat(res_rows, dim=2)
313 | return output
314 |
315 |
316 | def extract_into_tensor(
317 | a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int, ...]
318 | ) -> torch.Tensor:
319 | """
320 | Extracts values from a tensor into a new tensor using indices from another tensor.
321 |
322 | :param a: the tensor to extract values from.
323 | :param t: the tensor containing the indices.
324 | :param x_shape: the shape of the tensor to extract values into.
325 | :return: a new tensor containing the extracted values.
326 | """
327 |
328 | b, *_ = t.shape
329 | out = a.gather(-1, t)
330 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
331 |
332 |
333 | def pad(x: torch.Tensor, base_h: int, base_w: int) -> torch.Tensor:
334 | """
335 | Pads a tensor to the nearest multiple of base_h and base_w.
336 |
337 | :param x: the tensor to pad.
338 | :param base_h: the base height.
339 | :param base_w: the base width.
340 | :return: the padded tensor.
341 | """
342 | h, w = x.shape[-2:]
343 | h_ = math.ceil(h / base_h) * base_h
344 | w_ = math.ceil(w / base_w) * base_w
345 | if w_ != w:
346 | x = F.pad(x, (0, abs(w_ - w), 0, 0))
347 | if h_ != h:
348 | x = F.pad(x, (0, 0, 0, abs(h_ - h)))
349 | return x
350 |
351 |
352 | def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
353 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
354 | dims_to_append = target_dims - x.ndim
355 | if dims_to_append < 0:
356 | raise ValueError(
357 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
358 | )
359 | return x[(...,) + (None,) * dims_to_append]
360 |
361 |
362 | @torch.no_grad()
363 | def update_ema(
364 | target_params: List[torch.Tensor],
365 | source_params: List[torch.Tensor],
366 | rate: float = 0.99,
367 | ):
368 | """
369 | Update target parameters to be closer to those of source parameters using
370 | an exponential moving average.
371 |
372 | :param target_params: the target parameter sequence.
373 | :param source_params: the source parameter sequence.
374 | :param rate: the EMA rate (closer to 1 means slower).
375 | """
376 | for targ, src in zip(target_params, source_params):
377 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
378 |
--------------------------------------------------------------------------------
/src/lbm/models/vae/__init__.py:
--------------------------------------------------------------------------------
1 | from .autoencoderKL import AutoencoderKLDiffusers
2 | from .autoencoderKL_config import AutoencoderKLDiffusersConfig
3 |
4 | __all__ = ["AutoencoderKLDiffusers", "AutoencoderKLDiffusersConfig"]
5 |
--------------------------------------------------------------------------------
/src/lbm/models/vae/autoencoderKL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers.models import AutoencoderKL
3 |
4 | from ..base.base_model import BaseModel
5 | from ..utils import Tiler, pad
6 | from .autoencoderKL_config import AutoencoderKLDiffusersConfig
7 |
8 |
9 | class AutoencoderKLDiffusers(BaseModel):
10 | """This is the VAE class used to work with latent models
11 |
12 | Args:
13 |
14 | config (AutoencoderKLDiffusersConfig): The config class which defines all the required parameters.
15 | """
16 |
17 | def __init__(self, config: AutoencoderKLDiffusersConfig):
18 | BaseModel.__init__(self, config)
19 | self.config = config
20 | self.vae_model = AutoencoderKL.from_pretrained(
21 | config.version,
22 | subfolder=config.subfolder,
23 | revision=config.revision,
24 | )
25 | self.tiling_size = config.tiling_size
26 | self.tiling_overlap = config.tiling_overlap
27 |
28 | # get downsampling factor
29 | self._get_properties()
30 |
31 | @torch.no_grad()
32 | def _get_properties(self):
33 | self.has_shift_factor = (
34 | hasattr(self.vae_model.config, "shift_factor")
35 | and self.vae_model.config.shift_factor is not None
36 | )
37 | self.shift_factor = (
38 | self.vae_model.config.shift_factor if self.has_shift_factor else 0
39 | )
40 |
41 | # set latent channels
42 | self.latent_channels = self.vae_model.config.latent_channels
43 | self.has_latents_mean = (
44 | hasattr(self.vae_model.config, "latents_mean")
45 | and self.vae_model.config.latents_mean is not None
46 | )
47 | self.has_latents_std = (
48 | hasattr(self.vae_model.config, "latents_std")
49 | and self.vae_model.config.latents_std is not None
50 | )
51 | self.latents_mean = self.vae_model.config.latents_mean
52 | self.latents_std = self.vae_model.config.latents_std
53 |
54 | x = torch.randn(1, self.vae_model.config.in_channels, 32, 32)
55 | z = self.encode(x)
56 |
57 | # set downsampling factor
58 | self.downsampling_factor = int(x.shape[2] / z.shape[2])
59 |
60 | def encode(self, x: torch.tensor, batch_size: int = 8):
61 | latents = []
62 | for i in range(0, x.shape[0], batch_size):
63 | latents.append(
64 | self.vae_model.encode(x[i : i + batch_size]).latent_dist.sample()
65 | )
66 | latents = torch.cat(latents, dim=0)
67 | latents = (latents - self.shift_factor) * self.vae_model.config.scaling_factor
68 |
69 | return latents
70 |
71 | def decode(self, z: torch.tensor):
72 |
73 | if self.has_latents_mean and self.has_latents_std:
74 | latents_mean = (
75 | torch.tensor(self.latents_mean)
76 | .view(1, self.latent_channels, 1, 1)
77 | .to(z.device, z.dtype)
78 | )
79 | latents_std = (
80 | torch.tensor(self.latents_std)
81 | .view(1, self.latent_channels, 1, 1)
82 | .to(z.device, z.dtype)
83 | )
84 | z = z * latents_std / self.vae_model.config.scaling_factor + latents_mean
85 | else:
86 | z = z / self.vae_model.config.scaling_factor + self.shift_factor
87 |
88 | use_tiling = (
89 | z.shape[2] > self.tiling_size[0] or z.shape[3] > self.tiling_size[1]
90 | )
91 |
92 | if use_tiling:
93 | samples = []
94 | for i in range(z.shape[0]):
95 |
96 | z_i = z[i].unsqueeze(0)
97 |
98 | tiler = Tiler()
99 | tiles = tiler.get_tiles(
100 | input=z_i,
101 | tile_size=self.tiling_size,
102 | overlap_size=self.tiling_overlap,
103 | scale=self.downsampling_factor,
104 | out_channels=3,
105 | )
106 |
107 | for i, tile_row in enumerate(tiles):
108 | for j, tile in enumerate(tile_row):
109 | tile_shape = tile.shape
110 | # pad tile to inference size if tile is smaller than inference size
111 | tile = pad(
112 | tile,
113 | base_h=self.tiling_size[0],
114 | base_w=self.tiling_size[1],
115 | )
116 | tile_decoded = self.vae_model.decode(tile).sample
117 | tiles[i][j] = (
118 | tile_decoded[
119 | 0,
120 | :,
121 | : int(tile_shape[2] * self.downsampling_factor),
122 | : int(tile_shape[3] * self.downsampling_factor),
123 | ]
124 | .cpu()
125 | .unsqueeze(0)
126 | )
127 |
128 | # merge tiles
129 | samples.append(tiler.merge_tiles(tiles=tiles))
130 |
131 | samples = torch.cat(samples, dim=0)
132 |
133 | else:
134 | samples = self.vae_model.decode(z).sample
135 |
136 | return samples
137 |
--------------------------------------------------------------------------------
/src/lbm/models/vae/autoencoderKL_config.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from pydantic.dataclasses import dataclass
4 |
5 | from ..base import ModelConfig
6 |
7 |
8 | @dataclass
9 | class AutoencoderKLDiffusersConfig(ModelConfig):
10 | """This is the VAEConfig class which defines all the useful parameters to instantiate the model.
11 |
12 | Args:
13 |
14 | version (str): The version of the model. Defaults to "stabilityai/sdxl-vae".
15 | subfolder (str): The subfolder of the model if loaded from another model. Defaults to "".
16 | revision (str): The revision of the model. Defaults to "main".
17 | input_key (str): The key of the input data in the batch. Defaults to "image".
18 | tiling_size (Tuple[int, int]): The size of the tiling. Defaults to (64, 64).
19 | tiling_overlap (Tuple[int, int]): The overlap of the tiling. Defaults to (16, 16).
20 | """
21 |
22 | version: str = "stabilityai/sdxl-vae"
23 | subfolder: str = ""
24 | revision: str = "main"
25 | input_key: str = "image"
26 | tiling_size: Tuple[int, int] = (64, 64)
27 | tiling_overlap: Tuple[int, int] = (16, 16)
28 |
--------------------------------------------------------------------------------
/src/lbm/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains the training pipeline and the training configuration along with all relevant parts
3 | of the training pipeline such as loggers and callbacks.
4 |
5 | The :mod:`cr.trainer` includes the following submodules:
6 |
7 | - :mod:`cr.trainer.trainer`: the main training pipeline class for ClipDrop.
8 | - :mod:`cr.trainer.training_config`: the configuration for the training pipeline.
9 | - :mod:`cr.trainer.loggers`: the loggers for logging samples to wandb.
10 |
11 |
12 | Examples
13 | ########
14 |
15 | Train a model using the training pipeline
16 |
17 | .. code-block:: python
18 |
19 | from cr.trainer import TrainingPipeline, TrainingConfig
20 | from cr.data import DataPipeline, DataConfig
21 | from pytorch_lightning import Trainer
22 | from cr.data.datasets import DataModule, DataModuleConfig
23 |
24 | # Create a model to train
25 | model = DummyModel()
26 |
27 | # Create a training configuration
28 | config = TrainingConfig(
29 | experiment_id="test",
30 | optimizers_name=["AdamW"],
31 | optimizers_kwargs=[{}],
32 | learning_rates=[1e-3],
33 | lr_schedulers_name=[None],
34 | lr_schedulers_kwargs=[{}],
35 | trainable_params=[["./*"]],
36 | log_keys="txt",
37 | log_samples_model_kwargs={
38 | "max_samples": 8,
39 | "num_steps": 20,
40 | "input_shape": (4, 32, 32),
41 | "guidance_scale": 7.5,
42 | }
43 | )
44 |
45 | # Create a training pipeline
46 | pipeline = TrainingPipeline(model=model, pipeline_config=config)
47 |
48 | # Create a DataModule
49 | data_module = DataModule(
50 | train_config=DataModuleConfig(
51 | shards_path_or_urls="your urls or paths",
52 | decoder="pil",
53 | shuffle_buffer_size=100,
54 | per_worker_batch_size=32,
55 | num_workers=4,
56 | ),
57 | train_filters_mappers=your_mappers_and_filters,
58 | eval_config=DataModuleConfig(
59 | shards_path_or_urls="your urls or paths",
60 | decoder="pil",
61 | shuffle_buffer_size=100,
62 | per_worker_batch_size=32,
63 | num_workers=4,
64 | ),
65 | eval_filters_mappers=your_mappers_and_filters,
66 | )
67 |
68 | # Create a trainer
69 | trainer = Trainer(
70 | accelerator="cuda",
71 | max_epochs=1,
72 | devices=1,
73 | log_every_n_steps=1,
74 | default_root_dir="your dir",
75 | max_steps=2,
76 | )
77 |
78 | # Train the model
79 | trainer.fit(pipeline, data_module)
80 | """
81 |
82 | from .trainer import TrainingPipeline
83 | from .training_config import TrainingConfig
84 |
85 | __all__ = ["TrainingPipeline", "TrainingConfig"]
86 |
--------------------------------------------------------------------------------
/src/lbm/trainer/loggers.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from typing import Any, Dict, List, Tuple
4 |
5 | import numpy as np
6 | import torch
7 | import wandb
8 | from PIL import Image, ImageDraw, ImageFont
9 | from pytorch_lightning import Trainer
10 | from pytorch_lightning.callbacks import Callback
11 | from pytorch_lightning.utilities import rank_zero_only
12 | from torchvision.utils import make_grid
13 |
14 | from ..trainer import TrainingPipeline
15 |
16 | logging.basicConfig(level=logging.INFO)
17 |
18 |
19 | def create_grid_texts(
20 | texts: List[str],
21 | n_cols: int = 4,
22 | image_size: Tuple[int] = (512, 512),
23 | font_size: int = 40,
24 | margin: int = 5,
25 | offset: int = 5,
26 | ) -> Image.Image:
27 | """
28 | Create a grid of white images containing the given texts.
29 |
30 | Args:
31 | texts (List[str]): List of strings to be drawn on images.
32 | n_cols (int): Number of columns in the grid.
33 | image_size (tuple): Size of the generated images (width, height).
34 | font_size (int): Font size of the text.
35 | margin (int): Margin around the text.
36 | offset (int): Offset between lines.
37 |
38 | Returns:
39 | PIL.Image: List of generated images as a grid
40 | """
41 |
42 | images = []
43 | font = ImageFont.load_default(size=font_size)
44 |
45 | for text in texts:
46 | img = Image.new("RGB", image_size, color="white")
47 | draw = ImageDraw.Draw(img)
48 | margin_ = margin
49 | offset_ = offset
50 | for line in wrap_text(
51 | text=text, draw=draw, max_width=image_size[0] - 2 * margin_, font=font
52 | ):
53 | draw.text((margin_, offset_), line, font=font, fill="black")
54 | offset_ += font_size
55 | images.append(img)
56 |
57 | # create a pil grid
58 | n_rows = math.ceil(len(images) / n_cols)
59 | grid = Image.new(
60 | "RGB", (n_cols * image_size[0], n_rows * image_size[1]), color="white"
61 | )
62 | for i, img in enumerate(images):
63 | grid.paste(img, (i % n_cols * image_size[0], i // n_cols * image_size[1]))
64 |
65 | return grid
66 |
67 |
68 | def wrap_text(
69 | text: str, draw: ImageDraw.Draw, max_width: int, font: ImageFont
70 | ) -> List[str]:
71 | """
72 | Wrap text to fit within a specified width when drawn.
73 | It will return to the new line when the text is larger than the max_width.
74 |
75 | Args:
76 | text (str): The text to be wrapped.
77 | draw (ImageDraw.Draw): The draw object to calculate text size.
78 | max_width (int): The maximum width for the wrapped text.
79 | font (ImageFont): The font used for the text.
80 |
81 | Returns:
82 | List[str]: List of wrapped lines.
83 | """
84 | lines = []
85 | current_line = ""
86 | for letter in text:
87 | if draw.textbbox((0, 0), current_line + letter, font=font)[2] <= max_width:
88 | current_line += letter
89 | else:
90 | lines.append(current_line)
91 | current_line = letter
92 | lines.append(current_line)
93 | return lines
94 |
95 |
96 | class WandbSampleLogger(Callback):
97 | """
98 | Logger for logging samples to wandb. This logger is used to log images, text, and metrics to wandb.
99 |
100 | Args:
101 | log_batch_freq (int): The frequency of logging samples to wandb. Default is 100.
102 | """
103 |
104 | def __init__(self, log_batch_freq: int = 100):
105 | super().__init__()
106 | self.log_batch_freq = log_batch_freq
107 |
108 | def on_train_batch_end(
109 | self,
110 | trainer: Trainer,
111 | pl_module: TrainingPipeline,
112 | outputs: Dict[str, Any],
113 | batch: Any,
114 | batch_idx: int,
115 | ) -> None:
116 | self.log_samples(trainer, pl_module, outputs, batch, batch_idx, split="train")
117 | self._process_logs(trainer, outputs, split="train")
118 |
119 | def on_validation_batch_end(
120 | self,
121 | trainer: Trainer,
122 | pl_module: TrainingPipeline,
123 | outputs: Dict[str, Any],
124 | batch: Any,
125 | batch_idx: int,
126 | ) -> None:
127 | self.log_samples(trainer, pl_module, outputs, batch, batch_idx, split="val")
128 | self._process_logs(trainer, outputs, split="val")
129 |
130 | @rank_zero_only
131 | @torch.no_grad()
132 | def log_samples(
133 | self,
134 | trainer: Trainer,
135 | pl_module: TrainingPipeline,
136 | outputs: Dict[str, Any],
137 | batch: Dict[str, Any],
138 | batch_idx: int,
139 | split: str = "train",
140 | ) -> None:
141 | if hasattr(pl_module, "log_samples"):
142 | if batch_idx % self.log_batch_freq == 0:
143 | is_training = pl_module.training
144 | if is_training:
145 | pl_module.eval()
146 |
147 | logs = pl_module.log_samples(batch)
148 | logs = self._process_logs(trainer, logs, split=split)
149 |
150 | if is_training:
151 | pl_module.train()
152 | else:
153 | logging.warning(
154 | "log_img method not found in LightningModule. Skipping image logging."
155 | )
156 |
157 | @rank_zero_only
158 | def _process_logs(
159 | self, trainer, logs: Dict[str, Any], rescale=True, split="train"
160 | ) -> Dict[str, Any]:
161 | for key, value in logs.items():
162 | if isinstance(value, torch.Tensor):
163 | value = value.detach().cpu()
164 | if value.dim() == 4:
165 | images = value
166 | if rescale:
167 | images = (images + 1.0) / 2.0
168 | grid = make_grid(images, nrow=4)
169 | grid = grid.permute(1, 2, 0)
170 | grid = grid.mul(255).clamp(0, 255).to(torch.uint8)
171 | logs[key] = grid.numpy()
172 | trainer.logger.experiment.log(
173 | {f"{key}/{split}": [wandb.Image(Image.fromarray(logs[key]))]},
174 | step=trainer.global_step,
175 | )
176 |
177 | # Scalar tensor
178 | if value.dim() == 1 or value.dim() == 0:
179 | value = value.float().numpy()
180 | trainer.logger.experiment.log(
181 | {f"{key}/{split}": value}, step=trainer.global_step
182 | )
183 |
184 | # list of string (e.g. text)
185 | if isinstance(value, list):
186 | if isinstance(value[0], str):
187 | pil_image_texts = create_grid_texts(value)
188 | wandb_image = wandb.Image(pil_image_texts)
189 | trainer.logger.experiment.log(
190 | {f"{key}/{split}": [wandb_image]},
191 | step=trainer.global_step,
192 | )
193 |
194 | # dict of tensors (e.g. metrics)
195 | if isinstance(value, dict):
196 | for k, v in value.items():
197 | if isinstance(v, torch.Tensor):
198 | value[k] = v.detach().cpu().numpy()
199 | trainer.logger.experiment.log(
200 | {f"{key}/{split}": value}, step=trainer.global_step
201 | )
202 |
203 | if isinstance(value, int) or isinstance(value, float):
204 | trainer.logger.experiment.log(
205 | {f"{key}/{split}": value}, step=trainer.global_step
206 | )
207 |
208 | return logs
209 |
210 |
211 | class TensorBoardSampleLogger(Callback):
212 | """
213 | Logger for logging samples to tensorboard. This logger is used to log images, text, and metrics to tensorboard.
214 |
215 | Args:
216 | log_batch_freq (int): The frequency of logging samples to tensorboard. Default is 100.
217 | """
218 |
219 | def __init__(self, log_batch_freq: int = 100):
220 | super().__init__()
221 | self.log_batch_freq = log_batch_freq
222 |
223 | def on_train_batch_end(
224 | self,
225 | trainer: Trainer,
226 | pl_module: TrainingPipeline,
227 | outputs: Dict[str, Any],
228 | batch: Any,
229 | batch_idx: int,
230 | ) -> None:
231 | self.log_samples(trainer, pl_module, outputs, batch, batch_idx, split="train")
232 | self._process_logs(trainer, outputs, split="train")
233 |
234 | def on_validation_batch_end(
235 | self,
236 | trainer: Trainer,
237 | pl_module: TrainingPipeline,
238 | outputs: Dict[str, Any],
239 | batch: Any,
240 | batch_idx: int,
241 | ) -> None:
242 | self.log_samples(trainer, pl_module, outputs, batch, batch_idx, split="val")
243 | self._process_logs(trainer, outputs, split="val")
244 |
245 | @rank_zero_only
246 | @torch.no_grad()
247 | def log_samples(
248 | self,
249 | trainer: Trainer,
250 | pl_module: TrainingPipeline,
251 | outputs: Dict[str, Any],
252 | batch: Dict[str, Any],
253 | batch_idx: int,
254 | split: str = "train",
255 | ) -> None:
256 | if hasattr(pl_module, "log_samples"):
257 | if batch_idx % self.log_batch_freq == 0:
258 | is_training = pl_module.training
259 | if is_training:
260 | pl_module.eval()
261 |
262 | logs = pl_module.log_samples(batch)
263 | logs = self._process_logs(trainer, logs, split=split)
264 |
265 | if is_training:
266 | pl_module.train()
267 | else:
268 | logging.warning(
269 | "log_img method not found in LightningModule. Skipping image logging."
270 | )
271 |
272 | @rank_zero_only
273 | def _process_logs(
274 | self, trainer, logs: Dict[str, Any], rescale=True, split="train"
275 | ) -> Dict[str, Any]:
276 | for key, value in logs.items():
277 | if isinstance(value, torch.Tensor):
278 | value = value.detach().cpu()
279 | if value.dim() == 4:
280 | images = value
281 | if rescale:
282 | images = (images + 1.0) / 2.0
283 | grid = make_grid(images, nrow=4)
284 | # grid = grid.permute(1, 2, 0)
285 | grid = grid.mul(255).clamp(0, 255).to(torch.uint8)
286 | logs[key] = grid.numpy()
287 | trainer.logger.experiment.add_image(
288 | f"{key}/{split}",
289 | logs[key],
290 | trainer.global_step,
291 | )
292 |
293 | # Scalar tensor
294 | if value.dim() == 1 or value.dim() == 0:
295 | value = value.float().numpy()
296 | trainer.logger.experiment.add_scalar(
297 | f"{key}/{split}", value, trainer.global_step
298 | )
299 |
300 | # list of string (e.g. text)
301 | if isinstance(value, list):
302 | if isinstance(value[0], str):
303 | pil_image_texts = create_grid_texts(value)
304 | trainer.logger.experiment.add_image(
305 | f"{key}/{split}",
306 | np.transpose(np.array(pil_image_texts), (2, 0, 1)),
307 | trainer.global_step,
308 | )
309 |
310 | # dict of tensors (e.g. metrics)
311 | if isinstance(value, dict):
312 | for k, v in value.items():
313 | if isinstance(v, torch.Tensor):
314 | value[k] = v.detach().cpu().numpy()
315 | trainer.logger.experiment.add_scalar(
316 | f"{key}/{split}", value, trainer.global_step
317 | )
318 |
319 | if isinstance(value, int) or isinstance(value, float):
320 | trainer.logger.experiment.add_scalar(
321 | f"{key}/{split}", value, trainer.global_step
322 | )
323 |
324 | return logs
325 |
--------------------------------------------------------------------------------
/src/lbm/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import logging
3 | import re
4 | import time
5 | from typing import Any, Dict
6 |
7 | import pytorch_lightning as pl
8 | import torch
9 |
10 | from ..models.base.base_model import BaseModel
11 | from .training_config import TrainingConfig
12 |
13 | logging.basicConfig(level=logging.INFO)
14 |
15 |
16 | class TrainingPipeline(pl.LightningModule):
17 | """
18 | Main Training Pipeline class
19 |
20 | Args:
21 |
22 | model (BaseModel): The model to train
23 | pipeline_config (TrainingConfig): The configuration for the training pipeline
24 | verbose (bool): Whether to print logs in the console. Default is False.
25 | """
26 |
27 | def __init__(
28 | self,
29 | model: BaseModel,
30 | pipeline_config: TrainingConfig,
31 | verbose: bool = False,
32 | **kwargs,
33 | ):
34 | super().__init__()
35 |
36 | self.model = model
37 | self.pipeline_config = pipeline_config
38 | self.log_samples_model_kwargs = pipeline_config.log_samples_model_kwargs
39 |
40 | # save hyperparameters.
41 | self.save_hyperparameters(ignore="model")
42 | self.save_hyperparameters({"model_config": model.config.to_dict()})
43 |
44 | # logger.
45 | self.verbose = verbose
46 |
47 | # setup logging.
48 | log_keys = pipeline_config.log_keys
49 |
50 | if isinstance(log_keys, str):
51 | log_keys = [log_keys]
52 |
53 | if log_keys is None:
54 | log_keys = []
55 |
56 | self.log_keys = log_keys
57 |
58 | def on_fit_start(self) -> None:
59 | self.model.on_fit_start(device=self.device)
60 | if self.global_rank == 0:
61 | self.timer = time.perf_counter()
62 |
63 | def on_train_batch_end(
64 | self, outputs: Dict[str, Any], batch: Any, batch_idx: int
65 | ) -> None:
66 | if self.global_rank == 0:
67 | logging.debug("on_train_batch_end")
68 | self.model.on_train_batch_end(batch)
69 |
70 | average_time_frequency = 10
71 | if self.global_rank == 0 and batch_idx % average_time_frequency == 0:
72 | delta = time.perf_counter() - self.timer
73 | logging.info(
74 | f"Average time per batch {batch_idx} took {delta / (batch_idx + 1)} seconds"
75 | )
76 |
77 | def configure_optimizers(self) -> torch.optim.Optimizer:
78 | """
79 | Setup optimizers and learning rate schedulers.
80 | """
81 | optimizers = []
82 | lr = self.pipeline_config.learning_rate
83 | param_list = []
84 | n_params = 0
85 | param_list_ = {"params": []}
86 | for name, param in self.model.named_parameters():
87 | for regex in self.pipeline_config.trainable_params:
88 | pattern = re.compile(regex)
89 | if re.match(pattern, name):
90 | if param.requires_grad:
91 | param_list_["params"].append(param)
92 | n_params += param.numel()
93 |
94 | param_list.append(param_list_)
95 |
96 | logging.info(f"Number of trainable parameters: {n_params}")
97 |
98 | optimizer_cls = getattr(
99 | importlib.import_module("torch.optim"),
100 | self.pipeline_config.optimizer_name,
101 | )
102 | optimizer = optimizer_cls(
103 | param_list, lr=lr, **self.pipeline_config.optimizer_kwargs
104 | )
105 | optimizers.append(optimizer)
106 |
107 | self.optims = optimizers
108 | schedulers_config = self.configure_lr_schedulers()
109 |
110 | for name, param in self.model.named_parameters():
111 | set_grad_false = True
112 | for regex in self.pipeline_config.trainable_params:
113 | pattern = re.compile(regex)
114 | if re.match(pattern, name):
115 | if param.requires_grad:
116 | set_grad_false = False
117 | if set_grad_false:
118 | param.requires_grad = False
119 |
120 | num_trainable_params = sum(
121 | p.numel() for p in self.model.parameters() if p.requires_grad
122 | )
123 |
124 | logging.info(f"Number of trainable parameters: {num_trainable_params}")
125 |
126 | schedulers_config = self.configure_lr_schedulers()
127 |
128 | if schedulers_config is None:
129 | return optimizers
130 |
131 | return optimizers, [
132 | schedulers_config_ for schedulers_config_ in schedulers_config
133 | ]
134 |
135 | def configure_lr_schedulers(self):
136 | schedulers_config = []
137 | if self.pipeline_config.lr_scheduler_name is None:
138 | scheduler = None
139 | schedulers_config.append(scheduler)
140 | else:
141 | scheduler_cls = getattr(
142 | importlib.import_module("torch.optim.lr_scheduler"),
143 | self.pipeline_config.lr_scheduler_name,
144 | )
145 | scheduler = scheduler_cls(
146 | self.optims[0],
147 | **self.pipeline_config.lr_scheduler_kwargs,
148 | )
149 | lr_scheduler_config = {
150 | "scheduler": scheduler,
151 | "interval": self.pipeline_config.lr_scheduler_interval,
152 | "monitor": "val_loss",
153 | "frequency": self.pipeline_config.lr_scheduler_frequency,
154 | }
155 | schedulers_config.append(lr_scheduler_config)
156 |
157 | if all([scheduler is None for scheduler in schedulers_config]):
158 | return None
159 |
160 | return schedulers_config
161 |
162 | def training_step(self, train_batch: Dict[str, Any], batch_idx: int) -> dict:
163 | model_output = self.model(train_batch)
164 | loss = model_output["loss"]
165 | logging.info(f"loss: {loss}")
166 | return {
167 | "loss": loss,
168 | "batch_idx": batch_idx,
169 | }
170 |
171 | def validation_step(self, val_batch: Dict[str, Any], val_idx: int) -> dict:
172 | loss = self.model(val_batch, device=self.device)["loss"]
173 |
174 | metrics = self.model.compute_metrics(val_batch)
175 |
176 | return {"loss": loss, "metrics": metrics}
177 |
178 | def log_samples(self, batch: Dict[str, Any]):
179 | logging.debug("log_samples")
180 | logs = self.model.log_samples(
181 | batch,
182 | **self.log_samples_model_kwargs,
183 | )
184 |
185 | if logs is not None:
186 | N = min([logs[keys].shape[0] for keys in logs])
187 | else:
188 | N = 0
189 |
190 | # Log inputs
191 | if self.log_keys is not None:
192 | for key in self.log_keys:
193 | if key in batch:
194 | if N > 0:
195 | logs[key] = batch[key][:N]
196 | else:
197 | logs[key] = batch[key]
198 |
199 | return logs
200 |
--------------------------------------------------------------------------------
/src/lbm/trainer/training_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import field
2 | from typing import List, Literal, Optional, Union
3 |
4 | from pydantic.dataclasses import dataclass
5 |
6 | from ..config import BaseConfig
7 |
8 |
9 | @dataclass
10 | class TrainingConfig(BaseConfig):
11 | """
12 | Configuration for the training pipeline
13 |
14 | Args:
15 |
16 | experiment_id (str):
17 | The experiment id for the training run. If not provided, a random id will be generated.
18 | optimizer_name (str):
19 | The optimizer to use. Default is "AdamW". Choices are "Adam", "AdamW", "Adadelta", "Adagrad", "RMSprop", "SGD"
20 | optimizer_kwargs (Dict[str, Any])
21 | The optimizer kwargs. Default is [{}]
22 | learning_rate (float):
23 | The learning rate to use. Default is 1e-3
24 | lr_scheduler_name (str):
25 | The learning rate scheduler to use. Default is None. Choices are "StepLR", "CosineAnnealingLR",
26 | "CosineAnnealingWarmRestarts", "ReduceLROnPlateau", "ExponentialLR"
27 | lr_scheduler_kwargs (Dict[str, Any])
28 | The learning rate scheduler kwargs. Default is [{}]
29 | lr_scheduler_interval (str):
30 | The learning rate scheduler interval. Default is ["step"]. Choices are "step", "epoch"
31 | lr_scheduler_frequency (int):
32 | The learning rate scheduler frequency. Default is 1
33 | metrics (List[str])
34 | The metrics to use. Default is None
35 | tracking_metrics: Optional[List[str]]
36 | The metrics to track. Default is None
37 | backup_every (int):
38 | The frequency to backup the model. Default is 50.
39 | trainable_params (Union[str, List[str]]):
40 | Regexes indicateing the parameters to train.
41 | Default is [["./*"]] (i.e. all parameters are trainable)
42 | log_keys: Union[str, List[str]]:
43 | The keys to log when sampling from the model. Default is "txt"
44 | log_samples_model_kwargs (Dict[str, Any]):
45 | The kwargs for logging samples from the model. Default is {
46 | "max_samples": 4,
47 | "num_steps": 20,
48 | "input_shape": None,
49 | }
50 | """
51 |
52 | experiment_id: Optional[str] = None
53 | optimizer_name: Literal[
54 | "Adam", "AdamW", "Adadelta", "Adagrad", "RMSprop", "SGD"
55 | ] = field(default_factory=lambda: "AdamW")
56 | optimizer_kwargs: Optional[dict] = field(default_factory=lambda: {})
57 | learning_rate: float = field(default_factory=lambda: 1e-3)
58 | lr_scheduler_name: Optional[
59 | Literal[
60 | "StepLR",
61 | "CosineAnnealingLR",
62 | "CosineAnnealingWarmRestarts",
63 | "ReduceLROnPlateau",
64 | "ExponentialLR",
65 | None,
66 | ]
67 | ] = None
68 | lr_scheduler_kwargs: Optional[dict] = field(default_factory=lambda: {})
69 | lr_scheduler_interval: Optional[Literal["step", "epoch", None]] = "step"
70 | lr_scheduler_frequency: Optional[int] = 1
71 | metrics: Optional[List[str]] = None
72 | tracking_metrics: Optional[List[str]] = None
73 | backup_every: int = 50
74 | trainable_params: List[str] = field(default_factory=lambda: ["./*"])
75 | log_keys: Optional[Union[str, List[str]]] = "txt"
76 | log_samples_model_kwargs: Optional[dict] = field(
77 | default_factory=lambda: {
78 | "max_samples": 4,
79 | "num_steps": 20,
80 | "input_shape": None,
81 | }
82 | )
83 |
--------------------------------------------------------------------------------
/src/lbm/trainer/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import re
4 | import time
5 | from typing import Dict, List, Literal, Optional, Tuple
6 |
7 | import torch
8 |
9 |
10 | class StateDictAdapter:
11 | """
12 | StateDictAdapter for adapting the state dict of a model to a checkpoint state dict.
13 |
14 | This class will iterate over all keys in the checkpoint state dict and filter them by a list of regex keys.
15 | For each matching key, the class will adapt the checkpoint state dict to the model state dict.
16 | Depending on the target size, the class will add missing blocks or cut the block.
17 | When adding missing blocks, the class will use a strategy to fill the missing blocks: either adding zeros or normal random values.
18 |
19 | Example:
20 |
21 | ```
22 | adapter = StateDictAdapter()
23 | new_state_dict = adapter(
24 | model_state_dict=model.state_dict(),
25 | checkpoint_state_dict=state_dict,
26 | regex_keys=[
27 | r"class_embedding.linear_1.weight",
28 | r"conv_in.weight",
29 | r"(down_blocks|up_blocks)\.\d+\.attentions\.\d+\.transformer_blocks\.\d+\.attn\d+\.(to_k|to_v)\.weight",
30 | r"mid_block\.attentions\.\d+\.transformer_blocks\.\d+\.attn\d+\.(to_k|to_v)\.weight"
31 | ]
32 | )
33 | ```
34 |
35 | Args:
36 | model_state_dict (Dict[str, torch.Tensor]): The model state dict.
37 | checkpoint_state_dict (Dict[str, torch.Tensor]): The checkpoint state dict.
38 | regex_keys (Optional[List[str]]): A list of regex keys to adapt the checkpoint state dict. Defaults to None.
39 | Passing a list of regex will drastically reduce the latency.
40 | If None, all keys in the checkpoint state dict will be adapted.
41 | strategy (Literal["zeros", "normal"], optional): The strategy to fill the missing blocks. Defaults to "normal".
42 |
43 | """
44 |
45 | def _create_block(
46 | self,
47 | shape: List[int],
48 | strategy: Literal["zeros", "normal"],
49 | input: torch.Tensor = None,
50 | ):
51 | if strategy == "zeros":
52 | return torch.zeros(shape)
53 | elif strategy == "normal":
54 | if input is not None:
55 | mean = input.mean().item()
56 | std = input.std().item()
57 | return torch.randn(shape) * std + mean
58 | else:
59 | return torch.randn(shape)
60 | else:
61 | raise ValueError(f"Unknown strategy {strategy}")
62 |
63 | def __call__(
64 | self,
65 | model_state_dict: Dict[str, torch.Tensor],
66 | checkpoint_state_dict: Dict[str, torch.Tensor],
67 | regex_keys: Optional[List[str]] = None,
68 | strategy: Literal["zeros", "normal"] = "normal",
69 | ):
70 | start = time.perf_counter()
71 | # if no regex keys are provided, we use all keys in the model state dict
72 | if regex_keys is None:
73 | regex_keys = list(model_state_dict.keys())
74 |
75 | # iterate over all keys in the checkpoint state dict
76 | for checkpoint_key in list(checkpoint_state_dict.keys()):
77 | # iterate over all regex keys
78 | for regex_key in regex_keys:
79 | if re.match(regex_key, checkpoint_key):
80 | dst_shape = model_state_dict[checkpoint_key].shape
81 | src_shape = checkpoint_state_dict[checkpoint_key].shape
82 |
83 | ## Sizes adapter
84 | # if length of shapes are different, we need to unsqueeze or squeeze the tensor
85 | if len(dst_shape) != len(src_shape):
86 | # in the case [a] vs [a, b] -> unsqueeze [a, 1]
87 | if len(src_shape) == 1:
88 | checkpoint_state_dict[checkpoint_key] = (
89 | checkpoint_state_dict[checkpoint_key].unsqueeze(1)
90 | )
91 | logging.info(
92 | f"Unsqueeze {checkpoint_key}: {src_shape} -> {checkpoint_state_dict[checkpoint_key].shape}"
93 | )
94 | # in the case [a, b] vs [a] -> squeeze [a]
95 | elif len(dst_shape) == 1:
96 | checkpoint_state_dict[checkpoint_key] = (
97 | checkpoint_state_dict[checkpoint_key][:, 0]
98 | )
99 | logging.info(
100 | f"Squeeze {checkpoint_key}: {src_shape} -> {checkpoint_state_dict[checkpoint_key].shape}"
101 | )
102 | # in the other cases, raise an error
103 | else:
104 | raise ValueError(
105 | f"Shapes of {checkpoint_key} are different: {dst_shape} != {src_shape}"
106 | )
107 |
108 | # update the shapes
109 | dst_shape = model_state_dict[checkpoint_key].shape
110 | src_shape = checkpoint_state_dict[checkpoint_key].shape
111 | assert len(dst_shape) == len(
112 | src_shape
113 | ), f"Shapes of {checkpoint_key} are different: {dst_shape} != {src_shape}"
114 |
115 | ## Shapes adapter
116 | # modify the checkpoint state dict only if the shapes are different
117 | if dst_shape != src_shape:
118 | # create a copy of the tensor
119 | tmp = torch.clone(checkpoint_state_dict[checkpoint_key])
120 |
121 | # iterate over all dimensions
122 | for i in range(len(dst_shape)):
123 | if dst_shape[i] != src_shape[i]:
124 | diff = dst_shape[i] - src_shape[i]
125 |
126 | # if the difference is greater than 0, we need to add missing blocks
127 | if diff > 0:
128 | missing_shape = list(tmp.shape)
129 | missing_shape[i] = diff
130 | missing = self._create_block(
131 | shape=missing_shape,
132 | strategy=strategy,
133 | input=tmp,
134 | )
135 | tmp = torch.cat((tmp, missing), dim=i)
136 | logging.info(
137 | f"Adapting {checkpoint_key} with strategy:{strategy} from shape {src_shape} to {dst_shape}"
138 | )
139 | # if the difference is less than 0, we need to cut the block
140 | else:
141 | tmp = tmp.narrow(i, 0, dst_shape[i])
142 | logging.info(
143 | f"Adapting {checkpoint_key} by narrowing from shape {src_shape} to {dst_shape}"
144 | )
145 |
146 | checkpoint_state_dict[checkpoint_key] = tmp
147 | end = time.perf_counter()
148 | logging.info(f"StateDictAdapter took {end-start:.2f} seconds")
149 | return checkpoint_state_dict
150 |
151 |
152 | class StateDictRenamer:
153 | """
154 | StateDictRenamer for renaming keys in a checkpoint state dict.
155 | This class will iterate over all keys in the checkpoint state dict and rename them according to a rename dict.
156 |
157 | Example:
158 |
159 | ```
160 | renamer = StateDictRenamer()
161 | new_state_dict = renamer(
162 | checkpoint_state_dict=state_dict,
163 | rename_dict={
164 | "add_embedding.linear_1.weight": "class_embedding.linear_1.weight",
165 | "add_embedding.linear_1.bias": "class_embedding.linear_1.bias",
166 | "add_embedding.linear_2.weight": "class_embedding.linear_2.weight",
167 | "add_embedding.linear_2.bias": "class_embedding.linear_2.bias",
168 | }
169 | )
170 | ```
171 |
172 | Args:
173 |
174 | checkpoint_state_dict (Dict[str, torch.Tensor]): The checkpoint state dict.
175 | rename_dict (Dict[str, str]): The dictionary mapping the old keys to new keys
176 | """
177 |
178 | def __call__(
179 | self,
180 | checkpoint_state_dict: Dict[str, torch.Tensor],
181 | rename_dict: Dict[str, str],
182 | ) -> Dict[str, torch.Tensor]:
183 | for old_key, new_key in rename_dict.items():
184 | if old_key not in checkpoint_state_dict:
185 | logging.warning(f"Key {old_key} not found in checkpoint state dict")
186 | continue
187 | else:
188 | assert (
189 | new_key not in checkpoint_state_dict
190 | ), f"Key {new_key} already exists in checkpoint state dict"
191 | checkpoint_state_dict[new_key] = checkpoint_state_dict.pop(old_key)
192 | logging.info(f"Renaming {old_key} to {new_key}")
193 | return checkpoint_state_dict
194 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | # Tests
2 |
3 | ## Setup
4 |
5 | ```shell
6 | pip3 install -r requirements.txt
7 | ```
8 |
9 | ## Run the tests
10 |
11 | ```shell
12 | python3 -m pytest .
13 | ```
--------------------------------------------------------------------------------
/tests/requirements.txt:
--------------------------------------------------------------------------------
1 | pytest
--------------------------------------------------------------------------------
/tests/test_dataset/test_filters.py:
--------------------------------------------------------------------------------
1 | from lbm.data.filters import FilterWrapper, KeyFilter, KeyFilterConfig
2 |
3 |
4 | class TestKeyFilter:
5 | def test_key_filter(self):
6 | filter = KeyFilter(KeyFilterConfig(keys=["a", "b"]))
7 | assert filter({"a": 1, "b": 2, "c": 3})
8 | assert not filter({"a": 1})
9 | assert not filter({"b": 2})
10 | assert not filter({"c": 3})
11 |
12 | def test_key_filter_single_key(self):
13 | filter = KeyFilter(KeyFilterConfig(keys="a"))
14 | assert filter({"a": 1, "b": 2})
15 | assert not filter({"b": 2})
16 |
17 |
18 | class TestFilterWrapper:
19 | def test_filter_wrapper(self):
20 | filter = FilterWrapper(
21 | [
22 | KeyFilter(KeyFilterConfig(keys=["a", "b"])),
23 | KeyFilter(KeyFilterConfig(keys="c")),
24 | ]
25 | )
26 | assert not filter({"a": 1, "b": 2, "c": 3})
27 | assert not filter({"a": 1})
28 | assert not filter({"b": 2})
29 | assert not filter({"c": 3})
30 |
31 | def test_filter_wrapper(self):
32 | filter = FilterWrapper(
33 | [
34 | KeyFilter(KeyFilterConfig(keys=["a", "b"])),
35 | KeyFilter(KeyFilterConfig(keys=["a", "c"])),
36 | ]
37 | )
38 | assert filter({"a": 1, "b": 2, "c": 3})
39 | assert not filter({"a": 1})
40 | assert not filter({"b": 2})
41 | assert not filter({"c": 3})
42 |
--------------------------------------------------------------------------------
/tests/test_dataset/test_mappers.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import pytest
4 | import torch
5 | from PIL import Image
6 |
7 | from lbm.data.mappers import (
8 | KeyRenameMapper,
9 | KeyRenameMapperConfig,
10 | MapperWrapper,
11 | RescaleMapper,
12 | RescaleMapperConfig,
13 | TorchvisionMapper,
14 | TorchvisionMapperConfig,
15 | )
16 |
17 |
18 | class TestKeyRenameMapper:
19 | @pytest.fixture()
20 | def dummy_batch(self):
21 | return {"image": 1, "text": 2, "label": "dummy_label"}
22 |
23 | @pytest.fixture()
24 | def mapper(self):
25 | return KeyRenameMapper(
26 | KeyRenameMapperConfig(
27 | key_map={"image": "image_tensor", "text": "text_tensor"}
28 | )
29 | )
30 |
31 | def test_mapper(self, mapper, dummy_batch):
32 | output_data = mapper(dummy_batch)
33 | assert output_data["image_tensor"] == 1
34 | assert output_data["text_tensor"] == 2
35 | assert output_data["label"] == "dummy_label"
36 | assert "image" not in output_data
37 | assert "text" not in output_data
38 |
39 |
40 | class TestKeyRenameMapperWithCondition:
41 | @pytest.fixture(params=[1, 2])
42 | def dummy_batch(self, request):
43 | return {"image": 1, "text": 2, "label": request.param}
44 |
45 | @pytest.fixture(params=[{"image": "image_not_met", "text": "text_not_met"}, None])
46 | def else_key_map(self, request):
47 | return request.param
48 |
49 | @pytest.fixture()
50 | def mapper(self, else_key_map):
51 | return KeyRenameMapper(
52 | KeyRenameMapperConfig(
53 | key_map={"image": "image_tensor", "text": "text_tensor"},
54 | condition_key="label",
55 | condition_fn=lambda x: x == 1,
56 | else_key_map=else_key_map,
57 | )
58 | )
59 |
60 | def test_mapper(self, mapper, dummy_batch, else_key_map):
61 | output_data = mapper(dummy_batch)
62 | if dummy_batch["label"] == 1:
63 | assert output_data["image_tensor"] == 1
64 | assert output_data["text_tensor"] == 2
65 | assert output_data["label"] == 1
66 | assert "image" not in output_data
67 | assert "text" not in output_data
68 | elif else_key_map is not None:
69 | assert output_data["image_not_met"] == 1
70 | assert output_data["text_not_met"] == 2
71 | assert output_data["label"] == 2
72 | assert "image" not in output_data
73 | assert "text" not in output_data
74 | else:
75 | assert output_data["image"] == 1
76 | assert output_data["text"] == 2
77 | assert output_data["label"] == 2
78 | assert "image_tensor" not in output_data
79 | assert "text_tensor" not in output_data
80 |
81 |
82 | class TestMapperWrapper:
83 | @pytest.fixture()
84 | def dummy_batch(self):
85 | return {"image": 1, "text": 2, "label": "dummy_label"}
86 |
87 | @pytest.fixture()
88 | def mapper(self):
89 | return MapperWrapper(
90 | mappers=[
91 | KeyRenameMapper(
92 | KeyRenameMapperConfig(
93 | key_map={"image": "image_tensor", "text": "text_tensor"}
94 | )
95 | ),
96 | KeyRenameMapper(
97 | KeyRenameMapperConfig(
98 | key_map={
99 | "image_tensor": "image_array",
100 | "text_tensor": "text_array",
101 | }
102 | )
103 | ),
104 | ]
105 | )
106 |
107 | def test_mapper(self, mapper, dummy_batch):
108 | output_data = mapper(dummy_batch)
109 | assert output_data["image_array"] == 1
110 | assert output_data["text_array"] == 2
111 | assert output_data["label"] == "dummy_label"
112 | assert "image" not in output_data
113 | assert "text" not in output_data
114 | assert "image_tensor" not in output_data
115 | assert "text_tensor" not in output_data
116 |
117 |
118 | class TestTorchvisionMapper:
119 | @pytest.fixture()
120 | def dummy_batch(self):
121 | return {
122 | "image": torch.randn(
123 | 3,
124 | 256,
125 | 256,
126 | ),
127 | "text": 2,
128 | "label": "dummy_label",
129 | }
130 |
131 | @pytest.fixture()
132 | def mapper(self):
133 | return TorchvisionMapper(
134 | TorchvisionMapperConfig(
135 | key="image",
136 | transforms=["CenterCrop", "ToPILImage"],
137 | transforms_kwargs=[{"size": 224}, {}],
138 | )
139 | )
140 |
141 | def test_mapper(self, mapper, dummy_batch):
142 | output_data = mapper(dummy_batch)
143 | assert output_data["image"].size == (224, 224)
144 | assert isinstance(output_data["image"], Image.Image)
145 | assert output_data["text"] == 2
146 | assert output_data["label"] == "dummy_label"
147 |
148 | @pytest.fixture()
149 | def mapper_with_output_key(self):
150 | return TorchvisionMapper(
151 | TorchvisionMapperConfig(
152 | key="image",
153 | output_key="image_transformed",
154 | transforms=["CenterCrop", "ToPILImage"],
155 | transforms_kwargs=[{"size": 224}, {}],
156 | )
157 | )
158 |
159 | def test_mapper(self, mapper_with_output_key, dummy_batch):
160 | output_data = mapper_with_output_key(dummy_batch)
161 | assert output_data["image_transformed"].size == (224, 224)
162 | assert isinstance(output_data["image_transformed"], Image.Image)
163 | assert isinstance(output_data["image"], torch.Tensor)
164 | assert output_data["image"].size() == (3, 256, 256)
165 | assert output_data["text"] == 2
166 | assert output_data["label"] == "dummy_label"
167 |
168 |
169 | class TestRescaleMapper:
170 | @pytest.fixture()
171 | def dummy_batch(self):
172 | return {
173 | "image": torch.rand(
174 | 3,
175 | 256,
176 | 256,
177 | ),
178 | "text": 2,
179 | "label": "dummy_label",
180 | }
181 |
182 | @pytest.fixture()
183 | def mapper(self):
184 | return RescaleMapper(
185 | RescaleMapperConfig(
186 | input_key="image",
187 | output_key="image",
188 | )
189 | )
190 |
191 | def test_mapper(self, mapper, dummy_batch):
192 | output_data = mapper(dummy_batch)
193 | assert torch.all(output_data["image"] <= 1)
194 | assert torch.all(output_data["image"] >= -1)
195 | assert output_data["text"] == 2
196 | assert output_data["label"] == "dummy_label"
197 |
--------------------------------------------------------------------------------
/tests/test_lbm/test_lbm.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import pytest
4 | import torch
5 | import torch.nn as nn
6 | from diffusers import FlowMatchEulerDiscreteScheduler
7 |
8 | from lbm.models.embedders import ConditionerWrapper
9 | from lbm.models.lbm import LBMConfig, LBMModel
10 | from lbm.models.unets import DiffusersUNet2DCondWrapper
11 | from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig
12 |
13 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14 |
15 |
16 | class TestLBM:
17 | @pytest.fixture()
18 | def denoiser(self):
19 | return DiffusersUNet2DCondWrapper(
20 | in_channels=4, # VAE channels
21 | out_channels=4, # VAE channels
22 | up_block_types=["CrossAttnUpBlock2D"],
23 | down_block_types=[
24 | "CrossAttnDownBlock2D",
25 | ],
26 | cross_attention_dim=[320],
27 | block_out_channels=[320],
28 | transformer_layers_per_block=[1],
29 | attention_head_dim=[5],
30 | norm_num_groups=32,
31 | )
32 |
33 | @pytest.fixture()
34 | def conditioner(self):
35 | return ConditionerWrapper([])
36 |
37 | @pytest.fixture()
38 | def vae(self):
39 | return AutoencoderKLDiffusers(AutoencoderKLDiffusersConfig())
40 |
41 | @pytest.fixture()
42 | def sampling_noise_scheduler(self):
43 | return FlowMatchEulerDiscreteScheduler()
44 |
45 | @pytest.fixture()
46 | def training_noise_scheduler(self):
47 | return FlowMatchEulerDiscreteScheduler()
48 |
49 | @pytest.fixture()
50 | def model_config(self):
51 | return LBMConfig(
52 | source_key="source_image",
53 | target_key="target_image",
54 | )
55 |
56 | @pytest.fixture()
57 | def model_input(self):
58 | return {
59 | "source_image": torch.randn(2, 3, 256, 256).to(DEVICE),
60 | "target_image": torch.randn(2, 3, 256, 256).to(DEVICE),
61 | }
62 |
63 | @pytest.fixture()
64 | def model(
65 | self,
66 | model_config,
67 | denoiser,
68 | vae,
69 | sampling_noise_scheduler,
70 | training_noise_scheduler,
71 | conditioner,
72 | ):
73 | return LBMModel(
74 | config=model_config,
75 | denoiser=denoiser,
76 | vae=vae,
77 | sampling_noise_scheduler=sampling_noise_scheduler,
78 | training_noise_scheduler=training_noise_scheduler,
79 | conditioner=conditioner,
80 | ).to(DEVICE)
81 |
82 | @torch.no_grad()
83 | def test_model_forward(self, model, model_input):
84 | model_output = model(
85 | model_input,
86 | )
87 | assert model_output["loss"] > 0.0
88 |
89 | def test_optimizers(self, model, model_input):
90 | optimizer = torch.optim.Adam(model.denoiser.parameters(), lr=1e-4)
91 |
92 | model.train()
93 | model_init = deepcopy(model)
94 | optimizer.zero_grad()
95 | loss = model(model_input)["loss"]
96 | loss.backward()
97 | optimizer.step()
98 | assert not torch.equal(
99 | torch.cat([p.flatten() for p in model.denoiser.parameters()]),
100 | torch.cat([p.flatten() for p in model_init.denoiser.parameters()]),
101 | )
102 |
--------------------------------------------------------------------------------
/tests/test_unets/test_unets_wrappers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from lbm.models.unets import DiffusersUNet2DCondWrapper, DiffusersUNet2DWrapper
5 |
6 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
7 |
8 |
9 | class TestDiffusersUNet2DWrapper:
10 | # simulates class conditioning
11 | @pytest.fixture(params=[None, torch.randint(256, (2,)).to(DEVICE)])
12 | def conditioning(self, request):
13 | if request.param is not None:
14 | return {"cond": {"vector": request.param}}
15 | return None
16 |
17 | # simulates a latent sample
18 | @pytest.fixture()
19 | def sample(self):
20 | return torch.rand(2, 6, 32, 32).to(DEVICE)
21 |
22 | # simulates a timestep
23 | @pytest.fixture(
24 | params=[10.0, torch.randint(1000, (2,), dtype=torch.float).to(DEVICE), 3]
25 | )
26 | def timesteps(self, request):
27 | return request.param
28 |
29 | def test_unet2d_wrapper(self, sample, timesteps, conditioning):
30 | unet = DiffusersUNet2DWrapper(
31 | sample_size=sample.shape[2:],
32 | in_channels=sample.shape[1],
33 | out_channels=3,
34 | num_class_embeds=256 if conditioning else None,
35 | ).to(DEVICE)
36 | output = unet(sample, timesteps, conditioning)
37 | assert output.shape == (
38 | sample.shape[0],
39 | 3,
40 | sample.shape[2],
41 | sample.shape[3],
42 | )
43 |
44 |
45 | class TestDiffusersUNet2DCondWrapper:
46 | # simulates class conditioning
47 | @pytest.fixture(params=[None, torch.randn(2, 256).to(DEVICE)])
48 | def vector_conditioning(self, request):
49 | if request.param is not None:
50 | return {"vector": request.param}
51 | return None
52 |
53 | # simulates crossattn conditioning '(always needed for conditional UNet2D)' (see diffusers/models/unet.py
54 | @pytest.fixture()
55 | def crossattn_conditioning(self):
56 | return {"crossattn": torch.randn(2, 12, 123).to(DEVICE)}
57 |
58 | # simulates concat conditioning
59 | @pytest.fixture(params=[None, torch.randn(2, 2, 32, 32).to(DEVICE)])
60 | def concat_conditioning(self, request):
61 | if request.param is not None:
62 | return {"concat": request.param}
63 | return None
64 |
65 | @pytest.fixture()
66 | def conditioning(
67 | self, vector_conditioning, crossattn_conditioning, concat_conditioning
68 | ):
69 | cond = dict(cond=crossattn_conditioning)
70 | if vector_conditioning is not None:
71 | cond["cond"].update(vector_conditioning)
72 | if concat_conditioning is not None:
73 | cond["cond"].update(concat_conditioning)
74 | return cond
75 |
76 | # simulates a latent sample
77 | @pytest.fixture()
78 | def sample(self):
79 | return torch.rand(2, 6, 32, 32).to(DEVICE)
80 |
81 | # simulates a timestep
82 | @pytest.fixture(
83 | params=[10.0, torch.randint(1000, (2,), dtype=torch.float).to(DEVICE), 3]
84 | )
85 | def timesteps(self, request):
86 | return request.param
87 |
88 | def test_unet2d_cond_wrapper(self, sample, timesteps, conditioning):
89 | # for concat
90 | in_channels = (
91 | sample.shape[1] + conditioning["cond"]["concat"].shape[1]
92 | if conditioning["cond"].get("concat", None) is not None
93 | else sample.shape[1]
94 | )
95 |
96 | # for vector
97 | class_embed_type = (
98 | "projection" if conditioning["cond"].get("vector") is not None else None
99 | )
100 | projection_class_embeddings_input_dim = (
101 | conditioning["cond"]["vector"].shape[1]
102 | if conditioning["cond"].get("vector") is not None
103 | else None
104 | )
105 |
106 | # for crossattn
107 | cross_attention_dim = (
108 | conditioning["cond"]["crossattn"].shape[2]
109 | if conditioning["cond"].get("crossattn") is not None
110 | else 1280
111 | )
112 |
113 | unet = DiffusersUNet2DCondWrapper(
114 | sample_size=sample.shape[2:],
115 | in_channels=in_channels,
116 | out_channels=3,
117 | class_embed_type=class_embed_type,
118 | projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
119 | cross_attention_dim=cross_attention_dim,
120 | ).to(DEVICE)
121 | output = unet(sample, timesteps, conditioning)
122 | assert output.shape == (
123 | sample.shape[0],
124 | 3,
125 | sample.shape[2],
126 | sample.shape[3],
127 | )
128 |
--------------------------------------------------------------------------------
/tests/test_vaes/test_autoencoder.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig
5 |
6 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
7 |
8 |
9 | class TestAutoencoderKLDiffusers:
10 | @pytest.fixture(
11 | params=[
12 | dict(),
13 | dict(
14 | version="stabilityai/stable-diffusion-xl-base-1.0",
15 | subfolder="vae",
16 | ),
17 | ]
18 | )
19 | def model_config(self, request):
20 | return AutoencoderKLDiffusersConfig(
21 | **request.param, tiling_size=(16, 16), tiling_overlap=(8, 8), batch_size=1
22 | )
23 |
24 | @pytest.fixture()
25 | def model(self, model_config):
26 | return AutoencoderKLDiffusers(model_config).to(DEVICE)
27 |
28 | def test_model_initialization(self, model, model_config):
29 | assert model.config == model_config
30 |
31 | def test_encode(self, model):
32 | x = torch.randn(2, 3, 32, 32).to(DEVICE)
33 | z = model.encode(x)
34 | assert z.shape == (2, 4, 4, 4)
35 |
36 | def test_decode(self, model):
37 | z = torch.randn(2, 4, 4, 4).to(DEVICE)
38 | x = model.decode(z)
39 | assert x.shape == (2, 3, 32, 32)
40 |
41 | def test_decode_tiling(self, model):
42 | z = torch.randn(2, 4, 32, 32).to(DEVICE)
43 | x = model.decode(z)
44 | assert x.shape == (2, 3, 256, 256)
45 |
--------------------------------------------------------------------------------