├── .gitignore ├── LICENCE.txt ├── README.md ├── assets ├── LBM.jpg ├── depth_normal.jpg ├── object_removal.jpg ├── relight.gif ├── relight.jpg ├── relight_2.gif ├── shadow_control.gif └── upscaler.jpg ├── examples ├── inference │ ├── gradio_demo.py │ ├── inference.py │ └── utils.py └── training │ ├── config │ └── surface.yaml │ └── train_lbm_surface.py ├── pyproject.toml ├── requirements.txt ├── src └── lbm │ ├── config.py │ ├── data │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── collation_fn.py │ │ ├── dataset.py │ │ └── datasets_config.py │ ├── filters │ │ ├── __init__.py │ │ ├── base.py │ │ ├── filter_wrapper.py │ │ ├── filters.py │ │ └── filters_config.py │ └── mappers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── mappers.py │ │ ├── mappers_config.py │ │ └── mappers_wrapper.py │ ├── inference │ ├── __init__.py │ ├── inference.py │ └── utils.py │ ├── models │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── base_model.py │ │ └── model_config.py │ ├── embedders │ │ ├── __init__.py │ │ ├── base │ │ │ ├── __init__.py │ │ │ ├── base_conditioner.py │ │ │ └── base_conditioner_config.py │ │ ├── conditioners_wrapper.py │ │ └── latents_concat │ │ │ ├── __init__.py │ │ │ ├── latents_concat_embedder_config.py │ │ │ └── latents_concat_embedder_model.py │ ├── lbm │ │ ├── __init__.py │ │ ├── lbm_config.py │ │ └── lbm_model.py │ ├── unets │ │ ├── __init__.py │ │ └── unet.py │ ├── utils.py │ └── vae │ │ ├── __init__.py │ │ ├── autoencoderKL.py │ │ └── autoencoderKL_config.py │ └── trainer │ ├── __init__.py │ ├── loggers.py │ ├── trainer.py │ ├── training_config.py │ └── utils.py └── tests ├── README.md ├── requirements.txt ├── test_dataset ├── test_filters.py └── test_mappers.py ├── test_lbm └── test_lbm.py ├── test_unets └── test_unets_wrappers.py └── test_vaes └── test_autoencoder.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | *.png 3 | *.pyc 4 | *.gradio 5 | *.safetensors 6 | examples/inference/ckpts/* 7 | examples/inference/examples/* 8 | envs/* 9 | checkpoints/* 10 | *.sh -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Latent Bridge Matching (LBM) 2 | 3 | This repository is the official implementation of the paper [LBM: Latent Bridge Matching for Fast Image-to-Image Translation](http://arxiv.org/abs/2503.07535). 4 | 5 |

6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 |

30 | LBM Teaser 31 |

32 | 33 | 34 | 35 |

36 | 37 | DEMO space 38 | 39 |

40 | 41 | 42 | ## Abstract 43 | In this paper, we introduce Latent Bridge Matching (LBM), a new, versatile and scalable method that relies on Bridge Matching in a latent space to achieve fast image-to-image translation. We show that the method can reach state-of-the-art results for various image-to-image tasks using only a single inference step. In addition to its efficiency, we also demonstrate the versatility of the method across different image translation tasks such as object removal, normal and depth estimation, and object relighting. We also derive a conditional framework of LBM and demonstrate its effectiveness by tackling the tasks of controllable image relighting and shadow generation. 44 | 45 |

46 | 47 |

48 | 49 | ## License 50 | This code is released under the **Creative Commons BY-NC 4.0 license**. 51 | 52 | ## Considered Use-cases 53 | We validate the method on various use-cases such as object relighting, image restoration, object removal, depth and normal maps estimation as well as controllable object relighting and shadow generation. 54 |
55 | Image Relighting 🔦 56 |

57 | For object relighting, the method should translate the encoded source images created by pasting the foreground onto the target background image to the desired target relighted image. 58 |

59 |

60 | 61 |

62 |
63 |
64 | Image Restoration 🧹 65 |

66 | In the context of image restoration, the method shall transport the distribution of the degraded images to the distribution of the clean images. 67 |

68 |

69 | 70 |

71 |
72 |
73 | Object Removal ✂️ 74 | For object removal, the model is trained to find a transport map from the masked images to the images without the objects 75 |

76 | 77 |

78 |
79 |
80 | Controllable Image Relighting and Shadow Generation🕹️ 81 |

82 | We also derive a conditional framework of LBM and demonstrate its effectiveness by tackling the tasks of controllable image relighting and shadow generation 83 |

84 |

85 | 86 |

87 |
88 |
89 | Normals and Depth Maps Estimation 🗺️ 90 |

91 | Finally, we also consider common tasks such as normal and depth estimation where the model should translate an input image into a normal or depth map 92 |

93 |

94 | 95 |

96 |
97 | 98 | 99 | 100 | ## Setup 101 | To be up and running, you need first to create a virtual env with at least python3.10 installed and activate it 102 | 103 | ### With venv 104 | ```bash 105 | python3.10 -m venv envs/lbm 106 | source envs/lbm/bin/activate 107 | ``` 108 | 109 | ### With conda 110 | ```bash 111 | conda create -n lbm python=3.10 112 | conda activate lbm 113 | ``` 114 | 115 | Then install the required dependencies and the repo in editable mode 116 | 117 | ```bash 118 | pip install --upgrade pip 119 | pip install -e . 120 | ``` 121 | 122 | ## Inference 123 | 124 | We provide in `examples` a simple script to perform depth and normal estimation using the proposed method. 125 | 126 | ```bash 127 | python examples/inference/inference.py \ 128 | --model_name [depth|normals|relighting] \ 129 | --source_image path_to_your_image.jpg \ 130 | --output_path output_images 131 | ``` 132 | 133 | See the trained models on the HF Hub 🤗 134 | - [Surface normals Checkpoint](https://huggingface.co/jasperai/LBM_normals) 135 | - [Depth Checkpoint](https://huggingface.co/jasperai/LBM_depth) 136 | - [Relighting Checkpoint](https://huggingface.co/jasperai/LBM_relighting) 137 | 138 | ## Local Gradio Demo 139 | To run the local gradio demo, just run the following command: 140 | ```bash 141 | python examples/inference/gradio_demo.py 142 | ``` 143 | It will download the pretrained model from the HF Hub as well as example images. 144 | 145 | ## Training 146 | We provide in `examples\training` an example of a script to train a LBM for surface normal predictions on [`hypersim`](https://github.com/apple/ml-hypersim) see [this](https://github.com/prs-eth/Marigold/blob/main/script/dataset_preprocess/hypersim/README.md) for data processing. 147 | 148 | In `examples\trainig\configs`, you will find the configuration `yaml` associated to the training script. The only thing you need to do is to amend the `SHARDS_PATH_OR_URLS` section of the `yaml` so the model is trained on your own data. 149 | 150 | Please note that this package uses [`webdataset`](https://github.com/webdataset/webdataset) to handle the datastream and so the urls you use should be fomatted according to the [`webdataset format`](https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format). In particular, for this example, each sample in your `.tar` files needs to be composed of a `jpg` file containing the image, a `normal.png` file containing the target normals as well as a `mask.png` containing a mask indicating the valid pixels 151 | 152 | ``` 153 | sample = { 154 | "jpg": source_image, 155 | "normal.png": normals # target_image 156 | "mask.png": mask # mask of valid pixels 157 | } 158 | ``` 159 | 160 | To train the model, you can use the following command: 161 | 162 | ```bash 163 | python examples/training/train_lbm_surface.py examples/training/config/surface.yaml 164 | ``` 165 | 166 | *Note*: Make sure to update the relevant section of the `yaml` file to use your own data and log the results on your own [WandB](https://wandb.ai/site). 167 | 168 | ## Citation 169 | If you find this work useful or use it in your research, please consider citing us 170 | ```bibtex 171 | @article{chadebec2025lbm, 172 | title={LBM: Latent Bridge Matching for Fast Image-to-Image Translation}, 173 | author={Clément Chadebec and Onur Tasar and Sanjeev Sreetharan and Benjamin Aubin}, 174 | year={2025}, 175 | journal = {arXiv preprint arXiv:2503.07535}, 176 | } 177 | ``` 178 | -------------------------------------------------------------------------------- /assets/LBM.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/LBM.jpg -------------------------------------------------------------------------------- /assets/depth_normal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/depth_normal.jpg -------------------------------------------------------------------------------- /assets/object_removal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/object_removal.jpg -------------------------------------------------------------------------------- /assets/relight.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/relight.gif -------------------------------------------------------------------------------- /assets/relight.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/relight.jpg -------------------------------------------------------------------------------- /assets/relight_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/relight_2.gif -------------------------------------------------------------------------------- /assets/shadow_control.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/shadow_control.gif -------------------------------------------------------------------------------- /assets/upscaler.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/assets/upscaler.jpg -------------------------------------------------------------------------------- /examples/inference/gradio_demo.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | from copy import deepcopy 5 | 6 | import gradio as gr 7 | import numpy as np 8 | import PIL 9 | import torch 10 | from huggingface_hub import snapshot_download 11 | from PIL import Image 12 | from torchvision.transforms import ToPILImage, ToTensor 13 | from transformers import AutoModelForImageSegmentation 14 | from utils import extract_object, resize_and_center_crop 15 | 16 | from lbm.inference import get_model 17 | 18 | PATH = os.path.dirname(os.path.abspath(__file__)) 19 | os.environ["GRADIO_TEMP_DIR"] = ".gradio" 20 | 21 | 22 | if not os.path.exists(os.path.join(PATH, "ckpts", "relighting")): 23 | logging.info(f"Downloading relighting LBM model from HF hub...") 24 | model = get_model( 25 | f"jasperai/LBM_relighting", 26 | save_dir=os.path.join(PATH, "ckpts", "relighting"), 27 | torch_dtype=torch.bfloat16, 28 | device="cuda", 29 | ) 30 | else: 31 | model_dir = os.path.join(PATH, "ckpts", "relighting") 32 | logging.info(f"Loading relighting LBM model from local...") 33 | model = get_model( 34 | os.path.join(PATH, "ckpts", "relighting"), 35 | torch_dtype=torch.bfloat16, 36 | device="cuda", 37 | ) 38 | 39 | ASPECT_RATIOS = { 40 | str(512 / 2048): (512, 2048), 41 | str(1024 / 1024): (1024, 1024), 42 | str(2048 / 512): (2048, 512), 43 | str(896 / 1152): (896, 1152), 44 | str(1152 / 896): (1152, 896), 45 | str(512 / 1920): (512, 1920), 46 | str(640 / 1536): (640, 1536), 47 | str(768 / 1280): (768, 1280), 48 | str(1280 / 768): (1280, 768), 49 | str(1536 / 640): (1536, 640), 50 | str(1920 / 512): (1920, 512), 51 | } 52 | 53 | birefnet = AutoModelForImageSegmentation.from_pretrained( 54 | "ZhengPeng7/BiRefNet", trust_remote_code=True 55 | ).cuda() 56 | image_size = (1024, 1024) 57 | 58 | if not os.path.exists(os.path.join(PATH, "examples")): 59 | logging.info(f"Downloading backgrounds from HF hub...") 60 | _ = snapshot_download( 61 | "jasperai/LBM_relighting", 62 | repo_type="space", 63 | allow_patterns="*.jpg", 64 | local_dir=PATH, 65 | ) 66 | 67 | 68 | def evaluate( 69 | fg_image: PIL.Image.Image, 70 | bg_image: PIL.Image.Image, 71 | num_sampling_steps: int = 1, 72 | ): 73 | 74 | ori_h_bg, ori_w_bg = fg_image.size 75 | ar_bg = ori_h_bg / ori_w_bg 76 | closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg)) 77 | dimensions_bg = ASPECT_RATIOS[closest_ar_bg] 78 | 79 | _, fg_mask = extract_object(birefnet, deepcopy(fg_image)) 80 | 81 | fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1]) 82 | fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1]) 83 | bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1]) 84 | 85 | img_pasted = Image.composite(fg_image, bg_image, fg_mask) 86 | 87 | img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1 88 | batch = { 89 | "source_image": img_pasted_tensor.cuda().to(torch.bfloat16), 90 | } 91 | 92 | z_source = model.vae.encode(batch[model.source_key]) 93 | 94 | output_image = model.sample( 95 | z=z_source, 96 | num_steps=num_sampling_steps, 97 | conditioner_inputs=batch, 98 | max_samples=1, 99 | ).clamp(-1, 1) 100 | 101 | output_image = (output_image[0].float().cpu() + 1) / 2 102 | output_image = ToPILImage()(output_image) 103 | 104 | # paste the output image on the background image 105 | output_image = Image.composite(output_image, bg_image, fg_mask) 106 | 107 | output_image.resize((ori_h_bg, ori_w_bg)) 108 | 109 | return (np.array(img_pasted), np.array(output_image)) 110 | 111 | 112 | with gr.Blocks(title="LBM Object Relighting") as demo: 113 | gr.Markdown( 114 | f""" 115 | # Object Relighting with Latent Bridge Matching 116 | This is an interactive demo of [LBM: Latent Bridge Matching for Fast Image-to-Image Translation](https://arxiv.org/abs/2503.07535) *by Jasper Research*. This demo is based on the [LBM relighting checkpoint](https://huggingface.co/jasperai/LBM_relighting). 117 | """ 118 | ) 119 | gr.Markdown( 120 | """ 121 | If you enjoy the space, please also promote *open-source* by giving a ⭐ to the Github Repo. 122 | """ 123 | ) 124 | 125 | with gr.Row(): 126 | with gr.Column(): 127 | with gr.Row(): 128 | fg_image = gr.Image( 129 | type="pil", 130 | label="Input Image", 131 | image_mode="RGB", 132 | height=360, 133 | # width=360, 134 | ) 135 | bg_image = gr.Image( 136 | type="pil", 137 | label="Target Background", 138 | image_mode="RGB", 139 | height=360, 140 | # width=360, 141 | ) 142 | 143 | with gr.Row(): 144 | submit_button = gr.Button("Relight", variant="primary") 145 | with gr.Row(): 146 | num_inference_steps = gr.Slider( 147 | minimum=1, 148 | maximum=4, 149 | value=1, 150 | step=1, 151 | label="Number of Inference Steps", 152 | ) 153 | 154 | bg_gallery = gr.Gallery( 155 | # height=450, 156 | object_fit="contain", 157 | label="Background List", 158 | value=[ 159 | path 160 | for path in glob.glob( 161 | os.path.join(PATH, "examples/backgrounds/*.jpg") 162 | ) 163 | ], 164 | columns=5, 165 | allow_preview=False, 166 | ) 167 | 168 | with gr.Column(): 169 | output_slider = gr.ImageSlider(label="Composite vs LBM", type="numpy") 170 | output_slider.upload( 171 | fn=evaluate, 172 | inputs=[fg_image, bg_image, num_inference_steps], 173 | outputs=[output_slider], 174 | ) 175 | 176 | submit_button.click( 177 | evaluate, 178 | inputs=[fg_image, bg_image, num_inference_steps], 179 | outputs=[output_slider], 180 | ) 181 | 182 | with gr.Row(): 183 | gr.Examples( 184 | fn=evaluate, 185 | examples=[ 186 | [ 187 | os.path.join(PATH, "examples/foregrounds/2.jpg"), 188 | os.path.join(PATH, "examples/backgrounds/14.jpg"), 189 | 1, 190 | ], 191 | [ 192 | os.path.join(PATH, "examples/foregrounds/10.jpg"), 193 | os.path.join(PATH, "examples/backgrounds/4.jpg"), 194 | 1, 195 | ], 196 | [ 197 | os.path.join(PATH, "examples/foregrounds/11.jpg"), 198 | os.path.join(PATH, "examples/backgrounds/24.jpg"), 199 | 1, 200 | ], 201 | [ 202 | os.path.join(PATH, "examples/foregrounds/19.jpg"), 203 | os.path.join(PATH, "examples/backgrounds/3.jpg"), 204 | 1, 205 | ], 206 | [ 207 | os.path.join(PATH, "examples/foregrounds/4.jpg"), 208 | os.path.join(PATH, "examples/backgrounds/6.jpg"), 209 | 1, 210 | ], 211 | [ 212 | os.path.join(PATH, "examples/foregrounds/14.jpg"), 213 | os.path.join(PATH, "examples/backgrounds/22.jpg"), 214 | 1, 215 | ], 216 | [ 217 | os.path.join(PATH, "examples/foregrounds/12.jpg"), 218 | os.path.join(PATH, "examples/backgrounds/1.jpg"), 219 | 1, 220 | ], 221 | ], 222 | inputs=[fg_image, bg_image, num_inference_steps], 223 | outputs=[output_slider], 224 | run_on_click=True, 225 | ) 226 | 227 | def bg_gallery_selected(gal, evt: gr.SelectData): 228 | return gal[evt.index][0] 229 | 230 | bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image) 231 | 232 | if __name__ == "__main__": 233 | 234 | demo.launch(share=True) 235 | -------------------------------------------------------------------------------- /examples/inference/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import torch 6 | from PIL import Image 7 | 8 | from lbm.inference import evaluate, get_model 9 | 10 | PATH = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | logging.basicConfig(level=logging.INFO) 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--source_image", type=str, required=True) 16 | parser.add_argument("--output_path", type=str, required=True) 17 | parser.add_argument("--num_inference_steps", type=int, default=1) 18 | parser.add_argument( 19 | "--model_name", 20 | type=str, 21 | default="normals", 22 | choices=["normals", "depth", "relighting"], 23 | ) 24 | 25 | 26 | args = parser.parse_args() 27 | 28 | 29 | def main(): 30 | # download the weights from HF hub 31 | if not os.path.exists(os.path.join(PATH, "ckpts", f"{args.model_name}")): 32 | logging.info(f"Downloading {args.model_name} LBM model from HF hub...") 33 | model = get_model( 34 | f"jasperai/LBM_{args.model_name}", 35 | save_dir=os.path.join(PATH, "ckpts", f"{args.model_name}"), 36 | torch_dtype=torch.bfloat16, 37 | device="cuda", 38 | ) 39 | 40 | else: 41 | model_dir = os.path.join(PATH, "ckpts", f"{args.model_name}") 42 | logging.info(f"Loading {args.model_name} LBM model from local...") 43 | model = get_model(model_dir, torch_dtype=torch.bfloat16, device="cuda") 44 | 45 | source_image = Image.open(args.source_image).convert("RGB") 46 | 47 | output_image = evaluate(model, source_image, args.num_inference_steps) 48 | 49 | os.makedirs(args.output_path, exist_ok=True) 50 | 51 | source_image.save(os.path.join(args.output_path, "source_image.jpg")) 52 | output_image.save(os.path.join(args.output_path, "output_image.jpg")) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /examples/inference/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from torchvision import transforms 4 | 5 | 6 | def extract_object(birefnet, img): 7 | # Data settings 8 | image_size = (1024, 1024) 9 | transform_image = transforms.Compose( 10 | [ 11 | transforms.Resize(image_size), 12 | transforms.ToTensor(), 13 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 14 | ] 15 | ) 16 | 17 | image = img 18 | input_images = transform_image(image).unsqueeze(0).cuda() 19 | 20 | # Prediction 21 | with torch.no_grad(): 22 | preds = birefnet(input_images)[-1].sigmoid().cpu() 23 | pred = preds[0].squeeze() 24 | pred_pil = transforms.ToPILImage()(pred) 25 | mask = pred_pil.resize(image.size) 26 | image = Image.composite(image, Image.new("RGB", image.size, (127, 127, 127)), mask) 27 | return image, mask 28 | 29 | 30 | def resize_and_center_crop(image, target_width, target_height): 31 | original_width, original_height = image.size 32 | scale_factor = max(target_width / original_width, target_height / original_height) 33 | resized_width = int(round(original_width * scale_factor)) 34 | resized_height = int(round(original_height * scale_factor)) 35 | resized_image = image.resize((resized_width, resized_height), Image.LANCZOS) 36 | left = (resized_width - target_width) / 2 37 | top = (resized_height - target_height) / 2 38 | right = (resized_width + target_width) / 2 39 | bottom = (resized_height + target_height) / 2 40 | cropped_image = resized_image.crop((left, top, right, bottom)) 41 | return cropped_image 42 | -------------------------------------------------------------------------------- /examples/training/config/surface.yaml: -------------------------------------------------------------------------------- 1 | # wandb 2 | wandb_project: lbm-surface-flows 3 | timestep_sampling: custom_timesteps 4 | unet_input_channels: 4 5 | vae_num_channels: 4 6 | selected_timesteps: [250, 500, 750, 1000] 7 | prob: [0.25, 0.25, 0.25, 0.25] 8 | pixel_loss_type: lpips # l1 l2 9 | pixel_loss_weight: 10.0 10 | latent_loss_type: l2 # l1 l2 11 | latent_loss_weight: 1.0 12 | bridge_noise_sigma: 0.005 13 | conditioning_images_keys: [] 14 | conditioning_masks_keys: [] 15 | 16 | # SHARDS_PATH_OR_URLS 17 | train_shards: 18 | - pipe:cat PATH_TO_TRAIN_TARS 19 | 20 | validation_shards: 21 | - pipe:cat PATH_TO_VAL_TARS 22 | 23 | batch_size: 4 24 | learning_rate: 4e-5 25 | optimizer: AdamW 26 | num_steps: [1, 4] 27 | log_interval: 500 28 | resume_from_checkpoint: true 29 | max_epochs: 50 30 | save_interval: 5000 31 | save_ckpt_path: ./checkpoints 32 | -------------------------------------------------------------------------------- /examples/training/train_lbm_surface.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import random 5 | import re 6 | import shutil 7 | from typing import List, Optional 8 | 9 | import braceexpand 10 | import fire 11 | import torch 12 | import yaml 13 | from diffusers import FlowMatchEulerDiscreteScheduler, StableDiffusionXLPipeline 14 | from diffusers.models import UNet2DConditionModel 15 | from diffusers.models.attention import BasicTransformerBlock 16 | from diffusers.models.resnet import ResnetBlock2D 17 | from pytorch_lightning import Trainer, loggers 18 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 19 | from pytorch_lightning.strategies import FSDPStrategy 20 | from torch.distributed.fsdp.wrap import ModuleWrapPolicy 21 | from torchvision.transforms import InterpolationMode 22 | 23 | from lbm.data.datasets import DataModule, DataModuleConfig 24 | from lbm.data.filters import KeyFilter, KeyFilterConfig 25 | from lbm.data.mappers import ( 26 | KeyRenameMapper, 27 | KeyRenameMapperConfig, 28 | MapperWrapper, 29 | RescaleMapper, 30 | RescaleMapperConfig, 31 | TorchvisionMapper, 32 | TorchvisionMapperConfig, 33 | ) 34 | from lbm.models.embedders import ( 35 | ConditionerWrapper, 36 | LatentsConcatEmbedder, 37 | LatentsConcatEmbedderConfig, 38 | ) 39 | from lbm.models.lbm import LBMConfig, LBMModel 40 | from lbm.models.unets import DiffusersUNet2DCondWrapper 41 | from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig 42 | from lbm.trainer import TrainingConfig, TrainingPipeline 43 | from lbm.trainer.loggers import WandbSampleLogger 44 | from lbm.trainer.utils import StateDictAdapter 45 | 46 | 47 | def get_model( 48 | backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0", 49 | vae_num_channels: int = 4, 50 | unet_input_channels: int = 4, 51 | timestep_sampling: str = "log_normal", 52 | selected_timesteps: Optional[List[float]] = None, 53 | prob: Optional[List[float]] = None, 54 | conditioning_images_keys: Optional[List[str]] = [], 55 | conditioning_masks_keys: Optional[List[str]] = [], 56 | source_key: str = "source_image", 57 | target_key: str = "source_image_paste", 58 | mask_key: str = "mask", 59 | bridge_noise_sigma: float = 0.0, 60 | logit_mean: float = 0.0, 61 | logit_std: float = 1.0, 62 | pixel_loss_type: str = "lpips", 63 | latent_loss_type: str = "l2", 64 | latent_loss_weight: float = 1.0, 65 | pixel_loss_weight: float = 0.0, 66 | ): 67 | 68 | conditioners = [] 69 | 70 | # Load pretrained model as base 71 | pipe = StableDiffusionXLPipeline.from_pretrained( 72 | backbone_signature, 73 | torch_dtype=torch.bfloat16, 74 | ) 75 | 76 | ### MMMDiT ### 77 | # Get Architecture 78 | denoiser = DiffusersUNet2DCondWrapper( 79 | in_channels=unet_input_channels, # Add downsampled_image 80 | out_channels=vae_num_channels, 81 | center_input_sample=False, 82 | flip_sin_to_cos=True, 83 | freq_shift=0, 84 | down_block_types=[ 85 | "DownBlock2D", 86 | "CrossAttnDownBlock2D", 87 | "CrossAttnDownBlock2D", 88 | ], 89 | mid_block_type="UNetMidBlock2DCrossAttn", 90 | up_block_types=["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], 91 | only_cross_attention=False, 92 | block_out_channels=[320, 640, 1280], 93 | layers_per_block=2, 94 | downsample_padding=1, 95 | mid_block_scale_factor=1, 96 | dropout=0.0, 97 | act_fn="silu", 98 | norm_num_groups=32, 99 | norm_eps=1e-05, 100 | cross_attention_dim=[320, 640, 1280], 101 | transformer_layers_per_block=[1, 2, 10], 102 | reverse_transformer_layers_per_block=None, 103 | encoder_hid_dim=None, 104 | encoder_hid_dim_type=None, 105 | attention_head_dim=[5, 10, 20], 106 | num_attention_heads=None, 107 | dual_cross_attention=False, 108 | use_linear_projection=True, 109 | class_embed_type=None, 110 | addition_embed_type=None, 111 | addition_time_embed_dim=None, 112 | num_class_embeds=None, 113 | upcast_attention=None, 114 | resnet_time_scale_shift="default", 115 | resnet_skip_time_act=False, 116 | resnet_out_scale_factor=1.0, 117 | time_embedding_type="positional", 118 | time_embedding_dim=None, 119 | time_embedding_act_fn=None, 120 | timestep_post_act=None, 121 | time_cond_proj_dim=None, 122 | conv_in_kernel=3, 123 | conv_out_kernel=3, 124 | projection_class_embeddings_input_dim=None, 125 | attention_type="default", 126 | class_embeddings_concat=False, 127 | mid_block_only_cross_attention=None, 128 | cross_attention_norm=None, 129 | addition_embed_type_num_heads=64, 130 | ).to(torch.bfloat16) 131 | 132 | state_dict = pipe.unet.state_dict() 133 | 134 | del state_dict["add_embedding.linear_1.weight"] 135 | del state_dict["add_embedding.linear_1.bias"] 136 | del state_dict["add_embedding.linear_2.weight"] 137 | del state_dict["add_embedding.linear_2.bias"] 138 | 139 | # Adapt the shapes 140 | state_dict_adapter = StateDictAdapter() 141 | state_dict = state_dict_adapter( 142 | model_state_dict=denoiser.state_dict(), 143 | checkpoint_state_dict=state_dict, 144 | regex_keys=[ 145 | r"class_embedding.linear_\d+.(weight|bias)", 146 | r"conv_in.weight", 147 | r"(down_blocks|up_blocks)\.\d+\.attentions\.\d+\.transformer_blocks\.\d+\.attn\d+\.(to_k|to_v)\.weight", 148 | r"mid_block\.attentions\.\d+\.transformer_blocks\.\d+\.attn\d+\.(to_k|to_v)\.weight", 149 | ], 150 | strategy="zeros", 151 | ) 152 | 153 | denoiser.load_state_dict(state_dict, strict=True) 154 | 155 | del pipe 156 | 157 | if conditioning_images_keys != [] or conditioning_masks_keys != []: 158 | 159 | latents_concat_embedder_config = LatentsConcatEmbedderConfig( 160 | image_keys=conditioning_images_keys, 161 | mask_keys=conditioning_masks_keys, 162 | ) 163 | latent_concat_embedder = LatentsConcatEmbedder(latents_concat_embedder_config) 164 | latent_concat_embedder.freeze() 165 | conditioners.append(latent_concat_embedder) 166 | 167 | # Wrap conditioners and set to device 168 | conditioner = ConditionerWrapper( 169 | conditioners=conditioners, 170 | ) 171 | 172 | ## VAE ## 173 | # Get VAE model 174 | vae_config = AutoencoderKLDiffusersConfig( 175 | version=backbone_signature, 176 | subfolder="vae", 177 | tiling_size=(128, 128), 178 | ) 179 | vae = AutoencoderKLDiffusers(vae_config) 180 | vae.freeze() 181 | vae.to(torch.bfloat16) 182 | 183 | # LBM Config 184 | config = LBMConfig( 185 | ucg_keys=None, 186 | source_key=source_key, 187 | target_key=target_key, 188 | mask_key=mask_key, 189 | latent_loss_weight=latent_loss_weight, 190 | latent_loss_type=latent_loss_type, 191 | pixel_loss_type=pixel_loss_type, 192 | pixel_loss_weight=pixel_loss_weight, 193 | timestep_sampling=timestep_sampling, 194 | logit_mean=logit_mean, 195 | logit_std=logit_std, 196 | selected_timesteps=selected_timesteps, 197 | prob=prob, 198 | bridge_noise_sigma=bridge_noise_sigma, 199 | ) 200 | 201 | training_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 202 | backbone_signature, 203 | subfolder="scheduler", 204 | ) 205 | sampling_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 206 | backbone_signature, 207 | subfolder="scheduler", 208 | ) 209 | 210 | # LBM Model 211 | model = LBMModel( 212 | config, 213 | denoiser=denoiser, 214 | training_noise_scheduler=training_noise_scheduler, 215 | sampling_noise_scheduler=sampling_noise_scheduler, 216 | vae=vae, 217 | conditioner=conditioner, 218 | ).to(torch.bfloat16) 219 | 220 | return model 221 | 222 | 223 | def get_filter_mappers(): 224 | filters_mappers = [ 225 | KeyFilter(KeyFilterConfig(keys=["jpg", "normal_aligned.png", "mask.png"])), 226 | MapperWrapper( 227 | [ 228 | KeyRenameMapper( 229 | KeyRenameMapperConfig( 230 | key_map={ 231 | "jpg": "image", 232 | "normal_aligned.png": "normal", 233 | "mask.png": "mask", 234 | } 235 | ) 236 | ), 237 | TorchvisionMapper( 238 | TorchvisionMapperConfig( 239 | key="image", 240 | transforms=["ToTensor", "Resize"], 241 | transforms_kwargs=[ 242 | {}, 243 | { 244 | "size": (480, 640), 245 | "interpolation": InterpolationMode.NEAREST_EXACT, 246 | }, 247 | ], 248 | ) 249 | ), 250 | TorchvisionMapper( 251 | TorchvisionMapperConfig( 252 | key="normal", 253 | transforms=["ToTensor", "Resize"], 254 | transforms_kwargs=[ 255 | {}, 256 | { 257 | "size": (480, 640), 258 | "interpolation": InterpolationMode.NEAREST_EXACT, 259 | }, 260 | ], 261 | ) 262 | ), 263 | TorchvisionMapper( 264 | TorchvisionMapperConfig( 265 | key="mask", 266 | transforms=["ToTensor", "Resize", "Normalize"], 267 | transforms_kwargs=[ 268 | {}, 269 | { 270 | "size": (480, 640), 271 | "interpolation": InterpolationMode.NEAREST_EXACT, 272 | }, 273 | {"mean": 0.0, "std": 1.0}, 274 | ], 275 | ) 276 | ), 277 | RescaleMapper(RescaleMapperConfig(key="image")), 278 | RescaleMapper(RescaleMapperConfig(key="normal")), 279 | ], 280 | ), 281 | ] 282 | 283 | return filters_mappers 284 | 285 | 286 | def get_data_module( 287 | train_shards: List[str], 288 | validation_shards: List[str], 289 | batch_size: int, 290 | ): 291 | 292 | # TRAIN 293 | train_filters_mappers = get_filter_mappers() 294 | 295 | # unbrace urls 296 | train_shards_path_or_urls_unbraced = [] 297 | for train_shards_path_or_url in train_shards: 298 | train_shards_path_or_urls_unbraced.extend( 299 | braceexpand.braceexpand(train_shards_path_or_url) 300 | ) 301 | 302 | # shuffle shards 303 | random.shuffle(train_shards_path_or_urls_unbraced) 304 | 305 | # data config 306 | data_config = DataModuleConfig( 307 | shards_path_or_urls=train_shards_path_or_urls_unbraced, 308 | decoder="pil", 309 | shuffle_before_split_by_node_buffer_size=20, 310 | shuffle_before_split_by_workers_buffer_size=20, 311 | shuffle_before_filter_mappers_buffer_size=20, 312 | shuffle_after_filter_mappers_buffer_size=20, 313 | per_worker_batch_size=batch_size, 314 | num_workers=min(10, len(train_shards_path_or_urls_unbraced)), 315 | ) 316 | 317 | train_data_config = data_config 318 | 319 | # VALIDATION 320 | validation_filters_mappers = get_filter_mappers() 321 | 322 | # unbrace urls 323 | validation_shards_path_or_urls_unbraced = [] 324 | for validation_shards_path_or_url in validation_shards: 325 | validation_shards_path_or_urls_unbraced.extend( 326 | braceexpand.braceexpand(validation_shards_path_or_url) 327 | ) 328 | 329 | data_config = DataModuleConfig( 330 | shards_path_or_urls=validation_shards_path_or_urls_unbraced, 331 | decoder="pil", 332 | shuffle_before_split_by_node_buffer_size=10, 333 | shuffle_before_split_by_workers_buffer_size=10, 334 | shuffle_before_filter_mappers_buffer_size=10, 335 | shuffle_after_filter_mappers_buffer_size=10, 336 | per_worker_batch_size=batch_size, 337 | num_workers=min(10, len(train_shards_path_or_urls_unbraced)), 338 | ) 339 | 340 | validation_data_config = data_config 341 | 342 | # data module 343 | data_module = DataModule( 344 | train_config=train_data_config, 345 | train_filters_mappers=train_filters_mappers, 346 | eval_config=validation_data_config, 347 | eval_filters_mappers=validation_filters_mappers, 348 | ) 349 | 350 | return data_module 351 | 352 | 353 | def main( 354 | train_shards: List[str] = ["pipe:cat path/to/train/shards"], 355 | validation_shards: List[str] = ["pipe:cat path/to/validation/shards"], 356 | backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0", 357 | vae_num_channels: int = 4, 358 | unet_input_channels: int = 4, 359 | source_key: str = "image", 360 | target_key: str = "normal", 361 | mask_key: str = "mask", 362 | wandb_project: str = "lbm-surface", 363 | batch_size: int = 8, 364 | num_steps: List[int] = [1, 2, 4], 365 | learning_rate: float = 5e-5, 366 | learning_rate_scheduler: str = None, 367 | learning_rate_scheduler_kwargs: dict = {}, 368 | optimizer: str = "AdamW", 369 | optimizer_kwargs: dict = {}, 370 | timestep_sampling: str = "uniform", 371 | logit_mean: float = 0.0, 372 | logit_std: float = 1.0, 373 | pixel_loss_type: str = "lpips", 374 | latent_loss_type: str = "l2", 375 | latent_loss_weight: float = 1.0, 376 | pixel_loss_weight: float = 0.0, 377 | selected_timesteps: List[float] = None, 378 | prob: List[float] = None, 379 | conditioning_images_keys: Optional[List[str]] = [], 380 | conditioning_masks_keys: Optional[List[str]] = [], 381 | config_yaml: dict = None, 382 | save_ckpt_path: str = "./checkpoints", 383 | log_interval: int = 100, 384 | resume_from_checkpoint: bool = True, 385 | max_epochs: int = 100, 386 | bridge_noise_sigma: float = 0.005, 387 | save_interval: int = 1000, 388 | path_config: str = None, 389 | ): 390 | model = get_model( 391 | backbone_signature=backbone_signature, 392 | vae_num_channels=vae_num_channels, 393 | unet_input_channels=unet_input_channels, 394 | source_key=source_key, 395 | target_key=target_key, 396 | mask_key=mask_key, 397 | timestep_sampling=timestep_sampling, 398 | logit_mean=logit_mean, 399 | logit_std=logit_std, 400 | pixel_loss_type=pixel_loss_type, 401 | latent_loss_type=latent_loss_type, 402 | latent_loss_weight=latent_loss_weight, 403 | pixel_loss_weight=pixel_loss_weight, 404 | selected_timesteps=selected_timesteps, 405 | prob=prob, 406 | conditioning_images_keys=conditioning_images_keys, 407 | conditioning_masks_keys=conditioning_masks_keys, 408 | bridge_noise_sigma=bridge_noise_sigma, 409 | ) 410 | 411 | data_module = get_data_module( 412 | train_shards=train_shards, 413 | validation_shards=validation_shards, 414 | batch_size=batch_size, 415 | ) 416 | 417 | train_parameters = ["denoiser.*"] 418 | 419 | # Training Config 420 | training_config = TrainingConfig( 421 | learning_rate=learning_rate, 422 | lr_scheduler_name=learning_rate_scheduler, 423 | lr_scheduler_kwargs=learning_rate_scheduler_kwargs, 424 | log_keys=["image", "normal", "mask"], 425 | trainable_params=train_parameters, 426 | optimizer_name=optimizer, 427 | optimizer_kwargs=optimizer_kwargs, 428 | log_samples_model_kwargs={ 429 | "input_shape": None, 430 | "num_steps": num_steps, 431 | }, 432 | ) 433 | if ( 434 | os.path.exists(save_ckpt_path) 435 | and resume_from_checkpoint 436 | and "last.ckpt" in os.listdir(save_ckpt_path) 437 | ): 438 | start_ckpt = f"{save_ckpt_path}/last.ckpt" 439 | print(f"Resuming from checkpoint: {start_ckpt}") 440 | 441 | else: 442 | start_ckpt = None 443 | 444 | pipeline = TrainingPipeline(model=model, pipeline_config=training_config) 445 | 446 | pipeline.save_hyperparameters( 447 | { 448 | f"embedder_{i}": embedder.config.to_dict() 449 | for i, embedder in enumerate(model.conditioner.conditioners) 450 | } 451 | ) 452 | 453 | pipeline.save_hyperparameters( 454 | { 455 | "denoiser": model.denoiser.config, 456 | "vae": model.vae.config.to_dict(), 457 | "config_yaml": config_yaml, 458 | "training": training_config.to_dict(), 459 | "training_noise_scheduler": model.training_noise_scheduler.config, 460 | "sampling_noise_scheduler": model.sampling_noise_scheduler.config, 461 | } 462 | ) 463 | 464 | training_signature = ( 465 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 466 | + "-LBM-Surface" 467 | + f"{os.environ['SLURM_JOB_ID']}" 468 | + f"_{os.environ.get('SLURM_ARRAY_TASK_ID', 0)}" 469 | ) 470 | dir_path = f"{save_ckpt_path}/logs/{training_signature}" 471 | if os.environ["SLURM_PROCID"] == "0": 472 | os.makedirs(dir_path, exist_ok=True) 473 | if path_config is not None: 474 | shutil.copy(path_config, f"{save_ckpt_path}/config.yaml") 475 | run_name = training_signature 476 | 477 | # Ignore parameters unused during training 478 | ignore_states = [] 479 | for name, param in pipeline.model.named_parameters(): 480 | ignore = True 481 | for regex in ["denoiser."]: 482 | pattern = re.compile(regex) 483 | if re.match(pattern, name): 484 | ignore = False 485 | if ignore: 486 | ignore_states.append(param) 487 | 488 | # FSDP Strategy 489 | strategy = FSDPStrategy( 490 | auto_wrap_policy=ModuleWrapPolicy( 491 | [ 492 | UNet2DConditionModel, 493 | BasicTransformerBlock, 494 | ResnetBlock2D, 495 | torch.nn.Conv2d, 496 | ] 497 | ), 498 | activation_checkpointing_policy=ModuleWrapPolicy( 499 | [ 500 | BasicTransformerBlock, 501 | ResnetBlock2D, 502 | ] 503 | ), 504 | sharding_strategy="SHARD_GRAD_OP", 505 | ignored_states=ignore_states, 506 | ) 507 | 508 | trainer = Trainer( 509 | accelerator="gpu", 510 | devices=int(os.environ["SLURM_NPROCS"]) // int(os.environ["SLURM_NNODES"]), 511 | num_nodes=int(os.environ["SLURM_NNODES"]), 512 | strategy=strategy, 513 | default_root_dir="logs", 514 | logger=loggers.WandbLogger( 515 | project=wandb_project, offline=False, name=run_name, save_dir=save_ckpt_path 516 | ), 517 | callbacks=[ 518 | WandbSampleLogger(log_batch_freq=log_interval), 519 | LearningRateMonitor(logging_interval="step"), 520 | ModelCheckpoint( 521 | dirpath=save_ckpt_path, 522 | every_n_train_steps=save_interval, 523 | save_last=True, 524 | ), 525 | ], 526 | num_sanity_val_steps=0, 527 | precision="bf16-mixed", 528 | limit_val_batches=2, 529 | val_check_interval=1000, 530 | max_epochs=max_epochs, 531 | ) 532 | 533 | trainer.fit(pipeline, data_module, ckpt_path=start_ckpt) 534 | 535 | 536 | def main_from_config(path_config: str = None): 537 | with open(path_config, "r") as file: 538 | config = yaml.safe_load(file) 539 | logging.info( 540 | f"Running main with config: {yaml.dump(config, default_flow_style=False)}" 541 | ) 542 | main(**config, config_yaml=config, path_config=path_config) 543 | 544 | 545 | if __name__ == "__main__": 546 | fire.Fire(main_from_config) 547 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-requirements-txt"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "lbm" 7 | dynamic = ["dependencies", "optional-dependencies"] 8 | description = "LBM: Latent Bridge Matching for Fast Image-to-Image Translation" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | authors = [ 12 | { name = "Clement Chadebec", email = "clement.chadebec@jasper.ai" }, 13 | { name = "Benjamin Aubin", email = "benjamin.aubin@jasper.ai" }, 14 | ] 15 | maintainers = [ 16 | { name = "Clement Chadebec", email = "clement.chadebec@jasper.ai" }, 17 | ] 18 | classifiers = [ 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | "License :: OSI Approved :: Apache Software License", 24 | "Operating System :: OS Independent", 25 | ] 26 | version = "0.1" 27 | 28 | [project.urls] 29 | Homepage = "https://github.com/gojasper/LBM" 30 | Repository = "https://github.com/gojasper/LBM" 31 | 32 | [tool.hatch.metadata] 33 | allow-direct-references = true 34 | 35 | [tool.hatch.metadata.hooks.requirements_txt] 36 | files = ["requirements.txt"] 37 | 38 | [tool.hatch.build.targets.wheel] 39 | packages = ["src/lbm"] 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.4.0 2 | diffusers==0.32.2 3 | torch==2.7.0 4 | torchvision>=0.20.0 5 | black==24.2.0 6 | einops==0.7.0 7 | fire>=0.5.0 8 | gradio==5.29.0 9 | isort==5.13.2 10 | lightning==2.5.0 11 | lpips==0.1.4 12 | opencv-python==4.9.0.80 13 | peft==0.9.0 14 | pydantic>=2.6.1 15 | scipy>=1.12.0 16 | sentencepiece>=0.2.0 17 | timm==0.9.16 18 | tokenizers>=0.15.2 19 | torch-fidelity>=0.3.0 20 | torchmetrics>=1.3.1 21 | transformers==4.42.3 22 | wandb==0.16.2 23 | webdataset>=0.2.86 24 | kornia==0.8.0 -------------------------------------------------------------------------------- /src/lbm/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import warnings 4 | from dataclasses import asdict, field 5 | from typing import Any, Dict, Union 6 | 7 | import yaml 8 | from pydantic import ValidationError 9 | from pydantic.dataclasses import dataclass 10 | from yaml import safe_load 11 | 12 | 13 | @dataclass 14 | class BaseConfig: 15 | """This is the BaseConfig class which defines all the useful loading and saving methods 16 | of the configs""" 17 | 18 | name: str = field(init=False) 19 | 20 | def __post_init__(self): 21 | self.name = self.__class__.__name__ 22 | 23 | @classmethod 24 | def from_dict(cls, config_dict: Dict[str, Any]) -> "BaseConfig": 25 | """Creates a BaseConfig instance from a dictionnary 26 | 27 | Args: 28 | config_dict (dict): The Python dictionnary containing all the parameters 29 | 30 | Returns: 31 | :class:`BaseConfig`: The created instance 32 | """ 33 | try: 34 | config = cls(**config_dict) 35 | except (ValidationError, TypeError) as e: 36 | raise e 37 | return config 38 | 39 | @classmethod 40 | def _dict_from_json(cls, json_path: Union[str, os.PathLike]) -> Dict[str, Any]: 41 | try: 42 | with open(json_path) as f: 43 | try: 44 | config_dict = json.load(f) 45 | return config_dict 46 | 47 | except (TypeError, json.JSONDecodeError) as e: 48 | raise TypeError( 49 | f"File {json_path} not loadable. Maybe not json ? \n" 50 | f"Catch Exception {type(e)} with message: " + str(e) 51 | ) from e 52 | 53 | except FileNotFoundError: 54 | raise FileNotFoundError( 55 | f"Config file not found. Please check path '{json_path}'" 56 | ) 57 | 58 | @classmethod 59 | def from_json(cls, json_path: str) -> "BaseConfig": 60 | """Creates a BaseConfig instance from a JSON config file 61 | 62 | Args: 63 | json_path (str): The path to the json file containing all the parameters 64 | 65 | Returns: 66 | :class:`BaseConfig`: The created instance 67 | """ 68 | config_dict = cls._dict_from_json(json_path) 69 | 70 | config_name = config_dict.pop("name") 71 | 72 | if cls.__name__ != config_name: 73 | warnings.warn( 74 | f"You are trying to load a " 75 | f"`{ cls.__name__}` while a " 76 | f"`{config_name}` is given." 77 | ) 78 | 79 | return cls.from_dict(config_dict) 80 | 81 | def to_dict(self) -> dict: 82 | """Transforms object into a Python dictionnary 83 | 84 | Returns: 85 | (dict): The dictionnary containing all the parameters""" 86 | return asdict(self) 87 | 88 | def to_json_string(self): 89 | """Transforms object into a JSON string 90 | 91 | Returns: 92 | (str): The JSON str containing all the parameters""" 93 | return json.dumps(self.to_dict()) 94 | 95 | def save_json(self, file_path: str): 96 | """Saves a ``.json`` file from the dataclass 97 | 98 | Args: 99 | file_path (str): path to the file 100 | """ 101 | with open(os.path.join(file_path), "w", encoding="utf-8") as fp: 102 | fp.write(self.to_json_string()) 103 | 104 | def save_yaml(self, file_path: str): 105 | """Saves a ``.yaml`` file from the dataclass 106 | 107 | Args: 108 | file_path (str): path to the file 109 | """ 110 | with open(os.path.join(file_path), "w", encoding="utf-8") as fp: 111 | yaml.dump(self.to_dict(), fp) 112 | 113 | @classmethod 114 | def from_yaml(cls, yaml_path: str) -> "BaseConfig": 115 | """Creates a BaseConfig instance from a YAML config file 116 | 117 | Args: 118 | yaml_path (str): The path to the yaml file containing all the parameters 119 | 120 | Returns: 121 | :class:`BaseConfig`: The created instance 122 | """ 123 | with open(yaml_path, "r") as f: 124 | try: 125 | config_dict = safe_load(f) 126 | except yaml.YAMLError as e: 127 | raise yaml.YAMLError( 128 | f"File {yaml_path} not loadable. Maybe not yaml ? \n" 129 | f"Catch Exception {type(e)} with message: " + str(e) 130 | ) from e 131 | 132 | config_name = config_dict.pop("name") 133 | 134 | if cls.__name__ != config_name: 135 | warnings.warn( 136 | f"You are trying to load a " 137 | f"`{ cls.__name__}` while a " 138 | f"`{config_name}` is given." 139 | ) 140 | 141 | return cls.from_dict(config_dict) 142 | -------------------------------------------------------------------------------- /src/lbm/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains a collection of data related classes and functions to train the :mod:`cr.models`. 3 | In a training loop a batch of data is struvtued as a dictionnary on which the modules :mod:`cr.data.datasets` 4 | and :mod:`cr.data.filters` allow to perform several operations. 5 | 6 | 7 | Examples 8 | ######## 9 | 10 | Create a DataModule to train a model 11 | 12 | .. code-block::python 13 | 14 | from cr.data import DataModule, DataModuleConfig 15 | from cr.data.filters import KeyFilter, KeyFilterConfig 16 | from cr.data.mappers import KeyRenameMapper, KeyRenameMapperConfig 17 | 18 | # Create the filters and mappers 19 | filters_mappers = [ 20 | KeyFilter(KeyFilterConfig(keys=["image", "txt"])), 21 | KeyRenameMapper( 22 | KeyRenameMapperConfig(key_map={"jpg": "image", "txt": "text"}) 23 | ) 24 | ] 25 | 26 | # Create the DataModule 27 | data_module = DataModule( 28 | train_config=DataModuleConfig( 29 | shards_path_or_urls="your urls or paths", 30 | decoder="pil", 31 | shuffle_buffer_size=100, 32 | per_worker_batch_size=32, 33 | num_workers=4, 34 | ), 35 | train_filters_mappers=filters_mappers, 36 | eval_config=DataModuleConfig( 37 | shards_path_or_urls="your urls or paths", 38 | decoder="pil", 39 | shuffle_buffer_size=100, 40 | per_worker_batch_size=32, 41 | num_workers=4, 42 | ), 43 | eval_filters_mappers=filters_mappers, 44 | ) 45 | 46 | # This can then be passed to a :mod:`pytorch_lightning.Trainer` to train a model 47 | 48 | 49 | 50 | 51 | 52 | The :mod:`cr.data` includes the following submodules: 53 | 54 | - :mod:`cr.data.datasets`: a collection of :mod:`pytorch_lightning.LightningDataModule` used to train the models. In particular, 55 | they can used to create the dataloaders and setup the data pipelines. 56 | - :mod:`cr.data.filters`: a collection of filters used apply filters on a training batch of data/ 57 | 58 | """ 59 | 60 | from .datasets import DataModule 61 | 62 | __all__ = ["DataModule"] 63 | -------------------------------------------------------------------------------- /src/lbm/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A collection of :mod:`pytorch_lightning.LightningDataModule` used to train the models. In particular, 3 | they can be used to create the dataloaders and setup the data pipelines. 4 | """ 5 | 6 | from .dataset import DataModule 7 | from .datasets_config import DataModuleConfig 8 | 9 | __all__ = ["DataModule", "DataModuleConfig"] 10 | -------------------------------------------------------------------------------- /src/lbm/data/datasets/collation_fn.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def custom_collation_fn( 8 | samples: List[Dict[str, Union[int, float, np.ndarray, torch.Tensor]]], 9 | combine_tensors: bool = True, 10 | combine_scalars: bool = True, 11 | ) -> dict: 12 | """ 13 | Collate function for PyTorch DataLoader. 14 | 15 | Args: 16 | samples(List[Dict[str, Union[int, float, np.ndarray, torch.Tensor]]]): List of samples. 17 | combine_tensors (bool): Whether to turn lists of tensors into a single tensor. 18 | combine_scalars (bool): Whether to turn lists of scalars into a single ndarray. 19 | """ 20 | keys = set.intersection(*[set(sample.keys()) for sample in samples]) 21 | batched = {key: [] for key in keys} 22 | for s in samples: 23 | [batched[key].append(s[key]) for key in batched] 24 | 25 | result = {} 26 | for key in batched: 27 | if isinstance(batched[key][0], (int, float)): 28 | if combine_scalars: 29 | result[key] = np.array(list(batched[key])) 30 | elif isinstance(batched[key][0], torch.Tensor): 31 | if combine_tensors: 32 | result[key] = torch.stack(list(batched[key])) 33 | elif isinstance(batched[key][0], np.ndarray): 34 | if combine_tensors: 35 | result[key] = np.array(list(batched[key])) 36 | else: 37 | result[key] = list(batched[key]) 38 | 39 | del samples 40 | del batched 41 | return result 42 | -------------------------------------------------------------------------------- /src/lbm/data/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Union 2 | 3 | import pytorch_lightning as pl 4 | import webdataset as wds 5 | from webdataset import DataPipeline 6 | 7 | from ..filters import BaseFilter, FilterWrapper 8 | from ..mappers import BaseMapper, MapperWrapper 9 | from .collation_fn import custom_collation_fn 10 | from .datasets_config import DataModuleConfig 11 | 12 | 13 | class DataPipeline: 14 | """ 15 | DataPipeline class for creating a dataloader from a single configuration 16 | 17 | Args: 18 | 19 | config (DataModuleConfig): 20 | Configuration for the dataset 21 | 22 | filters_mappers (Union[List[Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]]): 23 | List of filters and mappers for the dataset. These will be sequentially applied. 24 | 25 | batched_filters_mappers (List[Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]]): 26 | List of batched transforms for the dataset. These will be sequentially applied. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | config: DataModuleConfig, 32 | filters_mappers: List[ 33 | Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper] 34 | ], 35 | batched_filters_mappers: List[ 36 | Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper] 37 | ] = None, 38 | ): 39 | self.config = config 40 | self.shards_path_or_urls = config.shards_path_or_urls 41 | self.filters_mappers = filters_mappers 42 | self.batched_filters_mappers = batched_filters_mappers or [] 43 | 44 | if filters_mappers is None: 45 | filters_mappers = [] 46 | 47 | # set processing pipeline 48 | self.processing_pipeline = [wds.decode(config.decoder, handler=config.handler)] 49 | self.processing_pipeline.extend( 50 | self._add_filters_mappers( 51 | filters_mappers=filters_mappers, 52 | handler=config.handler, 53 | ) 54 | ) 55 | 56 | def _add_filters_mappers( 57 | self, 58 | filters_mappers: List[ 59 | Union[ 60 | FilterWrapper, 61 | MapperWrapper, 62 | ] 63 | ], 64 | handler: Callable = wds.warn_and_continue, 65 | ) -> List[Union[FilterWrapper, MapperWrapper]]: 66 | tmp_pipeline = [] 67 | for filter_mapper in filters_mappers: 68 | if isinstance(filter_mapper, FilterWrapper) or isinstance( 69 | filter_mapper, BaseFilter 70 | ): 71 | tmp_pipeline.append(wds.select(filter_mapper)) 72 | elif isinstance(filter_mapper, MapperWrapper) or isinstance( 73 | filter_mapper, BaseMapper 74 | ): 75 | tmp_pipeline.append(wds.map(filter_mapper, handler=handler)) 76 | elif isinstance(filter_mapper) or isinstance(filter_mapper): 77 | tmp_pipeline.append(wds.map(filter_mapper, handler=handler)) 78 | else: 79 | raise ValueError("Unknown type of filter/mapper") 80 | return tmp_pipeline 81 | 82 | def setup(self): 83 | pipeline = [wds.SimpleShardList(self.shards_path_or_urls)] 84 | 85 | # shuffle before split by node 86 | if self.config.shuffle_before_split_by_node_buffer_size is not None: 87 | pipeline.append( 88 | wds.shuffle( 89 | self.config.shuffle_before_split_by_node_buffer_size, 90 | handler=self.config.handler, 91 | ) 92 | ) 93 | # split by node 94 | pipeline.append(wds.split_by_node) 95 | 96 | # shuffle before split by workers 97 | if self.config.shuffle_before_split_by_workers_buffer_size is not None: 98 | pipeline.append( 99 | wds.shuffle( 100 | self.config.shuffle_before_split_by_workers_buffer_size, 101 | handler=self.config.handler, 102 | ) 103 | ) 104 | # split by worker 105 | pipeline.extend( 106 | [ 107 | wds.split_by_worker, 108 | wds.tarfile_to_samples( 109 | handler=self.config.handler, 110 | rename_files=self.config.rename_files_fn, 111 | ), 112 | ] 113 | ) 114 | 115 | # shuffle before filter mappers 116 | if self.config.shuffle_before_filter_mappers_buffer_size is not None: 117 | pipeline.append( 118 | wds.shuffle( 119 | self.config.shuffle_before_filter_mappers_buffer_size, 120 | handler=self.config.handler, 121 | ) 122 | ) 123 | 124 | # apply filters and mappers 125 | pipeline.extend(self.processing_pipeline) 126 | 127 | # shuffle after filter mappers 128 | if self.config.shuffle_after_filter_mappers_buffer_size is not None: 129 | pipeline.append( 130 | wds.shuffle( 131 | self.config.shuffle_after_filter_mappers_buffer_size, 132 | handler=self.config.handler, 133 | ), 134 | ) 135 | 136 | # batching 137 | pipeline.append( 138 | wds.batched( 139 | self.config.per_worker_batch_size, 140 | collation_fn=custom_collation_fn, 141 | ) 142 | ) 143 | 144 | # apply batched transforms 145 | pipeline.extend( 146 | self._add_filters_mappers( 147 | filters_mappers=self.batched_filters_mappers, 148 | handler=self.config.handler, 149 | ) 150 | ) 151 | 152 | # create the data pipeline 153 | pipeline = wds.DataPipeline(*pipeline, handler=self.config.handler) 154 | 155 | # set the pipeline 156 | self.pipeline = pipeline 157 | 158 | def dataloader(self): 159 | # return the loader 160 | return wds.WebLoader( 161 | self.pipeline, 162 | batch_size=None, 163 | num_workers=self.config.num_workers, 164 | ) 165 | 166 | 167 | class DataModule(pl.LightningDataModule): 168 | """ 169 | Main DataModule class for creating data loaders and training/evaluating models 170 | 171 | Args: 172 | 173 | train_config (DataModuleConfig): 174 | Configuration for the training dataset 175 | 176 | train_filters_mappers (Union[List[Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]]): 177 | List of filters and mappers for the training dataset. These will be sequentially applied. 178 | 179 | train_batched_filters_mappers (List[Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]]): 180 | List of batched transforms for the training dataset. These will be sequentially applied. 181 | 182 | eval_config (DataModuleConfig): 183 | Configuration for the evaluation dataset 184 | 185 | eval_filters_mappers (List[Union[FilterWrapper, MapperWrapper]]): 186 | List of filters and mappers for the evaluation dataset.These will be sequentially applied. 187 | 188 | eval_batched_filters_mappers (List[Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper]]): 189 | List of batched transforms for the evaluation dataset. These will be sequentially applied. 190 | """ 191 | 192 | def __init__( 193 | self, 194 | train_config: DataModuleConfig, 195 | train_filters_mappers: List[ 196 | Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper] 197 | ] = None, 198 | train_batched_filters_mappers: List[ 199 | Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper] 200 | ] = None, 201 | eval_config: DataModuleConfig = None, 202 | eval_filters_mappers: List[Union[FilterWrapper, MapperWrapper]] = None, 203 | eval_batched_filters_mappers: List[ 204 | Union[BaseMapper, BaseFilter, FilterWrapper, MapperWrapper] 205 | ] = None, 206 | ): 207 | super().__init__() 208 | 209 | self.train_config = train_config 210 | self.train_filters_mappers = train_filters_mappers 211 | self.train_batched_filters_mappers = train_batched_filters_mappers 212 | 213 | self.eval_config = eval_config 214 | self.eval_filters_mappers = eval_filters_mappers 215 | self.eval_batched_filters_mappers = eval_batched_filters_mappers 216 | 217 | def setup(self, stage=None): 218 | """ 219 | Setup the data module and create the webdataset processing pipelines 220 | """ 221 | 222 | # train pipeline 223 | self.train_pipeline = DataPipeline( 224 | config=self.train_config, 225 | filters_mappers=self.train_filters_mappers, 226 | batched_filters_mappers=self.train_batched_filters_mappers, 227 | ) 228 | self.train_pipeline.setup() 229 | 230 | # eval pipeline 231 | if self.eval_config is not None: 232 | self.eval_pipeline = DataPipeline( 233 | config=self.eval_config, 234 | filters_mappers=self.eval_filters_mappers, 235 | batched_filters_mappers=self.eval_batched_filters_mappers, 236 | ) 237 | self.eval_pipeline.setup() 238 | 239 | def train_dataloader(self): 240 | return self.train_pipeline.dataloader() 241 | 242 | def val_dataloader(self): 243 | return self.eval_pipeline.dataloader() 244 | -------------------------------------------------------------------------------- /src/lbm/data/datasets/datasets_config.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Union 2 | 3 | import webdataset as wds 4 | from pydantic.dataclasses import dataclass 5 | 6 | from ...config import BaseConfig 7 | 8 | 9 | @dataclass 10 | class DataModuleConfig(BaseConfig): 11 | """ 12 | Configuration for the DataModule 13 | 14 | Args: 15 | 16 | shards_path_or_urls (Union[str, List[str]]): The path or url to the shards. Defaults to None. 17 | per_worker_batch_size (int): The batch size for the dataset. Defaults to 16. 18 | num_workers (int): The number of workers to use. Defaults to 1. 19 | shuffle_before_split_by_node_buffer_size (Optional[int]): The buffer size for the shuffle before split by node. Defaults to 100. 20 | shuffle_before_split_by_workers_buffer_size (Optional[int]): The buffer size for the shuffle before split by workers. Defaults to 100. 21 | shuffle_before_filter_mappers_buffer_size (Optional[int]): The buffer size for the shuffle before filter mappers. Defaults to 1000. 22 | shuffle_after_filter_mappers_buffer_size (Optional[int]): The buffer size for the shuffle after filter mappers. Defaults to 1000. 23 | decoder (str): The decoder to use. Defaults to "pil". 24 | handler (Callable): A callable to handle the warnings. Defaults to wds.warn_and_continue. 25 | rename_files_fn (Optional[Callable[[str], str]]): A callable to rename the files. Defaults to None. 26 | """ 27 | 28 | shards_path_or_urls: Union[str, List[str]] = None 29 | per_worker_batch_size: int = 16 30 | num_workers: int = 1 31 | shuffle_before_split_by_node_buffer_size: Optional[int] = 100 32 | shuffle_before_split_by_workers_buffer_size: Optional[int] = 100 33 | shuffle_before_filter_mappers_buffer_size: Optional[int] = 1000 34 | shuffle_after_filter_mappers_buffer_size: Optional[int] = 1000 35 | decoder: str = "pil" 36 | handler: Callable = wds.warn_and_continue 37 | rename_files_fn: Optional[Callable[[str], str]] = None 38 | 39 | def __post_init__(self): 40 | super().__post_init__() 41 | if self.rename_files_fn is not None: 42 | assert callable(self.rename_files_fn), "rename_files must be a callable" 43 | -------------------------------------------------------------------------------- /src/lbm/data/filters/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseFilter 2 | from .filter_wrapper import FilterWrapper 3 | from .filters import KeyFilter 4 | from .filters_config import BaseFilterConfig, KeyFilterConfig 5 | 6 | __all__ = [ 7 | "BaseFilter", 8 | "FilterWrapper", 9 | "KeyFilter", 10 | "BaseFilterConfig", 11 | "KeyFilterConfig", 12 | ] 13 | -------------------------------------------------------------------------------- /src/lbm/data/filters/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from .filters_config import BaseFilterConfig 4 | 5 | 6 | class BaseFilter: 7 | """ 8 | Base class for filters. This class should be subclassed to create a new filter. 9 | 10 | Args: 11 | 12 | config (BaseFilterConfig): 13 | Configuration for the filter 14 | """ 15 | 16 | def __init__(self, config: BaseFilterConfig): 17 | self.verbose = config.verbose 18 | 19 | def __call__(self, sample: Dict[str, Any]) -> bool: 20 | """This function should be implemented by the subclass""" 21 | raise NotImplementedError 22 | -------------------------------------------------------------------------------- /src/lbm/data/filters/filter_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Union 2 | 3 | from .base import BaseFilter 4 | 5 | 6 | class FilterWrapper: 7 | """ 8 | Wrapper for multiple filters. This class allows to apply multiple filters to a batch of data. 9 | The filters are applied in the order they are passed to the wrapper. 10 | 11 | Args: 12 | 13 | filters (List[BaseFilter]): 14 | List of filters to apply to the batch of data 15 | """ 16 | 17 | def __init__( 18 | self, 19 | filters: Union[List[BaseFilter], None] = None, 20 | ): 21 | self.filters = filters 22 | 23 | def __call__(self, batch: Dict[str, Any]) -> None: 24 | """ 25 | Forward pass through all filters 26 | 27 | Args: 28 | 29 | batch: batch of data 30 | """ 31 | filter_output = True 32 | for filter in self.filters: 33 | filter_output = filter(batch) 34 | if not filter_output: 35 | return False 36 | return True 37 | -------------------------------------------------------------------------------- /src/lbm/data/filters/filters.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .base import BaseFilter 4 | from .filters_config import KeyFilterConfig 5 | 6 | logging.basicConfig(level=logging.INFO) 7 | 8 | 9 | class KeyFilter(BaseFilter): 10 | """ 11 | This filter checks if ALL the given keys are present in the sample 12 | 13 | Args: 14 | 15 | config (KeyFilterConfig): configuration for the filter 16 | """ 17 | 18 | def __init__(self, config: KeyFilterConfig): 19 | super().__init__(config) 20 | keys = config.keys 21 | if isinstance(keys, str): 22 | keys = [keys] 23 | 24 | self.keys = set(keys) 25 | 26 | def __call__(self, batch: dict) -> bool: 27 | try: 28 | res = self.keys.issubset(set(batch.keys())) 29 | return res 30 | except Exception as e: 31 | if self.verbose: 32 | logging.error(f"Error in KeyFilter: {e}") 33 | return False 34 | -------------------------------------------------------------------------------- /src/lbm/data/filters/filters_config.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | from pydantic.dataclasses import dataclass 4 | 5 | from ...config import BaseConfig 6 | 7 | 8 | @dataclass 9 | class BaseFilterConfig(BaseConfig): 10 | """ 11 | Base configuration for filters 12 | 13 | Args: 14 | 15 | verbose (bool): 16 | If True, print debug information. Defaults to False""" 17 | 18 | verbose: bool = False 19 | 20 | 21 | @dataclass 22 | class KeyFilterConfig(BaseFilterConfig): 23 | """ 24 | This filter checks if the keys are present in a sample. 25 | 26 | Args: 27 | 28 | keys (Union[str, List[str]]): 29 | Key or list of keys to check. Defaults to "txt" 30 | """ 31 | 32 | keys: Union[str, List[str]] = "txt" 33 | -------------------------------------------------------------------------------- /src/lbm/data/mappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseMapper 2 | from .mappers import KeyRenameMapper, RescaleMapper, TorchvisionMapper 3 | from .mappers_config import ( 4 | KeyRenameMapperConfig, 5 | RescaleMapperConfig, 6 | TorchvisionMapperConfig, 7 | ) 8 | from .mappers_wrapper import MapperWrapper 9 | 10 | __all__ = [ 11 | "BaseMapper", 12 | "KeyRenameMapper", 13 | "RescaleMapper", 14 | "TorchvisionMapper", 15 | "KeyRenameMapperConfig", 16 | "RescaleMapperConfig", 17 | "TorchvisionMapperConfig", 18 | "MapperWrapper", 19 | ] 20 | -------------------------------------------------------------------------------- /src/lbm/data/mappers/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from .mappers_config import BaseMapperConfig 4 | 5 | 6 | class BaseMapper: 7 | """ 8 | Base class for the mappers used to modify the samples in the data pipeline. 9 | 10 | Args: 11 | 12 | config (BaseMapperConfig): 13 | Configuration for the mapper. 14 | """ 15 | 16 | def __init__(self, config: BaseMapperConfig): 17 | self.config = config 18 | self.key = config.key 19 | 20 | if config.output_key is None: 21 | self.output_key = config.key 22 | else: 23 | self.output_key = config.output_key 24 | 25 | def map(self, batch: Dict[str, Any], *args, **kwargs) -> Dict[str, Any]: 26 | raise NotImplementedError 27 | -------------------------------------------------------------------------------- /src/lbm/data/mappers/mappers.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from torchvision import transforms 4 | 5 | from .base import BaseMapper 6 | from .mappers_config import ( 7 | KeyRenameMapperConfig, 8 | RescaleMapperConfig, 9 | TorchvisionMapperConfig, 10 | ) 11 | 12 | 13 | class KeyRenameMapper(BaseMapper): 14 | """ 15 | Rename keys in a sample according to a key map 16 | 17 | Args: 18 | 19 | config (KeyRenameMapperConfig): Configuration for the mapper 20 | 21 | Examples 22 | ######## 23 | 24 | 1. Rename keys in a sample according to a key map 25 | 26 | .. code-block:: python 27 | 28 | from cr.data.mappers import KeyRenameMapper, KeyRenameMapperConfig 29 | 30 | config = KeyRenameMapperConfig( 31 | key_map={"old_key": "new_key"} 32 | ) 33 | 34 | mapper = KeyRenameMapper(config) 35 | 36 | sample = {"old_key": 1} 37 | new_sample = mapper(sample) 38 | print(new_sample) # {"new_key": 1} 39 | 40 | 2. Rename keys in a sample according to a key map and a condition key 41 | 42 | .. code-block:: python 43 | 44 | from cr.data.mappers import KeyRenameMapper, KeyRenameMapperConfig 45 | 46 | config = KeyRenameMapperConfig( 47 | key_map={"old_key": "new_key"}, 48 | condition_key="condition", 49 | condition_fn=lambda x: x == 1 50 | ) 51 | 52 | mapper = KeyRenameMapper(config) 53 | 54 | sample = {"old_key": 1, "condition": 1} 55 | new_sample = mapper(sample) 56 | print(new_sample) # {"new_key": 1} 57 | 58 | sample = {"old_key": 1, "condition": 0} 59 | new_sample = mapper(sample) 60 | print(new_sample) # {"old_key": 1} 61 | 62 | ``` 63 | """ 64 | 65 | def __init__(self, config: KeyRenameMapperConfig): 66 | super().__init__(config) 67 | self.key_map = config.key_map 68 | self.condition_key = config.condition_key 69 | self.condition_fn = config.condition_fn 70 | self.else_key_map = config.else_key_map 71 | 72 | def __call__(self, batch: Dict[str, Any], *args, **kwrags): 73 | if self.condition_key is not None: 74 | condition_key = batch[self.condition_key] 75 | if self.condition_fn(condition_key): 76 | for old_key, new_key in self.key_map.items(): 77 | if old_key in batch: 78 | batch[new_key] = batch.pop(old_key) 79 | 80 | elif self.else_key_map is not None: 81 | for old_key, new_key in self.else_key_map.items(): 82 | if old_key in batch: 83 | batch[new_key] = batch.pop(old_key) 84 | 85 | else: 86 | for old_key, new_key in self.key_map.items(): 87 | if old_key in batch: 88 | batch[new_key] = batch.pop(old_key) 89 | return batch 90 | 91 | 92 | class TorchvisionMapper(BaseMapper): 93 | """ 94 | Apply torchvision transforms to a sample 95 | 96 | Args: 97 | 98 | config (TorchvisionMapperConfig): Configuration for the mapper 99 | """ 100 | 101 | def __init__(self, config: TorchvisionMapperConfig): 102 | super().__init__(config) 103 | chained_transforms = [] 104 | for transform, kwargs in zip(config.transforms, config.transforms_kwargs): 105 | transform = getattr(transforms, transform) 106 | chained_transforms.append(transform(**kwargs)) 107 | self.transforms = transforms.Compose(chained_transforms) 108 | 109 | def __call__(self, batch: Dict[str, Any], *args, **kwrags) -> Dict[str, Any]: 110 | if self.key in batch: 111 | batch[self.output_key] = self.transforms(batch[self.key]) 112 | return batch 113 | 114 | 115 | class RescaleMapper(BaseMapper): 116 | """ 117 | Rescale a sample from [0, 1] to [-1, 1] 118 | 119 | Args: 120 | 121 | config (RescaleMapperConfig): Configuration for the mapper 122 | """ 123 | 124 | def __init__(self, config: RescaleMapperConfig): 125 | super().__init__(config) 126 | 127 | def __call__(self, batch: Dict[str, Any], *args, **kwrags) -> Dict[str, Any]: 128 | if isinstance(batch[self.key], list): 129 | tmp = [] 130 | for i, image in enumerate(batch[self.key]): 131 | tmp.append(2 * image - 1) 132 | batch[self.output_key] = tmp 133 | else: 134 | batch[self.output_key] = 2 * batch[self.key] - 1 135 | return batch 136 | -------------------------------------------------------------------------------- /src/lbm/data/mappers/mappers_config.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional 2 | 3 | from pydantic.dataclasses import dataclass 4 | 5 | from ...config import BaseConfig 6 | 7 | 8 | @dataclass 9 | class BaseMapperConfig(BaseConfig): 10 | """ 11 | Base configuration for mappers. 12 | 13 | Args: 14 | 15 | verbose (bool): 16 | If True, print debug information. Defaults to False 17 | 18 | key (Optional[str]): 19 | Key to apply the mapper to. Defaults to None 20 | 21 | output_key (Optional[str]): 22 | Key to store the output of the mapper. Defaults to None 23 | """ 24 | 25 | verbose: bool = False 26 | key: Optional[str] = None 27 | output_key: Optional[str] = None 28 | 29 | 30 | @dataclass 31 | class KeyRenameMapperConfig(BaseMapperConfig): 32 | """ 33 | Rename keys in a sample according to a key map 34 | 35 | Args: 36 | 37 | key_map (Dict[str, str]): Dictionary with the old keys as keys and the new keys as values 38 | condition_key (Optional[str]): Key to use for the condition. Defaults to None 39 | condition_fn (Optional[Callable[[Any], bool]]): Function to use for the condition to be met so 40 | the key map is applied. Defaults to None. 41 | else_key_map (Optional[Dict[str, str]]): Dictionary with the old keys as keys and the new keys as values 42 | if the condition is not met. Defaults to None *i.e.* the original key will be used. 43 | """ 44 | 45 | key_map: Dict[str, str] = None 46 | condition_key: Optional[str] = None 47 | condition_fn: Optional[Callable[[Any], bool]] = None 48 | else_key_map: Optional[Dict[str, str]] = None 49 | 50 | def __post_init__(self): 51 | super().__post_init__() 52 | assert self.key_map is not None, "key_map should be provided" 53 | assert all( 54 | isinstance(old_key, str) and isinstance(new_key, str) 55 | for old_key, new_key in self.key_map.items() 56 | ), "key_map should be a dictionary with string keys and values" 57 | if self.condition_key is not None: 58 | assert self.condition_fn is not None, "condition_fn should be provided" 59 | assert callable(self.condition_fn), "condition_fn should be callable" 60 | if self.condition_fn is not None: 61 | assert self.condition_key is not None, "condition_key should be provided" 62 | assert isinstance( 63 | self.condition_key, str 64 | ), "condition_key should be a string" 65 | if self.else_key_map is not None: 66 | assert all( 67 | isinstance(old_key, str) and isinstance(new_key, str) 68 | for old_key, new_key in self.else_key_map.items() 69 | ), "else_key_map should be a dictionary with string keys and values" 70 | 71 | 72 | @dataclass 73 | class TorchvisionMapperConfig(BaseMapperConfig): 74 | """ 75 | Apply torchvision transforms to a sample 76 | 77 | Args: 78 | 79 | key (str): Key to apply the transforms to 80 | transforms (torchvision.transforms): List of torchvision transforms to apply 81 | transforms_kwargs (Dict[str, Any]): List of kwargs for the transforms 82 | """ 83 | 84 | key: str = "image" 85 | transforms: List[str] = None 86 | transforms_kwargs: List[Dict[str, Any]] = None 87 | 88 | def __post_init__(self): 89 | super().__post_init__() 90 | if self.transforms is None: 91 | self.transforms = [] 92 | if self.transforms_kwargs is None: 93 | self.transforms_kwargs = [] 94 | assert len(self.transforms) == len( 95 | self.transforms_kwargs 96 | ), "Number of transforms and kwargs should be same" 97 | 98 | 99 | @dataclass 100 | class RescaleMapperConfig(BaseMapperConfig): 101 | """ 102 | Rescale a sample from [0, 1] to [-1, 1] 103 | 104 | Args: 105 | 106 | key (str): Key to rescale 107 | """ 108 | 109 | key: str = "image" 110 | -------------------------------------------------------------------------------- /src/lbm/data/mappers/mappers_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Union 2 | 3 | from .base import BaseMapper 4 | 5 | 6 | class MapperWrapper: 7 | """ 8 | Wrapper for the mappers to allow iterating over several mappers in one go. 9 | 10 | Args: 11 | 12 | mappers (Union[List[BaseMapper], None]): List of mappers to apply to the batch 13 | """ 14 | 15 | def __init__( 16 | self, 17 | mappers: Union[List[BaseMapper], None] = None, 18 | ): 19 | self.mappers = mappers 20 | 21 | def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]: 22 | """ 23 | Forward pass through all mappers 24 | 25 | Args: 26 | 27 | batch (Dict[str, Any]): batch of data 28 | """ 29 | for mapper in self.mappers: 30 | batch = mapper(batch) 31 | return batch 32 | -------------------------------------------------------------------------------- /src/lbm/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import evaluate 2 | from .utils import get_model 3 | 4 | __all__ = ["evaluate", "get_model"] 5 | -------------------------------------------------------------------------------- /src/lbm/inference/inference.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import PIL 4 | import torch 5 | from torchvision.transforms import ToPILImage, ToTensor 6 | 7 | from lbm.models.lbm import LBMModel 8 | 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | 12 | ASPECT_RATIOS = { 13 | str(512 / 2048): (512, 2048), 14 | str(1024 / 1024): (1024, 1024), 15 | str(2048 / 512): (2048, 512), 16 | str(896 / 1152): (896, 1152), 17 | str(1152 / 896): (1152, 896), 18 | str(512 / 1920): (512, 1920), 19 | str(640 / 1536): (640, 1536), 20 | str(768 / 1280): (768, 1280), 21 | str(1280 / 768): (1280, 768), 22 | str(1536 / 640): (1536, 640), 23 | str(1920 / 512): (1920, 512), 24 | } 25 | 26 | 27 | @torch.no_grad() 28 | def evaluate( 29 | model: LBMModel, 30 | source_image: PIL.Image.Image, 31 | num_sampling_steps: int = 1, 32 | ): 33 | """ 34 | Evaluate the model on an image coming from the source distribution and generate a new image from the target distribution. 35 | 36 | Args: 37 | model (LBMModel): The model to evaluate. 38 | source_image (PIL.Image.Image): The source image to evaluate the model on. 39 | num_sampling_steps (int): The number of sampling steps to use for the model. 40 | 41 | Returns: 42 | PIL.Image.Image: The generated image. 43 | """ 44 | 45 | ori_h_bg, ori_w_bg = source_image.size 46 | ar_bg = ori_h_bg / ori_w_bg 47 | closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg)) 48 | source_dimensions = ASPECT_RATIOS[closest_ar_bg] 49 | 50 | source_image = source_image.resize(source_dimensions) 51 | 52 | img_pasted_tensor = ToTensor()(source_image).unsqueeze(0) * 2 - 1 53 | batch = { 54 | "source_image": img_pasted_tensor.cuda().to(torch.bfloat16), 55 | } 56 | 57 | z_source = model.vae.encode(batch[model.source_key]) 58 | 59 | output_image = model.sample( 60 | z=z_source, 61 | num_steps=num_sampling_steps, 62 | conditioner_inputs=batch, 63 | max_samples=1, 64 | ).clamp(-1, 1) 65 | 66 | output_image = (output_image[0].float().cpu() + 1) / 2 67 | output_image = ToPILImage()(output_image) 68 | output_image.resize((ori_h_bg, ori_w_bg)) 69 | 70 | return output_image 71 | -------------------------------------------------------------------------------- /src/lbm/inference/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import List, Optional 4 | 5 | import torch 6 | import yaml 7 | from diffusers import FlowMatchEulerDiscreteScheduler 8 | from huggingface_hub import snapshot_download 9 | from safetensors.torch import load_file 10 | 11 | from lbm.models.embedders import ( 12 | ConditionerWrapper, 13 | LatentsConcatEmbedder, 14 | LatentsConcatEmbedderConfig, 15 | ) 16 | from lbm.models.lbm import LBMConfig, LBMModel 17 | from lbm.models.unets import DiffusersUNet2DCondWrapper 18 | from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig 19 | 20 | 21 | def get_model( 22 | model_dir: str, 23 | save_dir: Optional[str] = None, 24 | torch_dtype: torch.dtype = torch.bfloat16, 25 | device: str = "cuda", 26 | ) -> LBMModel: 27 | """Download the model from the model directory using either a local path or a path to HuggingFace Hub 28 | 29 | Args: 30 | model_dir (str): The path to the model directory containing the model weights and config, can be a local path or a path to HuggingFace Hub 31 | save_dir (Optional[str]): The local path to save the model if downloading from HuggingFace Hub. Defaults to None. 32 | torch_dtype (torch.dtype): The torch dtype to use for the model. Defaults to torch.bfloat16. 33 | device (str): The device to use for the model. Defaults to "cuda". 34 | 35 | Returns: 36 | LBMModel: The loaded model 37 | """ 38 | if not os.path.exists(model_dir): 39 | local_dir = snapshot_download( 40 | model_dir, 41 | local_dir=save_dir, 42 | ) 43 | model_dir = local_dir 44 | 45 | model_files = os.listdir(model_dir) 46 | 47 | # check yaml config file is present 48 | yaml_file = [f for f in model_files if f.endswith(".yaml")] 49 | if len(yaml_file) == 0: 50 | raise ValueError("No yaml file found in the model directory.") 51 | 52 | # check safetensors weights file is present 53 | safetensors_files = sorted([f for f in model_files if f.endswith(".safetensors")]) 54 | ckpt_files = sorted([f for f in model_files if f.endswith(".ckpt")]) 55 | if len(safetensors_files) == 0 and len(ckpt_files) == 0: 56 | raise ValueError("No safetensors or ckpt file found in the model directory") 57 | 58 | if len(model_files) == 0: 59 | raise ValueError("No model files found in the model directory") 60 | 61 | with open(os.path.join(model_dir, yaml_file[0]), "r") as f: 62 | config = yaml.safe_load(f) 63 | 64 | model = _get_model_from_config(**config, torch_dtype=torch_dtype) 65 | 66 | if len(safetensors_files) > 0: 67 | logging.info(f"Loading safetensors file: {safetensors_files[-1]}") 68 | sd = load_file(os.path.join(model_dir, safetensors_files[-1])) 69 | model.load_state_dict(sd, strict=True) 70 | elif len(ckpt_files) > 0: 71 | logging.info(f"Loading ckpt file: {ckpt_files[-1]}") 72 | sd = torch.load( 73 | os.path.join(model_dir, ckpt_files[-1]), 74 | map_location="cpu", 75 | )["state_dict"] 76 | sd = {k[6:]: v for k, v in sd.items() if k.startswith("model.")} 77 | model.load_state_dict( 78 | sd, 79 | strict=True, 80 | ) 81 | model.to(device).to(torch_dtype) 82 | 83 | model.eval() 84 | 85 | return model 86 | 87 | 88 | def _get_model_from_config( 89 | backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0", 90 | vae_num_channels: int = 4, 91 | unet_input_channels: int = 4, 92 | timestep_sampling: str = "log_normal", 93 | selected_timesteps: Optional[List[float]] = None, 94 | prob: Optional[List[float]] = None, 95 | conditioning_images_keys: Optional[List[str]] = [], 96 | conditioning_masks_keys: Optional[List[str]] = [], 97 | source_key: str = "source_image", 98 | target_key: str = "source_image_paste", 99 | bridge_noise_sigma: float = 0.0, 100 | logit_mean: float = 0.0, 101 | logit_std: float = 1.0, 102 | pixel_loss_type: str = "lpips", 103 | latent_loss_type: str = "l2", 104 | latent_loss_weight: float = 1.0, 105 | pixel_loss_weight: float = 0.0, 106 | torch_dtype: torch.dtype = torch.bfloat16, 107 | **kwargs, 108 | ): 109 | 110 | conditioners = [] 111 | 112 | denoiser = DiffusersUNet2DCondWrapper( 113 | in_channels=unet_input_channels, # Add downsampled_image 114 | out_channels=vae_num_channels, 115 | center_input_sample=False, 116 | flip_sin_to_cos=True, 117 | freq_shift=0, 118 | down_block_types=[ 119 | "DownBlock2D", 120 | "CrossAttnDownBlock2D", 121 | "CrossAttnDownBlock2D", 122 | ], 123 | mid_block_type="UNetMidBlock2DCrossAttn", 124 | up_block_types=["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], 125 | only_cross_attention=False, 126 | block_out_channels=[320, 640, 1280], 127 | layers_per_block=2, 128 | downsample_padding=1, 129 | mid_block_scale_factor=1, 130 | dropout=0.0, 131 | act_fn="silu", 132 | norm_num_groups=32, 133 | norm_eps=1e-05, 134 | cross_attention_dim=[320, 640, 1280], 135 | transformer_layers_per_block=[1, 2, 10], 136 | reverse_transformer_layers_per_block=None, 137 | encoder_hid_dim=None, 138 | encoder_hid_dim_type=None, 139 | attention_head_dim=[5, 10, 20], 140 | num_attention_heads=None, 141 | dual_cross_attention=False, 142 | use_linear_projection=True, 143 | class_embed_type=None, 144 | addition_embed_type=None, 145 | addition_time_embed_dim=None, 146 | num_class_embeds=None, 147 | upcast_attention=None, 148 | resnet_time_scale_shift="default", 149 | resnet_skip_time_act=False, 150 | resnet_out_scale_factor=1.0, 151 | time_embedding_type="positional", 152 | time_embedding_dim=None, 153 | time_embedding_act_fn=None, 154 | timestep_post_act=None, 155 | time_cond_proj_dim=None, 156 | conv_in_kernel=3, 157 | conv_out_kernel=3, 158 | projection_class_embeddings_input_dim=None, 159 | attention_type="default", 160 | class_embeddings_concat=False, 161 | mid_block_only_cross_attention=None, 162 | cross_attention_norm=None, 163 | addition_embed_type_num_heads=64, 164 | ).to(torch_dtype) 165 | 166 | if conditioning_images_keys != [] or conditioning_masks_keys != []: 167 | 168 | latents_concat_embedder_config = LatentsConcatEmbedderConfig( 169 | image_keys=conditioning_images_keys, 170 | mask_keys=conditioning_masks_keys, 171 | ) 172 | latent_concat_embedder = LatentsConcatEmbedder(latents_concat_embedder_config) 173 | latent_concat_embedder.freeze() 174 | conditioners.append(latent_concat_embedder) 175 | 176 | # Wrap conditioners and set to device 177 | conditioner = ConditionerWrapper( 178 | conditioners=conditioners, 179 | ) 180 | 181 | ## VAE ## 182 | # Get VAE model 183 | vae_config = AutoencoderKLDiffusersConfig( 184 | version=backbone_signature, 185 | subfolder="vae", 186 | tiling_size=(128, 128), 187 | ) 188 | vae = AutoencoderKLDiffusers(vae_config).to(torch_dtype) 189 | vae.freeze() 190 | vae.to(torch_dtype) 191 | 192 | ## Diffusion Model ## 193 | # Get diffusion model 194 | config = LBMConfig( 195 | source_key=source_key, 196 | target_key=target_key, 197 | latent_loss_weight=latent_loss_weight, 198 | latent_loss_type=latent_loss_type, 199 | pixel_loss_type=pixel_loss_type, 200 | pixel_loss_weight=pixel_loss_weight, 201 | timestep_sampling=timestep_sampling, 202 | logit_mean=logit_mean, 203 | logit_std=logit_std, 204 | selected_timesteps=selected_timesteps, 205 | prob=prob, 206 | bridge_noise_sigma=bridge_noise_sigma, 207 | ) 208 | 209 | sampling_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 210 | backbone_signature, 211 | subfolder="scheduler", 212 | ) 213 | 214 | model = LBMModel( 215 | config, 216 | denoiser=denoiser, 217 | sampling_noise_scheduler=sampling_noise_scheduler, 218 | vae=vae, 219 | conditioner=conditioner, 220 | ).to(torch_dtype) 221 | 222 | return model 223 | -------------------------------------------------------------------------------- /src/lbm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gojasper/LBM/b41e14461c90924b83022a95abd05379abba5c4e/src/lbm/models/__init__.py -------------------------------------------------------------------------------- /src/lbm/models/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from .model_config import ModelConfig 3 | 4 | __all__ = ["BaseModel", "ModelConfig"] 5 | -------------------------------------------------------------------------------- /src/lbm/models/base/base_model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .model_config import ModelConfig 7 | 8 | 9 | class BaseModel(nn.Module): 10 | def __init__(self, config: ModelConfig): 11 | nn.Module.__init__(self) 12 | self.config = config 13 | self.input_key = config.input_key 14 | self.device = torch.device("cpu") 15 | self.dtype = torch.float32 16 | 17 | def on_fit_start(self, device: torch.device | None = None, *args, **kwargs): 18 | """Called when the training starts 19 | 20 | Args: 21 | device (Optional[torch.device], optional): The device to use. Usefull to set 22 | relevant parameters on the model and embedder to the right device only 23 | once at the start of the training. Defaults to None. 24 | """ 25 | if device is not None: 26 | self.device = device 27 | self.to(self.device) 28 | 29 | def forward(self, batch: Dict[str, Any], *args, **kwargs): 30 | raise NotImplementedError("forward method is not implemented") 31 | 32 | def freeze(self): 33 | """Freeze the model""" 34 | self.eval() 35 | for param in self.parameters(): 36 | param.requires_grad = False 37 | 38 | def to(self, *args, **kwargs): 39 | device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs) 40 | self = super().to( 41 | device=device, 42 | dtype=dtype, 43 | non_blocking=non_blocking, 44 | ) 45 | 46 | if device is not None: 47 | self.device = device 48 | if dtype is not None: 49 | self.dtype = dtype 50 | return self 51 | 52 | def compute_metrics(self, batch: Dict[str, Any], *args, **kwargs): 53 | """Compute the metrics""" 54 | return {} 55 | 56 | def sample(self, batch: Dict[str, Any], *args, **kwargs): 57 | """Sample from the model""" 58 | return {} 59 | 60 | def log_samples(self, batch: Dict[str, Any], *args, **kwargs): 61 | """Log the samples""" 62 | return None 63 | 64 | def on_train_batch_end(self, batch: Dict[str, Any], *args, **kwargs): 65 | """Update the model an optimization is perforned on a batch.""" 66 | pass 67 | -------------------------------------------------------------------------------- /src/lbm/models/base/model_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ...config import BaseConfig 4 | 5 | 6 | @dataclass 7 | class ModelConfig(BaseConfig): 8 | input_key: str = "image" 9 | -------------------------------------------------------------------------------- /src/lbm/models/embedders/__init__.py: -------------------------------------------------------------------------------- 1 | from .conditioners_wrapper import ConditionerWrapper 2 | from .latents_concat import LatentsConcatEmbedder, LatentsConcatEmbedderConfig 3 | 4 | __all__ = ["LatentsConcatEmbedder", "LatentsConcatEmbedderConfig", "ConditionerWrapper"] 5 | -------------------------------------------------------------------------------- /src/lbm/models/embedders/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_conditioner import BaseConditioner 2 | from .base_conditioner_config import BaseConditionerConfig 3 | 4 | __all__ = ["BaseConditioner", "BaseConditionerConfig"] 5 | -------------------------------------------------------------------------------- /src/lbm/models/embedders/base/base_conditioner.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | import torch 4 | 5 | from ...base.base_model import BaseModel 6 | from .base_conditioner_config import BaseConditionerConfig 7 | 8 | DIM2CONDITIONING = { 9 | 2: "vector", 10 | 3: "crossattn", 11 | 4: "concat", 12 | } 13 | 14 | 15 | class BaseConditioner(BaseModel): 16 | """This is the base class for all the conditioners. This absctacts the conditioning process 17 | 18 | Args: 19 | 20 | config (BaseConditionerConfig): The configuration of the conditioner 21 | 22 | Examples 23 | ######## 24 | 25 | To use the conditioner, you can import the class and use it as follows: 26 | 27 | .. code-block:: python 28 | 29 | from cr.models.embedders import BaseConditioner, BaseConditionerConfig 30 | 31 | # Create the conditioner config 32 | config = BaseConditionerConfig( 33 | input_key="text", # The key for the input 34 | unconditional_conditioning_rate=0.3, # Drops the conditioning with 30% probability during training 35 | ) 36 | 37 | # Create the conditioner 38 | conditioner = BaseConditioner(config) 39 | """ 40 | 41 | def __init__(self, config: BaseConditionerConfig): 42 | BaseModel.__init__(self, config) 43 | self.config = config 44 | self.input_key = config.input_key 45 | self.dim2outputkey = DIM2CONDITIONING 46 | self.ucg_rate = config.unconditional_conditioning_rate 47 | 48 | def forward( 49 | self, batch: Dict[str, Any], force_zero_embedding: bool = False, *args, **kwargs 50 | ): 51 | """ 52 | Forward pass of the embedder. 53 | 54 | Args: 55 | 56 | batch (Dict[str, Any]): A dictionary containing the input data. 57 | force_zero_embedding (bool): Whether to force zero embedding. 58 | This will return an embedding with all entries set to 0. Defaults to False. 59 | """ 60 | raise NotImplementedError("Forward pass must be implemented in child class") 61 | -------------------------------------------------------------------------------- /src/lbm/models/embedders/base/base_conditioner_config.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from pydantic.dataclasses import dataclass 4 | 5 | from ....config import BaseConfig 6 | 7 | 8 | @dataclass 9 | class BaseConditionerConfig(BaseConfig): 10 | """This is the ClipEmbedderConfig class which defines all the useful parameters to instantiate the model 11 | 12 | Args: 13 | 14 | input_key (str): The key for the input. Defaults to "text". 15 | unconditional_conditioning_rate (float): Drops the conditioning with this probability during training. Defaults to 0.0. 16 | """ 17 | 18 | input_key: str = "text" 19 | unconditional_conditioning_rate: float = 0.0 20 | 21 | def __post_init__(self): 22 | super().__post_init__() 23 | 24 | assert ( 25 | self.unconditional_conditioning_rate >= 0.0 26 | and self.unconditional_conditioning_rate <= 1.0 27 | ), "Unconditional conditioning rate should be between 0 and 1" 28 | -------------------------------------------------------------------------------- /src/lbm/models/embedders/conditioners_wrapper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Dict, List, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .base import BaseConditioner 8 | 9 | KEY2CATDIM = { 10 | "vector": 1, 11 | "crossattn": 2, 12 | "concat": 1, 13 | } 14 | 15 | 16 | class ConditionerWrapper(nn.Module): 17 | """ 18 | Wrapper for conditioners. This class allows to apply multiple conditioners in a single forward pass. 19 | 20 | Args: 21 | 22 | conditioners (List[BaseConditioner]): List of conditioners to apply in the forward pass. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | conditioners: Union[List[BaseConditioner], None] = None, 28 | ): 29 | nn.Module.__init__(self) 30 | self.conditioners = nn.ModuleList(conditioners) 31 | self.device = torch.device("cpu") 32 | self.dtype = torch.float32 33 | 34 | def conditioner_sanity_check(self): 35 | cond_input_keys = [] 36 | for conditioner in self.conditioners: 37 | cond_input_keys.append(conditioner.input_key) 38 | 39 | assert all([key in set(cond_input_keys) for key in self.ucg_keys]) 40 | 41 | def on_fit_start(self, device: torch.device | None = None, *args, **kwargs): 42 | """Called when the training starts""" 43 | for conditioner in self.conditioners: 44 | conditioner.on_fit_start(device=device, *args, **kwargs) 45 | 46 | def forward( 47 | self, 48 | batch: Dict[str, Any], 49 | ucg_keys: List[str] = None, 50 | set_ucg_rate_zero=False, 51 | *args, 52 | **kwargs, 53 | ): 54 | """ 55 | Forward pass through all conditioners 56 | 57 | Args: 58 | 59 | batch: batch of data 60 | ucg_keys: keys to use for ucg. This will force zero conditioning in all the 61 | conditioners that have input_keys in ucg_keys 62 | set_ucg_rate_zero: set the ucg rate to zero for all the conditioners except the ones in ucg_keys 63 | 64 | Returns: 65 | 66 | Dict[str, Any]: The output of the conditioner. The output of the conditioner is a dictionary with the main key "cond" and value 67 | is a dictionary with the keys as the type of conditioning and the value as the conditioning tensor. 68 | """ 69 | if ucg_keys is None: 70 | ucg_keys = [] 71 | wrapper_outputs = dict(cond={}) 72 | for conditioner in self.conditioners: 73 | if conditioner.input_key in ucg_keys: 74 | force_zero_embedding = True 75 | elif conditioner.ucg_rate > 0 and not set_ucg_rate_zero: 76 | force_zero_embedding = bool(torch.rand(1) < conditioner.ucg_rate) 77 | else: 78 | force_zero_embedding = False 79 | 80 | conditioner_output = conditioner.forward( 81 | batch, force_zero_embedding=force_zero_embedding, *args, **kwargs 82 | ) 83 | logging.debug( 84 | f"conditioner:{conditioner.__class__.__name__}, input_key:{conditioner.input_key}, force_ucg_zero_embedding:{force_zero_embedding}" 85 | ) 86 | for key in conditioner_output: 87 | logging.debug( 88 | f"conditioner_output:{key}:{conditioner_output[key].shape}" 89 | ) 90 | if key in wrapper_outputs["cond"]: 91 | wrapper_outputs["cond"][key] = torch.cat( 92 | [wrapper_outputs["cond"][key], conditioner_output[key]], 93 | KEY2CATDIM[key], 94 | ) 95 | else: 96 | wrapper_outputs["cond"][key] = conditioner_output[key] 97 | 98 | return wrapper_outputs 99 | 100 | def to(self, *args, **kwargs): 101 | """ 102 | Move all conditioners to device and dtype 103 | """ 104 | device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs) 105 | self = super().to(device=device, dtype=dtype, non_blocking=non_blocking) 106 | for conditioner in self.conditioners: 107 | conditioner.to(device=device, dtype=dtype, non_blocking=non_blocking) 108 | 109 | if device is not None: 110 | self.device = device 111 | if dtype is not None: 112 | self.dtype = dtype 113 | 114 | return self 115 | -------------------------------------------------------------------------------- /src/lbm/models/embedders/latents_concat/__init__.py: -------------------------------------------------------------------------------- 1 | from .latents_concat_embedder_config import LatentsConcatEmbedderConfig 2 | from .latents_concat_embedder_model import LatentsConcatEmbedder 3 | 4 | __all__ = ["LatentsConcatEmbedder", "LatentsConcatEmbedderConfig"] 5 | -------------------------------------------------------------------------------- /src/lbm/models/embedders/latents_concat/latents_concat_embedder_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field 2 | from typing import List, Union 3 | 4 | from pydantic.dataclasses import dataclass 5 | 6 | from ..base import BaseConditionerConfig 7 | 8 | 9 | @dataclass 10 | class LatentsConcatEmbedderConfig(BaseConditionerConfig): 11 | """ 12 | Configs for the LatentsConcatEmbedder embedder 13 | 14 | Args: 15 | image_keys (Union[List[str], None]): Keys of the images to compute the VAE embeddings 16 | mask_keys (Union[List[str], None]): Keys of the masks to resize 17 | """ 18 | 19 | image_keys: Union[List[str], None] = field(default_factory=lambda: ["image"]) 20 | mask_keys: Union[List[str], None] = field(default_factory=lambda: ["mask"]) 21 | 22 | def __post_init__(self): 23 | super().__post_init__() 24 | 25 | # Make sure that at least one of the image_keys or mask_keys is provided 26 | assert (self.image_keys is not None) or ( 27 | self.mask_keys is not None 28 | ), "At least one of the image_keys or mask_keys must be provided." 29 | 30 | self.image_keys = self.image_keys if self.image_keys is not None else [] 31 | self.mask_keys = self.mask_keys if self.mask_keys is not None else [] 32 | -------------------------------------------------------------------------------- /src/lbm/models/embedders/latents_concat/latents_concat_embedder_model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import torch 4 | import torchvision.transforms.functional as F 5 | 6 | from lbm.models.vae import AutoencoderKLDiffusers 7 | 8 | from ..base import BaseConditioner 9 | from .latents_concat_embedder_config import LatentsConcatEmbedderConfig 10 | 11 | 12 | class LatentsConcatEmbedder(BaseConditioner): 13 | """ 14 | Class computing VAE embeddings from given images and resizing the masks. 15 | Then outputs are then concatenated to the noise in the latent space. 16 | 17 | Args: 18 | config (LatentsConcatEmbedderConfig): Configs to create the embedder 19 | """ 20 | 21 | def __init__(self, config: LatentsConcatEmbedderConfig): 22 | BaseConditioner.__init__(self, config) 23 | 24 | def forward( 25 | self, batch: Dict[str, Any], vae: AutoencoderKLDiffusers, *args, **kwargs 26 | ) -> dict: 27 | """ 28 | Args: 29 | batch (dict): A batch of images to be processed by this embedder. In the batch, 30 | the images must range between [-1, 1] and the masks range between [0, 1]. 31 | vae (AutoencoderKLDiffusers): VAE 32 | 33 | Returns: 34 | output (dict): outputs 35 | """ 36 | 37 | # Check if image are of the same size 38 | dims_list = [] 39 | for image_key in self.config.image_keys: 40 | dims_list.append(batch[image_key].shape[-2:]) 41 | for mask_key in self.config.mask_keys: 42 | dims_list.append(batch[mask_key].shape[-2:]) 43 | assert all( 44 | dims == dims_list[0] for dims in dims_list 45 | ), "All images and masks must have the same dimensions." 46 | 47 | # Find the latent dimensions 48 | if len(self.config.image_keys) > 0: 49 | latent_dims = ( 50 | batch[self.config.image_keys[0]].shape[-2] // vae.downsampling_factor, 51 | batch[self.config.image_keys[0]].shape[-1] // vae.downsampling_factor, 52 | ) 53 | else: 54 | latent_dims = ( 55 | batch[self.config.mask_keys[0]].shape[-2] // vae.downsampling_factor, 56 | batch[self.config.mask_keys[0]].shape[-1] // vae.downsampling_factor, 57 | ) 58 | 59 | outputs = [] 60 | 61 | # Resize the masks and concat them 62 | for mask_key in self.config.mask_keys: 63 | curr_latents = F.resize( 64 | batch[mask_key], 65 | size=latent_dims, 66 | interpolation=F.InterpolationMode.BILINEAR, 67 | ) 68 | outputs.append(curr_latents) 69 | 70 | # Compute VAE embeddings from the images 71 | for image_key in self.config.image_keys: 72 | vae_embs = vae.encode(batch[image_key]) 73 | outputs.append(vae_embs) 74 | 75 | # Concat all the outputs 76 | outputs = torch.concat(outputs, dim=1) 77 | 78 | outputs = {self.dim2outputkey[outputs.dim()]: outputs} 79 | 80 | return outputs 81 | -------------------------------------------------------------------------------- /src/lbm/models/lbm/__init__.py: -------------------------------------------------------------------------------- 1 | from .lbm_config import LBMConfig 2 | from .lbm_model import LBMModel 3 | 4 | __all__ = ["LBMModel", "LBMConfig"] 5 | -------------------------------------------------------------------------------- /src/lbm/models/lbm/lbm_config.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, Optional, Tuple 2 | 3 | from pydantic.dataclasses import dataclass 4 | 5 | from ..base import ModelConfig 6 | 7 | 8 | @dataclass 9 | class LBMConfig(ModelConfig): 10 | """This is the Config for LBM Model class which defines all the useful parameters to be used in the model. 11 | 12 | Args: 13 | 14 | source_key (str): 15 | Key for the source image. Defaults to "source_image" 16 | 17 | target_key (str): 18 | Key for the target image. Defaults to "target_image" 19 | 20 | mask_key (Optional[str]): 21 | Key for the mask showing the valid pixels. Defaults to None 22 | 23 | latent_loss_type (str): 24 | Loss type to use. Defaults to "l2". Choices are "l2", "l1" 25 | 26 | pixel_loss_type (str): 27 | Pixel loss type to use. Defaults to "l2". Choices are "l2", "l1", "lpips" 28 | 29 | pixel_loss_max_size (int): 30 | Maximum size of the image for pixel loss. 31 | The image will be cropped to this size to reduce decoding computation cost. Defaults to 512 32 | 33 | pixel_loss_weight (float): 34 | Weight of the pixel loss. Defaults to 0.0 35 | 36 | timestep_sampling (str): 37 | Timestep sampling to use. Defaults to "uniform". Choices are "uniform" 38 | 39 | input_key (str): 40 | Key for the input. Defaults to "image" 41 | 42 | controlnet_input_key (str): 43 | Key for the controlnet conditioning. Defaults to "controlnet_conditioning" 44 | 45 | adapter_input_key (str): 46 | Key for the adapter conditioning. Defaults to "adapter_conditioning" 47 | 48 | ucg_keys (Optional[List[str]]): 49 | List of keys for which we enforce zero_conditioning during Classifier-free guidance. Defaults to None 50 | 51 | prediction_type (str): 52 | Type of prediction to use. Defaults to "epsilon". Choices are "epsilon", "v_prediction", "flow 53 | 54 | logit_mean (Optional[float]): 55 | Mean of the logit for the log normal distribution. Defaults to 0.0 56 | 57 | logit_std (Optional[float]): 58 | Standard deviation of the logit for the log normal distribution. Defaults to 1.0 59 | 60 | guidance_scale (Optional[float]): 61 | The guidance scale. Useful for finetunning guidance distilled diffusion models. Defaults to None 62 | 63 | selected_timesteps (Optional[List[float]]): 64 | List of selected timesteps to be sampled from if using `custom_timesteps` timestep sampling. Defaults to None 65 | 66 | prob (Optional[List[float]]): 67 | List of probabilities for the selected timesteps if using `custom_timesteps` timestep sampling. Defaults to None 68 | """ 69 | 70 | source_key: str = "source_image" 71 | target_key: str = "target_image" 72 | mask_key: Optional[str] = None 73 | latent_loss_weight: float = 1.0 74 | latent_loss_type: Literal["l2", "l1"] = "l2" 75 | pixel_loss_type: Literal["l2", "l1", "lpips"] = "l2" 76 | pixel_loss_max_size: int = 512 77 | pixel_loss_weight: float = 0.0 78 | timestep_sampling: Literal["uniform", "log_normal", "custom_timesteps"] = "uniform" 79 | logit_mean: Optional[float] = 0.0 80 | logit_std: Optional[float] = 1.0 81 | selected_timesteps: Optional[List[float]] = None 82 | prob: Optional[List[float]] = None 83 | bridge_noise_sigma: float = 0.001 84 | 85 | def __post_init__(self): 86 | super().__post_init__() 87 | if self.timestep_sampling == "log_normal": 88 | assert isinstance(self.logit_mean, float) and isinstance( 89 | self.logit_std, float 90 | ), "logit_mean and logit_std should be float for log_normal timestep sampling" 91 | 92 | if self.timestep_sampling == "custom_timesteps": 93 | assert isinstance(self.selected_timesteps, list) and isinstance( 94 | self.prob, list 95 | ), "timesteps and prob should be list for custom_timesteps timestep sampling" 96 | assert len(self.selected_timesteps) == len( 97 | self.prob 98 | ), "timesteps and prob should be of same length for custom_timesteps timestep sampling" 99 | assert ( 100 | sum(self.prob) == 1 101 | ), "prob should sum to 1 for custom_timesteps timestep sampling" 102 | -------------------------------------------------------------------------------- /src/lbm/models/lbm/lbm_model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import lpips 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 8 | from tqdm import tqdm 9 | 10 | from ..base.base_model import BaseModel 11 | from ..embedders import ConditionerWrapper 12 | from ..unets import DiffusersUNet2DCondWrapper, DiffusersUNet2DWrapper 13 | from ..vae import AutoencoderKLDiffusers 14 | from .lbm_config import LBMConfig 15 | 16 | 17 | class LBMModel(BaseModel): 18 | """This is the LBM class which defines the model. 19 | 20 | Args: 21 | 22 | config (LBMConfig): 23 | Configuration for the model 24 | 25 | denoiser (Union[DiffusersUNet2DWrapper, DiffusersTransformer2DWrapper]): 26 | Denoiser to use for the diffusion model. Defaults to None 27 | 28 | training_noise_scheduler (EulerDiscreteScheduler): 29 | Noise scheduler to use for training. Defaults to None 30 | 31 | sampling_noise_scheduler (EulerDiscreteScheduler): 32 | Noise scheduler to use for sampling. Defaults to None 33 | 34 | vae (AutoencoderKLDiffusers): 35 | VAE to use for the diffusion model. Defaults to None 36 | 37 | conditioner (ConditionerWrapper): 38 | Conditioner to use for the diffusion model. Defaults to None 39 | """ 40 | 41 | @classmethod 42 | def load_from_config(cls, config: LBMConfig): 43 | return cls(config=config) 44 | 45 | def __init__( 46 | self, 47 | config: LBMConfig, 48 | denoiser: Union[ 49 | DiffusersUNet2DWrapper, 50 | DiffusersUNet2DCondWrapper, 51 | ] = None, 52 | training_noise_scheduler: FlowMatchEulerDiscreteScheduler = None, 53 | sampling_noise_scheduler: FlowMatchEulerDiscreteScheduler = None, 54 | vae: AutoencoderKLDiffusers = None, 55 | conditioner: ConditionerWrapper = None, 56 | ): 57 | BaseModel.__init__(self, config) 58 | 59 | self.vae = vae 60 | self.denoiser = denoiser 61 | self.conditioner = conditioner 62 | self.sampling_noise_scheduler = sampling_noise_scheduler 63 | self.training_noise_scheduler = training_noise_scheduler 64 | self.timestep_sampling = config.timestep_sampling 65 | self.latent_loss_type = config.latent_loss_type 66 | self.latent_loss_weight = config.latent_loss_weight 67 | self.pixel_loss_type = config.pixel_loss_type 68 | self.pixel_loss_max_size = config.pixel_loss_max_size 69 | self.pixel_loss_weight = config.pixel_loss_weight 70 | self.logit_mean = config.logit_mean 71 | self.logit_std = config.logit_std 72 | self.prob = config.prob 73 | self.selected_timesteps = config.selected_timesteps 74 | self.source_key = config.source_key 75 | self.target_key = config.target_key 76 | self.mask_key = config.mask_key 77 | self.bridge_noise_sigma = config.bridge_noise_sigma 78 | 79 | self.num_iterations = nn.Parameter( 80 | torch.tensor(0, dtype=torch.float32), requires_grad=False 81 | ) 82 | if self.pixel_loss_type == "lpips" and self.pixel_loss_weight > 0: 83 | self.lpips_loss = lpips.LPIPS(net="vgg") 84 | 85 | else: 86 | self.lpips_loss = None 87 | 88 | def on_fit_start(self, device: torch.device | None = None, *args, **kwargs): 89 | """Called when the training starts""" 90 | super().on_fit_start(device=device, *args, **kwargs) 91 | if self.vae is not None: 92 | self.vae.on_fit_start(device=device, *args, **kwargs) 93 | if self.conditioner is not None: 94 | self.conditioner.on_fit_start(device=device, *args, **kwargs) 95 | 96 | def forward(self, batch: Dict[str, Any], step=0, batch_idx=0, *args, **kwargs): 97 | 98 | self.num_iterations += 1 99 | 100 | # Get inputs/latents 101 | if self.vae is not None: 102 | vae_inputs = batch[self.target_key] 103 | z = self.vae.encode(vae_inputs) 104 | downsampling_factor = self.vae.downsampling_factor 105 | else: 106 | z = batch[self.target_key] 107 | downsampling_factor = 1 108 | 109 | if self.mask_key in batch: 110 | valid_mask = batch[self.mask_key].bool()[:, 0, :, :].unsqueeze(1) 111 | invalid_mask = ~valid_mask 112 | valid_mask_for_latent = ~torch.max_pool2d( 113 | invalid_mask.float(), 114 | downsampling_factor, 115 | downsampling_factor, 116 | ).bool() 117 | valid_mask_for_latent = valid_mask_for_latent.repeat((1, z.shape[1], 1, 1)) 118 | 119 | else: 120 | valid_mask = torch.ones_like(batch[self.target_key]).bool() 121 | valid_mask_for_latent = torch.ones_like(z).bool() 122 | 123 | source_image = batch[self.source_key] 124 | source_image = torch.nn.functional.interpolate( 125 | source_image, 126 | size=batch[self.target_key].shape[-2:], 127 | mode="bilinear", 128 | align_corners=False, 129 | ).to(z.dtype) 130 | if self.vae is not None: 131 | z_source = self.vae.encode(source_image) 132 | 133 | else: 134 | z_source = source_image 135 | 136 | # Get conditionings 137 | conditioning = self._get_conditioning(batch, *args, **kwargs) 138 | 139 | # Sample a timestep 140 | timestep = self._timestep_sampling(n_samples=z.shape[0], device=z.device) 141 | sigmas = None 142 | 143 | # Create interpolant 144 | sigmas = self._get_sigmas( 145 | self.training_noise_scheduler, timestep, n_dim=4, device=z.device 146 | ) 147 | noisy_sample = ( 148 | sigmas * z_source 149 | + (1.0 - sigmas) * z 150 | + self.bridge_noise_sigma 151 | * (sigmas * (1.0 - sigmas)) ** 0.5 152 | * torch.randn_like(z) 153 | ) 154 | 155 | for i, t in enumerate(timestep): 156 | if t.item() == self.training_noise_scheduler.timesteps[0]: 157 | noisy_sample[i] = z_source[i] 158 | 159 | # Predict noise level using denoiser 160 | prediction = self.denoiser( 161 | sample=noisy_sample, 162 | timestep=timestep, 163 | conditioning=conditioning, 164 | *args, 165 | **kwargs, 166 | ) 167 | 168 | target = z_source - z 169 | denoised_sample = noisy_sample - prediction * sigmas 170 | target_pixels = batch[self.target_key] 171 | 172 | # Compute loss 173 | if self.latent_loss_weight > 0: 174 | loss = self.latent_loss(prediction, target.detach(), valid_mask_for_latent) 175 | latent_recon_loss = loss.mean() 176 | 177 | else: 178 | loss = torch.zeros(z.shape[0], device=z.device) 179 | latent_recon_loss = torch.zeros_like(loss) 180 | 181 | if self.pixel_loss_weight > 0: 182 | denoised_sample = self._predicted_x_0( 183 | model_output=prediction, 184 | sample=noisy_sample, 185 | sigmas=sigmas, 186 | ) 187 | pixel_loss = self.pixel_loss( 188 | denoised_sample, target_pixels.detach(), valid_mask 189 | ) 190 | loss += self.pixel_loss_weight * pixel_loss 191 | 192 | else: 193 | pixel_loss = torch.zeros_like(latent_recon_loss) 194 | 195 | return { 196 | "loss": loss.mean(), 197 | "latent_recon_loss": latent_recon_loss, 198 | "pixel_recon_loss": pixel_loss.mean(), 199 | "predicted_hr": denoised_sample, 200 | "noisy_sample": noisy_sample, 201 | } 202 | 203 | def latent_loss(self, prediction, model_input, valid_latent_mask): 204 | if self.latent_loss_type == "l2": 205 | return torch.mean( 206 | ( 207 | (prediction * valid_latent_mask - model_input * valid_latent_mask) 208 | ** 2 209 | ).reshape(model_input.shape[0], -1), 210 | 1, 211 | ) 212 | elif self.latent_loss_type == "l1": 213 | return torch.mean( 214 | torch.abs( 215 | prediction * valid_latent_mask - model_input * valid_latent_mask 216 | ).reshape(model_input.shape[0], -1), 217 | 1, 218 | ) 219 | else: 220 | raise NotImplementedError( 221 | f"Loss type {self.latent_loss_type} not implemented" 222 | ) 223 | 224 | def pixel_loss(self, prediction, model_input, valid_mask): 225 | 226 | latent_crop = self.pixel_loss_max_size // self.vae.downsampling_factor 227 | input_crop = self.pixel_loss_max_size 228 | 229 | crop_h = max((prediction.shape[2] - latent_crop), 0) 230 | crop_w = max((prediction.shape[3] - latent_crop), 0) 231 | 232 | input_crop_h = max((model_input.shape[2] - self.pixel_loss_max_size), 0) 233 | input_crop_w = max((model_input.shape[3] - self.pixel_loss_max_size), 0) 234 | 235 | # image random cropping 236 | if crop_h == 0: 237 | offset_h = 0 238 | else: 239 | offset_h = torch.randint(0, crop_h, (1,)).item() 240 | 241 | if crop_w == 0: 242 | offset_w = 0 243 | else: 244 | offset_w = torch.randint(0, crop_w, (1,)).item() 245 | input_offset_h = offset_h * self.vae.downsampling_factor 246 | input_offset_w = offset_w * self.vae.downsampling_factor 247 | 248 | prediction = prediction[ 249 | :, 250 | :, 251 | crop_h 252 | - offset_h : min(crop_h - offset_h + latent_crop, prediction.shape[2]), 253 | crop_w 254 | - offset_w : min(crop_w - offset_w + latent_crop, prediction.shape[3]), 255 | ] 256 | 257 | model_input = model_input[ 258 | :, 259 | :, 260 | input_crop_h 261 | - input_offset_h : min( 262 | input_crop_h - input_offset_h + input_crop, model_input.shape[2] 263 | ), 264 | input_crop_w 265 | - input_offset_w : min( 266 | input_crop_w - input_offset_w + input_crop, model_input.shape[3] 267 | ), 268 | ] 269 | 270 | valid_mask = valid_mask[ 271 | :, 272 | :, 273 | input_crop_h 274 | - input_offset_h : min( 275 | input_crop_h - input_offset_h + input_crop, valid_mask.shape[2] 276 | ), 277 | input_crop_w 278 | - input_offset_w : min( 279 | input_crop_w - input_offset_w + input_crop, valid_mask.shape[3] 280 | ), 281 | ] 282 | 283 | decoded_prediction = self.vae.decode(prediction).clamp(-1, 1) 284 | 285 | if self.pixel_loss_type == "l2": 286 | return torch.mean( 287 | ( 288 | (decoded_prediction * valid_mask - model_input * valid_mask) ** 2 289 | ).reshape(model_input.shape[0], -1), 290 | 1, 291 | ) 292 | 293 | elif self.pixel_loss_type == "l1": 294 | return torch.mean( 295 | torch.abs( 296 | decoded_prediction * valid_mask - model_input * valid_mask 297 | ).reshape(model_input.shape[0], -1), 298 | 1, 299 | ) 300 | 301 | elif self.pixel_loss_type == "lpips": 302 | return self.lpips_loss( 303 | decoded_prediction * valid_mask, model_input * valid_mask 304 | ).mean() 305 | 306 | def _get_conditioning( 307 | self, 308 | batch: Dict[str, Any], 309 | ucg_keys: List[str] = None, 310 | set_ucg_rate_zero=False, 311 | *args, 312 | **kwargs, 313 | ): 314 | """ 315 | Get the conditionings 316 | """ 317 | if self.conditioner is not None: 318 | return self.conditioner( 319 | batch, 320 | ucg_keys=ucg_keys, 321 | set_ucg_rate_zero=set_ucg_rate_zero, 322 | vae=self.vae, 323 | *args, 324 | **kwargs, 325 | ) 326 | else: 327 | return None 328 | 329 | def _timestep_sampling(self, n_samples=1, device="cpu"): 330 | if self.timestep_sampling == "uniform": 331 | idx = torch.randint( 332 | 0, 333 | self.training_noise_scheduler.config.num_train_timesteps, 334 | (n_samples,), 335 | device="cpu", 336 | ) 337 | return self.training_noise_scheduler.timesteps[idx].to(device=device) 338 | 339 | elif self.timestep_sampling == "log_normal": 340 | u = torch.normal( 341 | mean=self.logit_mean, 342 | std=self.logit_std, 343 | size=(n_samples,), 344 | device="cpu", 345 | ) 346 | u = torch.nn.functional.sigmoid(u) 347 | indices = ( 348 | u * self.training_noise_scheduler.config.num_train_timesteps 349 | ).long() 350 | return self.training_noise_scheduler.timesteps[indices].to(device=device) 351 | 352 | elif self.timestep_sampling == "custom_timesteps": 353 | idx = np.random.choice(len(self.selected_timesteps), n_samples, p=self.prob) 354 | 355 | return torch.tensor( 356 | self.selected_timesteps, device=device, dtype=torch.long 357 | )[idx] 358 | 359 | def _predicted_x_0( 360 | self, 361 | model_output, 362 | sample, 363 | sigmas=None, 364 | ): 365 | """ 366 | Predict x_0, the orinal denoised sample, using the model output and the timesteps depending on the prediction type. 367 | """ 368 | pred_x_0 = sample - model_output * sigmas 369 | return pred_x_0 370 | 371 | def _get_sigmas( 372 | self, scheduler, timesteps, n_dim=4, dtype=torch.float32, device="cpu" 373 | ): 374 | sigmas = scheduler.sigmas.to(device=device, dtype=dtype) 375 | schedule_timesteps = scheduler.timesteps.to(device) 376 | timesteps = timesteps.to(device) 377 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 378 | 379 | sigma = sigmas[step_indices].flatten() 380 | while len(sigma.shape) < n_dim: 381 | sigma = sigma.unsqueeze(-1) 382 | return sigma 383 | 384 | @torch.no_grad() 385 | def sample( 386 | self, 387 | z: torch.Tensor, 388 | num_steps: int = 20, 389 | conditioner_inputs: Optional[Dict[str, Any]] = None, 390 | max_samples: Optional[int] = None, 391 | verbose: bool = False, 392 | ): 393 | self.sampling_noise_scheduler.set_timesteps( 394 | sigmas=np.linspace(1, 1 / num_steps, num_steps) 395 | ) 396 | 397 | sample = z 398 | 399 | # Get conditioning 400 | conditioning = self._get_conditioning( 401 | conditioner_inputs, set_ucg_rate_zero=True, device=z.device 402 | ) 403 | 404 | # If max_samples parameter is provided, limit the number of samples 405 | if max_samples is not None: 406 | sample = sample[:max_samples] 407 | 408 | if conditioning: 409 | conditioning["cond"] = { 410 | k: v[:max_samples] for k, v in conditioning["cond"].items() 411 | } 412 | 413 | for i, t in tqdm( 414 | enumerate(self.sampling_noise_scheduler.timesteps), disable=not verbose 415 | ): 416 | if hasattr(self.sampling_noise_scheduler, "scale_model_input"): 417 | denoiser_input = self.sampling_noise_scheduler.scale_model_input( 418 | sample, t 419 | ) 420 | 421 | else: 422 | denoiser_input = sample 423 | 424 | # Predict noise level using denoiser using conditionings 425 | pred = self.denoiser( 426 | sample=denoiser_input, 427 | timestep=t.to(z.device).repeat(denoiser_input.shape[0]), 428 | conditioning=conditioning, 429 | ) 430 | 431 | # Make one step on the reverse diffusion process 432 | sample = self.sampling_noise_scheduler.step( 433 | pred, t, sample, return_dict=False 434 | )[0] 435 | if i < len(self.sampling_noise_scheduler.timesteps) - 1: 436 | timestep = ( 437 | self.sampling_noise_scheduler.timesteps[i + 1] 438 | .to(z.device) 439 | .repeat(sample.shape[0]) 440 | ) 441 | sigmas = self._get_sigmas( 442 | self.sampling_noise_scheduler, timestep, n_dim=4, device=z.device 443 | ) 444 | sample = sample + self.bridge_noise_sigma * ( 445 | sigmas * (1.0 - sigmas) 446 | ) ** 0.5 * torch.randn_like(sample) 447 | sample = sample.to(z.dtype) 448 | 449 | if self.vae is not None: 450 | decoded_sample = self.vae.decode(sample) 451 | 452 | else: 453 | decoded_sample = sample 454 | 455 | return decoded_sample 456 | 457 | def log_samples( 458 | self, 459 | batch: Dict[str, Any], 460 | input_shape: Optional[Tuple[int, int, int]] = None, 461 | max_samples: Optional[int] = None, 462 | num_steps: Union[int, List[int]] = 20, 463 | ): 464 | if isinstance(num_steps, int): 465 | num_steps = [num_steps] 466 | 467 | logs = {} 468 | 469 | N = max_samples if max_samples is not None else len(batch[self.source_key]) 470 | 471 | batch = {k: v[:N] for k, v in batch.items()} 472 | 473 | # infer input shape based on VAE configuration if not passed 474 | if input_shape is None: 475 | if self.vae is not None: 476 | # get input pixel size of the vae 477 | input_shape = batch[self.target_key].shape[2:] 478 | # rescale to latent size 479 | input_shape = ( 480 | self.vae.latent_channels, 481 | input_shape[0] // self.vae.downsampling_factor, 482 | input_shape[1] // self.vae.downsampling_factor, 483 | ) 484 | else: 485 | raise ValueError( 486 | "input_shape must be passed when no VAE is used in the model" 487 | ) 488 | 489 | for num_step in num_steps: 490 | source_image = batch[self.source_key] 491 | source_image = torch.nn.functional.interpolate( 492 | source_image, 493 | size=batch[self.target_key].shape[2:], 494 | mode="bilinear", 495 | align_corners=False, 496 | ).to(dtype=self.dtype) 497 | if self.vae is not None: 498 | z = self.vae.encode(source_image) 499 | 500 | else: 501 | z = source_image 502 | 503 | with torch.autocast(dtype=self.dtype, device_type="cuda"): 504 | logs[f"samples_{num_step}_steps"] = self.sample( 505 | z, 506 | num_steps=num_step, 507 | conditioner_inputs=batch, 508 | max_samples=N, 509 | ) 510 | 511 | return logs 512 | -------------------------------------------------------------------------------- /src/lbm/models/unets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains a collection of U-Net models. 3 | The :mod:`cr.models.unets` module includes the following classes: 4 | 5 | - :class:`DiffusersUNet2DWrapper`: A 2D U-Net model for diffusers. 6 | - :class:`DiffusersUNet2DCondWrapper`: A 2D U-Net model for diffusers with conditional input. 7 | """ 8 | 9 | from .unet import DiffusersUNet2DCondWrapper, DiffusersUNet2DWrapper 10 | 11 | __all__ = [ 12 | "DiffusersUNet2DWrapper", 13 | "DiffusersUNet2DCondWrapper", 14 | ] 15 | -------------------------------------------------------------------------------- /src/lbm/models/unets/unet.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | import torch 4 | from diffusers.models import UNet2DConditionModel, UNet2DModel 5 | 6 | 7 | class DiffusersUNet2DWrapper(UNet2DModel): 8 | """ 9 | Wrapper for the UNet2DModel from diffusers 10 | 11 | See diffusers' UNet2DModel for more details 12 | """ 13 | 14 | def __init__(self, *args, **kwargs): 15 | UNet2DModel.__init__(self, *args, **kwargs) 16 | 17 | def forward( 18 | self, 19 | sample: torch.Tensor, 20 | timestep: Union[torch.Tensor, float, int], 21 | conditioning: Dict[str, torch.Tensor] = None, 22 | *args, 23 | **kwargs, 24 | ): 25 | """ 26 | The forward pass of the model 27 | 28 | Args: 29 | 30 | sample (torch.Tensor): The input sample 31 | timesteps (Union[torch.Tensor, float, int]): The number of timesteps 32 | """ 33 | if conditioning is not None: 34 | class_labels = conditioning["cond"].get("vector", None) 35 | concat = conditioning["cond"].get("concat", None) 36 | 37 | else: 38 | class_labels = None 39 | concat = None 40 | 41 | if concat is not None: 42 | sample = torch.cat([sample, concat], dim=1) 43 | 44 | return super().forward(sample, timestep, class_labels).sample 45 | 46 | def freeze(self): 47 | """ 48 | Freeze the model 49 | """ 50 | self.eval() 51 | for param in self.parameters(): 52 | param.requires_grad = False 53 | 54 | 55 | class DiffusersUNet2DCondWrapper(UNet2DConditionModel): 56 | """ 57 | Wrapper for the UNet2DConditionModel from diffusers 58 | 59 | See diffusers' Unet2DConditionModel for more details 60 | """ 61 | 62 | def __init__(self, *args, **kwargs): 63 | UNet2DConditionModel.__init__(self, *args, **kwargs) 64 | # BaseModel.__init__(self, config=ModelConfig()) 65 | 66 | def forward( 67 | self, 68 | sample: torch.Tensor, 69 | timestep: Union[torch.Tensor, float, int], 70 | conditioning: Dict[str, torch.Tensor], 71 | ip_adapter_cond_embedding: Optional[List[torch.Tensor]] = None, 72 | down_block_additional_residuals: torch.Tensor = None, 73 | mid_block_additional_residual: torch.Tensor = None, 74 | down_intrablock_additional_residuals: torch.Tensor = None, 75 | *args, 76 | **kwargs, 77 | ): 78 | """ 79 | The forward pass of the model 80 | 81 | Args: 82 | 83 | sample (torch.Tensor): The input sample 84 | timesteps (Union[torch.Tensor, float, int]): The number of timesteps 85 | conditioning (Dict[str, torch.Tensor]): The conditioning data 86 | down_block_additional_residuals (List[torch.Tensor]): Residuals for the down blocks. 87 | These residuals typically are used for the controlnet. 88 | mid_block_additional_residual (List[torch.Tensor]): Residuals for the mid blocks. 89 | These residuals typically are used for the controlnet. 90 | down_intrablock_additional_residuals (List[torch.Tensor]): Residuals for the down intrablocks. 91 | These residuals typically are used for the T2I adapters.middle block outputs. Defaults to False 92 | """ 93 | 94 | assert isinstance(conditioning, dict), "conditionings must be a dictionary" 95 | # assert "crossattn" in conditioning["cond"], "crossattn must be in conditionings" 96 | 97 | class_labels = conditioning["cond"].get("vector", None) 98 | crossattn = conditioning["cond"].get("crossattn", None) 99 | concat = conditioning["cond"].get("concat", None) 100 | 101 | # concat conditioning 102 | if concat is not None: 103 | sample = torch.cat([sample, concat], dim=1) 104 | 105 | # down_intrablock_additional_residuals needs to be cloned, since unet will modify it 106 | if down_intrablock_additional_residuals is not None: 107 | down_intrablock_additional_residuals_clone = [ 108 | curr_residuals.clone() 109 | for curr_residuals in down_intrablock_additional_residuals 110 | ] 111 | else: 112 | down_intrablock_additional_residuals_clone = None 113 | 114 | # Check diffusers.models.embeddings.py > MultiIPAdapterImageProjectionLayer > forward() for implementation 115 | # Exepected format : List[torch.Tensor] of shape (batch_size, num_image_embeds, embed_dim) 116 | # with length = number of ip_adapters loaded in the ip_adapter_wrapper 117 | if ip_adapter_cond_embedding is not None: 118 | added_cond_kwargs = { 119 | "image_embeds": [ 120 | ip_adapter_embedding.unsqueeze(1) 121 | for ip_adapter_embedding in ip_adapter_cond_embedding 122 | ] 123 | } 124 | else: 125 | added_cond_kwargs = None 126 | 127 | return ( 128 | super() 129 | .forward( 130 | sample=sample, 131 | timestep=timestep, 132 | encoder_hidden_states=crossattn, 133 | class_labels=class_labels, 134 | added_cond_kwargs=added_cond_kwargs, 135 | down_block_additional_residuals=down_block_additional_residuals, 136 | mid_block_additional_residual=mid_block_additional_residual, 137 | down_intrablock_additional_residuals=down_intrablock_additional_residuals_clone, 138 | ) 139 | .sample 140 | ) 141 | 142 | def freeze(self): 143 | """ 144 | Freeze the model 145 | """ 146 | self.eval() 147 | for param in self.parameters(): 148 | param.requires_grad = False 149 | -------------------------------------------------------------------------------- /src/lbm/models/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from copy import deepcopy 4 | from typing import List, Tuple 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | TILING_METHODS = ["average", "gaussian", "linear"] 10 | 11 | 12 | class Tiler: 13 | def get_tiles( 14 | self, 15 | input: torch.Tensor, 16 | tile_size: tuple, 17 | overlap_size: tuple, 18 | scale: int = 1, 19 | out_channels: int = 3, 20 | ) -> List[List[torch.tensor]]: 21 | """Get tiles 22 | Args: 23 | input (torch.Tensor): input array of shape (batch_size, channels, height, width) 24 | tile_size (tuple): tile size 25 | overlap_size (tuple): overlap size 26 | scale (int): scaling factor of the output wrt input 27 | out_channels (int): number of output channels 28 | Returns: 29 | List[List[torch.Tensor]]: List of tiles 30 | """ 31 | # assert isinstance(scale, int) 32 | assert ( 33 | overlap_size[0] <= tile_size[0] 34 | ), f"Overlap size {overlap_size} must be smaller than tile size {tile_size}" 35 | assert ( 36 | overlap_size[1] <= tile_size[1] 37 | ), f"Overlap size {overlap_size} must be smaller than tile size {tile_size}" 38 | 39 | B, C, H, W = input.shape 40 | tile_size_H, tile_size_W = tile_size 41 | 42 | # sets overlap to 0 if the input is smaller than the tile size (i.e. no overlap) 43 | overlap_H, overlap_W = ( 44 | overlap_size[0] if H > tile_size_H else 0, 45 | overlap_size[1] if W > tile_size_W else 0, 46 | ) 47 | 48 | self.output_overlap_size = ( 49 | int(overlap_H * scale), 50 | int(overlap_W * scale), 51 | ) 52 | self.tile_size = tile_size 53 | self.output_tile_size = ( 54 | int(tile_size_H * scale), 55 | int(tile_size_W * scale), 56 | ) 57 | self.output_shape = ( 58 | B, 59 | out_channels, 60 | int(H * scale), 61 | int(W * scale), 62 | ) 63 | tiles = [] 64 | logging.debug(f"(Tiler) Input shape: {(B, C, H, W)}") 65 | logging.debug(f"(Tiler) Output shape: {self.output_shape}") 66 | logging.debug(f"(Tiler) Tile size: {(tile_size_H, tile_size_W)}") 67 | logging.debug(f"(Tiler) Overlap size: {(overlap_H, overlap_W)}") 68 | # loop over all tiles in the image with overlap 69 | for i in range(0, H, tile_size_H - overlap_H): 70 | row = [] 71 | for j in range(0, W, tile_size_W - overlap_W): 72 | tile = deepcopy( 73 | input[ 74 | :, 75 | :, 76 | i : i + tile_size_H, 77 | j : j + tile_size_W, 78 | ] 79 | ) 80 | row.append(tile) 81 | tiles.append(row) 82 | return tiles 83 | 84 | def merge_tiles( 85 | self, tiles: List[List[torch.tensor]], tiling_method: str = "gaussian" 86 | ) -> torch.tensor: 87 | """Merge tiles by averaging the overlaping regions 88 | Args: 89 | tiles (Dict[str, Tile]): dictionary of processed tiles 90 | tiling_method (str): tiling method. Can be "average", "gaussian" or "linear" 91 | Returns: 92 | torch.tensor: output image 93 | """ 94 | if tiling_method == "average": 95 | return self._average_merge_tiles(tiles) 96 | elif tiling_method == "gaussian": 97 | return self._gaussian_merge_tiles(tiles) 98 | elif tiling_method == "linear": 99 | return self._linear_merge_tiles(tiles) 100 | else: 101 | raise ValueError( 102 | f"Unknown tiling method {tiling_method}. Available methods are {TILING_METHODS}" 103 | ) 104 | 105 | def _average_merge_tiles(self, tiles: List[List[torch.tensor]]) -> torch.tensor: 106 | """Merge tiles by averaging the overlaping regions 107 | Args: 108 | tiles (Dict[str, Tile]): dictionary of processed tiles 109 | Returns: 110 | torch.tensor: output image 111 | """ 112 | 113 | output = torch.zeros(self.output_shape) 114 | 115 | # weights to store multiplicity 116 | weights = torch.zeros(self.output_shape) 117 | 118 | _, _, output_H, output_W = self.output_shape 119 | output_overlap_size_H, output_overlap_size_W = self.output_overlap_size 120 | output_tile_size_H, output_tile_size_W = self.output_tile_size 121 | 122 | for id_i, i in enumerate( 123 | range( 124 | 0, 125 | output_H, 126 | output_tile_size_H - output_overlap_size_H, 127 | ) 128 | ): 129 | for id_j, j in enumerate( 130 | range( 131 | 0, 132 | output_W, 133 | output_tile_size_W - output_overlap_size_W, 134 | ) 135 | ): 136 | output[ 137 | :, 138 | :, 139 | i : i + output_tile_size_H, 140 | j : j + output_tile_size_W, 141 | ] += ( 142 | tiles[id_i][id_j] * 1 143 | ) 144 | weights[ 145 | :, 146 | :, 147 | i : i + output_tile_size_H, 148 | j : j + output_tile_size_W, 149 | ] += 1 150 | 151 | # outputs is summed up with this multiplicity 152 | # so we need to divide by the weights wich is either 1, 2 or 4 depending on the region 153 | output = output / weights 154 | return output 155 | 156 | def _gaussian_weights( 157 | self, tile_width: int, tile_height: int, nbatches: int, channels: int 158 | ): 159 | """Generates a gaussian mask of weights for tile contributions. 160 | 161 | Args: 162 | tile_width (int): width of the tile 163 | tile_height (int): height of the tile 164 | nbatches (int): number of batches 165 | channels (int): number of channels 166 | Returns: 167 | torch.tensor: weights 168 | """ 169 | import numpy as np 170 | from numpy import exp, pi, sqrt 171 | 172 | latent_width = tile_width 173 | latent_height = tile_height 174 | 175 | var = 0.01 176 | midpoint = ( 177 | latent_width - 1 178 | ) / 2 # -1 because index goes from 0 to latent_width - 1 179 | x_probs = [ 180 | exp( 181 | -(x - midpoint) 182 | * (x - midpoint) 183 | / (latent_width * latent_width) 184 | / (2 * var) 185 | ) 186 | / sqrt(2 * pi * var) 187 | for x in range(latent_width) 188 | ] 189 | midpoint = latent_height / 2 190 | y_probs = [ 191 | exp( 192 | -(y - midpoint) 193 | * (y - midpoint) 194 | / (latent_height * latent_height) 195 | / (2 * var) 196 | ) 197 | / sqrt(2 * pi * var) 198 | for y in range(latent_height) 199 | ] 200 | 201 | weights = np.outer(y_probs, x_probs) 202 | return torch.tile( 203 | torch.tensor(weights, device="cpu"), (nbatches, channels, 1, 1) 204 | ) 205 | 206 | def _gaussian_merge_tiles(self, tiles: List[List[torch.tensor]]) -> torch.tensor: 207 | """Merge tiles by averaging the overlaping regions 208 | Args: 209 | List[List[torch.tensor]]: List of processed tiles 210 | Returns: 211 | torch.tensor: output image 212 | """ 213 | B, output_C, output_H, output_W = self.output_shape 214 | output_overlap_size_H, output_overlap_size_W = self.output_overlap_size 215 | output_tile_size_H, output_tile_size_W = self.output_tile_size 216 | 217 | output = torch.zeros(self.output_shape) 218 | # weights to store multiplicity 219 | weights = torch.zeros(self.output_shape) 220 | 221 | for id_i, i in enumerate( 222 | range( 223 | 0, 224 | output_H, 225 | output_tile_size_H - output_overlap_size_H, 226 | ) 227 | ): 228 | for id_j, j in enumerate( 229 | range( 230 | 0, 231 | output_W, 232 | output_tile_size_W - output_overlap_size_W, 233 | ) 234 | ): 235 | w = self._gaussian_weights( 236 | tiles[id_i][id_j].shape[3], 237 | tiles[id_i][id_j].shape[2], 238 | B, 239 | output_C, 240 | ) 241 | output[ 242 | :, 243 | :, 244 | i : i + output_tile_size_H, 245 | j : j + output_tile_size_W, 246 | ] += ( 247 | tiles[id_i][id_j] * w 248 | ) 249 | weights[ 250 | :, 251 | :, 252 | i : i + output_tile_size_H, 253 | j : j + output_tile_size_W, 254 | ] += w 255 | 256 | # outputs is summed up with this multiplicity 257 | output = output / weights 258 | return output 259 | 260 | def _blend_v( 261 | self, a: torch.Tensor, b: torch.Tensor, blend_extent: int 262 | ) -> torch.Tensor: 263 | blend_extent = min(a.shape[2], b.shape[2], blend_extent) 264 | for y in range(blend_extent): 265 | b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[ 266 | :, :, y, : 267 | ] * (y / blend_extent) 268 | return b 269 | 270 | def _blend_h( 271 | self, a: torch.Tensor, b: torch.Tensor, blend_extent: int 272 | ) -> torch.Tensor: 273 | blend_extent = min(a.shape[3], b.shape[3], blend_extent) 274 | for x in range(blend_extent): 275 | b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[ 276 | :, :, :, x 277 | ] * (x / blend_extent) 278 | return b 279 | 280 | def _linear_merge_tiles(self, tiles: List[List[torch.tensor]]) -> torch.Tensor: 281 | """Merge tiles by blending the overlaping regions 282 | Args: 283 | tiles (List[List[torch.tensor]]): List of processed tiles 284 | Returns: 285 | torch.Tensor: output image 286 | """ 287 | output_overlap_size_H, output_overlap_size_W = self.output_overlap_size 288 | output_tile_size_H, output_tile_size_W = self.output_tile_size 289 | 290 | res_rows = [] 291 | tiles_copy = deepcopy(tiles) 292 | 293 | # Cut the right and bottom overlap region 294 | limit_i = output_tile_size_H - output_overlap_size_H 295 | limit_j = output_tile_size_W - output_overlap_size_W 296 | for i, tile_row in enumerate(tiles_copy): 297 | res_row = [] 298 | for j, tile in enumerate(tile_row): 299 | tile_val = tile 300 | if j > 0: 301 | tile_val = self._blend_h( 302 | tile_row[j - 1], tile, output_overlap_size_W 303 | ) 304 | tiles_copy[i][j] = tile_val 305 | if i > 0: 306 | tile_val = self._blend_v( 307 | tiles_copy[i - 1][j], tile_val, output_overlap_size_H 308 | ) 309 | tiles_copy[i][j] = tile_val 310 | res_row.append(tile_val[:, :, :limit_i, :limit_j]) 311 | res_rows.append(torch.cat(res_row, dim=3)) 312 | output = torch.cat(res_rows, dim=2) 313 | return output 314 | 315 | 316 | def extract_into_tensor( 317 | a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int, ...] 318 | ) -> torch.Tensor: 319 | """ 320 | Extracts values from a tensor into a new tensor using indices from another tensor. 321 | 322 | :param a: the tensor to extract values from. 323 | :param t: the tensor containing the indices. 324 | :param x_shape: the shape of the tensor to extract values into. 325 | :return: a new tensor containing the extracted values. 326 | """ 327 | 328 | b, *_ = t.shape 329 | out = a.gather(-1, t) 330 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 331 | 332 | 333 | def pad(x: torch.Tensor, base_h: int, base_w: int) -> torch.Tensor: 334 | """ 335 | Pads a tensor to the nearest multiple of base_h and base_w. 336 | 337 | :param x: the tensor to pad. 338 | :param base_h: the base height. 339 | :param base_w: the base width. 340 | :return: the padded tensor. 341 | """ 342 | h, w = x.shape[-2:] 343 | h_ = math.ceil(h / base_h) * base_h 344 | w_ = math.ceil(w / base_w) * base_w 345 | if w_ != w: 346 | x = F.pad(x, (0, abs(w_ - w), 0, 0)) 347 | if h_ != h: 348 | x = F.pad(x, (0, 0, 0, abs(h_ - h))) 349 | return x 350 | 351 | 352 | def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: 353 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 354 | dims_to_append = target_dims - x.ndim 355 | if dims_to_append < 0: 356 | raise ValueError( 357 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" 358 | ) 359 | return x[(...,) + (None,) * dims_to_append] 360 | 361 | 362 | @torch.no_grad() 363 | def update_ema( 364 | target_params: List[torch.Tensor], 365 | source_params: List[torch.Tensor], 366 | rate: float = 0.99, 367 | ): 368 | """ 369 | Update target parameters to be closer to those of source parameters using 370 | an exponential moving average. 371 | 372 | :param target_params: the target parameter sequence. 373 | :param source_params: the source parameter sequence. 374 | :param rate: the EMA rate (closer to 1 means slower). 375 | """ 376 | for targ, src in zip(target_params, source_params): 377 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 378 | -------------------------------------------------------------------------------- /src/lbm/models/vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoderKL import AutoencoderKLDiffusers 2 | from .autoencoderKL_config import AutoencoderKLDiffusersConfig 3 | 4 | __all__ = ["AutoencoderKLDiffusers", "AutoencoderKLDiffusersConfig"] 5 | -------------------------------------------------------------------------------- /src/lbm/models/vae/autoencoderKL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.models import AutoencoderKL 3 | 4 | from ..base.base_model import BaseModel 5 | from ..utils import Tiler, pad 6 | from .autoencoderKL_config import AutoencoderKLDiffusersConfig 7 | 8 | 9 | class AutoencoderKLDiffusers(BaseModel): 10 | """This is the VAE class used to work with latent models 11 | 12 | Args: 13 | 14 | config (AutoencoderKLDiffusersConfig): The config class which defines all the required parameters. 15 | """ 16 | 17 | def __init__(self, config: AutoencoderKLDiffusersConfig): 18 | BaseModel.__init__(self, config) 19 | self.config = config 20 | self.vae_model = AutoencoderKL.from_pretrained( 21 | config.version, 22 | subfolder=config.subfolder, 23 | revision=config.revision, 24 | ) 25 | self.tiling_size = config.tiling_size 26 | self.tiling_overlap = config.tiling_overlap 27 | 28 | # get downsampling factor 29 | self._get_properties() 30 | 31 | @torch.no_grad() 32 | def _get_properties(self): 33 | self.has_shift_factor = ( 34 | hasattr(self.vae_model.config, "shift_factor") 35 | and self.vae_model.config.shift_factor is not None 36 | ) 37 | self.shift_factor = ( 38 | self.vae_model.config.shift_factor if self.has_shift_factor else 0 39 | ) 40 | 41 | # set latent channels 42 | self.latent_channels = self.vae_model.config.latent_channels 43 | self.has_latents_mean = ( 44 | hasattr(self.vae_model.config, "latents_mean") 45 | and self.vae_model.config.latents_mean is not None 46 | ) 47 | self.has_latents_std = ( 48 | hasattr(self.vae_model.config, "latents_std") 49 | and self.vae_model.config.latents_std is not None 50 | ) 51 | self.latents_mean = self.vae_model.config.latents_mean 52 | self.latents_std = self.vae_model.config.latents_std 53 | 54 | x = torch.randn(1, self.vae_model.config.in_channels, 32, 32) 55 | z = self.encode(x) 56 | 57 | # set downsampling factor 58 | self.downsampling_factor = int(x.shape[2] / z.shape[2]) 59 | 60 | def encode(self, x: torch.tensor, batch_size: int = 8): 61 | latents = [] 62 | for i in range(0, x.shape[0], batch_size): 63 | latents.append( 64 | self.vae_model.encode(x[i : i + batch_size]).latent_dist.sample() 65 | ) 66 | latents = torch.cat(latents, dim=0) 67 | latents = (latents - self.shift_factor) * self.vae_model.config.scaling_factor 68 | 69 | return latents 70 | 71 | def decode(self, z: torch.tensor): 72 | 73 | if self.has_latents_mean and self.has_latents_std: 74 | latents_mean = ( 75 | torch.tensor(self.latents_mean) 76 | .view(1, self.latent_channels, 1, 1) 77 | .to(z.device, z.dtype) 78 | ) 79 | latents_std = ( 80 | torch.tensor(self.latents_std) 81 | .view(1, self.latent_channels, 1, 1) 82 | .to(z.device, z.dtype) 83 | ) 84 | z = z * latents_std / self.vae_model.config.scaling_factor + latents_mean 85 | else: 86 | z = z / self.vae_model.config.scaling_factor + self.shift_factor 87 | 88 | use_tiling = ( 89 | z.shape[2] > self.tiling_size[0] or z.shape[3] > self.tiling_size[1] 90 | ) 91 | 92 | if use_tiling: 93 | samples = [] 94 | for i in range(z.shape[0]): 95 | 96 | z_i = z[i].unsqueeze(0) 97 | 98 | tiler = Tiler() 99 | tiles = tiler.get_tiles( 100 | input=z_i, 101 | tile_size=self.tiling_size, 102 | overlap_size=self.tiling_overlap, 103 | scale=self.downsampling_factor, 104 | out_channels=3, 105 | ) 106 | 107 | for i, tile_row in enumerate(tiles): 108 | for j, tile in enumerate(tile_row): 109 | tile_shape = tile.shape 110 | # pad tile to inference size if tile is smaller than inference size 111 | tile = pad( 112 | tile, 113 | base_h=self.tiling_size[0], 114 | base_w=self.tiling_size[1], 115 | ) 116 | tile_decoded = self.vae_model.decode(tile).sample 117 | tiles[i][j] = ( 118 | tile_decoded[ 119 | 0, 120 | :, 121 | : int(tile_shape[2] * self.downsampling_factor), 122 | : int(tile_shape[3] * self.downsampling_factor), 123 | ] 124 | .cpu() 125 | .unsqueeze(0) 126 | ) 127 | 128 | # merge tiles 129 | samples.append(tiler.merge_tiles(tiles=tiles)) 130 | 131 | samples = torch.cat(samples, dim=0) 132 | 133 | else: 134 | samples = self.vae_model.decode(z).sample 135 | 136 | return samples 137 | -------------------------------------------------------------------------------- /src/lbm/models/vae/autoencoderKL_config.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from pydantic.dataclasses import dataclass 4 | 5 | from ..base import ModelConfig 6 | 7 | 8 | @dataclass 9 | class AutoencoderKLDiffusersConfig(ModelConfig): 10 | """This is the VAEConfig class which defines all the useful parameters to instantiate the model. 11 | 12 | Args: 13 | 14 | version (str): The version of the model. Defaults to "stabilityai/sdxl-vae". 15 | subfolder (str): The subfolder of the model if loaded from another model. Defaults to "". 16 | revision (str): The revision of the model. Defaults to "main". 17 | input_key (str): The key of the input data in the batch. Defaults to "image". 18 | tiling_size (Tuple[int, int]): The size of the tiling. Defaults to (64, 64). 19 | tiling_overlap (Tuple[int, int]): The overlap of the tiling. Defaults to (16, 16). 20 | """ 21 | 22 | version: str = "stabilityai/sdxl-vae" 23 | subfolder: str = "" 24 | revision: str = "main" 25 | input_key: str = "image" 26 | tiling_size: Tuple[int, int] = (64, 64) 27 | tiling_overlap: Tuple[int, int] = (16, 16) 28 | -------------------------------------------------------------------------------- /src/lbm/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the training pipeline and the training configuration along with all relevant parts 3 | of the training pipeline such as loggers and callbacks. 4 | 5 | The :mod:`cr.trainer` includes the following submodules: 6 | 7 | - :mod:`cr.trainer.trainer`: the main training pipeline class for ClipDrop. 8 | - :mod:`cr.trainer.training_config`: the configuration for the training pipeline. 9 | - :mod:`cr.trainer.loggers`: the loggers for logging samples to wandb. 10 | 11 | 12 | Examples 13 | ######## 14 | 15 | Train a model using the training pipeline 16 | 17 | .. code-block:: python 18 | 19 | from cr.trainer import TrainingPipeline, TrainingConfig 20 | from cr.data import DataPipeline, DataConfig 21 | from pytorch_lightning import Trainer 22 | from cr.data.datasets import DataModule, DataModuleConfig 23 | 24 | # Create a model to train 25 | model = DummyModel() 26 | 27 | # Create a training configuration 28 | config = TrainingConfig( 29 | experiment_id="test", 30 | optimizers_name=["AdamW"], 31 | optimizers_kwargs=[{}], 32 | learning_rates=[1e-3], 33 | lr_schedulers_name=[None], 34 | lr_schedulers_kwargs=[{}], 35 | trainable_params=[["./*"]], 36 | log_keys="txt", 37 | log_samples_model_kwargs={ 38 | "max_samples": 8, 39 | "num_steps": 20, 40 | "input_shape": (4, 32, 32), 41 | "guidance_scale": 7.5, 42 | } 43 | ) 44 | 45 | # Create a training pipeline 46 | pipeline = TrainingPipeline(model=model, pipeline_config=config) 47 | 48 | # Create a DataModule 49 | data_module = DataModule( 50 | train_config=DataModuleConfig( 51 | shards_path_or_urls="your urls or paths", 52 | decoder="pil", 53 | shuffle_buffer_size=100, 54 | per_worker_batch_size=32, 55 | num_workers=4, 56 | ), 57 | train_filters_mappers=your_mappers_and_filters, 58 | eval_config=DataModuleConfig( 59 | shards_path_or_urls="your urls or paths", 60 | decoder="pil", 61 | shuffle_buffer_size=100, 62 | per_worker_batch_size=32, 63 | num_workers=4, 64 | ), 65 | eval_filters_mappers=your_mappers_and_filters, 66 | ) 67 | 68 | # Create a trainer 69 | trainer = Trainer( 70 | accelerator="cuda", 71 | max_epochs=1, 72 | devices=1, 73 | log_every_n_steps=1, 74 | default_root_dir="your dir", 75 | max_steps=2, 76 | ) 77 | 78 | # Train the model 79 | trainer.fit(pipeline, data_module) 80 | """ 81 | 82 | from .trainer import TrainingPipeline 83 | from .training_config import TrainingConfig 84 | 85 | __all__ = ["TrainingPipeline", "TrainingConfig"] 86 | -------------------------------------------------------------------------------- /src/lbm/trainer/loggers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Any, Dict, List, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import wandb 8 | from PIL import Image, ImageDraw, ImageFont 9 | from pytorch_lightning import Trainer 10 | from pytorch_lightning.callbacks import Callback 11 | from pytorch_lightning.utilities import rank_zero_only 12 | from torchvision.utils import make_grid 13 | 14 | from ..trainer import TrainingPipeline 15 | 16 | logging.basicConfig(level=logging.INFO) 17 | 18 | 19 | def create_grid_texts( 20 | texts: List[str], 21 | n_cols: int = 4, 22 | image_size: Tuple[int] = (512, 512), 23 | font_size: int = 40, 24 | margin: int = 5, 25 | offset: int = 5, 26 | ) -> Image.Image: 27 | """ 28 | Create a grid of white images containing the given texts. 29 | 30 | Args: 31 | texts (List[str]): List of strings to be drawn on images. 32 | n_cols (int): Number of columns in the grid. 33 | image_size (tuple): Size of the generated images (width, height). 34 | font_size (int): Font size of the text. 35 | margin (int): Margin around the text. 36 | offset (int): Offset between lines. 37 | 38 | Returns: 39 | PIL.Image: List of generated images as a grid 40 | """ 41 | 42 | images = [] 43 | font = ImageFont.load_default(size=font_size) 44 | 45 | for text in texts: 46 | img = Image.new("RGB", image_size, color="white") 47 | draw = ImageDraw.Draw(img) 48 | margin_ = margin 49 | offset_ = offset 50 | for line in wrap_text( 51 | text=text, draw=draw, max_width=image_size[0] - 2 * margin_, font=font 52 | ): 53 | draw.text((margin_, offset_), line, font=font, fill="black") 54 | offset_ += font_size 55 | images.append(img) 56 | 57 | # create a pil grid 58 | n_rows = math.ceil(len(images) / n_cols) 59 | grid = Image.new( 60 | "RGB", (n_cols * image_size[0], n_rows * image_size[1]), color="white" 61 | ) 62 | for i, img in enumerate(images): 63 | grid.paste(img, (i % n_cols * image_size[0], i // n_cols * image_size[1])) 64 | 65 | return grid 66 | 67 | 68 | def wrap_text( 69 | text: str, draw: ImageDraw.Draw, max_width: int, font: ImageFont 70 | ) -> List[str]: 71 | """ 72 | Wrap text to fit within a specified width when drawn. 73 | It will return to the new line when the text is larger than the max_width. 74 | 75 | Args: 76 | text (str): The text to be wrapped. 77 | draw (ImageDraw.Draw): The draw object to calculate text size. 78 | max_width (int): The maximum width for the wrapped text. 79 | font (ImageFont): The font used for the text. 80 | 81 | Returns: 82 | List[str]: List of wrapped lines. 83 | """ 84 | lines = [] 85 | current_line = "" 86 | for letter in text: 87 | if draw.textbbox((0, 0), current_line + letter, font=font)[2] <= max_width: 88 | current_line += letter 89 | else: 90 | lines.append(current_line) 91 | current_line = letter 92 | lines.append(current_line) 93 | return lines 94 | 95 | 96 | class WandbSampleLogger(Callback): 97 | """ 98 | Logger for logging samples to wandb. This logger is used to log images, text, and metrics to wandb. 99 | 100 | Args: 101 | log_batch_freq (int): The frequency of logging samples to wandb. Default is 100. 102 | """ 103 | 104 | def __init__(self, log_batch_freq: int = 100): 105 | super().__init__() 106 | self.log_batch_freq = log_batch_freq 107 | 108 | def on_train_batch_end( 109 | self, 110 | trainer: Trainer, 111 | pl_module: TrainingPipeline, 112 | outputs: Dict[str, Any], 113 | batch: Any, 114 | batch_idx: int, 115 | ) -> None: 116 | self.log_samples(trainer, pl_module, outputs, batch, batch_idx, split="train") 117 | self._process_logs(trainer, outputs, split="train") 118 | 119 | def on_validation_batch_end( 120 | self, 121 | trainer: Trainer, 122 | pl_module: TrainingPipeline, 123 | outputs: Dict[str, Any], 124 | batch: Any, 125 | batch_idx: int, 126 | ) -> None: 127 | self.log_samples(trainer, pl_module, outputs, batch, batch_idx, split="val") 128 | self._process_logs(trainer, outputs, split="val") 129 | 130 | @rank_zero_only 131 | @torch.no_grad() 132 | def log_samples( 133 | self, 134 | trainer: Trainer, 135 | pl_module: TrainingPipeline, 136 | outputs: Dict[str, Any], 137 | batch: Dict[str, Any], 138 | batch_idx: int, 139 | split: str = "train", 140 | ) -> None: 141 | if hasattr(pl_module, "log_samples"): 142 | if batch_idx % self.log_batch_freq == 0: 143 | is_training = pl_module.training 144 | if is_training: 145 | pl_module.eval() 146 | 147 | logs = pl_module.log_samples(batch) 148 | logs = self._process_logs(trainer, logs, split=split) 149 | 150 | if is_training: 151 | pl_module.train() 152 | else: 153 | logging.warning( 154 | "log_img method not found in LightningModule. Skipping image logging." 155 | ) 156 | 157 | @rank_zero_only 158 | def _process_logs( 159 | self, trainer, logs: Dict[str, Any], rescale=True, split="train" 160 | ) -> Dict[str, Any]: 161 | for key, value in logs.items(): 162 | if isinstance(value, torch.Tensor): 163 | value = value.detach().cpu() 164 | if value.dim() == 4: 165 | images = value 166 | if rescale: 167 | images = (images + 1.0) / 2.0 168 | grid = make_grid(images, nrow=4) 169 | grid = grid.permute(1, 2, 0) 170 | grid = grid.mul(255).clamp(0, 255).to(torch.uint8) 171 | logs[key] = grid.numpy() 172 | trainer.logger.experiment.log( 173 | {f"{key}/{split}": [wandb.Image(Image.fromarray(logs[key]))]}, 174 | step=trainer.global_step, 175 | ) 176 | 177 | # Scalar tensor 178 | if value.dim() == 1 or value.dim() == 0: 179 | value = value.float().numpy() 180 | trainer.logger.experiment.log( 181 | {f"{key}/{split}": value}, step=trainer.global_step 182 | ) 183 | 184 | # list of string (e.g. text) 185 | if isinstance(value, list): 186 | if isinstance(value[0], str): 187 | pil_image_texts = create_grid_texts(value) 188 | wandb_image = wandb.Image(pil_image_texts) 189 | trainer.logger.experiment.log( 190 | {f"{key}/{split}": [wandb_image]}, 191 | step=trainer.global_step, 192 | ) 193 | 194 | # dict of tensors (e.g. metrics) 195 | if isinstance(value, dict): 196 | for k, v in value.items(): 197 | if isinstance(v, torch.Tensor): 198 | value[k] = v.detach().cpu().numpy() 199 | trainer.logger.experiment.log( 200 | {f"{key}/{split}": value}, step=trainer.global_step 201 | ) 202 | 203 | if isinstance(value, int) or isinstance(value, float): 204 | trainer.logger.experiment.log( 205 | {f"{key}/{split}": value}, step=trainer.global_step 206 | ) 207 | 208 | return logs 209 | 210 | 211 | class TensorBoardSampleLogger(Callback): 212 | """ 213 | Logger for logging samples to tensorboard. This logger is used to log images, text, and metrics to tensorboard. 214 | 215 | Args: 216 | log_batch_freq (int): The frequency of logging samples to tensorboard. Default is 100. 217 | """ 218 | 219 | def __init__(self, log_batch_freq: int = 100): 220 | super().__init__() 221 | self.log_batch_freq = log_batch_freq 222 | 223 | def on_train_batch_end( 224 | self, 225 | trainer: Trainer, 226 | pl_module: TrainingPipeline, 227 | outputs: Dict[str, Any], 228 | batch: Any, 229 | batch_idx: int, 230 | ) -> None: 231 | self.log_samples(trainer, pl_module, outputs, batch, batch_idx, split="train") 232 | self._process_logs(trainer, outputs, split="train") 233 | 234 | def on_validation_batch_end( 235 | self, 236 | trainer: Trainer, 237 | pl_module: TrainingPipeline, 238 | outputs: Dict[str, Any], 239 | batch: Any, 240 | batch_idx: int, 241 | ) -> None: 242 | self.log_samples(trainer, pl_module, outputs, batch, batch_idx, split="val") 243 | self._process_logs(trainer, outputs, split="val") 244 | 245 | @rank_zero_only 246 | @torch.no_grad() 247 | def log_samples( 248 | self, 249 | trainer: Trainer, 250 | pl_module: TrainingPipeline, 251 | outputs: Dict[str, Any], 252 | batch: Dict[str, Any], 253 | batch_idx: int, 254 | split: str = "train", 255 | ) -> None: 256 | if hasattr(pl_module, "log_samples"): 257 | if batch_idx % self.log_batch_freq == 0: 258 | is_training = pl_module.training 259 | if is_training: 260 | pl_module.eval() 261 | 262 | logs = pl_module.log_samples(batch) 263 | logs = self._process_logs(trainer, logs, split=split) 264 | 265 | if is_training: 266 | pl_module.train() 267 | else: 268 | logging.warning( 269 | "log_img method not found in LightningModule. Skipping image logging." 270 | ) 271 | 272 | @rank_zero_only 273 | def _process_logs( 274 | self, trainer, logs: Dict[str, Any], rescale=True, split="train" 275 | ) -> Dict[str, Any]: 276 | for key, value in logs.items(): 277 | if isinstance(value, torch.Tensor): 278 | value = value.detach().cpu() 279 | if value.dim() == 4: 280 | images = value 281 | if rescale: 282 | images = (images + 1.0) / 2.0 283 | grid = make_grid(images, nrow=4) 284 | # grid = grid.permute(1, 2, 0) 285 | grid = grid.mul(255).clamp(0, 255).to(torch.uint8) 286 | logs[key] = grid.numpy() 287 | trainer.logger.experiment.add_image( 288 | f"{key}/{split}", 289 | logs[key], 290 | trainer.global_step, 291 | ) 292 | 293 | # Scalar tensor 294 | if value.dim() == 1 or value.dim() == 0: 295 | value = value.float().numpy() 296 | trainer.logger.experiment.add_scalar( 297 | f"{key}/{split}", value, trainer.global_step 298 | ) 299 | 300 | # list of string (e.g. text) 301 | if isinstance(value, list): 302 | if isinstance(value[0], str): 303 | pil_image_texts = create_grid_texts(value) 304 | trainer.logger.experiment.add_image( 305 | f"{key}/{split}", 306 | np.transpose(np.array(pil_image_texts), (2, 0, 1)), 307 | trainer.global_step, 308 | ) 309 | 310 | # dict of tensors (e.g. metrics) 311 | if isinstance(value, dict): 312 | for k, v in value.items(): 313 | if isinstance(v, torch.Tensor): 314 | value[k] = v.detach().cpu().numpy() 315 | trainer.logger.experiment.add_scalar( 316 | f"{key}/{split}", value, trainer.global_step 317 | ) 318 | 319 | if isinstance(value, int) or isinstance(value, float): 320 | trainer.logger.experiment.add_scalar( 321 | f"{key}/{split}", value, trainer.global_step 322 | ) 323 | 324 | return logs 325 | -------------------------------------------------------------------------------- /src/lbm/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import re 4 | import time 5 | from typing import Any, Dict 6 | 7 | import pytorch_lightning as pl 8 | import torch 9 | 10 | from ..models.base.base_model import BaseModel 11 | from .training_config import TrainingConfig 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | 15 | 16 | class TrainingPipeline(pl.LightningModule): 17 | """ 18 | Main Training Pipeline class 19 | 20 | Args: 21 | 22 | model (BaseModel): The model to train 23 | pipeline_config (TrainingConfig): The configuration for the training pipeline 24 | verbose (bool): Whether to print logs in the console. Default is False. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model: BaseModel, 30 | pipeline_config: TrainingConfig, 31 | verbose: bool = False, 32 | **kwargs, 33 | ): 34 | super().__init__() 35 | 36 | self.model = model 37 | self.pipeline_config = pipeline_config 38 | self.log_samples_model_kwargs = pipeline_config.log_samples_model_kwargs 39 | 40 | # save hyperparameters. 41 | self.save_hyperparameters(ignore="model") 42 | self.save_hyperparameters({"model_config": model.config.to_dict()}) 43 | 44 | # logger. 45 | self.verbose = verbose 46 | 47 | # setup logging. 48 | log_keys = pipeline_config.log_keys 49 | 50 | if isinstance(log_keys, str): 51 | log_keys = [log_keys] 52 | 53 | if log_keys is None: 54 | log_keys = [] 55 | 56 | self.log_keys = log_keys 57 | 58 | def on_fit_start(self) -> None: 59 | self.model.on_fit_start(device=self.device) 60 | if self.global_rank == 0: 61 | self.timer = time.perf_counter() 62 | 63 | def on_train_batch_end( 64 | self, outputs: Dict[str, Any], batch: Any, batch_idx: int 65 | ) -> None: 66 | if self.global_rank == 0: 67 | logging.debug("on_train_batch_end") 68 | self.model.on_train_batch_end(batch) 69 | 70 | average_time_frequency = 10 71 | if self.global_rank == 0 and batch_idx % average_time_frequency == 0: 72 | delta = time.perf_counter() - self.timer 73 | logging.info( 74 | f"Average time per batch {batch_idx} took {delta / (batch_idx + 1)} seconds" 75 | ) 76 | 77 | def configure_optimizers(self) -> torch.optim.Optimizer: 78 | """ 79 | Setup optimizers and learning rate schedulers. 80 | """ 81 | optimizers = [] 82 | lr = self.pipeline_config.learning_rate 83 | param_list = [] 84 | n_params = 0 85 | param_list_ = {"params": []} 86 | for name, param in self.model.named_parameters(): 87 | for regex in self.pipeline_config.trainable_params: 88 | pattern = re.compile(regex) 89 | if re.match(pattern, name): 90 | if param.requires_grad: 91 | param_list_["params"].append(param) 92 | n_params += param.numel() 93 | 94 | param_list.append(param_list_) 95 | 96 | logging.info(f"Number of trainable parameters: {n_params}") 97 | 98 | optimizer_cls = getattr( 99 | importlib.import_module("torch.optim"), 100 | self.pipeline_config.optimizer_name, 101 | ) 102 | optimizer = optimizer_cls( 103 | param_list, lr=lr, **self.pipeline_config.optimizer_kwargs 104 | ) 105 | optimizers.append(optimizer) 106 | 107 | self.optims = optimizers 108 | schedulers_config = self.configure_lr_schedulers() 109 | 110 | for name, param in self.model.named_parameters(): 111 | set_grad_false = True 112 | for regex in self.pipeline_config.trainable_params: 113 | pattern = re.compile(regex) 114 | if re.match(pattern, name): 115 | if param.requires_grad: 116 | set_grad_false = False 117 | if set_grad_false: 118 | param.requires_grad = False 119 | 120 | num_trainable_params = sum( 121 | p.numel() for p in self.model.parameters() if p.requires_grad 122 | ) 123 | 124 | logging.info(f"Number of trainable parameters: {num_trainable_params}") 125 | 126 | schedulers_config = self.configure_lr_schedulers() 127 | 128 | if schedulers_config is None: 129 | return optimizers 130 | 131 | return optimizers, [ 132 | schedulers_config_ for schedulers_config_ in schedulers_config 133 | ] 134 | 135 | def configure_lr_schedulers(self): 136 | schedulers_config = [] 137 | if self.pipeline_config.lr_scheduler_name is None: 138 | scheduler = None 139 | schedulers_config.append(scheduler) 140 | else: 141 | scheduler_cls = getattr( 142 | importlib.import_module("torch.optim.lr_scheduler"), 143 | self.pipeline_config.lr_scheduler_name, 144 | ) 145 | scheduler = scheduler_cls( 146 | self.optims[0], 147 | **self.pipeline_config.lr_scheduler_kwargs, 148 | ) 149 | lr_scheduler_config = { 150 | "scheduler": scheduler, 151 | "interval": self.pipeline_config.lr_scheduler_interval, 152 | "monitor": "val_loss", 153 | "frequency": self.pipeline_config.lr_scheduler_frequency, 154 | } 155 | schedulers_config.append(lr_scheduler_config) 156 | 157 | if all([scheduler is None for scheduler in schedulers_config]): 158 | return None 159 | 160 | return schedulers_config 161 | 162 | def training_step(self, train_batch: Dict[str, Any], batch_idx: int) -> dict: 163 | model_output = self.model(train_batch) 164 | loss = model_output["loss"] 165 | logging.info(f"loss: {loss}") 166 | return { 167 | "loss": loss, 168 | "batch_idx": batch_idx, 169 | } 170 | 171 | def validation_step(self, val_batch: Dict[str, Any], val_idx: int) -> dict: 172 | loss = self.model(val_batch, device=self.device)["loss"] 173 | 174 | metrics = self.model.compute_metrics(val_batch) 175 | 176 | return {"loss": loss, "metrics": metrics} 177 | 178 | def log_samples(self, batch: Dict[str, Any]): 179 | logging.debug("log_samples") 180 | logs = self.model.log_samples( 181 | batch, 182 | **self.log_samples_model_kwargs, 183 | ) 184 | 185 | if logs is not None: 186 | N = min([logs[keys].shape[0] for keys in logs]) 187 | else: 188 | N = 0 189 | 190 | # Log inputs 191 | if self.log_keys is not None: 192 | for key in self.log_keys: 193 | if key in batch: 194 | if N > 0: 195 | logs[key] = batch[key][:N] 196 | else: 197 | logs[key] = batch[key] 198 | 199 | return logs 200 | -------------------------------------------------------------------------------- /src/lbm/trainer/training_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field 2 | from typing import List, Literal, Optional, Union 3 | 4 | from pydantic.dataclasses import dataclass 5 | 6 | from ..config import BaseConfig 7 | 8 | 9 | @dataclass 10 | class TrainingConfig(BaseConfig): 11 | """ 12 | Configuration for the training pipeline 13 | 14 | Args: 15 | 16 | experiment_id (str): 17 | The experiment id for the training run. If not provided, a random id will be generated. 18 | optimizer_name (str): 19 | The optimizer to use. Default is "AdamW". Choices are "Adam", "AdamW", "Adadelta", "Adagrad", "RMSprop", "SGD" 20 | optimizer_kwargs (Dict[str, Any]) 21 | The optimizer kwargs. Default is [{}] 22 | learning_rate (float): 23 | The learning rate to use. Default is 1e-3 24 | lr_scheduler_name (str): 25 | The learning rate scheduler to use. Default is None. Choices are "StepLR", "CosineAnnealingLR", 26 | "CosineAnnealingWarmRestarts", "ReduceLROnPlateau", "ExponentialLR" 27 | lr_scheduler_kwargs (Dict[str, Any]) 28 | The learning rate scheduler kwargs. Default is [{}] 29 | lr_scheduler_interval (str): 30 | The learning rate scheduler interval. Default is ["step"]. Choices are "step", "epoch" 31 | lr_scheduler_frequency (int): 32 | The learning rate scheduler frequency. Default is 1 33 | metrics (List[str]) 34 | The metrics to use. Default is None 35 | tracking_metrics: Optional[List[str]] 36 | The metrics to track. Default is None 37 | backup_every (int): 38 | The frequency to backup the model. Default is 50. 39 | trainable_params (Union[str, List[str]]): 40 | Regexes indicateing the parameters to train. 41 | Default is [["./*"]] (i.e. all parameters are trainable) 42 | log_keys: Union[str, List[str]]: 43 | The keys to log when sampling from the model. Default is "txt" 44 | log_samples_model_kwargs (Dict[str, Any]): 45 | The kwargs for logging samples from the model. Default is { 46 | "max_samples": 4, 47 | "num_steps": 20, 48 | "input_shape": None, 49 | } 50 | """ 51 | 52 | experiment_id: Optional[str] = None 53 | optimizer_name: Literal[ 54 | "Adam", "AdamW", "Adadelta", "Adagrad", "RMSprop", "SGD" 55 | ] = field(default_factory=lambda: "AdamW") 56 | optimizer_kwargs: Optional[dict] = field(default_factory=lambda: {}) 57 | learning_rate: float = field(default_factory=lambda: 1e-3) 58 | lr_scheduler_name: Optional[ 59 | Literal[ 60 | "StepLR", 61 | "CosineAnnealingLR", 62 | "CosineAnnealingWarmRestarts", 63 | "ReduceLROnPlateau", 64 | "ExponentialLR", 65 | None, 66 | ] 67 | ] = None 68 | lr_scheduler_kwargs: Optional[dict] = field(default_factory=lambda: {}) 69 | lr_scheduler_interval: Optional[Literal["step", "epoch", None]] = "step" 70 | lr_scheduler_frequency: Optional[int] = 1 71 | metrics: Optional[List[str]] = None 72 | tracking_metrics: Optional[List[str]] = None 73 | backup_every: int = 50 74 | trainable_params: List[str] = field(default_factory=lambda: ["./*"]) 75 | log_keys: Optional[Union[str, List[str]]] = "txt" 76 | log_samples_model_kwargs: Optional[dict] = field( 77 | default_factory=lambda: { 78 | "max_samples": 4, 79 | "num_steps": 20, 80 | "input_shape": None, 81 | } 82 | ) 83 | -------------------------------------------------------------------------------- /src/lbm/trainer/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | import time 5 | from typing import Dict, List, Literal, Optional, Tuple 6 | 7 | import torch 8 | 9 | 10 | class StateDictAdapter: 11 | """ 12 | StateDictAdapter for adapting the state dict of a model to a checkpoint state dict. 13 | 14 | This class will iterate over all keys in the checkpoint state dict and filter them by a list of regex keys. 15 | For each matching key, the class will adapt the checkpoint state dict to the model state dict. 16 | Depending on the target size, the class will add missing blocks or cut the block. 17 | When adding missing blocks, the class will use a strategy to fill the missing blocks: either adding zeros or normal random values. 18 | 19 | Example: 20 | 21 | ``` 22 | adapter = StateDictAdapter() 23 | new_state_dict = adapter( 24 | model_state_dict=model.state_dict(), 25 | checkpoint_state_dict=state_dict, 26 | regex_keys=[ 27 | r"class_embedding.linear_1.weight", 28 | r"conv_in.weight", 29 | r"(down_blocks|up_blocks)\.\d+\.attentions\.\d+\.transformer_blocks\.\d+\.attn\d+\.(to_k|to_v)\.weight", 30 | r"mid_block\.attentions\.\d+\.transformer_blocks\.\d+\.attn\d+\.(to_k|to_v)\.weight" 31 | ] 32 | ) 33 | ``` 34 | 35 | Args: 36 | model_state_dict (Dict[str, torch.Tensor]): The model state dict. 37 | checkpoint_state_dict (Dict[str, torch.Tensor]): The checkpoint state dict. 38 | regex_keys (Optional[List[str]]): A list of regex keys to adapt the checkpoint state dict. Defaults to None. 39 | Passing a list of regex will drastically reduce the latency. 40 | If None, all keys in the checkpoint state dict will be adapted. 41 | strategy (Literal["zeros", "normal"], optional): The strategy to fill the missing blocks. Defaults to "normal". 42 | 43 | """ 44 | 45 | def _create_block( 46 | self, 47 | shape: List[int], 48 | strategy: Literal["zeros", "normal"], 49 | input: torch.Tensor = None, 50 | ): 51 | if strategy == "zeros": 52 | return torch.zeros(shape) 53 | elif strategy == "normal": 54 | if input is not None: 55 | mean = input.mean().item() 56 | std = input.std().item() 57 | return torch.randn(shape) * std + mean 58 | else: 59 | return torch.randn(shape) 60 | else: 61 | raise ValueError(f"Unknown strategy {strategy}") 62 | 63 | def __call__( 64 | self, 65 | model_state_dict: Dict[str, torch.Tensor], 66 | checkpoint_state_dict: Dict[str, torch.Tensor], 67 | regex_keys: Optional[List[str]] = None, 68 | strategy: Literal["zeros", "normal"] = "normal", 69 | ): 70 | start = time.perf_counter() 71 | # if no regex keys are provided, we use all keys in the model state dict 72 | if regex_keys is None: 73 | regex_keys = list(model_state_dict.keys()) 74 | 75 | # iterate over all keys in the checkpoint state dict 76 | for checkpoint_key in list(checkpoint_state_dict.keys()): 77 | # iterate over all regex keys 78 | for regex_key in regex_keys: 79 | if re.match(regex_key, checkpoint_key): 80 | dst_shape = model_state_dict[checkpoint_key].shape 81 | src_shape = checkpoint_state_dict[checkpoint_key].shape 82 | 83 | ## Sizes adapter 84 | # if length of shapes are different, we need to unsqueeze or squeeze the tensor 85 | if len(dst_shape) != len(src_shape): 86 | # in the case [a] vs [a, b] -> unsqueeze [a, 1] 87 | if len(src_shape) == 1: 88 | checkpoint_state_dict[checkpoint_key] = ( 89 | checkpoint_state_dict[checkpoint_key].unsqueeze(1) 90 | ) 91 | logging.info( 92 | f"Unsqueeze {checkpoint_key}: {src_shape} -> {checkpoint_state_dict[checkpoint_key].shape}" 93 | ) 94 | # in the case [a, b] vs [a] -> squeeze [a] 95 | elif len(dst_shape) == 1: 96 | checkpoint_state_dict[checkpoint_key] = ( 97 | checkpoint_state_dict[checkpoint_key][:, 0] 98 | ) 99 | logging.info( 100 | f"Squeeze {checkpoint_key}: {src_shape} -> {checkpoint_state_dict[checkpoint_key].shape}" 101 | ) 102 | # in the other cases, raise an error 103 | else: 104 | raise ValueError( 105 | f"Shapes of {checkpoint_key} are different: {dst_shape} != {src_shape}" 106 | ) 107 | 108 | # update the shapes 109 | dst_shape = model_state_dict[checkpoint_key].shape 110 | src_shape = checkpoint_state_dict[checkpoint_key].shape 111 | assert len(dst_shape) == len( 112 | src_shape 113 | ), f"Shapes of {checkpoint_key} are different: {dst_shape} != {src_shape}" 114 | 115 | ## Shapes adapter 116 | # modify the checkpoint state dict only if the shapes are different 117 | if dst_shape != src_shape: 118 | # create a copy of the tensor 119 | tmp = torch.clone(checkpoint_state_dict[checkpoint_key]) 120 | 121 | # iterate over all dimensions 122 | for i in range(len(dst_shape)): 123 | if dst_shape[i] != src_shape[i]: 124 | diff = dst_shape[i] - src_shape[i] 125 | 126 | # if the difference is greater than 0, we need to add missing blocks 127 | if diff > 0: 128 | missing_shape = list(tmp.shape) 129 | missing_shape[i] = diff 130 | missing = self._create_block( 131 | shape=missing_shape, 132 | strategy=strategy, 133 | input=tmp, 134 | ) 135 | tmp = torch.cat((tmp, missing), dim=i) 136 | logging.info( 137 | f"Adapting {checkpoint_key} with strategy:{strategy} from shape {src_shape} to {dst_shape}" 138 | ) 139 | # if the difference is less than 0, we need to cut the block 140 | else: 141 | tmp = tmp.narrow(i, 0, dst_shape[i]) 142 | logging.info( 143 | f"Adapting {checkpoint_key} by narrowing from shape {src_shape} to {dst_shape}" 144 | ) 145 | 146 | checkpoint_state_dict[checkpoint_key] = tmp 147 | end = time.perf_counter() 148 | logging.info(f"StateDictAdapter took {end-start:.2f} seconds") 149 | return checkpoint_state_dict 150 | 151 | 152 | class StateDictRenamer: 153 | """ 154 | StateDictRenamer for renaming keys in a checkpoint state dict. 155 | This class will iterate over all keys in the checkpoint state dict and rename them according to a rename dict. 156 | 157 | Example: 158 | 159 | ``` 160 | renamer = StateDictRenamer() 161 | new_state_dict = renamer( 162 | checkpoint_state_dict=state_dict, 163 | rename_dict={ 164 | "add_embedding.linear_1.weight": "class_embedding.linear_1.weight", 165 | "add_embedding.linear_1.bias": "class_embedding.linear_1.bias", 166 | "add_embedding.linear_2.weight": "class_embedding.linear_2.weight", 167 | "add_embedding.linear_2.bias": "class_embedding.linear_2.bias", 168 | } 169 | ) 170 | ``` 171 | 172 | Args: 173 | 174 | checkpoint_state_dict (Dict[str, torch.Tensor]): The checkpoint state dict. 175 | rename_dict (Dict[str, str]): The dictionary mapping the old keys to new keys 176 | """ 177 | 178 | def __call__( 179 | self, 180 | checkpoint_state_dict: Dict[str, torch.Tensor], 181 | rename_dict: Dict[str, str], 182 | ) -> Dict[str, torch.Tensor]: 183 | for old_key, new_key in rename_dict.items(): 184 | if old_key not in checkpoint_state_dict: 185 | logging.warning(f"Key {old_key} not found in checkpoint state dict") 186 | continue 187 | else: 188 | assert ( 189 | new_key not in checkpoint_state_dict 190 | ), f"Key {new_key} already exists in checkpoint state dict" 191 | checkpoint_state_dict[new_key] = checkpoint_state_dict.pop(old_key) 192 | logging.info(f"Renaming {old_key} to {new_key}") 193 | return checkpoint_state_dict 194 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | 3 | ## Setup 4 | 5 | ```shell 6 | pip3 install -r requirements.txt 7 | ``` 8 | 9 | ## Run the tests 10 | 11 | ```shell 12 | python3 -m pytest . 13 | ``` -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest -------------------------------------------------------------------------------- /tests/test_dataset/test_filters.py: -------------------------------------------------------------------------------- 1 | from lbm.data.filters import FilterWrapper, KeyFilter, KeyFilterConfig 2 | 3 | 4 | class TestKeyFilter: 5 | def test_key_filter(self): 6 | filter = KeyFilter(KeyFilterConfig(keys=["a", "b"])) 7 | assert filter({"a": 1, "b": 2, "c": 3}) 8 | assert not filter({"a": 1}) 9 | assert not filter({"b": 2}) 10 | assert not filter({"c": 3}) 11 | 12 | def test_key_filter_single_key(self): 13 | filter = KeyFilter(KeyFilterConfig(keys="a")) 14 | assert filter({"a": 1, "b": 2}) 15 | assert not filter({"b": 2}) 16 | 17 | 18 | class TestFilterWrapper: 19 | def test_filter_wrapper(self): 20 | filter = FilterWrapper( 21 | [ 22 | KeyFilter(KeyFilterConfig(keys=["a", "b"])), 23 | KeyFilter(KeyFilterConfig(keys="c")), 24 | ] 25 | ) 26 | assert not filter({"a": 1, "b": 2, "c": 3}) 27 | assert not filter({"a": 1}) 28 | assert not filter({"b": 2}) 29 | assert not filter({"c": 3}) 30 | 31 | def test_filter_wrapper(self): 32 | filter = FilterWrapper( 33 | [ 34 | KeyFilter(KeyFilterConfig(keys=["a", "b"])), 35 | KeyFilter(KeyFilterConfig(keys=["a", "c"])), 36 | ] 37 | ) 38 | assert filter({"a": 1, "b": 2, "c": 3}) 39 | assert not filter({"a": 1}) 40 | assert not filter({"b": 2}) 41 | assert not filter({"c": 3}) 42 | -------------------------------------------------------------------------------- /tests/test_dataset/test_mappers.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | import torch 5 | from PIL import Image 6 | 7 | from lbm.data.mappers import ( 8 | KeyRenameMapper, 9 | KeyRenameMapperConfig, 10 | MapperWrapper, 11 | RescaleMapper, 12 | RescaleMapperConfig, 13 | TorchvisionMapper, 14 | TorchvisionMapperConfig, 15 | ) 16 | 17 | 18 | class TestKeyRenameMapper: 19 | @pytest.fixture() 20 | def dummy_batch(self): 21 | return {"image": 1, "text": 2, "label": "dummy_label"} 22 | 23 | @pytest.fixture() 24 | def mapper(self): 25 | return KeyRenameMapper( 26 | KeyRenameMapperConfig( 27 | key_map={"image": "image_tensor", "text": "text_tensor"} 28 | ) 29 | ) 30 | 31 | def test_mapper(self, mapper, dummy_batch): 32 | output_data = mapper(dummy_batch) 33 | assert output_data["image_tensor"] == 1 34 | assert output_data["text_tensor"] == 2 35 | assert output_data["label"] == "dummy_label" 36 | assert "image" not in output_data 37 | assert "text" not in output_data 38 | 39 | 40 | class TestKeyRenameMapperWithCondition: 41 | @pytest.fixture(params=[1, 2]) 42 | def dummy_batch(self, request): 43 | return {"image": 1, "text": 2, "label": request.param} 44 | 45 | @pytest.fixture(params=[{"image": "image_not_met", "text": "text_not_met"}, None]) 46 | def else_key_map(self, request): 47 | return request.param 48 | 49 | @pytest.fixture() 50 | def mapper(self, else_key_map): 51 | return KeyRenameMapper( 52 | KeyRenameMapperConfig( 53 | key_map={"image": "image_tensor", "text": "text_tensor"}, 54 | condition_key="label", 55 | condition_fn=lambda x: x == 1, 56 | else_key_map=else_key_map, 57 | ) 58 | ) 59 | 60 | def test_mapper(self, mapper, dummy_batch, else_key_map): 61 | output_data = mapper(dummy_batch) 62 | if dummy_batch["label"] == 1: 63 | assert output_data["image_tensor"] == 1 64 | assert output_data["text_tensor"] == 2 65 | assert output_data["label"] == 1 66 | assert "image" not in output_data 67 | assert "text" not in output_data 68 | elif else_key_map is not None: 69 | assert output_data["image_not_met"] == 1 70 | assert output_data["text_not_met"] == 2 71 | assert output_data["label"] == 2 72 | assert "image" not in output_data 73 | assert "text" not in output_data 74 | else: 75 | assert output_data["image"] == 1 76 | assert output_data["text"] == 2 77 | assert output_data["label"] == 2 78 | assert "image_tensor" not in output_data 79 | assert "text_tensor" not in output_data 80 | 81 | 82 | class TestMapperWrapper: 83 | @pytest.fixture() 84 | def dummy_batch(self): 85 | return {"image": 1, "text": 2, "label": "dummy_label"} 86 | 87 | @pytest.fixture() 88 | def mapper(self): 89 | return MapperWrapper( 90 | mappers=[ 91 | KeyRenameMapper( 92 | KeyRenameMapperConfig( 93 | key_map={"image": "image_tensor", "text": "text_tensor"} 94 | ) 95 | ), 96 | KeyRenameMapper( 97 | KeyRenameMapperConfig( 98 | key_map={ 99 | "image_tensor": "image_array", 100 | "text_tensor": "text_array", 101 | } 102 | ) 103 | ), 104 | ] 105 | ) 106 | 107 | def test_mapper(self, mapper, dummy_batch): 108 | output_data = mapper(dummy_batch) 109 | assert output_data["image_array"] == 1 110 | assert output_data["text_array"] == 2 111 | assert output_data["label"] == "dummy_label" 112 | assert "image" not in output_data 113 | assert "text" not in output_data 114 | assert "image_tensor" not in output_data 115 | assert "text_tensor" not in output_data 116 | 117 | 118 | class TestTorchvisionMapper: 119 | @pytest.fixture() 120 | def dummy_batch(self): 121 | return { 122 | "image": torch.randn( 123 | 3, 124 | 256, 125 | 256, 126 | ), 127 | "text": 2, 128 | "label": "dummy_label", 129 | } 130 | 131 | @pytest.fixture() 132 | def mapper(self): 133 | return TorchvisionMapper( 134 | TorchvisionMapperConfig( 135 | key="image", 136 | transforms=["CenterCrop", "ToPILImage"], 137 | transforms_kwargs=[{"size": 224}, {}], 138 | ) 139 | ) 140 | 141 | def test_mapper(self, mapper, dummy_batch): 142 | output_data = mapper(dummy_batch) 143 | assert output_data["image"].size == (224, 224) 144 | assert isinstance(output_data["image"], Image.Image) 145 | assert output_data["text"] == 2 146 | assert output_data["label"] == "dummy_label" 147 | 148 | @pytest.fixture() 149 | def mapper_with_output_key(self): 150 | return TorchvisionMapper( 151 | TorchvisionMapperConfig( 152 | key="image", 153 | output_key="image_transformed", 154 | transforms=["CenterCrop", "ToPILImage"], 155 | transforms_kwargs=[{"size": 224}, {}], 156 | ) 157 | ) 158 | 159 | def test_mapper(self, mapper_with_output_key, dummy_batch): 160 | output_data = mapper_with_output_key(dummy_batch) 161 | assert output_data["image_transformed"].size == (224, 224) 162 | assert isinstance(output_data["image_transformed"], Image.Image) 163 | assert isinstance(output_data["image"], torch.Tensor) 164 | assert output_data["image"].size() == (3, 256, 256) 165 | assert output_data["text"] == 2 166 | assert output_data["label"] == "dummy_label" 167 | 168 | 169 | class TestRescaleMapper: 170 | @pytest.fixture() 171 | def dummy_batch(self): 172 | return { 173 | "image": torch.rand( 174 | 3, 175 | 256, 176 | 256, 177 | ), 178 | "text": 2, 179 | "label": "dummy_label", 180 | } 181 | 182 | @pytest.fixture() 183 | def mapper(self): 184 | return RescaleMapper( 185 | RescaleMapperConfig( 186 | input_key="image", 187 | output_key="image", 188 | ) 189 | ) 190 | 191 | def test_mapper(self, mapper, dummy_batch): 192 | output_data = mapper(dummy_batch) 193 | assert torch.all(output_data["image"] <= 1) 194 | assert torch.all(output_data["image"] >= -1) 195 | assert output_data["text"] == 2 196 | assert output_data["label"] == "dummy_label" 197 | -------------------------------------------------------------------------------- /tests/test_lbm/test_lbm.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pytest 4 | import torch 5 | import torch.nn as nn 6 | from diffusers import FlowMatchEulerDiscreteScheduler 7 | 8 | from lbm.models.embedders import ConditionerWrapper 9 | from lbm.models.lbm import LBMConfig, LBMModel 10 | from lbm.models.unets import DiffusersUNet2DCondWrapper 11 | from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig 12 | 13 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 14 | 15 | 16 | class TestLBM: 17 | @pytest.fixture() 18 | def denoiser(self): 19 | return DiffusersUNet2DCondWrapper( 20 | in_channels=4, # VAE channels 21 | out_channels=4, # VAE channels 22 | up_block_types=["CrossAttnUpBlock2D"], 23 | down_block_types=[ 24 | "CrossAttnDownBlock2D", 25 | ], 26 | cross_attention_dim=[320], 27 | block_out_channels=[320], 28 | transformer_layers_per_block=[1], 29 | attention_head_dim=[5], 30 | norm_num_groups=32, 31 | ) 32 | 33 | @pytest.fixture() 34 | def conditioner(self): 35 | return ConditionerWrapper([]) 36 | 37 | @pytest.fixture() 38 | def vae(self): 39 | return AutoencoderKLDiffusers(AutoencoderKLDiffusersConfig()) 40 | 41 | @pytest.fixture() 42 | def sampling_noise_scheduler(self): 43 | return FlowMatchEulerDiscreteScheduler() 44 | 45 | @pytest.fixture() 46 | def training_noise_scheduler(self): 47 | return FlowMatchEulerDiscreteScheduler() 48 | 49 | @pytest.fixture() 50 | def model_config(self): 51 | return LBMConfig( 52 | source_key="source_image", 53 | target_key="target_image", 54 | ) 55 | 56 | @pytest.fixture() 57 | def model_input(self): 58 | return { 59 | "source_image": torch.randn(2, 3, 256, 256).to(DEVICE), 60 | "target_image": torch.randn(2, 3, 256, 256).to(DEVICE), 61 | } 62 | 63 | @pytest.fixture() 64 | def model( 65 | self, 66 | model_config, 67 | denoiser, 68 | vae, 69 | sampling_noise_scheduler, 70 | training_noise_scheduler, 71 | conditioner, 72 | ): 73 | return LBMModel( 74 | config=model_config, 75 | denoiser=denoiser, 76 | vae=vae, 77 | sampling_noise_scheduler=sampling_noise_scheduler, 78 | training_noise_scheduler=training_noise_scheduler, 79 | conditioner=conditioner, 80 | ).to(DEVICE) 81 | 82 | @torch.no_grad() 83 | def test_model_forward(self, model, model_input): 84 | model_output = model( 85 | model_input, 86 | ) 87 | assert model_output["loss"] > 0.0 88 | 89 | def test_optimizers(self, model, model_input): 90 | optimizer = torch.optim.Adam(model.denoiser.parameters(), lr=1e-4) 91 | 92 | model.train() 93 | model_init = deepcopy(model) 94 | optimizer.zero_grad() 95 | loss = model(model_input)["loss"] 96 | loss.backward() 97 | optimizer.step() 98 | assert not torch.equal( 99 | torch.cat([p.flatten() for p in model.denoiser.parameters()]), 100 | torch.cat([p.flatten() for p in model_init.denoiser.parameters()]), 101 | ) 102 | -------------------------------------------------------------------------------- /tests/test_unets/test_unets_wrappers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lbm.models.unets import DiffusersUNet2DCondWrapper, DiffusersUNet2DWrapper 5 | 6 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 7 | 8 | 9 | class TestDiffusersUNet2DWrapper: 10 | # simulates class conditioning 11 | @pytest.fixture(params=[None, torch.randint(256, (2,)).to(DEVICE)]) 12 | def conditioning(self, request): 13 | if request.param is not None: 14 | return {"cond": {"vector": request.param}} 15 | return None 16 | 17 | # simulates a latent sample 18 | @pytest.fixture() 19 | def sample(self): 20 | return torch.rand(2, 6, 32, 32).to(DEVICE) 21 | 22 | # simulates a timestep 23 | @pytest.fixture( 24 | params=[10.0, torch.randint(1000, (2,), dtype=torch.float).to(DEVICE), 3] 25 | ) 26 | def timesteps(self, request): 27 | return request.param 28 | 29 | def test_unet2d_wrapper(self, sample, timesteps, conditioning): 30 | unet = DiffusersUNet2DWrapper( 31 | sample_size=sample.shape[2:], 32 | in_channels=sample.shape[1], 33 | out_channels=3, 34 | num_class_embeds=256 if conditioning else None, 35 | ).to(DEVICE) 36 | output = unet(sample, timesteps, conditioning) 37 | assert output.shape == ( 38 | sample.shape[0], 39 | 3, 40 | sample.shape[2], 41 | sample.shape[3], 42 | ) 43 | 44 | 45 | class TestDiffusersUNet2DCondWrapper: 46 | # simulates class conditioning 47 | @pytest.fixture(params=[None, torch.randn(2, 256).to(DEVICE)]) 48 | def vector_conditioning(self, request): 49 | if request.param is not None: 50 | return {"vector": request.param} 51 | return None 52 | 53 | # simulates crossattn conditioning '(always needed for conditional UNet2D)' (see diffusers/models/unet.py 54 | @pytest.fixture() 55 | def crossattn_conditioning(self): 56 | return {"crossattn": torch.randn(2, 12, 123).to(DEVICE)} 57 | 58 | # simulates concat conditioning 59 | @pytest.fixture(params=[None, torch.randn(2, 2, 32, 32).to(DEVICE)]) 60 | def concat_conditioning(self, request): 61 | if request.param is not None: 62 | return {"concat": request.param} 63 | return None 64 | 65 | @pytest.fixture() 66 | def conditioning( 67 | self, vector_conditioning, crossattn_conditioning, concat_conditioning 68 | ): 69 | cond = dict(cond=crossattn_conditioning) 70 | if vector_conditioning is not None: 71 | cond["cond"].update(vector_conditioning) 72 | if concat_conditioning is not None: 73 | cond["cond"].update(concat_conditioning) 74 | return cond 75 | 76 | # simulates a latent sample 77 | @pytest.fixture() 78 | def sample(self): 79 | return torch.rand(2, 6, 32, 32).to(DEVICE) 80 | 81 | # simulates a timestep 82 | @pytest.fixture( 83 | params=[10.0, torch.randint(1000, (2,), dtype=torch.float).to(DEVICE), 3] 84 | ) 85 | def timesteps(self, request): 86 | return request.param 87 | 88 | def test_unet2d_cond_wrapper(self, sample, timesteps, conditioning): 89 | # for concat 90 | in_channels = ( 91 | sample.shape[1] + conditioning["cond"]["concat"].shape[1] 92 | if conditioning["cond"].get("concat", None) is not None 93 | else sample.shape[1] 94 | ) 95 | 96 | # for vector 97 | class_embed_type = ( 98 | "projection" if conditioning["cond"].get("vector") is not None else None 99 | ) 100 | projection_class_embeddings_input_dim = ( 101 | conditioning["cond"]["vector"].shape[1] 102 | if conditioning["cond"].get("vector") is not None 103 | else None 104 | ) 105 | 106 | # for crossattn 107 | cross_attention_dim = ( 108 | conditioning["cond"]["crossattn"].shape[2] 109 | if conditioning["cond"].get("crossattn") is not None 110 | else 1280 111 | ) 112 | 113 | unet = DiffusersUNet2DCondWrapper( 114 | sample_size=sample.shape[2:], 115 | in_channels=in_channels, 116 | out_channels=3, 117 | class_embed_type=class_embed_type, 118 | projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, 119 | cross_attention_dim=cross_attention_dim, 120 | ).to(DEVICE) 121 | output = unet(sample, timesteps, conditioning) 122 | assert output.shape == ( 123 | sample.shape[0], 124 | 3, 125 | sample.shape[2], 126 | sample.shape[3], 127 | ) 128 | -------------------------------------------------------------------------------- /tests/test_vaes/test_autoencoder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig 5 | 6 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 7 | 8 | 9 | class TestAutoencoderKLDiffusers: 10 | @pytest.fixture( 11 | params=[ 12 | dict(), 13 | dict( 14 | version="stabilityai/stable-diffusion-xl-base-1.0", 15 | subfolder="vae", 16 | ), 17 | ] 18 | ) 19 | def model_config(self, request): 20 | return AutoencoderKLDiffusersConfig( 21 | **request.param, tiling_size=(16, 16), tiling_overlap=(8, 8), batch_size=1 22 | ) 23 | 24 | @pytest.fixture() 25 | def model(self, model_config): 26 | return AutoencoderKLDiffusers(model_config).to(DEVICE) 27 | 28 | def test_model_initialization(self, model, model_config): 29 | assert model.config == model_config 30 | 31 | def test_encode(self, model): 32 | x = torch.randn(2, 3, 32, 32).to(DEVICE) 33 | z = model.encode(x) 34 | assert z.shape == (2, 4, 4, 4) 35 | 36 | def test_decode(self, model): 37 | z = torch.randn(2, 4, 4, 4).to(DEVICE) 38 | x = model.decode(z) 39 | assert x.shape == (2, 3, 32, 32) 40 | 41 | def test_decode_tiling(self, model): 42 | z = torch.randn(2, 4, 32, 32).to(DEVICE) 43 | x = model.decode(z) 44 | assert x.shape == (2, 3, 256, 256) 45 | --------------------------------------------------------------------------------