├── .gitattributes ├── README.md ├── __init__.py ├── bci ├── sampling_and_metrics.py └── script.py ├── datasets └── readme.txt ├── improved_consistency_model_conditional.py ├── irvi ├── consistency_models2.py └── script.py ├── llvip ├── checkpoints │ └── llvip │ │ ├── config.json │ │ └── model.pt ├── sampling_and_metrics.py └── script.py ├── lolv1 ├── sampling_and_metrics.py └── script.py ├── lolv2 ├── sampling_and_metrics.py ├── script.py └── switch_between_real_and_synthetic.txt ├── requirements.txt └── sid ├── sampling_and_metrics.py └── script.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Consistency Models 2 | 3 | Welcome to the official repository for Conditional Consistency Models (CCM). This repository hosts the implementation and evaluation of consistency models tailored for various datasets, including **IRVI**, **BCI**, **LLVIP**, **LOLv1**, **LOLv2**, and **SID**. 4 | 5 | This repository contains code for training and evaluating models. 6 | 7 | ## Table of Contents 8 | 9 | 1. [Requirements and Setup](#requirements-and-setup) 10 | 2. [Directory Structure](#directory-structure) 11 | 3. [Datasets](#datasets) 12 | 4. [Usage](#usage) 13 | 14 | 15 | ## Requirements and Setup 16 | 17 | ### 1. Clone the Repository 18 | ```bash 19 | cd Conditional-Consistency-Models 20 | ``` 21 | 22 | ### 2. Create a Conda Environment 23 | Create and activate a new conda environment named ccm: 24 | ```bash 25 | conda create -n ccm python=3.10 -y 26 | conda activate ccm 27 | ``` 28 | 29 | ### 3. Install Dependencies 30 | Install all required Python packages: 31 | ```bash 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## Directory Structure 36 | 37 | Ensure your project directory is organized as follows: 38 | 39 | ``` 40 | Conditional-Consistency-Models/ 41 | ├── datasets/ # Folder for datasets 42 | │ ├── bci/ 43 | │ ├── llvip/ 44 | │ ├── lolv1/ 45 | │ ├── lolv2/ 46 | │ ├── sid/ 47 | │ ├── irvi/ 48 | ├── bci/ # Folder for BCI model scripts and metrics 49 | │ ├── script.py 50 | │ ├── sampling_and_metric.py 51 | │ ├── checkpoints/bci/ # Pre-trained BCI model weights 52 | │ ├── model.json 53 | │ ├── model.pt 54 | ├── llvip/ 55 | ├── lolv1/ 56 | ├── lolv2/ 57 | ├── sid/ 58 | ├── irvi/ 59 | ├── improved_consistency_model_conditional.py 60 | ├── README.md 61 | ├── requirements.txt 62 | ``` 63 | 64 | ## Datasets 65 | 66 | ### Download Datasets 67 | 68 | Use the links below to download datasets for each model. Once downloaded, extract and place them inside the `datasets/` directory. 69 | 70 | - BCI Dataset: [https://bupt-ai-cz.github.io/BCI/](URL) 71 | - LLVIP Dataset: [https://drive.google.com/file/d/1VTlT3Y7e1h-Zsne4zahjx5q0TK2ClMVv/view](URL) 72 | - LOLv1 Dataset: [https://drive.google.com/file/d/1L-kqSQyrmMueBh_ziWoPFhfsAh50h20H/view](URL) 73 | - LOLv2 Dataset: [https://drive.google.com/file/d/1Ou9EljYZW8o5dbDCf9R34FS8Pd8kEp2U/view](URL) 74 | - SID Dataset: [https://drive.google.com/drive/folders/1eQ-5Z303sbASEvsgCBSDbhijzLTWQJtR](URL) 75 | - IRVI Dataset: [https://drive.google.com/file/d/1ZcJ0EfF5n_uqtsLc7-8hJgTcr2zHSXY3/view](URL) 76 | 77 | ## Usage 78 | 79 | ### 1. Training 80 | 81 | To train a model on a specific dataset, use the corresponding `script.py` inside the dataset's folder. 82 | 83 | For example, to train the BCI model: 84 | ```bash 85 | cd .. 86 | CUDA_VISIBLE_DEVICES=0 python -m bci.script 87 | ``` 88 | 89 | ### 2. Evaluation (Metrics) 90 | 91 | To evaluate the model and calculate PSNR/SSIM metrics, use the `sampling_and_metric.py` script. 92 | 93 | Example: Evaluate the BCI model: 94 | ```bash 95 | cd .. 96 | CUDA_VISIBLE_DEVICES=0 python -m bci.sampling_and_metric 97 | ``` 98 | 99 | Ensure the dataset is correctly placed, and pre-trained checkpoints are available before running the evaluation. 100 | 101 | ### Results 102 | 103 | Evaluation metrics (PSNR, SSIM) and generated images are saved in the corresponding output folder. 104 | 105 | For any issues, feel free to open an issue on this repository. 106 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amilbhagat/Conditional-Consistency-Models/473bc618e520afc2a133e2b9eabe13b1220afab2/__init__.py -------------------------------------------------------------------------------- /bci/sampling_and_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import torchvision.transforms as T 4 | import torchvision.transforms.functional as TF 5 | from improved_consistency_model_conditional import ConsistencySamplingAndEditing 6 | from bci.script import UNet 7 | import os 8 | from skimage.metrics import peak_signal_noise_ratio as psnr 9 | from skimage.metrics import structural_similarity as ssim 10 | import numpy as np 11 | 12 | 13 | ##In this context visible = HE, infrared = IHC (metrics was borrowed from llvip, hence the nomenclature) 14 | # =============================== 15 | # Configuration and Setup 16 | # =============================== 17 | 18 | # ------------------------------- 19 | # 1. Model Loading 20 | # ------------------------------- 21 | 22 | # Path to the pre-trained model checkpoint 23 | model_path = "checkpoints/bci" 24 | 25 | # Load the pre-trained UNet model 26 | model = UNet.from_pretrained(model_path) 27 | 28 | # Select device: GPU if available, else CPU 29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | model = model.to(device).eval() 31 | 32 | # ------------------------------- 33 | # 2. Image Transformations 34 | # ------------------------------- 35 | 36 | # Transformation pipeline: Convert PIL Image to Tensor and normalize to [-1, 1] 37 | transform = T.Compose([ 38 | T.ToTensor(), # Converts PIL Image to Tensor and scales to [0, 1] 39 | T.Lambda(lambda x: (x * 2) - 1), # Normalize to [-1, 1] 40 | ]) 41 | 42 | # Inverse transformation: Denormalize to [0, 1] and clamp 43 | inverse_transform = T.Compose([ 44 | T.Lambda(lambda x: (x + 1) / 2), # Denormalize to [0, 1] 45 | T.Lambda(lambda x: x.clamp(0, 1)) # Clamp values to [0, 1] 46 | ]) 47 | 48 | # ------------------------------- 49 | # 3. Define Image Folders 50 | # ------------------------------- 51 | 52 | visible_folder = "../datasets/bci/HE/test" # Folder containing visible (HE) images 53 | infrared_folder = "../datasets/bci/IHC/test" # Folder containing infrared (IHC) images 54 | 55 | # ------------------------------- 56 | # 4. Sampling Instance and Sigma Schedule 57 | # ------------------------------- 58 | 59 | # Initialize the consistency sampling and editing instance 60 | consistency_sampling = ConsistencySamplingAndEditing() 61 | 62 | # Define the sigma schedule for noise levels 63 | sigmas = [80.0, 40.0, 20.0, 10.0, 5.0, 2.5, 1.25, 0.625, 0.3125, 0.15625, 0.078125] 64 | 65 | # ------------------------------- 66 | # 5. Metrics Accumulators and Results Folder 67 | # ------------------------------- 68 | 69 | # Initialize accumulators for PSNR and SSIM metrics 70 | total_psnr = 0.0 71 | total_ssim = 0.0 72 | num_images = 0 73 | 74 | # Create results folder to save concatenated images 75 | results_folder = "results_bci" 76 | os.makedirs(results_folder, exist_ok=True) 77 | 78 | # ------------------------------- 79 | # 6. Define Resize Transforms 80 | # ------------------------------- 81 | 82 | # Resize down to 512x512 with bicubic interpolation and anti-aliasing 83 | resize_down = T.Resize((512, 512), interpolation=T.InterpolationMode.BICUBIC, antialias=True) 84 | 85 | # Resize up to 1024x1024 with bicubic interpolation 86 | resize_up = T.Resize((1024, 1024), interpolation=T.InterpolationMode.BICUBIC) 87 | 88 | # =============================== 89 | # Main Processing Loop 90 | # =============================== 91 | 92 | # Iterate through all visible images 93 | for idx, visible_image_name in enumerate(os.listdir(visible_folder), start=1): 94 | # Construct full paths to visible and infrared images 95 | visible_image_path = os.path.join(visible_folder, visible_image_name) 96 | infrared_image_path = os.path.join(infrared_folder, visible_image_name) 97 | 98 | # Check if the corresponding infrared image exists 99 | if not os.path.exists(infrared_image_path): 100 | print(f"[{idx}] Infrared image {infrared_image_path} not found, skipping.") 101 | continue 102 | 103 | try: 104 | # Load images and convert to RGB 105 | visible_image = Image.open(visible_image_path).convert("RGB") 106 | infrared_image = Image.open(infrared_image_path).convert("RGB") 107 | except Exception as e: 108 | print(f"[{idx}] Error loading images: {e}, skipping {visible_image_name}") 109 | continue 110 | 111 | # Apply transformations: Convert images to tensors and normalize 112 | visible_tensor = transform(visible_image).unsqueeze(0).to(device) # Shape: [1, 3, 1024, 1024] 113 | infrared_tensor = transform(infrared_image).unsqueeze(0).to(device) # Shape: [1, 3, 1024, 1024] 114 | 115 | # Resize infrared image to 512x512 116 | infrared_resized = TF.resize(infrared_tensor, [512, 512], interpolation=TF.InterpolationMode.BICUBIC, antialias=True) # Shape: [1, 3, 512, 512] 117 | 118 | # Resize visible image to 512x512 for model input 119 | visible_resized = TF.resize(visible_tensor, [512, 512], interpolation=TF.InterpolationMode.BICUBIC, antialias=True) # Shape: [1, 3, 512, 512] 120 | 121 | # Add Gaussian noise to the resized infrared image 122 | noise = torch.randn_like(infrared_resized) * sigmas[0] # Sigma = 80.0 123 | noisy_infrared_tensor = infrared_resized + noise 124 | 125 | try: 126 | # Generate denoised infrared image using the model 127 | with torch.no_grad(): 128 | generated_infrared_tensor = consistency_sampling( 129 | model=model, 130 | y=noisy_infrared_tensor, 131 | v=visible_resized, 132 | sigmas=sigmas, 133 | start_from_y=True, 134 | add_initial_noise=False, 135 | clip_denoised=True, 136 | verbose=False, 137 | ) 138 | except Exception as e: 139 | print(f"[{idx}] Error during model inference: {e}, skipping {visible_image_name}") 140 | continue 141 | 142 | # Denormalize the generated infrared tensor to [0, 1] 143 | generated_infrared_denorm = inverse_transform(generated_infrared_tensor.squeeze(0).cpu()).clamp(0, 1).numpy().transpose(1, 2, 0) # [512, 512, 3] 144 | 145 | # Upsize the generated infrared image back to 1024x1024 146 | generated_infrared_resized_pil = resize_up(T.ToPILImage()(generated_infrared_denorm)).convert("RGB") # PIL Image 147 | generated_infrared_resized = T.ToTensor()(generated_infrared_resized_pil).numpy().transpose(1, 2, 0) # [1024, 1024, 3] 148 | 149 | # Convert generated infrared image to uint8 for saving 150 | generated_image_save = (generated_infrared_resized * 255).astype(np.uint8) 151 | 152 | # Original visible and infrared images are already at 1024x1024 153 | visible_image_save = visible_image 154 | infrared_image_save = infrared_image 155 | 156 | # ------------------------------- 157 | # Concatenate Images for Visualization 158 | # ------------------------------- 159 | 160 | # Calculate concatenated image dimensions 161 | concatenated_width = visible_image_save.width + infrared_image_save.width + generated_image_save.shape[1] # 1024 + 1024 + 1024 = 3072 162 | concatenated_height = max(visible_image_save.height, infrared_image_save.height, generated_image_save.shape[0]) # 1024 163 | 164 | # Create a new blank image for concatenation 165 | concatenated_image = Image.new("RGB", (concatenated_width, concatenated_height)) 166 | 167 | # Paste the original visible image 168 | concatenated_image.paste(visible_image_save, (0, 0)) 169 | 170 | # Paste the original infrared image next to the visible image 171 | concatenated_image.paste(infrared_image_save, (visible_image_save.width, 0)) 172 | 173 | # Paste the generated infrared image next to the original infrared image 174 | concatenated_image.paste(Image.fromarray(generated_image_save), (visible_image_save.width + infrared_image_save.width, 0)) 175 | 176 | # Save the concatenated image 177 | concatenated_image_path = os.path.join(results_folder, f"concatenated_{visible_image_name}") 178 | concatenated_image.save(concatenated_image_path) 179 | 180 | # ------------------------------- 181 | # Calculate PSNR and SSIM Metrics 182 | # ------------------------------- 183 | 184 | # Convert original infrared image to tensor and denormalize to [0, 1] for metric calculation 185 | infrared_original_tensor = transform(infrared_image).unsqueeze(0).to(device) # Shape: [1, 3, 1024, 1024] 186 | infrared_original_denorm = inverse_transform(infrared_original_tensor.squeeze(0).cpu()).clamp(0, 1).numpy().transpose(1, 2, 0) # [1024, 1024, 3] 187 | 188 | # Ensure generated_infrared_resized is in [0,1] range 189 | generated_infrared_resized = generated_infrared_resized.clip(0, 1) 190 | 191 | # Calculate PSNR between original and generated infrared images 192 | psnr_value = psnr(infrared_original_denorm, generated_infrared_resized, data_range=1.0) 193 | 194 | # Calculate SSIM between original and generated infrared images 195 | ssim_value = ssim(infrared_original_denorm, generated_infrared_resized, data_range=1.0, multichannel=True, win_size=3, gaussian_weights=True, sigma=1.5) 196 | 197 | # Accumulate metrics 198 | total_psnr += psnr_value 199 | total_ssim += ssim_value 200 | num_images += 1 201 | 202 | # Print metrics for the current image 203 | print(f"[{idx}] Image: {visible_image_name} | PSNR: {psnr_value:.2f} | SSIM: {ssim_value:.4f}") 204 | 205 | # =============================== 206 | # Final Metrics Calculation 207 | # =============================== 208 | 209 | # Calculate and print average metrics 210 | if num_images > 0: 211 | avg_psnr = total_psnr / num_images 212 | avg_ssim = total_ssim / num_images 213 | 214 | print(f"\nProcessed {num_images} images.") 215 | print(f"Average PSNR: {avg_psnr:.2f}") 216 | print(f"Average SSIM: {avg_ssim:.4f}") 217 | 218 | # Save metrics to a text file 219 | with open("metrics.txt", "a") as f: 220 | f.write(f"Processed {num_images} images.\n") 221 | f.write(f"Average PSNR: {avg_psnr:.2f}\n") 222 | f.write(f"Average SSIM: {avg_ssim:.4f}\n") 223 | else: 224 | print("No images were processed.") 225 | -------------------------------------------------------------------------------- /bci/script.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import asdict, dataclass 4 | from typing import Any, Callable, List, Optional, Tuple, Union 5 | 6 | import torch 7 | from torchvision.transforms import functional as TF 8 | import math 9 | from einops import rearrange 10 | from einops.layers.torch import Rearrange 11 | from lightning import LightningDataModule, LightningModule, Trainer, seed_everything 12 | from lightning.pytorch.callbacks import LearningRateMonitor 13 | from lightning.pytorch.loggers import TensorBoardLogger 14 | from matplotlib import pyplot as plt 15 | from torch import Tensor, nn 16 | from torch.nn import functional as F 17 | from torch.utils.data import DataLoader 18 | from torchinfo import summary 19 | from torchvision import transforms as T 20 | from torchvision.datasets import ImageFolder 21 | from torchvision.utils import make_grid 22 | 23 | from improved_consistency_model_conditional import ( 24 | ConsistencySamplingAndEditing, 25 | ImprovedConsistencyTraining, 26 | pseudo_huber_loss, 27 | update_ema_model_, 28 | ) 29 | 30 | 31 | from torch.utils.data import Dataset 32 | from torchvision import transforms as T 33 | import os 34 | from PIL import Image 35 | import torch 36 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint 37 | 38 | class PairedDataset(Dataset): 39 | def __init__( 40 | self, 41 | HE_dir: str, 42 | IHC_dir: str, 43 | transform: Optional[Callable] = None, 44 | crop_size: Tuple[int, int] = (256, 256), 45 | resize_size: Tuple[int, int] = (128, 128), 46 | ): 47 | 48 | self.HE_dir = HE_dir 49 | self.IHC_dir = IHC_dir 50 | self.HE_images = sorted(os.listdir(HE_dir)) 51 | self.IHC_images = sorted(os.listdir(IHC_dir)) 52 | self.transform = transform 53 | self.crop_size = crop_size 54 | self.resize_size = resize_size 55 | def __len__(self) -> int: 56 | return len(self.HE_images) 57 | 58 | def __getitem__(self, index: int) -> Optional[Tuple[Tensor, Tensor]]: 59 | HE_path = os.path.join(self.HE_dir, self.HE_images[index]) 60 | IHC_path = os.path.join(self.IHC_dir, self.IHC_images[index]) 61 | 62 | HE_image = Image.open(HE_path).convert("RGB") 63 | IHC_image = Image.open(IHC_path).convert("RGB") 64 | 65 | if HE_image.size != IHC_image.size: 66 | print(f"Skipping image pair at index {index} due to mismatched sizes") 67 | return None 68 | 69 | if torch.rand(1).item() > 0.5: 70 | HE_image = TF.hflip(HE_image) 71 | IHC_image = TF.hflip(IHC_image) 72 | 73 | i, j, h, w = T.RandomCrop.get_params(HE_image, output_size=self.crop_size) 74 | HE_image = TF.crop(HE_image, i, j, h, w) 75 | IHC_image = TF.crop(IHC_image, i, j, h, w) 76 | 77 | if self.transform: 78 | HE_image = self.transform(HE_image) 79 | IHC_image = self.transform(IHC_image) 80 | 81 | return HE_image, IHC_image 82 | 83 | from dataclasses import dataclass 84 | from typing import Tuple 85 | 86 | @dataclass 87 | class ImageDataModuleConfig: 88 | data_dir: str = "dataset/BCI" # Path to the dataset directory 89 | image_size_crop: Tuple[int, int] = (256, 256) # Size for random cropping 90 | image_size_resize: Tuple[int, int] = (128, 128) # Resize to 128x128 91 | batch_size: int = 4 # Number of images in each batch 92 | num_workers: int = 8 # Number of worker threads for data loading 93 | pin_memory: bool = True # Whether to pin memory in data loader 94 | persistent_workers: bool = True # Keep workers alive between epochs 95 | 96 | from torch.utils.data import DataLoader 97 | from lightning.pytorch import LightningDataModule 98 | 99 | class LLVIPDataModule(LightningDataModule): 100 | def __init__(self, config: ImageDataModuleConfig) -> None: 101 | super().__init__() 102 | self.config = config 103 | 104 | def setup(self, stage: str = None) -> None: 105 | # Define transforms excluding cropping and resizing 106 | self.transform = T.Compose([ 107 | T.ToTensor(), 108 | T.Lambda(lambda x: (x * 2) - 1), # Normalize to [-1, 1] 109 | ]) 110 | 111 | self.dataset = PairedDataset( 112 | HE_dir=os.path.join(self.config.data_dir, "HE/train"), 113 | IHC_dir=os.path.join(self.config.data_dir, "IHC/train"), 114 | transform=self.transform, 115 | crop_size=self.config.image_size_crop, 116 | resize_size=self.config.image_size_resize 117 | ) 118 | 119 | def train_dataloader(self) -> DataLoader: 120 | return DataLoader( 121 | self.dataset, 122 | batch_size=self.config.batch_size, 123 | shuffle=True, 124 | num_workers=self.config.num_workers, 125 | pin_memory=self.config.pin_memory, 126 | persistent_workers=self.config.persistent_workers, 127 | ) 128 | 129 | def GroupNorm(channels: int) -> nn.GroupNorm: 130 | return nn.GroupNorm(num_groups=min(32, channels // 4), num_channels=channels) 131 | 132 | 133 | class SelfAttention(nn.Module): 134 | def __init__( 135 | self, 136 | in_channels: int, 137 | out_channels: int, 138 | n_heads: int = 8, 139 | dropout: float = 0.3, 140 | ) -> None: 141 | super().__init__() 142 | 143 | self.dropout = dropout 144 | 145 | self.qkv_projection = nn.Sequential( 146 | GroupNorm(in_channels), 147 | nn.Conv2d(in_channels, 3 * in_channels, kernel_size=1, bias=False), 148 | Rearrange("b (i h d) x y -> i b h (x y) d", i=3, h=n_heads), 149 | ) 150 | self.output_projection = nn.Sequential( 151 | Rearrange("b h l d -> b l (h d)"), 152 | nn.Linear(in_channels, out_channels, bias=False), 153 | Rearrange("b l d -> b d l"), 154 | GroupNorm(out_channels), 155 | nn.Dropout1d(dropout), 156 | ) 157 | self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1) 158 | 159 | def forward(self, x: Tensor) -> Tensor: 160 | q, k, v = self.qkv_projection(x).unbind(dim=0) 161 | 162 | output = F.scaled_dot_product_attention( 163 | q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=False 164 | ) 165 | output = self.output_projection(output) 166 | output = rearrange(output, "b c (x y) -> b c x y", x=x.shape[-2], y=x.shape[-1]) 167 | 168 | return output + self.residual_projection(x) 169 | 170 | 171 | class UNetBlock(nn.Module): 172 | def __init__( 173 | self, 174 | in_channels: int, 175 | out_channels: int, 176 | noise_level_channels: int, 177 | dropout: float = 0.3, 178 | ) -> None: 179 | super().__init__() 180 | 181 | self.input_projection = nn.Sequential( 182 | GroupNorm(in_channels), 183 | nn.SiLU(), 184 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same"), 185 | nn.Dropout2d(dropout), 186 | ) 187 | self.noise_level_projection = nn.Sequential( 188 | nn.SiLU(), 189 | nn.Conv2d(noise_level_channels, out_channels, kernel_size=1), 190 | ) 191 | self.output_projection = nn.Sequential( 192 | GroupNorm(out_channels), 193 | nn.SiLU(), 194 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same"), 195 | nn.Dropout2d(dropout), 196 | ) 197 | self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1) 198 | 199 | def forward(self, x: Tensor, noise_level: Tensor) -> Tensor: 200 | h = self.input_projection(x) 201 | h = h + self.noise_level_projection(noise_level) 202 | 203 | return self.output_projection(h) + self.residual_projection(x) 204 | 205 | 206 | class UNetBlockWithSelfAttention(nn.Module): 207 | def __init__( 208 | self, 209 | in_channels: int, 210 | out_channels: int, 211 | noise_level_channels: int, 212 | n_heads: int = 8, 213 | dropout: float = 0.3, 214 | ) -> None: 215 | super().__init__() 216 | 217 | self.unet_block = UNetBlock( 218 | in_channels, out_channels, noise_level_channels, dropout 219 | ) 220 | self.self_attention = SelfAttention( 221 | out_channels, out_channels, n_heads, dropout 222 | ) 223 | 224 | def forward(self, x: Tensor, noise_level: Tensor) -> Tensor: 225 | return self.self_attention(self.unet_block(x, noise_level)) 226 | 227 | 228 | class Downsample(nn.Module): 229 | def __init__(self, channels: int) -> None: 230 | super().__init__() 231 | 232 | self.projection = nn.Sequential( 233 | Rearrange("b c (h ph) (w pw) -> b (c ph pw) h w", ph=2, pw=2), 234 | nn.Conv2d(4 * channels, channels, kernel_size=1), 235 | ) 236 | 237 | def forward(self, x: Tensor) -> Tensor: 238 | return self.projection(x) 239 | 240 | 241 | class Upsample(nn.Module): 242 | def __init__(self, channels: int) -> None: 243 | super().__init__() 244 | 245 | self.projection = nn.Sequential( 246 | nn.Upsample(scale_factor=2.0, mode="nearest"), 247 | nn.Conv2d(channels, channels, kernel_size=3, padding="same"), 248 | ) 249 | 250 | def forward(self, x: Tensor) -> Tensor: 251 | return self.projection(x) 252 | 253 | 254 | class NoiseLevelEmbedding(nn.Module): 255 | def __init__(self, channels: int, scale: float = 0.02) -> None: 256 | super().__init__() 257 | 258 | self.W = nn.Parameter(torch.randn(channels // 2) * scale, requires_grad=False) 259 | 260 | self.projection = nn.Sequential( 261 | nn.Linear(channels, 4 * channels), 262 | nn.SiLU(), 263 | nn.Linear(4 * channels, channels), 264 | Rearrange("b c -> b c () ()"), 265 | ) 266 | 267 | def forward(self, x: Tensor) -> Tensor: 268 | h = x[:, None] * self.W[None, :] * 2 * torch.pi 269 | h = torch.cat([torch.sin(h), torch.cos(h)], dim=-1) 270 | 271 | return self.projection(h) 272 | 273 | 274 | @dataclass 275 | class UNetConfig: 276 | channels: int = 3 277 | noise_level_channels: int = 256 278 | noise_level_scale: float = 0.02 279 | n_heads: int = 8 280 | top_blocks_channels: Tuple[int, ...] = (128, 128) 281 | top_blocks_n_blocks_per_resolution: Tuple[int, ...] = (2, 2) 282 | top_blocks_has_resampling: Tuple[bool, ...] = (True, True) 283 | top_blocks_dropout: Tuple[float, ...] = (0.0, 0.0) 284 | mid_blocks_channels: Tuple[int, ...] = (256, 512) 285 | mid_blocks_n_blocks_per_resolution: Tuple[int, ...] = (4, 4) 286 | mid_blocks_has_resampling: Tuple[bool, ...] = (True, False) 287 | mid_blocks_dropout: Tuple[float, ...] = (0.0, 0.3) 288 | 289 | 290 | class UNet(nn.Module): 291 | def __init__(self, config: UNetConfig) -> None: 292 | super().__init__() 293 | 294 | self.config = config 295 | 296 | self.input_projection = nn.Conv2d( 297 | config.channels * 2, 298 | config.top_blocks_channels[0], 299 | kernel_size=3, 300 | padding="same", 301 | ) 302 | self.noise_level_embedding = NoiseLevelEmbedding( 303 | config.noise_level_channels, config.noise_level_scale 304 | ) 305 | self.top_encoder_blocks = self._make_encoder_blocks( 306 | self.config.top_blocks_channels + self.config.mid_blocks_channels[:1], 307 | self.config.top_blocks_n_blocks_per_resolution, 308 | self.config.top_blocks_has_resampling, 309 | self.config.top_blocks_dropout, 310 | self._make_top_block, 311 | ) 312 | self.mid_encoder_blocks = self._make_encoder_blocks( 313 | self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:], 314 | self.config.mid_blocks_n_blocks_per_resolution, 315 | self.config.mid_blocks_has_resampling, 316 | self.config.mid_blocks_dropout, 317 | self._make_mid_block, 318 | ) 319 | self.mid_decoder_blocks = self._make_decoder_blocks( 320 | self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:], 321 | self.config.mid_blocks_n_blocks_per_resolution, 322 | self.config.mid_blocks_has_resampling, 323 | self.config.mid_blocks_dropout, 324 | self._make_mid_block, 325 | ) 326 | self.top_decoder_blocks = self._make_decoder_blocks( 327 | self.config.top_blocks_channels + self.config.mid_blocks_channels[:1], 328 | self.config.top_blocks_n_blocks_per_resolution, 329 | self.config.top_blocks_has_resampling, 330 | self.config.top_blocks_dropout, 331 | self._make_top_block, 332 | ) 333 | self.output_projection = nn.Conv2d( 334 | config.top_blocks_channels[0], 335 | config.channels, 336 | kernel_size=3, 337 | padding="same", 338 | ) 339 | 340 | def forward(self, x: Tensor, noise_level: Tensor, v: Tensor) -> Tensor: 341 | x = torch.cat([x, v], dim = 1) 342 | h = self.input_projection(x) 343 | noise_level = self.noise_level_embedding(noise_level) 344 | 345 | top_encoder_embeddings = [] 346 | for block in self.top_encoder_blocks: 347 | if isinstance(block, UNetBlock): 348 | h = block(h, noise_level) 349 | top_encoder_embeddings.append(h) 350 | else: 351 | h = block(h) 352 | 353 | mid_encoder_embeddings = [] 354 | for block in self.mid_encoder_blocks: 355 | if isinstance(block, UNetBlockWithSelfAttention): 356 | h = block(h, noise_level) 357 | mid_encoder_embeddings.append(h) 358 | else: 359 | h = block(h) 360 | 361 | for block in self.mid_decoder_blocks: 362 | if isinstance(block, UNetBlockWithSelfAttention): 363 | h = torch.cat((h, mid_encoder_embeddings.pop()), dim=1) 364 | h = block(h, noise_level) 365 | else: 366 | h = block(h) 367 | 368 | for block in self.top_decoder_blocks: 369 | if isinstance(block, UNetBlock): 370 | h = torch.cat((h, top_encoder_embeddings.pop()), dim=1) 371 | h = block(h, noise_level) 372 | else: 373 | h = block(h) 374 | 375 | output = self.output_projection(h) 376 | 377 | return output 378 | 379 | def _make_encoder_blocks( 380 | self, 381 | channels: Tuple[int, ...], 382 | n_blocks_per_resolution: Tuple[int, ...], 383 | has_resampling: Tuple[bool, ...], 384 | dropout: Tuple[float, ...], 385 | block_fn: Callable[[], nn.Module], 386 | ) -> nn.ModuleList: 387 | blocks = nn.ModuleList() 388 | 389 | channel_pairs = list(zip(channels[:-1], channels[1:])) 390 | for idx, (in_channels, out_channels) in enumerate(channel_pairs): 391 | for _ in range(n_blocks_per_resolution[idx]): 392 | blocks.append(block_fn(in_channels, out_channels, dropout[idx])) 393 | in_channels = out_channels 394 | 395 | if has_resampling[idx]: 396 | blocks.append(Downsample(out_channels)) 397 | 398 | return blocks 399 | 400 | def _make_decoder_blocks( 401 | self, 402 | channels: Tuple[int, ...], 403 | n_blocks_per_resolution: Tuple[int, ...], 404 | has_resampling: Tuple[bool, ...], 405 | dropout: Tuple[float, ...], 406 | block_fn: Callable[[], nn.Module], 407 | ) -> nn.ModuleList: 408 | blocks = nn.ModuleList() 409 | 410 | channel_pairs = list(zip(channels[:-1], channels[1:]))[::-1] 411 | for idx, (out_channels, in_channels) in enumerate(channel_pairs): 412 | if has_resampling[::-1][idx]: 413 | blocks.append(Upsample(in_channels)) 414 | 415 | inner_blocks = [] 416 | for _ in range(n_blocks_per_resolution[::-1][idx]): 417 | inner_blocks.append( 418 | block_fn(in_channels * 2, out_channels, dropout[::-1][idx]) 419 | ) 420 | out_channels = in_channels 421 | blocks.extend(inner_blocks[::-1]) 422 | 423 | return blocks 424 | 425 | def _make_top_block( 426 | self, in_channels: int, out_channels: int, dropout: float 427 | ) -> UNetBlock: 428 | return UNetBlock( 429 | in_channels, 430 | out_channels, 431 | self.config.noise_level_channels, 432 | dropout, 433 | ) 434 | 435 | def _make_mid_block( 436 | self, 437 | in_channels: int, 438 | out_channels: int, 439 | dropout: float, 440 | ) -> UNetBlockWithSelfAttention: 441 | return UNetBlockWithSelfAttention( 442 | in_channels, 443 | out_channels, 444 | self.config.noise_level_channels, 445 | self.config.n_heads, 446 | dropout, 447 | ) 448 | 449 | def save_pretrained(self, pretrained_path: str) -> None: 450 | os.makedirs(pretrained_path, exist_ok=True) 451 | 452 | with open(os.path.join(pretrained_path, "config.json"), mode="w") as f: 453 | json.dump(asdict(self.config), f) 454 | 455 | torch.save(self.state_dict(), os.path.join(pretrained_path, "model.pt")) 456 | 457 | @classmethod 458 | def from_pretrained(cls, pretrained_path: str) -> "UNet": 459 | with open(os.path.join(pretrained_path, "config.json"), mode="r") as f: 460 | config_dict = json.load(f) 461 | config = UNetConfig(**config_dict) 462 | 463 | model = cls(config) 464 | 465 | state_dict = torch.load( 466 | os.path.join(pretrained_path, "model.pt"), map_location=torch.device("cpu") 467 | ) 468 | model.load_state_dict(state_dict) 469 | 470 | return model 471 | 472 | 473 | @dataclass 474 | class LitImprovedConsistencyModelConfig: 475 | ema_decay_rate: float = 0.99993 476 | lr: float = 1e-4 477 | betas: Tuple[float, float] = (0.9, 0.995) 478 | lr_scheduler_start_factor: float = 1e-5 479 | lr_scheduler_iters: int = 10_000 480 | sample_every_n_steps: int = 10_000 481 | num_samples: int = 8 482 | sampling_sigmas: Tuple[Tuple[int, ...], ...] = ( 483 | (80,), 484 | (80.0, 0.661), 485 | (80.0, 24.4, 5.84, 0.9, 0.661), 486 | ) 487 | 488 | 489 | class LitImprovedConsistencyModel(LightningModule): 490 | def __init__( 491 | self, 492 | consistency_training: ImprovedConsistencyTraining, 493 | consistency_sampling: ConsistencySamplingAndEditing, 494 | model: UNet, 495 | ema_model: UNet, 496 | config: LitImprovedConsistencyModelConfig, 497 | ) -> None: 498 | super().__init__() 499 | 500 | self.consistency_training = consistency_training 501 | self.consistency_sampling = consistency_sampling 502 | self.model = model 503 | self.ema_model = ema_model 504 | self.config = config 505 | 506 | # Freeze the EMA model and set it to eval mode 507 | for param in self.ema_model.parameters(): 508 | param.requires_grad = False 509 | self.ema_model = self.ema_model.eval() 510 | 511 | def training_step(self, batch: Union[Tensor, List[Tensor]], batch_idx: int) -> None: 512 | # if isinstance(batch, list): 513 | # batch = batch[0] 514 | 515 | HE_images, IHC_images = batch # Unpack the batch 516 | 517 | output = self.consistency_training( 518 | self.model, 519 | IHC_images, 520 | HE_images, 521 | self.global_step, 522 | self.trainer.max_steps 523 | ) 524 | 525 | loss = ( 526 | pseudo_huber_loss(output.predicted, output.target) * output.loss_weights 527 | ).mean() 528 | 529 | self.log_dict({"train_loss": loss, "num_timesteps": output.num_timesteps}) 530 | 531 | return loss 532 | 533 | def on_train_batch_end( 534 | self, outputs: Any, batch: Union[Tensor, List[Tensor]], batch_idx: int 535 | ) -> None: 536 | update_ema_model_(self.model, self.ema_model, self.config.ema_decay_rate) 537 | 538 | if ( 539 | (self.global_step + 1) % self.config.sample_every_n_steps == 0 540 | ) or self.global_step == 0: 541 | self.__sample_and_log_samples(batch) 542 | 543 | def configure_optimizers(self): 544 | opt = torch.optim.Adam( 545 | self.model.parameters(), lr=self.config.lr, betas=self.config.betas 546 | ) 547 | sched = torch.optim.lr_scheduler.LinearLR( 548 | opt, 549 | start_factor=self.config.lr_scheduler_start_factor, 550 | total_iters=self.config.lr_scheduler_iters, 551 | ) 552 | sched = {"scheduler": sched, "interval": "step", "frequency": 1} 553 | 554 | return [opt], [sched] 555 | 556 | @torch.no_grad() 557 | def __sample_and_log_samples(self, batch: Union[Tensor, List[Tensor]]) -> None: 558 | if isinstance(batch, list): 559 | batch = batch[0] 560 | 561 | # Ensure the number of samples does not exceed the batch size 562 | num_samples = min(self.config.num_samples, batch.shape[0]) 563 | noise = torch.randn_like(batch[:num_samples]) 564 | 565 | # Log ground truth samples 566 | self.__log_images( 567 | batch[:num_samples].detach().clone(), f"ground_truth", self.global_step 568 | ) 569 | 570 | for sigmas in self.config.sampling_sigmas: 571 | samples = self.consistency_sampling( 572 | self.ema_model, noise, sigmas, clip_denoised=True, verbose=True 573 | ) 574 | samples = samples.clamp(min=-1.0, max=1.0) 575 | 576 | # Generated samples 577 | self.__log_images( 578 | samples, 579 | f"generated_samples-sigmas={sigmas}", 580 | self.global_step, 581 | ) 582 | 583 | @torch.no_grad() 584 | def __log_images(self, images: Tensor, title: str, global_step: int) -> None: 585 | images = images.detach().float() 586 | 587 | grid = make_grid( 588 | images.clamp(-1.0, 1.0), value_range=(-1.0, 1.0), normalize=True 589 | ) 590 | self.logger.experiment.add_image(title, grid, global_step) 591 | 592 | 593 | @dataclass 594 | class TrainingConfig: 595 | image_dm_config: ImageDataModuleConfig 596 | unet_config: UNetConfig 597 | consistency_training: ImprovedConsistencyTraining 598 | consistency_sampling: ConsistencySamplingAndEditing 599 | lit_icm_config: LitImprovedConsistencyModelConfig 600 | trainer: Trainer 601 | model_ckpt_path: str = "checkpoints/bci" 602 | seed: int = 42 603 | 604 | def run_training(config: TrainingConfig) -> None: 605 | # Set seed 606 | seed_everything(config.seed) 607 | 608 | # Create data module 609 | dm = LLVIPDataModule(config.image_dm_config) 610 | dm.setup() 611 | print("DataModule setup complete.") 612 | 613 | # Create model and its EMA 614 | model = UNet(config.unet_config) 615 | ema_model = UNet(config.unet_config) 616 | ema_model.load_state_dict(model.state_dict()) 617 | 618 | # Create lightning module 619 | lit_icm = LitImprovedConsistencyModel( 620 | config.consistency_training, 621 | config.consistency_sampling, 622 | model, 623 | ema_model, 624 | config.lit_icm_config, 625 | ) 626 | 627 | print("Lightning module created.") 628 | 629 | # Run training 630 | print("Starting training...") 631 | config.trainer.fit(lit_icm, datamodule=dm) 632 | print("Training completed.") 633 | 634 | # Save model 635 | lit_icm.model.save_pretrained(config.model_ckpt_path) 636 | print("Model saved.") 637 | 638 | # Main function 639 | def main(): 640 | # Define the checkpoint callback 641 | checkpoint_callback = ModelCheckpoint( 642 | dirpath="checkpoints_bci", 643 | filename="{epoch}-{step}", 644 | save_top_k=-1, # Save all checkpoints 645 | every_n_epochs=20, # Adjust as needed 646 | ) 647 | 648 | # Set up the logger 649 | logger = TensorBoardLogger("logs", name="bci_icm") 650 | 651 | training_config = TrainingConfig( 652 | image_dm_config=ImageDataModuleConfig(data_dir="../datasets/bci"), 653 | unet_config=UNetConfig(), 654 | consistency_training=ImprovedConsistencyTraining(final_timesteps=11), 655 | consistency_sampling=ConsistencySamplingAndEditing(), 656 | lit_icm_config=LitImprovedConsistencyModelConfig( 657 | sample_every_n_steps=2100000, lr_scheduler_iters=1000 658 | ), 659 | trainer=Trainer( 660 | max_steps=100, 661 | precision="16", 662 | log_every_n_steps=10, 663 | logger=logger, 664 | callbacks=[ 665 | LearningRateMonitor(logging_interval="step"), 666 | checkpoint_callback, # Add the checkpoint callback here 667 | ], 668 | ), 669 | ) 670 | run_training(training_config) 671 | 672 | if __name__ == "__main__": 673 | main() -------------------------------------------------------------------------------- /datasets/readme.txt: -------------------------------------------------------------------------------- 1 | download all the datasets here by the respective paths: 2 | 3 | LLVIP: 4 | datasets/LLVIP 5 | [datasets/LLVIP/visible , datasets/LLVIP/infrared] 6 | 7 | BCI: 8 | datasets/bci 9 | [datasets/bci/HE , datasets/bci/IHC] 10 | 11 | LOLv1: 12 | datasets/lolv1 13 | [datasets/lolv1/ours485/ , datasets/lolv1/eval15/] 14 | 15 | LOLv2: 16 | datasets/LOL-v2/Real_captured 17 | datasets/LOL-v2/Synthetic 18 | [lolv2/LOL-v2/Real_captured/Test , lolv2/LOL-v2/Real_captured/Train 19 | lolv2/LOL-v2/Synthetic/Test , lolv2/LOL-v2/Synthetic/Test] 20 | 21 | SID: 22 | datasets/sid/Sony 23 | [datasets/sid/Sony/long , datasets/sid/Sony/long , datasets/sid/Sony_test_list.txt , datasets/sid/Sony_train_list.txt , sdatasets/sid/Sony_val_list.txt] -------------------------------------------------------------------------------- /irvi/consistency_models2.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import Any, Callable, Iterable, Optional, Tuple, Union 4 | 5 | import torch 6 | from torch import Tensor, nn 7 | from tqdm.auto import tqdm 8 | 9 | from typing import Iterator 10 | 11 | from torch import Tensor, nn 12 | 13 | 14 | def pad_dims_like(x: Tensor, other: Tensor) -> Tensor: 15 | """Pad dimensions of tensor `x` to match the shape of tensor `other`. 16 | 17 | Parameters 18 | ---------- 19 | x : Tensor 20 | Tensor to be padded. 21 | other : Tensor 22 | Tensor whose shape will be used as reference for padding. 23 | 24 | Returns 25 | ------- 26 | Tensor 27 | Padded tensor with the same shape as other. 28 | """ 29 | ndim = other.ndim - x.ndim 30 | return x.view(*x.shape, *((1,) * ndim)) 31 | 32 | 33 | def _update_ema_weights( 34 | ema_weight_iter: Iterator[Tensor], 35 | online_weight_iter: Iterator[Tensor], 36 | ema_decay_rate: float, 37 | ) -> None: 38 | for ema_weight, online_weight in zip(ema_weight_iter, online_weight_iter): 39 | if ema_weight.data is None: 40 | ema_weight.data.copy_(online_weight.data) 41 | else: 42 | ema_weight.data.lerp_(online_weight.data, 1.0 - ema_decay_rate) 43 | 44 | 45 | def update_ema_model_( 46 | ema_model: nn.Module, online_model: nn.Module, ema_decay_rate: float 47 | ) -> nn.Module: 48 | """Updates weights of a moving average model with an online/source model. 49 | 50 | Parameters 51 | ---------- 52 | ema_model : nn.Module 53 | Moving average model. 54 | online_model : nn.Module 55 | Online or source model. 56 | ema_decay_rate : float 57 | Parameter that controls by how much the moving average weights are changed. 58 | 59 | Returns 60 | ------- 61 | nn.Module 62 | Updated moving average model. 63 | """ 64 | # Update parameters 65 | _update_ema_weights( 66 | ema_model.parameters(), online_model.parameters(), ema_decay_rate 67 | ) 68 | # Update buffers 69 | _update_ema_weights(ema_model.buffers(), online_model.buffers(), ema_decay_rate) 70 | 71 | return ema_model 72 | 73 | 74 | def pad_dims_like(x: Tensor, other: Tensor) -> Tensor: 75 | """Pad dimensions of tensor `x` to match the shape of tensor `other`. 76 | 77 | Parameters 78 | ---------- 79 | x : Tensor 80 | Tensor to be padded. 81 | other : Tensor 82 | Tensor whose shape will be used as reference for padding. 83 | 84 | Returns 85 | ------- 86 | Tensor 87 | Padded tensor with the same shape as other. 88 | """ 89 | ndim = other.ndim - x.ndim 90 | return x.view(*x.shape, *((1,) * ndim)) 91 | 92 | 93 | def timesteps_schedule( 94 | current_training_step: int, 95 | total_training_steps: int, 96 | initial_timesteps: int = 2, 97 | final_timesteps: int = 150, 98 | ) -> int: 99 | """Implements the proposed timestep discretization schedule. 100 | 101 | Parameters 102 | ---------- 103 | current_training_step : int 104 | Current step in the training loop. 105 | total_training_steps : int 106 | Total number of steps the model will be trained for. 107 | initial_timesteps : int, default=2 108 | Timesteps at the start of training. 109 | final_timesteps : int, default=150 110 | Timesteps at the end of training. 111 | 112 | Returns 113 | ------- 114 | int 115 | Number of timesteps at the current point in training. 116 | """ 117 | num_timesteps = (final_timesteps + 1) ** 2 - initial_timesteps**2 118 | num_timesteps = current_training_step * num_timesteps / total_training_steps 119 | num_timesteps = math.ceil(math.sqrt(num_timesteps + initial_timesteps**2) - 1) 120 | 121 | return num_timesteps + 1 122 | 123 | 124 | def improved_timesteps_schedule( 125 | current_training_step: int, 126 | total_training_steps: int, 127 | initial_timesteps: int = 10, 128 | final_timesteps: int = 1280, 129 | ) -> int: 130 | """Implements the improved timestep discretization schedule. 131 | 132 | Parameters 133 | ---------- 134 | current_training_step : int 135 | Current step in the training loop. 136 | total_training_steps : int 137 | Total number of steps the model will be trained for. 138 | initial_timesteps : int, default=2 139 | Timesteps at the start of training. 140 | final_timesteps : int, default=150 141 | Timesteps at the end of training. 142 | 143 | Returns 144 | ------- 145 | int 146 | Number of timesteps at the current point in training. 147 | 148 | References 149 | ---------- 150 | [1] [Improved Techniques For Consistency Training](https://arxiv.org/pdf/2310.14189.pdf) 151 | """ 152 | total_training_steps_prime = math.floor( 153 | total_training_steps 154 | / (math.log2(math.floor(final_timesteps / initial_timesteps)) + 1) 155 | ) 156 | num_timesteps = initial_timesteps * math.pow( 157 | 2, math.floor(current_training_step / total_training_steps_prime) 158 | ) 159 | num_timesteps = min(num_timesteps, final_timesteps) + 1 160 | 161 | return num_timesteps 162 | 163 | 164 | def ema_decay_rate_schedule( 165 | num_timesteps: int, initial_ema_decay_rate: float = 0.95, initial_timesteps: int = 2 166 | ) -> float: 167 | """Implements the proposed EMA decay rate schedule. 168 | 169 | Parameters 170 | ---------- 171 | num_timesteps : int 172 | Number of timesteps at the current point in training. 173 | initial_ema_decay_rate : float, default=0.95 174 | EMA rate at the start of training. 175 | initial_timesteps : int, default=2 176 | Timesteps at the start of training. 177 | 178 | Returns 179 | ------- 180 | float 181 | EMA decay rate at the current point in training. 182 | """ 183 | return math.exp( 184 | (initial_timesteps * math.log(initial_ema_decay_rate)) / num_timesteps 185 | ) 186 | 187 | 188 | def karras_schedule( 189 | num_timesteps: int, 190 | sigma_min: float = 0.002, 191 | sigma_max: float = 80.0, 192 | rho: float = 7.0, 193 | device: torch.device = None, 194 | ) -> Tensor: 195 | """Implements the karras schedule that controls the standard deviation of 196 | noise added. 197 | 198 | Parameters 199 | ---------- 200 | num_timesteps : int 201 | Number of timesteps at the current point in training. 202 | sigma_min : float, default=0.002 203 | Minimum standard deviation. 204 | sigma_max : float, default=80.0 205 | Maximum standard deviation 206 | rho : float, default=7.0 207 | Schedule hyper-parameter. 208 | device : torch.device, default=None 209 | Device to generate the schedule/sigmas/boundaries/ts on. 210 | 211 | Returns 212 | ------- 213 | Tensor 214 | Generated schedule/sigmas/boundaries/ts. 215 | """ 216 | rho_inv = 1.0 / rho 217 | # Clamp steps to 1 so that we don't get nans 218 | steps = torch.arange(num_timesteps, device=device) / max(num_timesteps - 1, 1) 219 | sigmas = sigma_min**rho_inv + steps * ( 220 | sigma_max**rho_inv - sigma_min**rho_inv 221 | ) 222 | sigmas = sigmas**rho 223 | 224 | return sigmas 225 | 226 | 227 | def lognormal_timestep_distribution( 228 | num_samples: int, 229 | sigmas: Tensor, 230 | mean: float = -1.1, 231 | std: float = 2.0, 232 | ) -> Tensor: 233 | """Draws timesteps from a lognormal distribution. 234 | 235 | Parameters 236 | ---------- 237 | num_samples : int 238 | Number of samples to draw. 239 | sigmas : Tensor 240 | Standard deviations of the noise. 241 | mean : float, default=-1.1 242 | Mean of the lognormal distribution. 243 | std : float, default=2.0 244 | Standard deviation of the lognormal distribution. 245 | 246 | Returns 247 | ------- 248 | Tensor 249 | Timesteps drawn from the lognormal distribution. 250 | 251 | References 252 | ---------- 253 | [1] [Improved Techniques For Consistency Training](https://arxiv.org/pdf/2310.14189.pdf) 254 | """ 255 | pdf = torch.erf((torch.log(sigmas[1:]) - mean) / (std * math.sqrt(2))) - torch.erf( 256 | (torch.log(sigmas[:-1]) - mean) / (std * math.sqrt(2)) 257 | ) 258 | pdf = pdf / pdf.sum() 259 | 260 | timesteps = torch.multinomial(pdf, num_samples, replacement=True) 261 | 262 | return timesteps 263 | 264 | 265 | def improved_loss_weighting(sigmas: Tensor) -> Tensor: 266 | """Computes the weighting for the consistency loss. 267 | 268 | Parameters 269 | ---------- 270 | sigmas : Tensor 271 | Standard deviations of the noise. 272 | 273 | Returns 274 | ------- 275 | Tensor 276 | Weighting for the consistency loss. 277 | 278 | References 279 | ---------- 280 | [1] [Improved Techniques For Consistency Training](https://arxiv.org/pdf/2310.14189.pdf) 281 | """ 282 | return 1 / (sigmas[1:] - sigmas[:-1]) 283 | 284 | 285 | def pseudo_huber_loss(input: Tensor, target: Tensor) -> Tensor: 286 | """Computes the pseudo huber loss. 287 | 288 | Parameters 289 | ---------- 290 | input : Tensor 291 | Input tensor. 292 | target : Tensor 293 | Target tensor. 294 | 295 | Returns 296 | ------- 297 | Tensor 298 | Pseudo huber loss. 299 | """ 300 | c = 0.00054 * math.sqrt(math.prod(input.shape[1:])) 301 | return torch.sqrt((input - target) ** 2 + c**2) - c 302 | 303 | 304 | def skip_scaling( 305 | sigma: Tensor, sigma_data: float = 0.5, sigma_min: float = 0.002 306 | ) -> Tensor: 307 | """Computes the scaling value for the residual connection. 308 | 309 | Parameters 310 | ---------- 311 | sigma : Tensor 312 | Current standard deviation of the noise. 313 | sigma_data : float, default=0.5 314 | Standard deviation of the data. 315 | sigma_min : float, default=0.002 316 | Minimum standard deviation of the noise from the karras schedule. 317 | 318 | Returns 319 | ------- 320 | Tensor 321 | Scaling value for the residual connection. 322 | """ 323 | return sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) 324 | 325 | 326 | def output_scaling( 327 | sigma: Tensor, sigma_data: float = 0.5, sigma_min: float = 0.002 328 | ) -> Tensor: 329 | """Computes the scaling value for the model's output. 330 | 331 | Parameters 332 | ---------- 333 | sigma : Tensor 334 | Current standard deviation of the noise. 335 | sigma_data : float, default=0.5 336 | Standard deviation of the data. 337 | sigma_min : float, default=0.002 338 | Minimum standard deviation of the noise from the karras schedule. 339 | 340 | Returns 341 | ------- 342 | Tensor 343 | Scaling value for the model's output. 344 | """ 345 | return (sigma_data * (sigma - sigma_min)) / (sigma_data**2 + sigma**2) ** 0.5 346 | 347 | 348 | def model_forward_wrapper( 349 | model: nn.Module, 350 | x: Tensor, 351 | sigma: Tensor, 352 | v: Tensor, 353 | sigma_data: float = 0.5, 354 | sigma_min: float = 0.002, 355 | **kwargs: Any 356 | ) -> Tensor: 357 | """Wrapper for the model call to ensure that the residual connection and scaling 358 | for the residual and output values are applied. 359 | 360 | Parameters 361 | ---------- 362 | model : nn.Module 363 | Model to call. 364 | x : Tensor 365 | Input to the model, e.g: the noisy samples. 366 | sigma : Tensor 367 | Standard deviation of the noise. Normally referred to as t. 368 | sigma_data : float, default=0.5 369 | Standard deviation of the data. 370 | sigma_min : float, default=0.002 371 | Minimum standard deviation of the noise. 372 | **kwargs : Any 373 | Extra arguments to be passed during the model call. 374 | 375 | Returns 376 | ------- 377 | Tensor 378 | Scaled output from the model with the residual connection applied. 379 | """ 380 | c_skip = skip_scaling(sigma, sigma_data, sigma_min) 381 | c_out = output_scaling(sigma, sigma_data, sigma_min) 382 | 383 | # Pad dimensions as broadcasting will not work 384 | c_skip = pad_dims_like(c_skip, x) 385 | c_out = pad_dims_like(c_out, x) 386 | 387 | return c_skip * x + c_out * model(x, sigma, v, **kwargs) 388 | 389 | 390 | @dataclass 391 | class ConsistencyTrainingOutput: 392 | """Type of the output of the (Improved)ConsistencyTraining.__call__ method. 393 | 394 | Attributes 395 | ---------- 396 | predicted : Tensor 397 | Predicted values. 398 | target : Tensor 399 | Target values. 400 | num_timesteps : int 401 | Number of timesteps at the current point in training from the timestep discretization schedule. 402 | sigmas : Tensor 403 | Standard deviations of the noise. 404 | loss_weights : Optional[Tensor], default=None 405 | Weighting for the Improved Consistency Training loss. 406 | """ 407 | 408 | predicted: Tensor 409 | target: Tensor 410 | num_timesteps: int 411 | sigmas: Tensor 412 | loss_weights: Optional[Tensor] = None 413 | 414 | 415 | class ConsistencyTraining: 416 | """Implements the Consistency Training algorithm proposed in the paper. 417 | 418 | Parameters 419 | ---------- 420 | sigma_min : float, default=0.002 421 | Minimum standard deviation of the noise. 422 | sigma_max : float, default=80.0 423 | Maximum standard deviation of the noise. 424 | rho : float, default=7.0 425 | Schedule hyper-parameter. 426 | sigma_data : float, default=0.5 427 | Standard deviation of the data. 428 | initial_timesteps : int, default=2 429 | Schedule timesteps at the start of training. 430 | final_timesteps : int, default=150 431 | Schedule timesteps at the end of training. 432 | """ 433 | 434 | def __init__( 435 | self, 436 | sigma_min: float = 0.002, 437 | sigma_max: float = 80.0, 438 | rho: float = 7.0, 439 | sigma_data: float = 0.5, 440 | initial_timesteps: int = 2, 441 | final_timesteps: int = 150, 442 | ) -> None: 443 | self.sigma_min = sigma_min 444 | self.sigma_max = sigma_max 445 | self.rho = rho 446 | self.sigma_data = sigma_data 447 | self.initial_timesteps = initial_timesteps 448 | self.final_timesteps = final_timesteps 449 | 450 | def __call__( 451 | self, 452 | student_model: nn.Module, 453 | teacher_model: nn.Module, 454 | x: Tensor, 455 | current_training_step: int, 456 | total_training_steps: int, 457 | **kwargs: Any, 458 | ) -> ConsistencyTrainingOutput: 459 | """Runs one step of the consistency training algorithm. 460 | 461 | Parameters 462 | ---------- 463 | student_model : nn.Module 464 | Model that is being trained. 465 | teacher_model : nn.Module 466 | An EMA of the student model. 467 | x : Tensor 468 | Clean data. 469 | current_training_step : int 470 | Current step in the training loop. 471 | total_training_steps : int 472 | Total number of steps in the training loop. 473 | **kwargs : Any 474 | Additional keyword arguments to be passed to the models. 475 | 476 | Returns 477 | ------- 478 | ConsistencyTrainingOutput 479 | The predicted and target values for computing the loss as well as sigmas (noise levels). 480 | """ 481 | num_timesteps = timesteps_schedule( 482 | current_training_step, 483 | total_training_steps, 484 | self.initial_timesteps, 485 | self.final_timesteps, 486 | ) 487 | sigmas = karras_schedule( 488 | num_timesteps, self.sigma_min, self.sigma_max, self.rho, x.device 489 | ) 490 | noise = torch.randn_like(x) 491 | 492 | timesteps = torch.randint(0, num_timesteps - 1, (x.shape[0],), device=x.device) 493 | 494 | current_sigmas = sigmas[timesteps] 495 | next_sigmas = sigmas[timesteps + 1] 496 | 497 | next_noisy_x = x + pad_dims_like(next_sigmas, x) * noise 498 | next_x = model_forward_wrapper( 499 | student_model, 500 | next_noisy_x, 501 | next_sigmas, 502 | self.sigma_data, 503 | self.sigma_min, 504 | **kwargs, 505 | ) 506 | 507 | with torch.no_grad(): 508 | current_noisy_x = x + pad_dims_like(current_sigmas, x) * noise 509 | current_x = model_forward_wrapper( 510 | teacher_model, 511 | current_noisy_x, 512 | current_sigmas, 513 | self.sigma_data, 514 | self.sigma_min, 515 | **kwargs, 516 | ) 517 | 518 | return ConsistencyTrainingOutput(next_x, current_x, num_timesteps, sigmas) 519 | 520 | 521 | class ImprovedConsistencyTraining: 522 | """Implements the Improved Consistency Training algorithm. 523 | 524 | Parameters 525 | ---------- 526 | sigma_min : float, default=0.002 527 | Minimum standard deviation of the noise. 528 | sigma_max : float, default=80.0 529 | Maximum standard deviation of the noise. 530 | rho : float, default=7.0 531 | Schedule hyper-parameter. 532 | sigma_data : float, default=0.5 533 | Standard deviation of the data. 534 | initial_timesteps : int, default=10 535 | Schedule timesteps at the start of training. 536 | final_timesteps : int, default=1280 537 | Schedule timesteps at the end of training. 538 | lognormal_mean : float, default=-1.1 539 | Mean of the lognormal timestep distribution. 540 | lognormal_std : float, default=2.0 541 | Standard deviation of the lognormal timestep distribution. 542 | """ 543 | 544 | def __init__( 545 | self, 546 | sigma_min: float = 0.002, 547 | sigma_max: float = 80.0, 548 | rho: float = 7.0, 549 | sigma_data: float = 0.5, 550 | initial_timesteps: int = 10, 551 | final_timesteps: int = 1280, 552 | lognormal_mean: float = -1.1, 553 | lognormal_std: float = 2.0, 554 | ) -> None: 555 | self.sigma_min = sigma_min 556 | self.sigma_max = sigma_max 557 | self.rho = rho 558 | self.sigma_data = sigma_data 559 | self.initial_timesteps = initial_timesteps 560 | self.final_timesteps = final_timesteps 561 | self.lognormal_mean = lognormal_mean 562 | self.lognormal_std = lognormal_std 563 | 564 | def __call__( 565 | self, 566 | model: nn.Module, 567 | x: Tensor, 568 | v: Tensor, 569 | current_training_step: int, 570 | total_training_steps: int, 571 | **kwargs: Any, 572 | ) -> ConsistencyTrainingOutput: 573 | """Runs one step of the improved consistency training algorithm. 574 | 575 | Parameters 576 | ---------- 577 | model : nn.Module 578 | Both teacher and student model. 579 | teacher_model : nn.Module 580 | Teacher model. 581 | x : Tensor 582 | Clean data. 583 | current_training_step : int 584 | Current step in the training loop. 585 | total_training_steps : int 586 | Total number of steps in the training loop. 587 | **kwargs : Any 588 | Additional keyword arguments to be passed to the models. 589 | 590 | Returns 591 | ------- 592 | ConsistencyTrainingOutput 593 | The predicted and target values for computing the loss, sigmas (noise levels) as well as the loss weights. 594 | """ 595 | 596 | num_timesteps = improved_timesteps_schedule( 597 | current_training_step, 598 | total_training_steps, 599 | self.initial_timesteps, 600 | self.final_timesteps, 601 | ) 602 | sigmas = karras_schedule( 603 | num_timesteps, self.sigma_min, self.sigma_max, self.rho, x.device 604 | ) 605 | noise = torch.randn_like(x) 606 | 607 | timesteps = lognormal_timestep_distribution( 608 | x.shape[0], sigmas, self.lognormal_mean, self.lognormal_std 609 | ) 610 | 611 | current_sigmas = sigmas[timesteps] 612 | next_sigmas = sigmas[timesteps + 1] 613 | 614 | # Add noise to infrared images 615 | next_noisy_x = x + pad_dims_like(next_sigmas, x) * noise 616 | # Concatenate with visible images 617 | # input_next = torch.cat([next_noisy_x, v], dim=1) 618 | 619 | 620 | next_x = model_forward_wrapper( 621 | model, 622 | next_noisy_x, 623 | next_sigmas, 624 | v, 625 | self.sigma_data, 626 | self.sigma_min, 627 | **kwargs, 628 | ) 629 | 630 | with torch.no_grad(): 631 | current_noisy_x = x + pad_dims_like(current_sigmas, x) * noise 632 | # input_current = torch.cat([current_noisy_x, v], dim=1) 633 | 634 | current_x = model_forward_wrapper( 635 | model, 636 | current_noisy_x, 637 | current_sigmas, 638 | v, 639 | self.sigma_data, 640 | self.sigma_min, 641 | **kwargs, 642 | ) 643 | 644 | loss_weights = pad_dims_like(improved_loss_weighting(sigmas)[timesteps], next_x) 645 | 646 | return ConsistencyTrainingOutput( 647 | next_x, current_x, num_timesteps, sigmas, loss_weights 648 | ) 649 | 650 | 651 | class ConsistencySamplingAndEditing: 652 | """Implements the Consistency Sampling and Zero-Shot Editing algorithms. 653 | 654 | Parameters 655 | ---------- 656 | sigma_min : float, default=0.002 657 | Minimum standard deviation of the noise. 658 | sigma_data : float, default=0.5 659 | Standard deviation of the data. 660 | """ 661 | 662 | def __init__(self, sigma_min: float = 0.002, sigma_data: float = 0.5) -> None: 663 | self.sigma_min = sigma_min 664 | self.sigma_data = sigma_data 665 | 666 | def __call__( 667 | self, 668 | model: nn.Module, 669 | y: Tensor, 670 | v: Tensor, 671 | sigmas: Iterable[Union[Tensor, float]], 672 | mask: Optional[Tensor] = None, 673 | transform_fn: Callable[[Tensor], Tensor] = lambda x: x, 674 | inverse_transform_fn: Callable[[Tensor], Tensor] = lambda x: x, 675 | start_from_y: bool = False, 676 | add_initial_noise: bool = True, 677 | clip_denoised: bool = False, 678 | verbose: bool = False, 679 | **kwargs: Any, 680 | ) -> Tensor: 681 | """Runs the sampling/zero-shot editing loop. 682 | 683 | With the default parameters the function performs consistency sampling. 684 | 685 | Parameters 686 | ---------- 687 | model : nn.Module 688 | Model to sample from. 689 | y : Tensor 690 | Reference sample e.g: a masked image or noise. 691 | sigmas : Iterable[Union[Tensor, float]] 692 | Decreasing standard deviations of the noise. 693 | mask : Tensor, default=None 694 | A mask of zeros and ones with ones indicating where to edit. By 695 | default the whole sample will be edited. This is useful for sampling. 696 | transform_fn : Callable[[Tensor], Tensor], default=lambda x: x 697 | An invertible linear transformation. Defaults to the identity function. 698 | inverse_transform_fn : Callable[[Tensor], Tensor], default=lambda x: x 699 | Inverse of the linear transformation. Defaults to the identity function. 700 | start_from_y : bool, default=False 701 | Whether to use y as an initial sample and add noise to it instead of starting 702 | from random gaussian noise. This is useful for tasks like style transfer. 703 | add_initial_noise : bool, default=True 704 | Whether to add noise at the start of the schedule. Useful for tasks like interpolation 705 | where noise will alerady be added in advance. 706 | clip_denoised : bool, default=False 707 | Whether to clip denoised values to [-1, 1] range. 708 | verbose : bool, default=False 709 | Whether to display the progress bar. 710 | **kwargs : Any 711 | Additional keyword arguments to be passed to the model. 712 | 713 | Returns 714 | ------- 715 | Tensor 716 | Edited/sampled sample. 717 | """ 718 | # Set mask to all ones which is useful for sampling and style transfer 719 | if mask is None: 720 | mask = torch.ones_like(y) 721 | 722 | # Use y as an initial sample which is useful for tasks like style transfer 723 | # and interpolation where we want to use content from the reference sample 724 | x = y if start_from_y else torch.zeros_like(y) 725 | 726 | # Sample at the end of the schedule 727 | y = self.__mask_transform(x, y, mask, transform_fn, inverse_transform_fn) 728 | # For tasks like interpolation where noise will already be added in advance we 729 | # can skip the noising process 730 | x = y + sigmas[0] * torch.randn_like(y) if add_initial_noise else y 731 | # input_x = torch.cat([x, v], dim=1) 732 | sigma = torch.full((x.shape[0],), sigmas[0], dtype=x.dtype, device=x.device) 733 | x = model_forward_wrapper( 734 | model, x, sigma, v, self.sigma_data, self.sigma_min, **kwargs 735 | ) 736 | if clip_denoised: 737 | x = x.clamp(min=-1.0, max=1.0) 738 | x = self.__mask_transform(x, y, mask, transform_fn, inverse_transform_fn) 739 | 740 | # Progressively denoise the sample and skip the first step as it has already 741 | # been run 742 | for sigma_value in sigmas[1:]: 743 | sigma = torch.full((x.shape[0],), sigma_value, dtype=x.dtype, device=x.device) 744 | x = x + pad_dims_like( 745 | (sigma**2 - self.sigma_min**2) ** 0.5, x 746 | ) * torch.randn_like(x) 747 | # input_x = torch.cat([x, v], dim=1) 748 | x = model_forward_wrapper( 749 | model, x, sigma, v, self.sigma_data, self.sigma_min, **kwargs 750 | ) 751 | if clip_denoised: 752 | x = x.clamp(min=-1.0, max=1.0) 753 | x = self.__mask_transform(x, y, mask, transform_fn, inverse_transform_fn) 754 | 755 | return x 756 | 757 | def interpolate( 758 | self, 759 | model: nn.Module, 760 | a: Tensor, 761 | b: Tensor, 762 | ab_ratio: float, 763 | sigmas: Iterable[Union[Tensor, float]], 764 | clip_denoised: bool = False, 765 | verbose: bool = False, 766 | **kwargs: Any, 767 | ) -> Tensor: 768 | """Runs the interpolation loop. 769 | 770 | Parameters 771 | ---------- 772 | model : nn.Module 773 | Model to sample from. 774 | a : Tensor 775 | First reference sample. 776 | b : Tensor 777 | Second refernce sample. 778 | ab_ratio : float 779 | Ratio of the first reference sample to the second reference sample. 780 | clip_denoised : bool, default=False 781 | Whether to clip denoised values to [-1, 1] range. 782 | verbose : bool, default=False 783 | Whether to display the progress bar. 784 | **kwargs : Any 785 | Additional keyword arguments to be passed to the model. 786 | 787 | Returns 788 | ------- 789 | Tensor 790 | Intepolated sample. 791 | """ 792 | # Obtain latent samples from the initial samples 793 | a = a + sigmas[0] * torch.randn_like(a) 794 | b = b + sigmas[0] * torch.randn_like(b) 795 | 796 | # Perform spherical linear interpolation of the latents 797 | omega = torch.arccos(torch.sum((a / a.norm(p=2)) * (b / b.norm(p=2)))) 798 | a = torch.sin(ab_ratio * omega) / torch.sin(omega) * a 799 | b = torch.sin((1 - ab_ratio) * omega) / torch.sin(omega) * b 800 | ab = a + b 801 | 802 | # Denoise the interpolated latents 803 | return self( 804 | model, 805 | ab, 806 | sigmas, 807 | start_from_y=True, 808 | add_initial_noise=False, 809 | clip_denoised=clip_denoised, 810 | verbose=verbose, 811 | **kwargs, 812 | ) 813 | 814 | def __mask_transform( 815 | self, 816 | x: Tensor, 817 | y: Tensor, 818 | mask: Tensor, 819 | transform_fn: Callable[[Tensor], Tensor] = lambda x: x, 820 | inverse_transform_fn: Callable[[Tensor], Tensor] = lambda x: x, 821 | ) -> Tensor: 822 | return inverse_transform_fn(transform_fn(y) * (1.0 - mask) + x * mask) 823 | -------------------------------------------------------------------------------- /llvip/checkpoints/llvip/config.json: -------------------------------------------------------------------------------- 1 | {"channels": 3, "noise_level_channels": 256, "noise_level_scale": 0.02, "n_heads": 8, "top_blocks_channels": [128, 128], "top_blocks_n_blocks_per_resolution": [2, 2], "top_blocks_has_resampling": [true, true], "top_blocks_dropout": [0.0, 0.0], "mid_blocks_channels": [256, 512], "mid_blocks_n_blocks_per_resolution": [4, 4], "mid_blocks_has_resampling": [true, false], "mid_blocks_dropout": [0.0, 0.3]} -------------------------------------------------------------------------------- /llvip/checkpoints/llvip/model.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:748b26b21e3308076ef2a8e1c14638b5323fbe9c8e8ee163c4442fddd9e388d2 3 | size 515824214 4 | -------------------------------------------------------------------------------- /llvip/sampling_and_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | import torchvision.transforms as T 5 | from improved_consistency_model_conditional import ConsistencySamplingAndEditing 6 | from llvip.script import UNet 7 | import numpy as np 8 | from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim 9 | 10 | # Load the pretrained UNet model 11 | model_path = "checkpoints/llvip512x512_128x128/" 12 | model = UNet.from_pretrained(model_path) 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | model = model.to(device).eval() 15 | 16 | # Define image transformation: Resize to 256x256 and normalize to [-1, 1] 17 | transform = T.Compose([ 18 | T.Resize((256, 256)), 19 | T.ToTensor(), 20 | T.Lambda(lambda x: (x * 2) - 1), # Normalize pixel values 21 | ]) 22 | 23 | # Input data directories for visible and infrared test sets 24 | visible_folder = "../datasets/LLVIP/visible/test/" 25 | infrared_folder = "../datasets/LLVIP/infrared/test/" 26 | 27 | # Create the consistency sampling instance for generating images 28 | consistency_sampling = ConsistencySamplingAndEditing() 29 | 30 | # Define the noise sigma schedule 31 | sigmas = [80.0, 40.0, 20.0, 10.0, 5.0, 2.5, 1.25, 0.625, 0.3125, 0.15625, 0.078125] 32 | 33 | # Initialize accumulators for PSNR and SSIM metrics 34 | total_psnr = 0.0 35 | total_ssim = 0.0 36 | num_images = 0 37 | 38 | # Directory to save generated results 39 | results_folder = "results_llvip_128x128" 40 | os.makedirs(results_folder, exist_ok=True) 41 | 42 | # Loop through all visible images in the test set 43 | for idx, visible_image_name in enumerate(os.listdir(visible_folder), start=1): 44 | visible_image_path = os.path.join(visible_folder, visible_image_name) 45 | infrared_image_path = os.path.join(infrared_folder, visible_image_name) 46 | 47 | if not os.path.exists(infrared_image_path): 48 | print(f"Infrared image {infrared_image_path} not found, skipping.") 49 | continue 50 | 51 | try: 52 | # Load visible and infrared images 53 | visible_image = Image.open(visible_image_path).convert("RGB") 54 | infrared_image = Image.open(infrared_image_path).convert("RGB") 55 | except Exception as e: 56 | print(f"Error loading images: {e}, skipping {visible_image_name}") 57 | continue 58 | 59 | # Apply transformations to resize and normalize images 60 | visible_tensor = transform(visible_image).unsqueeze(0).to(device) 61 | infrared_tensor = transform(infrared_image).unsqueeze(0).to(device) 62 | 63 | # Add noise to the infrared image 64 | noise = torch.randn_like(infrared_tensor) * sigmas[0] 65 | noisy_infrared_tensor = infrared_tensor + noise 66 | 67 | try: 68 | # Generate the infrared image using consistency sampling 69 | with torch.no_grad(): 70 | generated_infrared_tensor = consistency_sampling( 71 | model=model, 72 | y=noisy_infrared_tensor, 73 | v=visible_tensor, 74 | sigmas=sigmas, 75 | start_from_y=True, 76 | add_initial_noise=False, 77 | clip_denoised=True, 78 | verbose=False, 79 | ) 80 | except Exception as e: 81 | print(f"Error during model inference: {e}, skipping {visible_image_name}") 82 | continue 83 | 84 | # Denormalize tensors to convert to valid image range [0, 1] 85 | visible_denorm = ((visible_tensor.squeeze(0).cpu() + 1) / 2).clamp(0, 1).numpy().transpose(1, 2, 0) 86 | infrared_denorm = ((infrared_tensor.squeeze(0).cpu() + 1) / 2).clamp(0, 1).numpy().transpose(1, 2, 0) 87 | generated_infrared_denorm = ((generated_infrared_tensor.squeeze(0).cpu() + 1) / 2).clamp(0, 1).numpy().transpose(1, 2, 0) 88 | 89 | # Convert to uint8 format for saving 90 | visible_image_save = (visible_denorm * 255).astype(np.uint8) 91 | infrared_image_save = (infrared_denorm * 255).astype(np.uint8) 92 | generated_image_save = (generated_infrared_denorm * 255).astype(np.uint8) 93 | 94 | # Save generated and ground truth images for reference 95 | base_name, ext = os.path.splitext(visible_image_name) 96 | Image.fromarray(generated_image_save).save(os.path.join(results_folder, f"generated_{base_name}{ext}")) 97 | Image.fromarray(infrared_image_save).save(os.path.join(results_folder, f"groundtruth_{base_name}{ext}")) 98 | Image.fromarray(visible_image_save).save(os.path.join(results_folder, f"visible_{base_name}{ext}")) 99 | 100 | # Calculate PSNR and SSIM metrics 101 | psnr_value = psnr(infrared_denorm, generated_infrared_denorm, data_range=1.0) 102 | ssim_value = ssim(infrared_denorm, generated_infrared_denorm, data_range=1.0, multichannel=True, win_size=3) 103 | 104 | # Accumulate metrics 105 | total_psnr += psnr_value 106 | total_ssim += ssim_value 107 | num_images += 1 108 | 109 | print(f"Image {idx}: PSNR = {psnr_value:.2f}, SSIM = {ssim_value:.4f}") 110 | 111 | # Compute and save average metrics 112 | if num_images > 0: 113 | avg_psnr = total_psnr / num_images 114 | avg_ssim = total_ssim / num_images 115 | 116 | print(f"Average PSNR: {avg_psnr:.2f}") 117 | print(f"Average SSIM: {avg_ssim:.4f}") 118 | 119 | with open("metrics.txt", "w") as f: 120 | f.write(f"Average PSNR: {avg_psnr:.2f}\n") 121 | f.write(f"Average SSIM: {avg_ssim:.4f}\n") 122 | else: 123 | print("No images processed.") 124 | -------------------------------------------------------------------------------- /llvip/script.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import asdict, dataclass 4 | from typing import Any, Callable, List, Optional, Tuple, Union 5 | 6 | import torch 7 | from torchvision.transforms import functional as TF 8 | import math 9 | from einops import rearrange 10 | from einops.layers.torch import Rearrange 11 | from lightning import LightningDataModule, LightningModule, Trainer, seed_everything 12 | from lightning.pytorch.callbacks import LearningRateMonitor 13 | from lightning.pytorch.loggers import TensorBoardLogger 14 | from matplotlib import pyplot as plt 15 | from torch import Tensor, nn 16 | from torch.nn import functional as F 17 | from torch.utils.data import DataLoader 18 | from torchinfo import summary 19 | from torchvision import transforms as T 20 | from torchvision.datasets import ImageFolder 21 | from torchvision.utils import make_grid 22 | 23 | from improved_consistency_model_conditional import ( 24 | ConsistencySamplingAndEditing, 25 | ImprovedConsistencyTraining, 26 | pseudo_huber_loss, 27 | update_ema_model_, 28 | ) 29 | 30 | 31 | from torch.utils.data import Dataset 32 | from torchvision import transforms as T 33 | import os 34 | from PIL import Image 35 | import torch 36 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint 37 | from lightning.pytorch import LightningDataModule 38 | 39 | 40 | class PairedDataset(Dataset): 41 | def __init__( 42 | self, 43 | visible_dir: str, 44 | infrared_dir: str, 45 | transform: Optional[Callable] = None, 46 | crop_size: Tuple[int, int] = (512, 512), 47 | resize_size: Tuple[int, int] = (256, 256) 48 | ): 49 | self.visible_dir = visible_dir 50 | self.infrared_dir = infrared_dir 51 | self.visible_images = sorted(os.listdir(visible_dir)) 52 | self.infrared_images = sorted(os.listdir(infrared_dir)) 53 | self.transform = transform 54 | self.crop_size = crop_size 55 | self.resize_size = resize_size 56 | 57 | def __len__(self) -> int: 58 | return len(self.visible_images) 59 | 60 | def __getitem__(self, index: int) -> Optional[Tuple[Tensor, Tensor]]: 61 | visible_path = os.path.join(self.visible_dir, self.visible_images[index]) 62 | infrared_path = os.path.join(self.infrared_dir, self.infrared_images[index]) 63 | 64 | visible_image = Image.open(visible_path).convert("RGB") 65 | infrared_image = Image.open(infrared_path).convert("RGB") 66 | 67 | if visible_image.size != infrared_image.size: 68 | print(f"Skipping image pair at index {index} due to mismatched sizes") 69 | return None 70 | 71 | # Perform synchronized random horizontal flip 72 | if torch.rand(1).item() > 0.5: 73 | visible_image = TF.hflip(visible_image) 74 | infrared_image = TF.hflip(infrared_image) 75 | 76 | # Perform synchronized random crop 77 | i, j, h, w = T.RandomCrop.get_params(visible_image, output_size=self.crop_size) 78 | visible_image = TF.crop(visible_image, i, j, h, w) 79 | infrared_image = TF.crop(infrared_image, i, j, h, w) 80 | 81 | # Resize to desired size 82 | visible_image = TF.resize(visible_image, self.resize_size) 83 | infrared_image = TF.resize(infrared_image, self.resize_size) 84 | 85 | if self.transform: 86 | visible_image = self.transform(visible_image) 87 | infrared_image = self.transform(infrared_image) 88 | 89 | return visible_image, infrared_image 90 | 91 | 92 | @dataclass 93 | class ImageDataModuleConfig: 94 | data_dir: str = "dataset/LLVIP" # Path to the dataset directory 95 | image_size_crop: Tuple[int, int] = (512, 512) 96 | image_size_resize: Tuple[int, int] = (128, 128) # Resize to 128x128 97 | batch_size: int = 4 # Number of images in each batch 98 | num_workers: int = 8 # Number of worker threads for data loading 99 | pin_memory: bool = True # Whether to pin memory in data loader 100 | persistent_workers: bool = True # Keep workers alive between epochs 101 | 102 | 103 | class LLVIPDataModule(LightningDataModule): 104 | def __init__(self, config: ImageDataModuleConfig) -> None: 105 | super().__init__() 106 | self.config = config 107 | 108 | def setup(self, stage: str = None) -> None: 109 | # Define transforms excluding cropping and resizing 110 | self.transform = T.Compose([ 111 | T.ToTensor(), 112 | T.Lambda(lambda x: (x * 2) - 1), # Normalize to [-1, 1] 113 | ]) 114 | 115 | self.dataset = PairedDataset( 116 | visible_dir="datasets/LLVIP/visible/train", 117 | infrared_dir="datasets/LLVIP/infrared/train", 118 | transform=self.transform, 119 | crop_size=self.config.image_size_crop, 120 | resize_size=self.config.image_size_resize 121 | ) 122 | 123 | 124 | def train_dataloader(self) -> DataLoader: 125 | return DataLoader( 126 | self.dataset, 127 | batch_size=self.config.batch_size, 128 | shuffle=True, 129 | num_workers=self.config.num_workers, 130 | pin_memory=self.config.pin_memory, 131 | persistent_workers=self.config.persistent_workers, 132 | ) 133 | 134 | 135 | 136 | def GroupNorm(channels: int) -> nn.GroupNorm: 137 | return nn.GroupNorm(num_groups=min(32, channels // 4), num_channels=channels) 138 | 139 | 140 | class SelfAttention(nn.Module): 141 | def __init__( 142 | self, 143 | in_channels: int, 144 | out_channels: int, 145 | n_heads: int = 8, 146 | dropout: float = 0.3, 147 | ) -> None: 148 | super().__init__() 149 | 150 | self.dropout = dropout 151 | 152 | self.qkv_projection = nn.Sequential( 153 | GroupNorm(in_channels), 154 | nn.Conv2d(in_channels, 3 * in_channels, kernel_size=1, bias=False), 155 | Rearrange("b (i h d) x y -> i b h (x y) d", i=3, h=n_heads), 156 | ) 157 | self.output_projection = nn.Sequential( 158 | Rearrange("b h l d -> b l (h d)"), 159 | nn.Linear(in_channels, out_channels, bias=False), 160 | Rearrange("b l d -> b d l"), 161 | GroupNorm(out_channels), 162 | nn.Dropout1d(dropout), 163 | ) 164 | self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1) 165 | 166 | def forward(self, x: Tensor) -> Tensor: 167 | q, k, v = self.qkv_projection(x).unbind(dim=0) 168 | 169 | output = F.scaled_dot_product_attention( 170 | q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=False 171 | ) 172 | output = self.output_projection(output) 173 | output = rearrange(output, "b c (x y) -> b c x y", x=x.shape[-2], y=x.shape[-1]) 174 | 175 | return output + self.residual_projection(x) 176 | 177 | 178 | class UNetBlock(nn.Module): 179 | def __init__( 180 | self, 181 | in_channels: int, 182 | out_channels: int, 183 | noise_level_channels: int, 184 | dropout: float = 0.3, 185 | ) -> None: 186 | super().__init__() 187 | 188 | self.input_projection = nn.Sequential( 189 | GroupNorm(in_channels), 190 | nn.SiLU(), 191 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same"), 192 | nn.Dropout2d(dropout), 193 | ) 194 | self.noise_level_projection = nn.Sequential( 195 | nn.SiLU(), 196 | nn.Conv2d(noise_level_channels, out_channels, kernel_size=1), 197 | ) 198 | self.output_projection = nn.Sequential( 199 | GroupNorm(out_channels), 200 | nn.SiLU(), 201 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same"), 202 | nn.Dropout2d(dropout), 203 | ) 204 | self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1) 205 | 206 | def forward(self, x: Tensor, noise_level: Tensor) -> Tensor: 207 | h = self.input_projection(x) 208 | h = h + self.noise_level_projection(noise_level) 209 | 210 | return self.output_projection(h) + self.residual_projection(x) 211 | 212 | 213 | class UNetBlockWithSelfAttention(nn.Module): 214 | def __init__( 215 | self, 216 | in_channels: int, 217 | out_channels: int, 218 | noise_level_channels: int, 219 | n_heads: int = 8, 220 | dropout: float = 0.3, 221 | ) -> None: 222 | super().__init__() 223 | 224 | self.unet_block = UNetBlock( 225 | in_channels, out_channels, noise_level_channels, dropout 226 | ) 227 | self.self_attention = SelfAttention( 228 | out_channels, out_channels, n_heads, dropout 229 | ) 230 | 231 | def forward(self, x: Tensor, noise_level: Tensor) -> Tensor: 232 | return self.self_attention(self.unet_block(x, noise_level)) 233 | 234 | 235 | class Downsample(nn.Module): 236 | def __init__(self, channels: int) -> None: 237 | super().__init__() 238 | 239 | self.projection = nn.Sequential( 240 | Rearrange("b c (h ph) (w pw) -> b (c ph pw) h w", ph=2, pw=2), 241 | nn.Conv2d(4 * channels, channels, kernel_size=1), 242 | ) 243 | 244 | def forward(self, x: Tensor) -> Tensor: 245 | return self.projection(x) 246 | 247 | 248 | class Upsample(nn.Module): 249 | def __init__(self, channels: int) -> None: 250 | super().__init__() 251 | 252 | self.projection = nn.Sequential( 253 | nn.Upsample(scale_factor=2.0, mode="nearest"), 254 | nn.Conv2d(channels, channels, kernel_size=3, padding="same"), 255 | ) 256 | 257 | def forward(self, x: Tensor) -> Tensor: 258 | return self.projection(x) 259 | 260 | 261 | class NoiseLevelEmbedding(nn.Module): 262 | def __init__(self, channels: int, scale: float = 0.02) -> None: 263 | super().__init__() 264 | 265 | self.W = nn.Parameter(torch.randn(channels // 2) * scale, requires_grad=False) 266 | 267 | self.projection = nn.Sequential( 268 | nn.Linear(channels, 4 * channels), 269 | nn.SiLU(), 270 | nn.Linear(4 * channels, channels), 271 | Rearrange("b c -> b c () ()"), 272 | ) 273 | 274 | def forward(self, x: Tensor) -> Tensor: 275 | h = x[:, None] * self.W[None, :] * 2 * torch.pi 276 | h = torch.cat([torch.sin(h), torch.cos(h)], dim=-1) 277 | 278 | return self.projection(h) 279 | 280 | 281 | @dataclass 282 | class UNetConfig: 283 | channels: int = 3 284 | noise_level_channels: int = 256 285 | noise_level_scale: float = 0.02 286 | n_heads: int = 8 287 | top_blocks_channels: Tuple[int, ...] = (128, 128) 288 | top_blocks_n_blocks_per_resolution: Tuple[int, ...] = (2, 2) 289 | top_blocks_has_resampling: Tuple[bool, ...] = (True, True) 290 | top_blocks_dropout: Tuple[float, ...] = (0.0, 0.0) 291 | mid_blocks_channels: Tuple[int, ...] = (256, 512) 292 | mid_blocks_n_blocks_per_resolution: Tuple[int, ...] = (4, 4) 293 | mid_blocks_has_resampling: Tuple[bool, ...] = (True, False) 294 | mid_blocks_dropout: Tuple[float, ...] = (0.0, 0.3) 295 | 296 | 297 | class UNet(nn.Module): 298 | def __init__(self, config: UNetConfig) -> None: 299 | super().__init__() 300 | 301 | self.config = config 302 | 303 | self.input_projection = nn.Conv2d( 304 | config.channels * 2, 305 | config.top_blocks_channels[0], 306 | kernel_size=3, 307 | padding="same", 308 | ) 309 | self.noise_level_embedding = NoiseLevelEmbedding( 310 | config.noise_level_channels, config.noise_level_scale 311 | ) 312 | self.top_encoder_blocks = self._make_encoder_blocks( 313 | self.config.top_blocks_channels + self.config.mid_blocks_channels[:1], 314 | self.config.top_blocks_n_blocks_per_resolution, 315 | self.config.top_blocks_has_resampling, 316 | self.config.top_blocks_dropout, 317 | self._make_top_block, 318 | ) 319 | self.mid_encoder_blocks = self._make_encoder_blocks( 320 | self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:], 321 | self.config.mid_blocks_n_blocks_per_resolution, 322 | self.config.mid_blocks_has_resampling, 323 | self.config.mid_blocks_dropout, 324 | self._make_mid_block, 325 | ) 326 | self.mid_decoder_blocks = self._make_decoder_blocks( 327 | self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:], 328 | self.config.mid_blocks_n_blocks_per_resolution, 329 | self.config.mid_blocks_has_resampling, 330 | self.config.mid_blocks_dropout, 331 | self._make_mid_block, 332 | ) 333 | self.top_decoder_blocks = self._make_decoder_blocks( 334 | self.config.top_blocks_channels + self.config.mid_blocks_channels[:1], 335 | self.config.top_blocks_n_blocks_per_resolution, 336 | self.config.top_blocks_has_resampling, 337 | self.config.top_blocks_dropout, 338 | self._make_top_block, 339 | ) 340 | self.output_projection = nn.Conv2d( 341 | config.top_blocks_channels[0], 342 | config.channels, 343 | kernel_size=3, 344 | padding="same", 345 | ) 346 | 347 | def forward(self, x: Tensor, noise_level: Tensor, v: Tensor) -> Tensor: 348 | x = torch.cat([x, v], dim = 1) 349 | h = self.input_projection(x) 350 | noise_level = self.noise_level_embedding(noise_level) 351 | 352 | top_encoder_embeddings = [] 353 | for block in self.top_encoder_blocks: 354 | if isinstance(block, UNetBlock): 355 | h = block(h, noise_level) 356 | top_encoder_embeddings.append(h) 357 | else: 358 | h = block(h) 359 | 360 | mid_encoder_embeddings = [] 361 | for block in self.mid_encoder_blocks: 362 | if isinstance(block, UNetBlockWithSelfAttention): 363 | h = block(h, noise_level) 364 | mid_encoder_embeddings.append(h) 365 | else: 366 | h = block(h) 367 | 368 | for block in self.mid_decoder_blocks: 369 | if isinstance(block, UNetBlockWithSelfAttention): 370 | h = torch.cat((h, mid_encoder_embeddings.pop()), dim=1) 371 | h = block(h, noise_level) 372 | else: 373 | h = block(h) 374 | 375 | for block in self.top_decoder_blocks: 376 | if isinstance(block, UNetBlock): 377 | h = torch.cat((h, top_encoder_embeddings.pop()), dim=1) 378 | h = block(h, noise_level) 379 | else: 380 | h = block(h) 381 | 382 | output = self.output_projection(h) 383 | 384 | return output 385 | 386 | def _make_encoder_blocks( 387 | self, 388 | channels: Tuple[int, ...], 389 | n_blocks_per_resolution: Tuple[int, ...], 390 | has_resampling: Tuple[bool, ...], 391 | dropout: Tuple[float, ...], 392 | block_fn: Callable[[], nn.Module], 393 | ) -> nn.ModuleList: 394 | blocks = nn.ModuleList() 395 | 396 | channel_pairs = list(zip(channels[:-1], channels[1:])) 397 | for idx, (in_channels, out_channels) in enumerate(channel_pairs): 398 | for _ in range(n_blocks_per_resolution[idx]): 399 | blocks.append(block_fn(in_channels, out_channels, dropout[idx])) 400 | in_channels = out_channels 401 | 402 | if has_resampling[idx]: 403 | blocks.append(Downsample(out_channels)) 404 | 405 | return blocks 406 | 407 | def _make_decoder_blocks( 408 | self, 409 | channels: Tuple[int, ...], 410 | n_blocks_per_resolution: Tuple[int, ...], 411 | has_resampling: Tuple[bool, ...], 412 | dropout: Tuple[float, ...], 413 | block_fn: Callable[[], nn.Module], 414 | ) -> nn.ModuleList: 415 | blocks = nn.ModuleList() 416 | 417 | channel_pairs = list(zip(channels[:-1], channels[1:]))[::-1] 418 | for idx, (out_channels, in_channels) in enumerate(channel_pairs): 419 | if has_resampling[::-1][idx]: 420 | blocks.append(Upsample(in_channels)) 421 | 422 | inner_blocks = [] 423 | for _ in range(n_blocks_per_resolution[::-1][idx]): 424 | inner_blocks.append( 425 | block_fn(in_channels * 2, out_channels, dropout[::-1][idx]) 426 | ) 427 | out_channels = in_channels 428 | blocks.extend(inner_blocks[::-1]) 429 | 430 | return blocks 431 | 432 | def _make_top_block( 433 | self, in_channels: int, out_channels: int, dropout: float 434 | ) -> UNetBlock: 435 | return UNetBlock( 436 | in_channels, 437 | out_channels, 438 | self.config.noise_level_channels, 439 | dropout, 440 | ) 441 | 442 | def _make_mid_block( 443 | self, 444 | in_channels: int, 445 | out_channels: int, 446 | dropout: float, 447 | ) -> UNetBlockWithSelfAttention: 448 | return UNetBlockWithSelfAttention( 449 | in_channels, 450 | out_channels, 451 | self.config.noise_level_channels, 452 | self.config.n_heads, 453 | dropout, 454 | ) 455 | 456 | def save_pretrained(self, pretrained_path: str) -> None: 457 | os.makedirs(pretrained_path, exist_ok=True) 458 | 459 | with open(os.path.join(pretrained_path, "config.json"), mode="w") as f: 460 | json.dump(asdict(self.config), f) 461 | 462 | torch.save(self.state_dict(), os.path.join(pretrained_path, "model.pt")) 463 | 464 | @classmethod 465 | def from_pretrained(cls, pretrained_path: str) -> "UNet": 466 | with open(os.path.join(pretrained_path, "config.json"), mode="r") as f: 467 | config_dict = json.load(f) 468 | config = UNetConfig(**config_dict) 469 | 470 | model = cls(config) 471 | 472 | state_dict = torch.load( 473 | os.path.join(pretrained_path, "model.pt"), map_location=torch.device("cpu") 474 | ) 475 | model.load_state_dict(state_dict) 476 | 477 | return model 478 | 479 | @dataclass 480 | class LitImprovedConsistencyModelConfig: 481 | ema_decay_rate: float = 0.99993 482 | lr: float = 1e-4 483 | betas: Tuple[float, float] = (0.9, 0.995) 484 | lr_scheduler_start_factor: float = 1e-5 485 | lr_scheduler_iters: int = 10_000 486 | sample_every_n_steps: int = 10_000 487 | num_samples: int = 8 488 | sampling_sigmas: Tuple[Tuple[int, ...], ...] = ( 489 | (80,), 490 | (80.0, 0.661), 491 | (80.0, 24.4, 5.84, 0.9, 0.661), 492 | ) 493 | 494 | 495 | class LitImprovedConsistencyModel(LightningModule): 496 | def __init__( 497 | self, 498 | consistency_training: ImprovedConsistencyTraining, 499 | consistency_sampling: ConsistencySamplingAndEditing, 500 | model: UNet, 501 | ema_model: UNet, 502 | config: LitImprovedConsistencyModelConfig, 503 | ) -> None: 504 | super().__init__() 505 | 506 | self.consistency_training = consistency_training 507 | self.consistency_sampling = consistency_sampling 508 | self.model = model 509 | self.ema_model = ema_model 510 | self.config = config 511 | 512 | # Freeze the EMA model and set it to eval mode 513 | for param in self.ema_model.parameters(): 514 | param.requires_grad = False 515 | self.ema_model = self.ema_model.eval() 516 | 517 | def training_step(self, batch: Union[Tensor, List[Tensor]], batch_idx: int) -> None: 518 | # if isinstance(batch, list): 519 | # batch = batch[0] 520 | 521 | visible_images, infrared_images = batch # Unpack the batch 522 | 523 | output = self.consistency_training( 524 | self.model, 525 | infrared_images, 526 | visible_images, # Pass visible images to the training function 527 | self.global_step, 528 | self.trainer.max_steps 529 | ) 530 | 531 | loss = ( 532 | pseudo_huber_loss(output.predicted, output.target) * output.loss_weights 533 | ).mean() 534 | 535 | self.log_dict({"train_loss": loss, "num_timesteps": output.num_timesteps}) 536 | 537 | return loss 538 | 539 | def on_train_batch_end( 540 | self, outputs: Any, batch: Union[Tensor, List[Tensor]], batch_idx: int 541 | ) -> None: 542 | update_ema_model_(self.model, self.ema_model, self.config.ema_decay_rate) 543 | 544 | if ( 545 | (self.global_step + 1) % self.config.sample_every_n_steps == 0 546 | ) or self.global_step == 0: 547 | self.__sample_and_log_samples(batch) 548 | 549 | def configure_optimizers(self): 550 | opt = torch.optim.Adam( 551 | self.model.parameters(), lr=self.config.lr, betas=self.config.betas 552 | ) 553 | sched = torch.optim.lr_scheduler.LinearLR( 554 | opt, 555 | start_factor=self.config.lr_scheduler_start_factor, 556 | total_iters=self.config.lr_scheduler_iters, 557 | ) 558 | sched = {"scheduler": sched, "interval": "step", "frequency": 1} 559 | 560 | return [opt], [sched] 561 | 562 | @torch.no_grad() 563 | def __sample_and_log_samples(self, batch: Union[Tensor, List[Tensor]]) -> None: 564 | if isinstance(batch, list): 565 | batch = batch[0] 566 | 567 | # Ensure the number of samples does not exceed the batch size 568 | num_samples = min(self.config.num_samples, batch.shape[0]) 569 | noise = torch.randn_like(batch[:num_samples]) 570 | 571 | # Log ground truth samples 572 | self.__log_images( 573 | batch[:num_samples].detach().clone(), f"ground_truth", self.global_step 574 | ) 575 | 576 | for sigmas in self.config.sampling_sigmas: 577 | samples = self.consistency_sampling( 578 | self.ema_model, noise, sigmas, clip_denoised=True, verbose=True 579 | ) 580 | samples = samples.clamp(min=-1.0, max=1.0) 581 | 582 | # Generated samples 583 | self.__log_images( 584 | samples, 585 | f"generated_samples-sigmas={sigmas}", 586 | self.global_step, 587 | ) 588 | 589 | @torch.no_grad() 590 | def __log_images(self, images: Tensor, title: str, global_step: int) -> None: 591 | images = images.detach().float() 592 | 593 | grid = make_grid( 594 | images.clamp(-1.0, 1.0), value_range=(-1.0, 1.0), normalize=True 595 | ) 596 | self.logger.experiment.add_image(title, grid, global_step) 597 | 598 | 599 | @dataclass 600 | class TrainingConfig: 601 | image_dm_config: ImageDataModuleConfig 602 | unet_config: UNetConfig 603 | consistency_training: ImprovedConsistencyTraining 604 | consistency_sampling: ConsistencySamplingAndEditing 605 | lit_icm_config: LitImprovedConsistencyModelConfig 606 | trainer: Trainer 607 | model_ckpt_path: str = "checkpoints/llvip512x512_128x128" 608 | seed: int = 42 609 | 610 | def run_training(config: TrainingConfig) -> None: 611 | # Set seed 612 | seed_everything(config.seed) 613 | 614 | # Create data module 615 | dm = LLVIPDataModule(config.image_dm_config) 616 | dm.setup() 617 | print("DataModule setup complete.") 618 | 619 | # Create model and its EMA 620 | model = UNet(config.unet_config) 621 | ema_model = UNet(config.unet_config) 622 | ema_model.load_state_dict(model.state_dict()) 623 | 624 | # Create lightning module 625 | lit_icm = LitImprovedConsistencyModel( 626 | config.consistency_training, 627 | config.consistency_sampling, 628 | model, 629 | ema_model, 630 | config.lit_icm_config, 631 | ) 632 | 633 | print("Lightning module created.") 634 | 635 | # Run training 636 | print("Starting training...") 637 | config.trainer.fit(lit_icm, datamodule=dm) #add ckpt_path to load checkpoints 638 | print("Training completed.") 639 | 640 | # Save model 641 | lit_icm.model.save_pretrained(config.model_ckpt_path) 642 | print("Model saved.") 643 | 644 | # Main function 645 | def main(): 646 | # Define the checkpoint callback 647 | checkpoint_callback = ModelCheckpoint( 648 | dirpath="checkpoints_512x512_128x128", 649 | filename="{epoch}-{step}", 650 | save_top_k=-1, # Save all checkpoints 651 | every_n_epochs=20, # Adjust as needed 652 | ) 653 | 654 | # Set up the logger 655 | logger = TensorBoardLogger("logs", name="icm") 656 | 657 | training_config = TrainingConfig( 658 | image_dm_config = ImageDataModuleConfig(data_dir="../datasets/LLVIP"), 659 | unet_config=UNetConfig(), 660 | consistency_training=ImprovedConsistencyTraining(final_timesteps=11), 661 | consistency_sampling=ConsistencySamplingAndEditing(), 662 | lit_icm_config=LitImprovedConsistencyModelConfig( 663 | sample_every_n_steps=2100000, lr_scheduler_iters=1000 664 | ), 665 | trainer=Trainer( 666 | max_steps=200000, 667 | precision="16", 668 | log_every_n_steps=10, 669 | logger=logger, 670 | callbacks=[ 671 | LearningRateMonitor(logging_interval="step"), 672 | checkpoint_callback, # Add the checkpoint callback here 673 | ], 674 | ), 675 | ) 676 | run_training(training_config) 677 | 678 | if __name__ == "__main__": 679 | main() -------------------------------------------------------------------------------- /lolv1/sampling_and_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image, ImageOps 3 | import torchvision.transforms as T 4 | import torchvision.transforms.functional as TF 5 | from improved_consistency_model_conditional import ConsistencySamplingAndEditing 6 | from lolv1.script import UNet 7 | import os 8 | from skimage.metrics import peak_signal_noise_ratio as calculate_psnr 9 | from skimage.metrics import structural_similarity as calculate_ssim 10 | import numpy as np 11 | import random 12 | 13 | # ======================= 14 | # Configuration Parameters 15 | # ======================= 16 | 17 | # Path to the trained LOLv2 model 18 | model_path = "checkpoints/lolv1_128x128" # Update this path if different 19 | 20 | # Dataset paths 21 | visible_folder = "../datasets/lolv1/eval15/low/" # Visible (Low Exposure) images 22 | infrared_folder = "../datasets/lolv1/eval15/high/" # Infrared (Normal Exposure) images 23 | 24 | # Output directories 25 | output_folder = "lolv1_128x128" # Directory to save generated images 26 | metrics_file = "metrics_lolv1.txt" # File to save PSNR and SSIM metrics 27 | 28 | # Transformation after cropping and resizing 29 | transform = T.Compose([ 30 | T.ToTensor(), 31 | T.Lambda(lambda x: (x * 2) - 1), # Normalize to [-1, 1] 32 | ]) 33 | 34 | # Sigma schedule for noise addition (adjust based on training) 35 | sigmas = [80.0, 40.0, 20.0, 10.0, 5.0, 2.5, 1.25, 0.625, 0.3125, 0.15625, 0.078125] 36 | 37 | # Number of images to process (set to None to process all images) 38 | num_images_to_process = None # Set to an integer value to limit the number of images 39 | 40 | # Seed for reproducibility 41 | random_seed = 28 42 | 43 | # ======================= 44 | # Initialize Model and Device 45 | # ======================= 46 | 47 | # Load the trained UNet model 48 | model = UNet.from_pretrained(model_path) 49 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 50 | model = model.to(device).eval() 51 | 52 | # Initialize the consistency sampling instance 53 | consistency_sampling = ConsistencySamplingAndEditing() 54 | 55 | # ======================= 56 | # Utility Functions 57 | # ======================= 58 | 59 | def denormalize(tensor): 60 | """ 61 | Denormalize a tensor from [-1, 1] to [0, 1]. 62 | """ 63 | return (tensor + 1) / 2 64 | 65 | def calculate_metrics(reference, generated): 66 | """ 67 | Calculate PSNR and SSIM between two images. 68 | 69 | Args: 70 | reference (numpy.ndarray): Reference image array. 71 | generated (numpy.ndarray): Generated image array. 72 | 73 | Returns: 74 | tuple: PSNR and SSIM values. 75 | """ 76 | # Ensure the images are in the range [0, 1] 77 | reference = reference.astype(np.float32) / 255.0 78 | generated = generated.astype(np.float32) / 255.0 79 | 80 | # Calculate PSNR 81 | psnr_value = calculate_psnr(reference, generated, data_range=1.0) 82 | 83 | # Calculate SSIM 84 | ssim_value = calculate_ssim(reference, generated, data_range=1.0, multichannel=True, win_size=3) 85 | 86 | return psnr_value, ssim_value 87 | 88 | def get_infrared_image_name(visible_image_name): 89 | """ 90 | Given a visible image name like 'lowXXXX.png', return the corresponding 91 | infrared image name like 'normalXXXX.png'. 92 | 93 | Args: 94 | visible_image_name (str): Filename of the visible image. 95 | 96 | Returns: 97 | str: Corresponding infrared image filename. 98 | 99 | Raises: 100 | ValueError: If the visible image does not start with 'low'. 101 | """ 102 | if visible_image_name.lower().startswith('low'): 103 | return visible_image_name 104 | else: 105 | raise ValueError(f"Unexpected visible image prefix in {visible_image_name}") 106 | 107 | # ======================= 108 | # Main Processing Function 109 | # ======================= 110 | 111 | def process_image(visible_image_name): 112 | """ 113 | Process a single image pair: apply random crop, resize, generate denoised infrared, 114 | and calculate PSNR and SSIM. 115 | 116 | Args: 117 | visible_image_name (str): Filename of the visible image. 118 | """ 119 | global total_psnr, total_ssim, num_processed_images 120 | 121 | # try: 122 | # Map to corresponding infrared image 123 | #infrared_image_name = get_infrared_image_name(visible_image_name) 124 | #except ValueError as ve: 125 | #print(f"Skipping {visible_image_name}: {ve}") 126 | #return 127 | 128 | # Construct full image paths 129 | visible_image_path = os.path.join(visible_folder, visible_image_name) 130 | infrared_image_path = os.path.join(infrared_folder, visible_image_name) 131 | 132 | # Check existence of infrared image 133 | if not os.path.exists(infrared_image_path): 134 | print(f"Infrared image {infrared_image_path} not found, skipping.") 135 | return 136 | 137 | # Load images 138 | try: 139 | visible_image = Image.open(visible_image_path).convert("RGB") 140 | infrared_image = Image.open(infrared_image_path).convert("RGB") 141 | except Exception as e: 142 | print(f"Error loading images {visible_image_name} and/or {infrared_image_name}: {e}, skipping.") 143 | return 144 | 145 | 146 | # Crop both images 147 | visible_cropped = visible_image 148 | infrared_cropped = infrared_image 149 | 150 | # Resize to 128x128 151 | # visible_resized = visible_cropped.resize((128, 128), Image.BICUBIC) 152 | # infrared_resized = infrared_cropped.resize((128, 128), Image.BICUBIC) 153 | 154 | # Apply transformations 155 | visible_tensor = transform(visible_cropped).unsqueeze(0).to(device) 156 | infrared_tensor = transform(infrared_cropped).unsqueeze(0).to(device) 157 | 158 | # Add noise to the infrared image 159 | max_sigma = sigmas[0] # Highest sigma value 160 | noise = torch.randn_like(infrared_tensor) * max_sigma 161 | noisy_infrared_tensor = noise 162 | 163 | # Generate the infrared image starting from the noisy infrared image 164 | try: 165 | with torch.no_grad(): 166 | generated_infrared_tensor = consistency_sampling( 167 | model=model, 168 | y=noisy_infrared_tensor, 169 | v=visible_tensor, 170 | sigmas=sigmas, 171 | start_from_y=True, 172 | add_initial_noise=False, 173 | clip_denoised=True, 174 | verbose=False, # Set verbose=False to reduce output 175 | ) 176 | except Exception as e: 177 | print(f"Error during model inference for {visible_image_name}: {e}, skipping.") 178 | return 179 | 180 | # Denormalize tensors 181 | generated_infrared_denorm = denormalize(generated_infrared_tensor.squeeze(0).cpu()) 182 | 183 | # Convert tensors to PIL images 184 | generated_infrared_pil = TF.to_pil_image(generated_infrared_denorm) 185 | 186 | # Reference infrared image (already resized to 128x128) 187 | reference_infrared_pil = infrared_cropped 188 | 189 | # Convert images to numpy arrays for metric calculation 190 | reference_image_np = np.array(reference_infrared_pil) 191 | generated_image_np = np.array(generated_infrared_pil) 192 | 193 | # Calculate PSNR and SSIM 194 | psnr_value, ssim_value = calculate_metrics(reference_image_np, generated_image_np) 195 | 196 | # Accumulate PSNR and SSIM 197 | total_psnr += psnr_value 198 | total_ssim += ssim_value 199 | num_processed_images += 1 200 | 201 | # Print PSNR and SSIM for the current image 202 | print(f"Image : PSNR = {psnr_value:.2f}, SSIM = {ssim_value:.4f}") 203 | 204 | # Save the generated infrared image for visual inspection 205 | output_filename_infrared = f"generated_infrared_{visible_image_name}" 206 | output_path_infrared = os.path.join(output_folder, output_filename_infrared) 207 | generated_infrared_pil.save(output_path_infrared) 208 | print(f"Saved generated infrared image to {output_path_infrared}\n") 209 | 210 | # ======================= 211 | # Processing All Images 212 | # ======================= 213 | 214 | def main(): 215 | global total_psnr, total_ssim, num_processed_images 216 | total_psnr = 0.0 217 | total_ssim = 0.0 218 | num_processed_images = 0 219 | 220 | # Get a list of all images in the visible folder 221 | visible_images = os.listdir(visible_folder) 222 | image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff') 223 | visible_images = [img for img in visible_images if img.lower().endswith(image_extensions)] 224 | 225 | # Optionally limit the number of images to process 226 | if num_images_to_process is not None: 227 | selected_visible_images = random.sample(visible_images, min(num_images_to_process, len(visible_images))) 228 | else: 229 | selected_visible_images = visible_images 230 | 231 | print(f"Processing {len(selected_visible_images)} images...\n") 232 | 233 | for idx, visible_image_name in enumerate(selected_visible_images, start=1): 234 | print(f"Processing image {idx}/{len(selected_visible_images)}: {visible_image_name}") 235 | process_image(visible_image_name) 236 | 237 | # Calculate and print average PSNR and SSIM 238 | if num_processed_images > 0: 239 | average_psnr = total_psnr / num_processed_images 240 | average_ssim = total_ssim / num_processed_images 241 | print(f"\nProcessed {num_processed_images} images.") 242 | print(f"Average PSNR: {average_psnr:.2f}") 243 | print(f"Average SSIM: {average_ssim:.4f}") 244 | 245 | # Save metrics to a text file 246 | with open(metrics_file, "a") as f: 247 | f.write(f"Processed {num_processed_images} images.\n") 248 | f.write(f"Average PSNR: {average_psnr:.2f}\n") 249 | f.write(f"Average SSIM: {average_ssim:.4f}\n\n") 250 | print(f"Saved metrics to {metrics_file}") 251 | else: 252 | print("No images were processed.") 253 | 254 | if __name__ == "__main__": 255 | main() 256 | -------------------------------------------------------------------------------- /lolv1/script.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import asdict, dataclass 4 | from typing import Any, Callable, List, Optional, Tuple, Union 5 | 6 | import torch 7 | from torchvision.transforms import functional as TF 8 | import math 9 | from einops import rearrange 10 | from einops.layers.torch import Rearrange 11 | from lightning import LightningDataModule, LightningModule, Trainer, seed_everything 12 | from lightning.pytorch.callbacks import LearningRateMonitor 13 | from lightning.pytorch.loggers import TensorBoardLogger 14 | from matplotlib import pyplot as plt 15 | from torch import Tensor, nn 16 | from torch.nn import functional as F 17 | from torch.utils.data import DataLoader 18 | from torchinfo import summary 19 | from torchvision import transforms as T 20 | from torchvision.datasets import ImageFolder 21 | from torchvision.utils import make_grid 22 | 23 | from improved_consistency_model_conditional import ( 24 | ConsistencySamplingAndEditing, 25 | ImprovedConsistencyTraining, 26 | pseudo_huber_loss, 27 | update_ema_model_, 28 | ) 29 | 30 | 31 | from torch.utils.data import Dataset 32 | from torchvision import transforms as T 33 | import os 34 | from PIL import Image 35 | import torch 36 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint 37 | 38 | class PairedDataset(Dataset): 39 | def __init__( 40 | self, 41 | visible_dir: str, 42 | infrared_dir: str, 43 | transform: Optional[Callable] = None, 44 | crop_size: Tuple[int, int] = (128, 128), 45 | resize_size: Tuple[int, int] = (128, 128) 46 | ): 47 | self.visible_dir = visible_dir 48 | self.infrared_dir = infrared_dir 49 | self.visible_images = sorted(os.listdir(visible_dir)) 50 | self.infrared_images = sorted(os.listdir(infrared_dir)) 51 | self.transform = transform 52 | self.crop_size = crop_size 53 | self.resize_size = resize_size 54 | 55 | def __len__(self) -> int: 56 | return len(self.visible_images) 57 | 58 | def __getitem__(self, index: int) -> Optional[Tuple[Tensor, Tensor]]: 59 | visible_path = os.path.join(self.visible_dir, self.visible_images[index]) 60 | infrared_path = os.path.join(self.infrared_dir, self.infrared_images[index]) 61 | 62 | visible_image = Image.open(visible_path).convert("RGB") 63 | infrared_image = Image.open(infrared_path).convert("RGB") 64 | 65 | if visible_image.size != infrared_image.size: 66 | print(f"Skipping image pair at index {index} due to mismatched sizes") 67 | return None 68 | 69 | # Perform synchronized random horizontal flip 70 | if torch.rand(1).item() > 0.5: 71 | visible_image = TF.hflip(visible_image) 72 | infrared_image = TF.hflip(infrared_image) 73 | 74 | # Perform synchronized random crop 75 | i, j, h, w = T.RandomCrop.get_params(visible_image, output_size=self.crop_size) 76 | visible_image = TF.crop(visible_image, i, j, h, w) 77 | infrared_image = TF.crop(infrared_image, i, j, h, w) 78 | 79 | if self.transform: 80 | visible_image = self.transform(visible_image) 81 | infrared_image = self.transform(infrared_image) 82 | 83 | return visible_image, infrared_image 84 | 85 | from dataclasses import dataclass 86 | from typing import Tuple 87 | 88 | @dataclass 89 | class ImageDataModuleConfig: 90 | data_dir: str = "datasets/lolv1" # Path to the dataset directory 91 | image_size_crop: Tuple[int, int] = (128, 128) # Size for random cropping 92 | image_size_resize: Tuple[int, int] = (128, 128) # Resize to 128x128 93 | batch_size: int = 34 # Number of images in each batch 94 | num_workers: int = 6 # Number of worker threads for data loading 95 | pin_memory: bool = True # Whether to pin memory in data loader 96 | persistent_workers: bool = True # Keep workers alive between epochs 97 | 98 | from torch.utils.data import DataLoader 99 | from lightning.pytorch import LightningDataModule 100 | 101 | class LLVIPDataModule(LightningDataModule): 102 | def __init__(self, config: ImageDataModuleConfig) -> None: 103 | super().__init__() 104 | self.config = config 105 | 106 | def setup(self, stage: str = None) -> None: 107 | # Define transforms excluding cropping and resizing 108 | self.transform = T.Compose([ 109 | T.ToTensor(), 110 | T.Lambda(lambda x: (x * 2) - 1), # Normalize to [-1, 1] 111 | ]) 112 | 113 | self.dataset = PairedDataset( 114 | visible_dir=os.path.join(self.config.data_dir, "our485/low"), 115 | infrared_dir=os.path.join(self.config.data_dir, "our485/high"), 116 | transform=self.transform, 117 | crop_size=self.config.image_size_crop, 118 | resize_size=self.config.image_size_resize 119 | ) 120 | 121 | def train_dataloader(self) -> DataLoader: 122 | return DataLoader( 123 | self.dataset, 124 | batch_size=self.config.batch_size, 125 | shuffle=True, 126 | num_workers=self.config.num_workers, 127 | pin_memory=self.config.pin_memory, 128 | persistent_workers=self.config.persistent_workers, 129 | ) 130 | 131 | 132 | def GroupNorm(channels: int) -> nn.GroupNorm: 133 | return nn.GroupNorm(num_groups=min(32, channels // 4), num_channels=channels) 134 | 135 | 136 | class SelfAttention(nn.Module): 137 | def __init__( 138 | self, 139 | in_channels: int, 140 | out_channels: int, 141 | n_heads: int = 8, 142 | dropout: float = 0.3, 143 | ) -> None: 144 | super().__init__() 145 | 146 | self.dropout = dropout 147 | 148 | self.qkv_projection = nn.Sequential( 149 | GroupNorm(in_channels), 150 | nn.Conv2d(in_channels, 3 * in_channels, kernel_size=1, bias=False), 151 | Rearrange("b (i h d) x y -> i b h (x y) d", i=3, h=n_heads), 152 | ) 153 | self.output_projection = nn.Sequential( 154 | Rearrange("b h l d -> b l (h d)"), 155 | nn.Linear(in_channels, out_channels, bias=False), 156 | Rearrange("b l d -> b d l"), 157 | GroupNorm(out_channels), 158 | nn.Dropout1d(dropout), 159 | ) 160 | self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1) 161 | 162 | def forward(self, x: Tensor) -> Tensor: 163 | q, k, v = self.qkv_projection(x).unbind(dim=0) 164 | 165 | output = F.scaled_dot_product_attention( 166 | q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=False 167 | ) 168 | output = self.output_projection(output) 169 | output = rearrange(output, "b c (x y) -> b c x y", x=x.shape[-2], y=x.shape[-1]) 170 | 171 | return output + self.residual_projection(x) 172 | 173 | 174 | class UNetBlock(nn.Module): 175 | def __init__( 176 | self, 177 | in_channels: int, 178 | out_channels: int, 179 | noise_level_channels: int, 180 | dropout: float = 0.3, 181 | ) -> None: 182 | super().__init__() 183 | 184 | self.input_projection = nn.Sequential( 185 | GroupNorm(in_channels), 186 | nn.SiLU(), 187 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same"), 188 | nn.Dropout2d(dropout), 189 | ) 190 | self.noise_level_projection = nn.Sequential( 191 | nn.SiLU(), 192 | nn.Conv2d(noise_level_channels, out_channels, kernel_size=1), 193 | ) 194 | self.output_projection = nn.Sequential( 195 | GroupNorm(out_channels), 196 | nn.SiLU(), 197 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same"), 198 | nn.Dropout2d(dropout), 199 | ) 200 | self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1) 201 | 202 | def forward(self, x: Tensor, noise_level: Tensor) -> Tensor: 203 | h = self.input_projection(x) 204 | h = h + self.noise_level_projection(noise_level) 205 | 206 | return self.output_projection(h) + self.residual_projection(x) 207 | 208 | 209 | class UNetBlockWithSelfAttention(nn.Module): 210 | def __init__( 211 | self, 212 | in_channels: int, 213 | out_channels: int, 214 | noise_level_channels: int, 215 | n_heads: int = 8, 216 | dropout: float = 0.3, 217 | ) -> None: 218 | super().__init__() 219 | 220 | self.unet_block = UNetBlock( 221 | in_channels, out_channels, noise_level_channels, dropout 222 | ) 223 | self.self_attention = SelfAttention( 224 | out_channels, out_channels, n_heads, dropout 225 | ) 226 | 227 | def forward(self, x: Tensor, noise_level: Tensor) -> Tensor: 228 | return self.self_attention(self.unet_block(x, noise_level)) 229 | 230 | 231 | class Downsample(nn.Module): 232 | def __init__(self, channels: int) -> None: 233 | super().__init__() 234 | 235 | self.projection = nn.Sequential( 236 | Rearrange("b c (h ph) (w pw) -> b (c ph pw) h w", ph=2, pw=2), 237 | nn.Conv2d(4 * channels, channels, kernel_size=1), 238 | ) 239 | 240 | def forward(self, x: Tensor) -> Tensor: 241 | return self.projection(x) 242 | 243 | 244 | class Upsample(nn.Module): 245 | def __init__(self, channels: int) -> None: 246 | super().__init__() 247 | 248 | self.projection = nn.Sequential( 249 | nn.Upsample(scale_factor=2.0, mode="nearest"), 250 | nn.Conv2d(channels, channels, kernel_size=3, padding="same"), 251 | ) 252 | 253 | def forward(self, x: Tensor) -> Tensor: 254 | return self.projection(x) 255 | 256 | 257 | class NoiseLevelEmbedding(nn.Module): 258 | def __init__(self, channels: int, scale: float = 0.02) -> None: 259 | super().__init__() 260 | 261 | self.W = nn.Parameter(torch.randn(channels // 2) * scale, requires_grad=False) 262 | 263 | self.projection = nn.Sequential( 264 | nn.Linear(channels, 4 * channels), 265 | nn.SiLU(), 266 | nn.Linear(4 * channels, channels), 267 | Rearrange("b c -> b c () ()"), 268 | ) 269 | 270 | def forward(self, x: Tensor) -> Tensor: 271 | h = x[:, None] * self.W[None, :] * 2 * torch.pi 272 | h = torch.cat([torch.sin(h), torch.cos(h)], dim=-1) 273 | 274 | return self.projection(h) 275 | 276 | 277 | @dataclass 278 | class UNetConfig: 279 | channels: int = 3 280 | noise_level_channels: int = 256 281 | noise_level_scale: float = 0.02 282 | n_heads: int = 8 283 | top_blocks_channels: Tuple[int, ...] = (128, 128) 284 | top_blocks_n_blocks_per_resolution: Tuple[int, ...] = (2, 2) 285 | top_blocks_has_resampling: Tuple[bool, ...] = (True, True) 286 | top_blocks_dropout: Tuple[float, ...] = (0.0, 0.0) 287 | mid_blocks_channels: Tuple[int, ...] = (256, 512) 288 | mid_blocks_n_blocks_per_resolution: Tuple[int, ...] = (4, 4) 289 | mid_blocks_has_resampling: Tuple[bool, ...] = (True, False) 290 | mid_blocks_dropout: Tuple[float, ...] = (0.0, 0.3) 291 | 292 | 293 | class UNet(nn.Module): 294 | def __init__(self, config: UNetConfig) -> None: 295 | super().__init__() 296 | 297 | self.config = config 298 | 299 | self.input_projection = nn.Conv2d( 300 | config.channels * 2, 301 | config.top_blocks_channels[0], 302 | kernel_size=3, 303 | padding="same", 304 | ) 305 | self.noise_level_embedding = NoiseLevelEmbedding( 306 | config.noise_level_channels, config.noise_level_scale 307 | ) 308 | self.top_encoder_blocks = self._make_encoder_blocks( 309 | self.config.top_blocks_channels + self.config.mid_blocks_channels[:1], 310 | self.config.top_blocks_n_blocks_per_resolution, 311 | self.config.top_blocks_has_resampling, 312 | self.config.top_blocks_dropout, 313 | self._make_top_block, 314 | ) 315 | self.mid_encoder_blocks = self._make_encoder_blocks( 316 | self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:], 317 | self.config.mid_blocks_n_blocks_per_resolution, 318 | self.config.mid_blocks_has_resampling, 319 | self.config.mid_blocks_dropout, 320 | self._make_mid_block, 321 | ) 322 | self.mid_decoder_blocks = self._make_decoder_blocks( 323 | self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:], 324 | self.config.mid_blocks_n_blocks_per_resolution, 325 | self.config.mid_blocks_has_resampling, 326 | self.config.mid_blocks_dropout, 327 | self._make_mid_block, 328 | ) 329 | self.top_decoder_blocks = self._make_decoder_blocks( 330 | self.config.top_blocks_channels + self.config.mid_blocks_channels[:1], 331 | self.config.top_blocks_n_blocks_per_resolution, 332 | self.config.top_blocks_has_resampling, 333 | self.config.top_blocks_dropout, 334 | self._make_top_block, 335 | ) 336 | self.output_projection = nn.Conv2d( 337 | config.top_blocks_channels[0], 338 | config.channels, 339 | kernel_size=3, 340 | padding="same", 341 | ) 342 | 343 | def forward(self, x: Tensor, noise_level: Tensor, v: Tensor) -> Tensor: 344 | x = torch.cat([x, v], dim = 1) 345 | h = self.input_projection(x) 346 | noise_level = self.noise_level_embedding(noise_level) 347 | 348 | top_encoder_embeddings = [] 349 | for block in self.top_encoder_blocks: 350 | if isinstance(block, UNetBlock): 351 | h = block(h, noise_level) 352 | top_encoder_embeddings.append(h) 353 | else: 354 | h = block(h) 355 | 356 | mid_encoder_embeddings = [] 357 | for block in self.mid_encoder_blocks: 358 | if isinstance(block, UNetBlockWithSelfAttention): 359 | h = block(h, noise_level) 360 | mid_encoder_embeddings.append(h) 361 | else: 362 | h = block(h) 363 | 364 | for block in self.mid_decoder_blocks: 365 | if isinstance(block, UNetBlockWithSelfAttention): 366 | h = torch.cat((h, mid_encoder_embeddings.pop()), dim=1) 367 | h = block(h, noise_level) 368 | else: 369 | h = block(h) 370 | 371 | for block in self.top_decoder_blocks: 372 | if isinstance(block, UNetBlock): 373 | h = torch.cat((h, top_encoder_embeddings.pop()), dim=1) 374 | h = block(h, noise_level) 375 | else: 376 | h = block(h) 377 | 378 | output = self.output_projection(h) 379 | 380 | return output 381 | 382 | def _make_encoder_blocks( 383 | self, 384 | channels: Tuple[int, ...], 385 | n_blocks_per_resolution: Tuple[int, ...], 386 | has_resampling: Tuple[bool, ...], 387 | dropout: Tuple[float, ...], 388 | block_fn: Callable[[], nn.Module], 389 | ) -> nn.ModuleList: 390 | blocks = nn.ModuleList() 391 | 392 | channel_pairs = list(zip(channels[:-1], channels[1:])) 393 | for idx, (in_channels, out_channels) in enumerate(channel_pairs): 394 | for _ in range(n_blocks_per_resolution[idx]): 395 | blocks.append(block_fn(in_channels, out_channels, dropout[idx])) 396 | in_channels = out_channels 397 | 398 | if has_resampling[idx]: 399 | blocks.append(Downsample(out_channels)) 400 | 401 | return blocks 402 | 403 | def _make_decoder_blocks( 404 | self, 405 | channels: Tuple[int, ...], 406 | n_blocks_per_resolution: Tuple[int, ...], 407 | has_resampling: Tuple[bool, ...], 408 | dropout: Tuple[float, ...], 409 | block_fn: Callable[[], nn.Module], 410 | ) -> nn.ModuleList: 411 | blocks = nn.ModuleList() 412 | 413 | channel_pairs = list(zip(channels[:-1], channels[1:]))[::-1] 414 | for idx, (out_channels, in_channels) in enumerate(channel_pairs): 415 | if has_resampling[::-1][idx]: 416 | blocks.append(Upsample(in_channels)) 417 | 418 | inner_blocks = [] 419 | for _ in range(n_blocks_per_resolution[::-1][idx]): 420 | inner_blocks.append( 421 | block_fn(in_channels * 2, out_channels, dropout[::-1][idx]) 422 | ) 423 | out_channels = in_channels 424 | blocks.extend(inner_blocks[::-1]) 425 | 426 | return blocks 427 | 428 | def _make_top_block( 429 | self, in_channels: int, out_channels: int, dropout: float 430 | ) -> UNetBlock: 431 | return UNetBlock( 432 | in_channels, 433 | out_channels, 434 | self.config.noise_level_channels, 435 | dropout, 436 | ) 437 | 438 | def _make_mid_block( 439 | self, 440 | in_channels: int, 441 | out_channels: int, 442 | dropout: float, 443 | ) -> UNetBlockWithSelfAttention: 444 | return UNetBlockWithSelfAttention( 445 | in_channels, 446 | out_channels, 447 | self.config.noise_level_channels, 448 | self.config.n_heads, 449 | dropout, 450 | ) 451 | 452 | def save_pretrained(self, pretrained_path: str) -> None: 453 | os.makedirs(pretrained_path, exist_ok=True) 454 | 455 | with open(os.path.join(pretrained_path, "config.json"), mode="w") as f: 456 | json.dump(asdict(self.config), f) 457 | 458 | torch.save(self.state_dict(), os.path.join(pretrained_path, "model.pt")) 459 | 460 | @classmethod 461 | def from_pretrained(cls, pretrained_path: str) -> "UNet": 462 | with open(os.path.join(pretrained_path, "config.json"), mode="r") as f: 463 | config_dict = json.load(f) 464 | config = UNetConfig(**config_dict) 465 | 466 | model = cls(config) 467 | 468 | state_dict = torch.load( 469 | os.path.join(pretrained_path, "model.pt"), map_location=torch.device("cpu") 470 | ) 471 | model.load_state_dict(state_dict) 472 | 473 | return model 474 | 475 | 476 | # summary(UNet(UNetConfig()), input_size=((1, 6, 32, 32), (1,))) 477 | 478 | 479 | @dataclass 480 | class LitImprovedConsistencyModelConfig: 481 | ema_decay_rate: float = 0.99993 482 | lr: float = 1e-4 483 | betas: Tuple[float, float] = (0.9, 0.995) 484 | lr_scheduler_start_factor: float = 1e-5 485 | lr_scheduler_iters: int = 10_000 486 | sample_every_n_steps: int = 10_000 487 | num_samples: int = 8 488 | sampling_sigmas: Tuple[Tuple[int, ...], ...] = ( 489 | (80,), 490 | (80.0, 0.661), 491 | (80.0, 24.4, 5.84, 0.9, 0.661), 492 | ) 493 | 494 | 495 | class LitImprovedConsistencyModel(LightningModule): 496 | def __init__( 497 | self, 498 | consistency_training: ImprovedConsistencyTraining, 499 | consistency_sampling: ConsistencySamplingAndEditing, 500 | model: UNet, 501 | ema_model: UNet, 502 | config: LitImprovedConsistencyModelConfig, 503 | ) -> None: 504 | super().__init__() 505 | 506 | self.consistency_training = consistency_training 507 | self.consistency_sampling = consistency_sampling 508 | self.model = model 509 | self.ema_model = ema_model 510 | self.config = config 511 | 512 | # Freeze the EMA model and set it to eval mode 513 | for param in self.ema_model.parameters(): 514 | param.requires_grad = False 515 | self.ema_model = self.ema_model.eval() 516 | 517 | def training_step(self, batch: Union[Tensor, List[Tensor]], batch_idx: int) -> None: 518 | # if isinstance(batch, list): 519 | # batch = batch[0] 520 | 521 | visible_images, infrared_images = batch # Unpack the batch 522 | 523 | output = self.consistency_training( 524 | self.model, 525 | infrared_images, 526 | visible_images, # Pass visible images to the training function 527 | self.global_step, 528 | self.trainer.max_steps 529 | ) 530 | 531 | loss = ( 532 | pseudo_huber_loss(output.predicted, output.target) * output.loss_weights 533 | ).mean() 534 | 535 | self.log_dict({"train_loss": loss, "num_timesteps": output.num_timesteps}) 536 | 537 | return loss 538 | 539 | def on_train_batch_end( 540 | self, outputs: Any, batch: Union[Tensor, List[Tensor]], batch_idx: int 541 | ) -> None: 542 | update_ema_model_(self.model, self.ema_model, self.config.ema_decay_rate) 543 | 544 | if ( 545 | (self.global_step + 1) % self.config.sample_every_n_steps == 0 546 | ) or self.global_step == 0: 547 | self.__sample_and_log_samples(batch) 548 | 549 | def configure_optimizers(self): 550 | opt = torch.optim.Adam( 551 | self.model.parameters(), lr=self.config.lr, betas=self.config.betas 552 | ) 553 | sched = torch.optim.lr_scheduler.LinearLR( 554 | opt, 555 | start_factor=self.config.lr_scheduler_start_factor, 556 | total_iters=self.config.lr_scheduler_iters, 557 | ) 558 | sched = {"scheduler": sched, "interval": "step", "frequency": 1} 559 | 560 | return [opt], [sched] 561 | 562 | @torch.no_grad() 563 | def __sample_and_log_samples(self, batch: Union[Tensor, List[Tensor]]) -> None: 564 | if isinstance(batch, list): 565 | batch = batch[0] 566 | 567 | # Ensure the number of samples does not exceed the batch size 568 | num_samples = min(self.config.num_samples, batch.shape[0]) 569 | noise = torch.randn_like(batch[:num_samples]) 570 | 571 | # Log ground truth samples 572 | self.__log_images( 573 | batch[:num_samples].detach().clone(), f"ground_truth", self.global_step 574 | ) 575 | 576 | for sigmas in self.config.sampling_sigmas: 577 | samples = self.consistency_sampling( 578 | self.ema_model, noise, sigmas, clip_denoised=True, verbose=True 579 | ) 580 | samples = samples.clamp(min=-1.0, max=1.0) 581 | 582 | # Generated samples 583 | self.__log_images( 584 | samples, 585 | f"generated_samples-sigmas={sigmas}", 586 | self.global_step, 587 | ) 588 | 589 | @torch.no_grad() 590 | def __log_images(self, images: Tensor, title: str, global_step: int) -> None: 591 | images = images.detach().float() 592 | 593 | grid = make_grid( 594 | images.clamp(-1.0, 1.0), value_range=(-1.0, 1.0), normalize=True 595 | ) 596 | self.logger.experiment.add_image(title, grid, global_step) 597 | 598 | 599 | @dataclass 600 | class TrainingConfig: 601 | image_dm_config: ImageDataModuleConfig 602 | unet_config: UNetConfig 603 | consistency_training: ImprovedConsistencyTraining 604 | consistency_sampling: ConsistencySamplingAndEditing 605 | lit_icm_config: LitImprovedConsistencyModelConfig 606 | trainer: Trainer 607 | model_ckpt_path: str = "checkpoints/lolv1_128x128" 608 | seed: int = 42 609 | 610 | def run_training(config: TrainingConfig) -> None: 611 | # Set seed 612 | seed_everything(config.seed) 613 | 614 | # Create data module 615 | dm = LLVIPDataModule(config.image_dm_config) 616 | dm.setup() 617 | print("DataModule setup complete.") 618 | 619 | # Create model and its EMA 620 | model = UNet(config.unet_config) 621 | ema_model = UNet(config.unet_config) 622 | ema_model.load_state_dict(model.state_dict()) 623 | 624 | # Create lightning module 625 | lit_icm = LitImprovedConsistencyModel( 626 | config.consistency_training, 627 | config.consistency_sampling, 628 | model, 629 | ema_model, 630 | config.lit_icm_config, 631 | ) 632 | 633 | print("Lightning module created.") 634 | 635 | # Run training 636 | print("Starting training...") 637 | config.trainer.fit(lit_icm, datamodule=dm) 638 | print("Training completed.") 639 | 640 | # Save model 641 | lit_icm.model.save_pretrained(config.model_ckpt_path) 642 | print("Model saved.") 643 | 644 | # Main function 645 | def main(): 646 | # Define the checkpoint callback 647 | checkpoint_callback = ModelCheckpoint( 648 | dirpath="checkpoints_lolv1", 649 | filename="{epoch}-{step}", 650 | save_top_k=-1, # Save all checkpoints 651 | every_n_epochs=100, # Adjust as needed 652 | ) 653 | 654 | # Set up the logger 655 | logger = TensorBoardLogger("logs", name="icm") 656 | 657 | training_config = TrainingConfig( 658 | image_dm_config=ImageDataModuleConfig(data_dir="../datasets/lolv1"), 659 | unet_config=UNetConfig(), 660 | consistency_training=ImprovedConsistencyTraining(final_timesteps=11), 661 | consistency_sampling=ConsistencySamplingAndEditing(), 662 | lit_icm_config=LitImprovedConsistencyModelConfig( 663 | sample_every_n_steps=2100000, lr_scheduler_iters=1000 664 | ), 665 | trainer=Trainer( 666 | max_steps=200000, 667 | precision="16", 668 | log_every_n_steps=10, 669 | logger=logger, 670 | callbacks=[ 671 | LearningRateMonitor(logging_interval="step"), 672 | checkpoint_callback, # Add the checkpoint callback here 673 | ], 674 | ), 675 | ) 676 | run_training(training_config) 677 | 678 | if __name__ == "__main__": 679 | main() -------------------------------------------------------------------------------- /lolv2/sampling_and_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image, ImageOps 3 | import torchvision.transforms as T 4 | import torchvision.transforms.functional as TF 5 | from improved_consistency_model_conditional import ConsistencySamplingAndEditing 6 | from lolv2.script import UNet 7 | import os 8 | from skimage.metrics import peak_signal_noise_ratio as calculate_psnr 9 | from skimage.metrics import structural_similarity as calculate_ssim 10 | import numpy as np 11 | import random 12 | 13 | # ======================= 14 | # Configuration Parameters 15 | # ======================= 16 | 17 | # Path to the trained LOLv2 model 18 | model_path = "checkpoints/lolv2_real/" # Update this path if different 19 | 20 | # Dataset paths 21 | visible_folder = "../datasets/LOL-v2/Real_captured/Test/Low/" # Visible (Low Exposure) images 22 | infrared_folder = "../datasets/LOL-v2/Real_captured/Test/Normal/" # Infrared (Normal Exposure) images 23 | 24 | # Output directories 25 | output_folder = "results_lolv2_real" # Directory to save generated images 26 | metrics_file = "metrics_lolv2_real.txt" # File to save PSNR and SSIM metrics 27 | 28 | # Transformation after cropping and resizing 29 | transform = T.Compose([ 30 | T.ToTensor(), 31 | T.Lambda(lambda x: (x * 2) - 1), # Normalize to [-1, 1] 32 | ]) 33 | 34 | # Sigma schedule for noise addition (adjust based on training) 35 | sigmas = [80.0, 40.0, 20.0, 10.0, 5.0, 2.5, 1.25, 0.625, 0.3125, 0.15625, 0.078125] 36 | 37 | # Number of images to process (set to None to process all images) 38 | num_images_to_process = None # Set to an integer value to limit the number of images 39 | 40 | # Seed for reproducibility 41 | random_seed = 28 42 | 43 | # ======================= 44 | # Initialize Model and Device 45 | # ======================= 46 | 47 | # Load the trained UNet model 48 | model = UNet.from_pretrained(model_path) 49 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 50 | model = model.to(device).eval() 51 | 52 | # Initialize the consistency sampling instance 53 | consistency_sampling = ConsistencySamplingAndEditing() 54 | 55 | # ======================= 56 | # Utility Functions 57 | # ======================= 58 | 59 | def denormalize(tensor): 60 | """ 61 | Denormalize a tensor from [-1, 1] to [0, 1]. 62 | """ 63 | return (tensor + 1) / 2 64 | 65 | def calculate_metrics(reference, generated): 66 | """ 67 | Calculate PSNR and SSIM between two images. 68 | 69 | Args: 70 | reference (numpy.ndarray): Reference image array. 71 | generated (numpy.ndarray): Generated image array. 72 | 73 | Returns: 74 | tuple: PSNR and SSIM values. 75 | """ 76 | # Ensure the images are in the range [0, 1] 77 | reference = reference.astype(np.float32) / 255.0 78 | generated = generated.astype(np.float32) / 255.0 79 | 80 | # Calculate PSNR 81 | psnr_value = calculate_psnr(reference, generated, data_range=1.0) 82 | 83 | # Calculate SSIM 84 | ssim_value = calculate_ssim(reference, generated, data_range=1.0, multichannel=True, win_size=3) 85 | 86 | return psnr_value, ssim_value 87 | 88 | def get_infrared_image_name(visible_image_name): 89 | """ 90 | Given a visible image name like 'lowXXXX.png', return the corresponding 91 | infrared image name like 'normalXXXX.png'. 92 | 93 | Args: 94 | visible_image_name (str): Filename of the visible image. 95 | 96 | Returns: 97 | str: Corresponding infrared image filename. 98 | 99 | Raises: 100 | ValueError: If the visible image does not start with 'low'. 101 | """ 102 | prefix = 'low' 103 | replacement = 'normal' 104 | 105 | if visible_image_name.lower().startswith(prefix): 106 | # Find the length of the prefix to replace it accurately 107 | prefix_length = len(prefix) 108 | # Replace the prefix with 'normal', preserving the original case after the prefix 109 | infrared_image_name = replacement + visible_image_name[prefix_length:] 110 | return infrared_image_name 111 | else: 112 | raise ValueError(f"Unexpected visible image prefix in '{visible_image_name}'. Expected to start with '{prefix}'.") 113 | 114 | # ======================= 115 | # Main Processing Function 116 | # ======================= 117 | 118 | def process_image(visible_image_name): 119 | """ 120 | Process a single image pair: apply random crop, resize, generate denoised infrared, 121 | and calculate PSNR and SSIM. 122 | 123 | Args: 124 | visible_image_name (str): Filename of the visible image. 125 | """ 126 | global total_psnr, total_ssim, num_processed_images 127 | 128 | try: 129 | infrared_image_name = get_infrared_image_name(visible_image_name) 130 | except ValueError as ve: 131 | print(f"Skipping {visible_image_name}: {ve}") 132 | return 133 | 134 | # Construct full image paths 135 | visible_image_path = os.path.join(visible_folder, visible_image_name) 136 | infrared_image_path = os.path.join(infrared_folder, infrared_image_name) 137 | 138 | # Check existence of infrared image 139 | if not os.path.exists(infrared_image_path): 140 | print(f"Infrared image {infrared_image_path} not found, skipping.") 141 | return 142 | 143 | # Load images 144 | try: 145 | visible_image = Image.open(visible_image_path).convert("RGB") 146 | infrared_image = Image.open(infrared_image_path).convert("RGB") 147 | except Exception as e: 148 | print(f"Error loading images {visible_image_name} and/or {infrared_image_name}: {e}, skipping.") 149 | return 150 | 151 | 152 | # Crop both images 153 | visible_cropped = visible_image 154 | infrared_cropped = infrared_image 155 | 156 | # Resize to 128x128 157 | # visible_resized = visible_cropped.resize((128, 128), Image.BICUBIC) 158 | # infrared_resized = infrared_cropped.resize((128, 128), Image.BICUBIC) 159 | 160 | # Apply transformations 161 | visible_tensor = transform(visible_cropped).unsqueeze(0).to(device) 162 | infrared_tensor = transform(infrared_cropped).unsqueeze(0).to(device) 163 | 164 | # Add noise to the infrared image 165 | max_sigma = sigmas[0] # Highest sigma value 166 | noise = torch.randn_like(infrared_tensor) * max_sigma 167 | noisy_infrared_tensor = noise 168 | 169 | # Generate the infrared image starting from the noisy infrared image 170 | try: 171 | with torch.no_grad(): 172 | generated_infrared_tensor = consistency_sampling( 173 | model=model, 174 | y=noisy_infrared_tensor, 175 | v=visible_tensor, 176 | sigmas=sigmas, 177 | start_from_y=True, 178 | add_initial_noise=False, 179 | clip_denoised=True, 180 | verbose=False, # Set verbose=False to reduce output 181 | ) 182 | except Exception as e: 183 | print(f"Error during model inference for {visible_image_name}: {e}, skipping.") 184 | return 185 | 186 | # Denormalize tensors 187 | generated_infrared_denorm = denormalize(generated_infrared_tensor.squeeze(0).cpu()) 188 | 189 | # Convert tensors to PIL images 190 | generated_infrared_pil = TF.to_pil_image(generated_infrared_denorm) 191 | 192 | # Reference infrared image (already resized to 128x128) 193 | reference_infrared_pil = infrared_cropped 194 | 195 | # Convert images to numpy arrays for metric calculation 196 | reference_image_np = np.array(reference_infrared_pil) 197 | generated_image_np = np.array(generated_infrared_pil) 198 | 199 | # Calculate PSNR and SSIM 200 | psnr_value, ssim_value = calculate_metrics(reference_image_np, generated_image_np) 201 | 202 | # Accumulate PSNR and SSIM 203 | total_psnr += psnr_value 204 | total_ssim += ssim_value 205 | num_processed_images += 1 206 | 207 | # Print PSNR and SSIM for the current image 208 | print(f"Image : PSNR = {psnr_value:.2f}, SSIM = {ssim_value:.4f}") 209 | 210 | # Save the generated infrared image for visual inspection 211 | output_filename_infrared = f"generated_infrared_{visible_image_name}" 212 | output_path_infrared = os.path.join(output_folder, output_filename_infrared) 213 | generated_infrared_pil.save(output_path_infrared) 214 | print(f"Saved generated infrared image to {output_path_infrared}\n") 215 | 216 | # ======================= 217 | # Processing All Images 218 | # ======================= 219 | 220 | def main(): 221 | global total_psnr, total_ssim, num_processed_images 222 | total_psnr = 0.0 223 | total_ssim = 0.0 224 | num_processed_images = 0 225 | 226 | # Get a list of all images in the visible folder 227 | visible_images = os.listdir(visible_folder) 228 | image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff') 229 | visible_images = [img for img in visible_images if img.lower().endswith(image_extensions)] 230 | 231 | # Optionally limit the number of images to process 232 | if num_images_to_process is not None: 233 | selected_visible_images = random.sample(visible_images, min(num_images_to_process, len(visible_images))) 234 | else: 235 | selected_visible_images = visible_images 236 | 237 | print(f"Processing {len(selected_visible_images)} images...\n") 238 | 239 | for idx, visible_image_name in enumerate(selected_visible_images, start=1): 240 | print(f"Processing image {idx}/{len(selected_visible_images)}: {visible_image_name}") 241 | process_image(visible_image_name) 242 | 243 | # Calculate and print average PSNR and SSIM 244 | if num_processed_images > 0: 245 | average_psnr = total_psnr / num_processed_images 246 | average_ssim = total_ssim / num_processed_images 247 | print(f"\nProcessed {num_processed_images} images.") 248 | print(f"Average PSNR: {average_psnr:.2f}") 249 | print(f"Average SSIM: {average_ssim:.4f}") 250 | 251 | # Save metrics to a text file 252 | with open(metrics_file, "a") as f: 253 | f.write(f"Processed {num_processed_images} images.\n") 254 | f.write(f"Average PSNR: {average_psnr:.2f}\n") 255 | f.write(f"Average SSIM: {average_ssim:.4f}\n\n") 256 | print(f"Saved metrics to {metrics_file}") 257 | else: 258 | print("No images were processed.") 259 | 260 | if __name__ == "__main__": 261 | main() 262 | -------------------------------------------------------------------------------- /lolv2/script.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import asdict, dataclass 4 | from typing import Any, Callable, List, Optional, Tuple, Union 5 | 6 | import torch 7 | from torchvision.transforms import functional as TF 8 | import math 9 | from einops import rearrange 10 | from einops.layers.torch import Rearrange 11 | from lightning import LightningDataModule, LightningModule, Trainer, seed_everything 12 | from lightning.pytorch.callbacks import LearningRateMonitor 13 | from lightning.pytorch.loggers import TensorBoardLogger 14 | from matplotlib import pyplot as plt 15 | from torch import Tensor, nn 16 | from torch.nn import functional as F 17 | from torch.utils.data import DataLoader 18 | from torchinfo import summary 19 | from torchvision import transforms as T 20 | from torchvision.datasets import ImageFolder 21 | from torchvision.utils import make_grid 22 | 23 | from improved_consistency_model_conditional import ( 24 | ConsistencySamplingAndEditing, 25 | ImprovedConsistencyTraining, 26 | pseudo_huber_loss, 27 | update_ema_model_, 28 | ) 29 | 30 | 31 | from torch.utils.data import Dataset 32 | from torchvision import transforms as T 33 | import os 34 | from PIL import Image 35 | import torch 36 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint 37 | 38 | class PairedDataset(Dataset): 39 | def __init__( 40 | self, 41 | visible_dir: str, 42 | infrared_dir: str, 43 | transform: Optional[Callable] = None, 44 | crop_size: Tuple[int, int] = (128, 128), 45 | resize_size: Tuple[int, int] = (128, 128) 46 | ): 47 | self.visible_dir = visible_dir 48 | self.infrared_dir = infrared_dir 49 | self.visible_images = sorted(os.listdir(visible_dir)) 50 | self.infrared_images = sorted(os.listdir(infrared_dir)) 51 | self.transform = transform 52 | self.crop_size = crop_size 53 | self.resize_size = resize_size 54 | 55 | def __len__(self) -> int: 56 | return len(self.visible_images) 57 | 58 | def __getitem__(self, index: int) -> Optional[Tuple[Tensor, Tensor]]: 59 | visible_path = os.path.join(self.visible_dir, self.visible_images[index]) 60 | infrared_path = os.path.join(self.infrared_dir, self.infrared_images[index]) 61 | 62 | visible_image = Image.open(visible_path).convert("RGB") 63 | infrared_image = Image.open(infrared_path).convert("RGB") 64 | 65 | if visible_image.size != infrared_image.size: 66 | print(f"Skipping image pair at index {index} due to mismatched sizes") 67 | return None 68 | 69 | # Perform synchronized random horizontal flip 70 | if torch.rand(1).item() > 0.5: 71 | visible_image = TF.hflip(visible_image) 72 | infrared_image = TF.hflip(infrared_image) 73 | 74 | # Perform synchronized random crop 75 | i, j, h, w = T.RandomCrop.get_params(visible_image, output_size=self.crop_size) 76 | visible_image = TF.crop(visible_image, i, j, h, w) 77 | infrared_image = TF.crop(infrared_image, i, j, h, w) 78 | 79 | # Resize to desired size 80 | # visible_image = TF.resize(visible_image, self.resize_size) 81 | # infrared_image = TF.resize(infrared_image, self.resize_size) 82 | 83 | if self.transform: 84 | visible_image = self.transform(visible_image) 85 | infrared_image = self.transform(infrared_image) 86 | 87 | return visible_image, infrared_image 88 | 89 | from dataclasses import dataclass 90 | from typing import Tuple 91 | 92 | @dataclass 93 | class ImageDataModuleConfig: 94 | data_dir: str = "datasets/LOL-v2" # Path to the dataset directory 95 | image_size_crop: Tuple[int, int] = (128, 128) # Size for random cropping 96 | image_size_resize: Tuple[int, int] = (128, 128) # Resize to 128x128 97 | batch_size: int = 30 # Number of images in each batch 98 | num_workers: int = 8 # Number of worker threads for data loading 99 | pin_memory: bool = True # Whether to pin memory in data loader 100 | persistent_workers: bool = True # Keep workers alive between epochs 101 | 102 | from torch.utils.data import DataLoader 103 | from lightning.pytorch import LightningDataModule 104 | 105 | class LLVIPDataModule(LightningDataModule): 106 | def __init__(self, config: ImageDataModuleConfig) -> None: 107 | super().__init__() 108 | self.config = config 109 | 110 | def setup(self, stage: str = None) -> None: 111 | # Define transforms excluding cropping and resizing 112 | self.transform = T.Compose([ 113 | T.ToTensor(), 114 | T.Lambda(lambda x: (x * 2) - 1), # Normalize to [-1, 1] 115 | ]) 116 | 117 | self.dataset = PairedDataset( 118 | visible_dir=os.path.join(self.config.data_dir, "Real_captured/Train/Low"), 119 | infrared_dir=os.path.join(self.config.data_dir, "Real_captured/Train/Normal"), 120 | transform=self.transform, 121 | crop_size=self.config.image_size_crop, 122 | resize_size=self.config.image_size_resize 123 | ) 124 | 125 | def train_dataloader(self) -> DataLoader: 126 | return DataLoader( 127 | self.dataset, 128 | batch_size=self.config.batch_size, 129 | shuffle=True, 130 | num_workers=self.config.num_workers, 131 | pin_memory=self.config.pin_memory, 132 | persistent_workers=self.config.persistent_workers, 133 | ) 134 | 135 | def GroupNorm(channels: int) -> nn.GroupNorm: 136 | return nn.GroupNorm(num_groups=min(32, channels // 4), num_channels=channels) 137 | 138 | 139 | class SelfAttention(nn.Module): 140 | def __init__( 141 | self, 142 | in_channels: int, 143 | out_channels: int, 144 | n_heads: int = 8, 145 | dropout: float = 0.3, 146 | ) -> None: 147 | super().__init__() 148 | 149 | self.dropout = dropout 150 | 151 | self.qkv_projection = nn.Sequential( 152 | GroupNorm(in_channels), 153 | nn.Conv2d(in_channels, 3 * in_channels, kernel_size=1, bias=False), 154 | Rearrange("b (i h d) x y -> i b h (x y) d", i=3, h=n_heads), 155 | ) 156 | self.output_projection = nn.Sequential( 157 | Rearrange("b h l d -> b l (h d)"), 158 | nn.Linear(in_channels, out_channels, bias=False), 159 | Rearrange("b l d -> b d l"), 160 | GroupNorm(out_channels), 161 | nn.Dropout1d(dropout), 162 | ) 163 | self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1) 164 | 165 | def forward(self, x: Tensor) -> Tensor: 166 | q, k, v = self.qkv_projection(x).unbind(dim=0) 167 | 168 | output = F.scaled_dot_product_attention( 169 | q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=False 170 | ) 171 | output = self.output_projection(output) 172 | output = rearrange(output, "b c (x y) -> b c x y", x=x.shape[-2], y=x.shape[-1]) 173 | 174 | return output + self.residual_projection(x) 175 | 176 | 177 | class UNetBlock(nn.Module): 178 | def __init__( 179 | self, 180 | in_channels: int, 181 | out_channels: int, 182 | noise_level_channels: int, 183 | dropout: float = 0.3, 184 | ) -> None: 185 | super().__init__() 186 | 187 | self.input_projection = nn.Sequential( 188 | GroupNorm(in_channels), 189 | nn.SiLU(), 190 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same"), 191 | nn.Dropout2d(dropout), 192 | ) 193 | self.noise_level_projection = nn.Sequential( 194 | nn.SiLU(), 195 | nn.Conv2d(noise_level_channels, out_channels, kernel_size=1), 196 | ) 197 | self.output_projection = nn.Sequential( 198 | GroupNorm(out_channels), 199 | nn.SiLU(), 200 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same"), 201 | nn.Dropout2d(dropout), 202 | ) 203 | self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1) 204 | 205 | def forward(self, x: Tensor, noise_level: Tensor) -> Tensor: 206 | h = self.input_projection(x) 207 | h = h + self.noise_level_projection(noise_level) 208 | 209 | return self.output_projection(h) + self.residual_projection(x) 210 | 211 | 212 | class UNetBlockWithSelfAttention(nn.Module): 213 | def __init__( 214 | self, 215 | in_channels: int, 216 | out_channels: int, 217 | noise_level_channels: int, 218 | n_heads: int = 8, 219 | dropout: float = 0.3, 220 | ) -> None: 221 | super().__init__() 222 | 223 | self.unet_block = UNetBlock( 224 | in_channels, out_channels, noise_level_channels, dropout 225 | ) 226 | self.self_attention = SelfAttention( 227 | out_channels, out_channels, n_heads, dropout 228 | ) 229 | 230 | def forward(self, x: Tensor, noise_level: Tensor) -> Tensor: 231 | return self.self_attention(self.unet_block(x, noise_level)) 232 | 233 | 234 | class Downsample(nn.Module): 235 | def __init__(self, channels: int) -> None: 236 | super().__init__() 237 | 238 | self.projection = nn.Sequential( 239 | Rearrange("b c (h ph) (w pw) -> b (c ph pw) h w", ph=2, pw=2), 240 | nn.Conv2d(4 * channels, channels, kernel_size=1), 241 | ) 242 | 243 | def forward(self, x: Tensor) -> Tensor: 244 | return self.projection(x) 245 | 246 | 247 | class Upsample(nn.Module): 248 | def __init__(self, channels: int) -> None: 249 | super().__init__() 250 | 251 | self.projection = nn.Sequential( 252 | nn.Upsample(scale_factor=2.0, mode="nearest"), 253 | nn.Conv2d(channels, channels, kernel_size=3, padding="same"), 254 | ) 255 | 256 | def forward(self, x: Tensor) -> Tensor: 257 | return self.projection(x) 258 | 259 | 260 | class NoiseLevelEmbedding(nn.Module): 261 | def __init__(self, channels: int, scale: float = 0.02) -> None: 262 | super().__init__() 263 | 264 | self.W = nn.Parameter(torch.randn(channels // 2) * scale, requires_grad=False) 265 | 266 | self.projection = nn.Sequential( 267 | nn.Linear(channels, 4 * channels), 268 | nn.SiLU(), 269 | nn.Linear(4 * channels, channels), 270 | Rearrange("b c -> b c () ()"), 271 | ) 272 | 273 | def forward(self, x: Tensor) -> Tensor: 274 | h = x[:, None] * self.W[None, :] * 2 * torch.pi 275 | h = torch.cat([torch.sin(h), torch.cos(h)], dim=-1) 276 | 277 | return self.projection(h) 278 | 279 | 280 | @dataclass 281 | class UNetConfig: 282 | channels: int = 3 283 | noise_level_channels: int = 256 284 | noise_level_scale: float = 0.02 285 | n_heads: int = 8 286 | top_blocks_channels: Tuple[int, ...] = (128, 128) 287 | top_blocks_n_blocks_per_resolution: Tuple[int, ...] = (2, 2) 288 | top_blocks_has_resampling: Tuple[bool, ...] = (True, True) 289 | top_blocks_dropout: Tuple[float, ...] = (0.0, 0.0) 290 | mid_blocks_channels: Tuple[int, ...] = (256, 512) 291 | mid_blocks_n_blocks_per_resolution: Tuple[int, ...] = (4, 4) 292 | mid_blocks_has_resampling: Tuple[bool, ...] = (True, False) 293 | mid_blocks_dropout: Tuple[float, ...] = (0.0, 0.3) 294 | 295 | 296 | class UNet(nn.Module): 297 | def __init__(self, config: UNetConfig) -> None: 298 | super().__init__() 299 | 300 | self.config = config 301 | 302 | self.input_projection = nn.Conv2d( 303 | config.channels * 2, 304 | config.top_blocks_channels[0], 305 | kernel_size=3, 306 | padding="same", 307 | ) 308 | self.noise_level_embedding = NoiseLevelEmbedding( 309 | config.noise_level_channels, config.noise_level_scale 310 | ) 311 | self.top_encoder_blocks = self._make_encoder_blocks( 312 | self.config.top_blocks_channels + self.config.mid_blocks_channels[:1], 313 | self.config.top_blocks_n_blocks_per_resolution, 314 | self.config.top_blocks_has_resampling, 315 | self.config.top_blocks_dropout, 316 | self._make_top_block, 317 | ) 318 | self.mid_encoder_blocks = self._make_encoder_blocks( 319 | self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:], 320 | self.config.mid_blocks_n_blocks_per_resolution, 321 | self.config.mid_blocks_has_resampling, 322 | self.config.mid_blocks_dropout, 323 | self._make_mid_block, 324 | ) 325 | self.mid_decoder_blocks = self._make_decoder_blocks( 326 | self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:], 327 | self.config.mid_blocks_n_blocks_per_resolution, 328 | self.config.mid_blocks_has_resampling, 329 | self.config.mid_blocks_dropout, 330 | self._make_mid_block, 331 | ) 332 | self.top_decoder_blocks = self._make_decoder_blocks( 333 | self.config.top_blocks_channels + self.config.mid_blocks_channels[:1], 334 | self.config.top_blocks_n_blocks_per_resolution, 335 | self.config.top_blocks_has_resampling, 336 | self.config.top_blocks_dropout, 337 | self._make_top_block, 338 | ) 339 | self.output_projection = nn.Conv2d( 340 | config.top_blocks_channels[0], 341 | config.channels, 342 | kernel_size=3, 343 | padding="same", 344 | ) 345 | 346 | def forward(self, x: Tensor, noise_level: Tensor, v: Tensor) -> Tensor: 347 | x = torch.cat([x, v], dim = 1) 348 | h = self.input_projection(x) 349 | noise_level = self.noise_level_embedding(noise_level) 350 | 351 | top_encoder_embeddings = [] 352 | for block in self.top_encoder_blocks: 353 | if isinstance(block, UNetBlock): 354 | h = block(h, noise_level) 355 | top_encoder_embeddings.append(h) 356 | else: 357 | h = block(h) 358 | 359 | mid_encoder_embeddings = [] 360 | for block in self.mid_encoder_blocks: 361 | if isinstance(block, UNetBlockWithSelfAttention): 362 | h = block(h, noise_level) 363 | mid_encoder_embeddings.append(h) 364 | else: 365 | h = block(h) 366 | 367 | for block in self.mid_decoder_blocks: 368 | if isinstance(block, UNetBlockWithSelfAttention): 369 | h = torch.cat((h, mid_encoder_embeddings.pop()), dim=1) 370 | h = block(h, noise_level) 371 | else: 372 | h = block(h) 373 | 374 | for block in self.top_decoder_blocks: 375 | if isinstance(block, UNetBlock): 376 | h = torch.cat((h, top_encoder_embeddings.pop()), dim=1) 377 | h = block(h, noise_level) 378 | else: 379 | h = block(h) 380 | 381 | output = self.output_projection(h) 382 | 383 | return output 384 | 385 | def _make_encoder_blocks( 386 | self, 387 | channels: Tuple[int, ...], 388 | n_blocks_per_resolution: Tuple[int, ...], 389 | has_resampling: Tuple[bool, ...], 390 | dropout: Tuple[float, ...], 391 | block_fn: Callable[[], nn.Module], 392 | ) -> nn.ModuleList: 393 | blocks = nn.ModuleList() 394 | 395 | channel_pairs = list(zip(channels[:-1], channels[1:])) 396 | for idx, (in_channels, out_channels) in enumerate(channel_pairs): 397 | for _ in range(n_blocks_per_resolution[idx]): 398 | blocks.append(block_fn(in_channels, out_channels, dropout[idx])) 399 | in_channels = out_channels 400 | 401 | if has_resampling[idx]: 402 | blocks.append(Downsample(out_channels)) 403 | 404 | return blocks 405 | 406 | def _make_decoder_blocks( 407 | self, 408 | channels: Tuple[int, ...], 409 | n_blocks_per_resolution: Tuple[int, ...], 410 | has_resampling: Tuple[bool, ...], 411 | dropout: Tuple[float, ...], 412 | block_fn: Callable[[], nn.Module], 413 | ) -> nn.ModuleList: 414 | blocks = nn.ModuleList() 415 | 416 | channel_pairs = list(zip(channels[:-1], channels[1:]))[::-1] 417 | for idx, (out_channels, in_channels) in enumerate(channel_pairs): 418 | if has_resampling[::-1][idx]: 419 | blocks.append(Upsample(in_channels)) 420 | 421 | inner_blocks = [] 422 | for _ in range(n_blocks_per_resolution[::-1][idx]): 423 | inner_blocks.append( 424 | block_fn(in_channels * 2, out_channels, dropout[::-1][idx]) 425 | ) 426 | out_channels = in_channels 427 | blocks.extend(inner_blocks[::-1]) 428 | 429 | return blocks 430 | 431 | def _make_top_block( 432 | self, in_channels: int, out_channels: int, dropout: float 433 | ) -> UNetBlock: 434 | return UNetBlock( 435 | in_channels, 436 | out_channels, 437 | self.config.noise_level_channels, 438 | dropout, 439 | ) 440 | 441 | def _make_mid_block( 442 | self, 443 | in_channels: int, 444 | out_channels: int, 445 | dropout: float, 446 | ) -> UNetBlockWithSelfAttention: 447 | return UNetBlockWithSelfAttention( 448 | in_channels, 449 | out_channels, 450 | self.config.noise_level_channels, 451 | self.config.n_heads, 452 | dropout, 453 | ) 454 | 455 | def save_pretrained(self, pretrained_path: str) -> None: 456 | os.makedirs(pretrained_path, exist_ok=True) 457 | 458 | with open(os.path.join(pretrained_path, "config.json"), mode="w") as f: 459 | json.dump(asdict(self.config), f) 460 | 461 | torch.save(self.state_dict(), os.path.join(pretrained_path, "model.pt")) 462 | 463 | @classmethod 464 | def from_pretrained(cls, pretrained_path: str) -> "UNet": 465 | with open(os.path.join(pretrained_path, "config.json"), mode="r") as f: 466 | config_dict = json.load(f) 467 | config = UNetConfig(**config_dict) 468 | 469 | model = cls(config) 470 | 471 | state_dict = torch.load( 472 | os.path.join(pretrained_path, "model.pt"), map_location=torch.device("cpu") 473 | ) 474 | model.load_state_dict(state_dict) 475 | 476 | return model 477 | 478 | @dataclass 479 | class LitImprovedConsistencyModelConfig: 480 | ema_decay_rate: float = 0.99993 481 | lr: float = 1e-4 482 | betas: Tuple[float, float] = (0.9, 0.995) 483 | lr_scheduler_start_factor: float = 1e-5 484 | lr_scheduler_iters: int = 10_000 485 | sample_every_n_steps: int = 10_000 486 | num_samples: int = 8 487 | sampling_sigmas: Tuple[Tuple[int, ...], ...] = ( 488 | (80,), 489 | (80.0, 0.661), 490 | (80.0, 24.4, 5.84, 0.9, 0.661), 491 | ) 492 | 493 | 494 | class LitImprovedConsistencyModel(LightningModule): 495 | def __init__( 496 | self, 497 | consistency_training: ImprovedConsistencyTraining, 498 | consistency_sampling: ConsistencySamplingAndEditing, 499 | model: UNet, 500 | ema_model: UNet, 501 | config: LitImprovedConsistencyModelConfig, 502 | ) -> None: 503 | super().__init__() 504 | 505 | self.consistency_training = consistency_training 506 | self.consistency_sampling = consistency_sampling 507 | self.model = model 508 | self.ema_model = ema_model 509 | self.config = config 510 | 511 | # Freeze the EMA model and set it to eval mode 512 | for param in self.ema_model.parameters(): 513 | param.requires_grad = False 514 | self.ema_model = self.ema_model.eval() 515 | 516 | def training_step(self, batch: Union[Tensor, List[Tensor]], batch_idx: int) -> None: 517 | # if isinstance(batch, list): 518 | # batch = batch[0] 519 | 520 | visible_images, infrared_images = batch # Unpack the batch 521 | 522 | output = self.consistency_training( 523 | self.model, 524 | infrared_images, 525 | visible_images, # Pass visible images to the training function 526 | self.global_step, 527 | self.trainer.max_steps 528 | ) 529 | 530 | loss = ( 531 | pseudo_huber_loss(output.predicted, output.target) * output.loss_weights 532 | ).mean() 533 | 534 | self.log_dict({"train_loss": loss, "num_timesteps": output.num_timesteps}) 535 | 536 | return loss 537 | 538 | def on_train_batch_end( 539 | self, outputs: Any, batch: Union[Tensor, List[Tensor]], batch_idx: int 540 | ) -> None: 541 | update_ema_model_(self.model, self.ema_model, self.config.ema_decay_rate) 542 | 543 | if ( 544 | (self.global_step + 1) % self.config.sample_every_n_steps == 0 545 | ) or self.global_step == 0: 546 | self.__sample_and_log_samples(batch) 547 | 548 | def configure_optimizers(self): 549 | opt = torch.optim.Adam( 550 | self.model.parameters(), lr=self.config.lr, betas=self.config.betas 551 | ) 552 | sched = torch.optim.lr_scheduler.LinearLR( 553 | opt, 554 | start_factor=self.config.lr_scheduler_start_factor, 555 | total_iters=self.config.lr_scheduler_iters, 556 | ) 557 | sched = {"scheduler": sched, "interval": "step", "frequency": 1} 558 | 559 | return [opt], [sched] 560 | 561 | @torch.no_grad() 562 | def __sample_and_log_samples(self, batch: Union[Tensor, List[Tensor]]) -> None: 563 | if isinstance(batch, list): 564 | batch = batch[0] 565 | 566 | # Ensure the number of samples does not exceed the batch size 567 | num_samples = min(self.config.num_samples, batch.shape[0]) 568 | noise = torch.randn_like(batch[:num_samples]) 569 | 570 | # Log ground truth samples 571 | self.__log_images( 572 | batch[:num_samples].detach().clone(), f"ground_truth", self.global_step 573 | ) 574 | 575 | for sigmas in self.config.sampling_sigmas: 576 | samples = self.consistency_sampling( 577 | self.ema_model, noise, sigmas, clip_denoised=True, verbose=True 578 | ) 579 | samples = samples.clamp(min=-1.0, max=1.0) 580 | 581 | # Generated samples 582 | self.__log_images( 583 | samples, 584 | f"generated_samples-sigmas={sigmas}", 585 | self.global_step, 586 | ) 587 | 588 | @torch.no_grad() 589 | def __log_images(self, images: Tensor, title: str, global_step: int) -> None: 590 | images = images.detach().float() 591 | 592 | grid = make_grid( 593 | images.clamp(-1.0, 1.0), value_range=(-1.0, 1.0), normalize=True 594 | ) 595 | self.logger.experiment.add_image(title, grid, global_step) 596 | 597 | 598 | @dataclass 599 | class TrainingConfig: 600 | image_dm_config: ImageDataModuleConfig 601 | unet_config: UNetConfig 602 | consistency_training: ImprovedConsistencyTraining 603 | consistency_sampling: ConsistencySamplingAndEditing 604 | lit_icm_config: LitImprovedConsistencyModelConfig 605 | trainer: Trainer 606 | model_ckpt_path: str = "checkpoints/lolv2_real" 607 | seed: int = 42 608 | 609 | def run_training(config: TrainingConfig) -> None: 610 | # Set seed 611 | seed_everything(config.seed) 612 | 613 | # Create data module 614 | dm = LLVIPDataModule(config.image_dm_config) 615 | dm.setup() 616 | print("DataModule setup complete.") 617 | 618 | # Create model and its EMA 619 | model = UNet(config.unet_config) 620 | ema_model = UNet(config.unet_config) 621 | ema_model.load_state_dict(model.state_dict()) 622 | 623 | # Create lightning module 624 | lit_icm = LitImprovedConsistencyModel( 625 | config.consistency_training, 626 | config.consistency_sampling, 627 | model, 628 | ema_model, 629 | config.lit_icm_config, 630 | ) 631 | 632 | print("Lightning module created.") 633 | 634 | # Run training 635 | print("Starting training...") 636 | config.trainer.fit(lit_icm, datamodule=dm) 637 | print("Training completed.") 638 | 639 | # Save model 640 | lit_icm.model.save_pretrained(config.model_ckpt_path) 641 | print("Model saved.") 642 | 643 | # Main function 644 | def main(): 645 | # Define the checkpoint callback 646 | checkpoint_callback = ModelCheckpoint( 647 | dirpath="checkpoints_lolv2_real", 648 | filename="{epoch}-{step}", 649 | save_top_k=-1, # Save all checkpoints 650 | every_n_epochs=100, # Adjust as needed 651 | ) 652 | 653 | # Set up the logger 654 | logger = TensorBoardLogger("logs", name="icm") 655 | 656 | training_config = TrainingConfig( 657 | image_dm_config=ImageDataModuleConfig(data_dir="../datasets/LOL-v2/"), 658 | unet_config=UNetConfig(), 659 | consistency_training=ImprovedConsistencyTraining(final_timesteps=11), 660 | consistency_sampling=ConsistencySamplingAndEditing(), 661 | lit_icm_config=LitImprovedConsistencyModelConfig( 662 | sample_every_n_steps=2100000, lr_scheduler_iters=1000 663 | ), 664 | trainer=Trainer( 665 | max_steps=100, 666 | precision="16", 667 | log_every_n_steps=10, 668 | logger=logger, 669 | callbacks=[ 670 | LearningRateMonitor(logging_interval="step"), 671 | checkpoint_callback, # Add the checkpoint callback here 672 | ], 673 | ), 674 | ) 675 | run_training(training_config) 676 | 677 | if __name__ == "__main__": 678 | main() -------------------------------------------------------------------------------- /lolv2/switch_between_real_and_synthetic.txt: -------------------------------------------------------------------------------- 1 | The current script.py and metrics.py are functional on the LOLv2 real dataset. 2 | To train and sample the model for LOLv2 synthetic the only change that is needed is the path to the datasets and where the model is being saved. 3 | Change the dataset paths 4 | 5 | to: 6 | datasets/LOL-v2/Synthetic 7 | 8 | from: 9 | datasets/LOL-v2/Real_captured 10 | 11 | everywhere necessary. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | Pillow 5 | matplotlib 6 | einops 7 | lightning 8 | tqdm 9 | scikit-image -------------------------------------------------------------------------------- /sid/sampling_and_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import torchvision.transforms as T 4 | import torchvision.transforms.functional as TF 5 | from improved_consistency_model_conditional import ConsistencySamplingAndEditing 6 | from sid.script import UNet # Replace 'script_name' with the actual script where UNet is defined 7 | import os 8 | import random 9 | import rawpy 10 | import numpy as np 11 | from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim 12 | 13 | # Function to load raw images 14 | def load_raw_image(path: str) -> torch.Tensor: 15 | with rawpy.imread(path) as raw: 16 | # Postprocess to get RGB image 17 | rgb_image = raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16) 18 | # Convert to float and normalize 19 | rgb_image = np.float32(rgb_image) / (2**16 - 1) 20 | return torch.from_numpy(rgb_image).permute(2, 0, 1) # Shape: (C, H, W) 21 | 22 | # Function to denormalize tensors 23 | def denormalize(tensor): 24 | return (tensor + 1) / 2 # Convert from [-1,1] to [0,1] 25 | 26 | # Define the transform 27 | def transform(image): 28 | return (image * 2) - 1 # Normalize to [-1, 1] 29 | 30 | # Paths 31 | data_dir = "../datasets/sid" 32 | txt_file = os.path.join(data_dir, "Sony_test_list.txt") # Use 'Sony_test_list.txt' if available 33 | output_folder = "results_sid" 34 | 35 | # Create output folder if it doesn't exist 36 | if not os.path.exists(output_folder): 37 | os.makedirs(output_folder) 38 | 39 | # Load the list of image pairs 40 | image_pairs = [] 41 | with open(txt_file, 'r') as f: 42 | lines = f.readlines() 43 | for line in lines: 44 | short_path, long_path, iso, f_number = line.strip().split() 45 | # Extract exposure times from file names 46 | short_exposure = float(os.path.basename(short_path).split('_')[-1].replace('s.ARW', '')) 47 | long_exposure = float(os.path.basename(long_path).split('_')[-1].replace('s.ARW', '')) 48 | ratio = long_exposure / short_exposure 49 | image_pairs.append((short_path, long_path, ratio)) 50 | 51 | # Number of images to process 52 | # num_images = 10 # Remove or comment out to process all images 53 | 54 | # Uncomment the following line to process all images 55 | num_images = len(image_pairs) 56 | 57 | # Optionally, limit the number of images (for testing purposes) 58 | # num_images = min(num_images, len(image_pairs)) 59 | 60 | # Randomly select images 61 | random.seed(40) # Optional: Set seed for reproducibility 62 | selected_images = random.sample(image_pairs, num_images) 63 | 64 | # Load the model 65 | model_path = "checkpoints/sid" # Replace with your actual model checkpoint path 66 | model = UNet.from_pretrained(model_path) 67 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 68 | model = model.to(device).eval() 69 | 70 | # Create the sampling instance 71 | consistency_sampling = ConsistencySamplingAndEditing() 72 | 73 | # Define the sigma schedule 74 | sigmas = [80.0, 40.0, 20.0, 10.0, 5.0, 2.5, 1.25, 0.625, 0.3125, 0.15625, 0.078125] 75 | 76 | # Initialize variables to accumulate PSNR and SSIM 77 | total_psnr = 0.0 78 | total_ssim = 0.0 79 | processed_count = 0 # To handle any skipped images 80 | 81 | for idx, (short_path, long_path, ratio) in enumerate(selected_images, start=1): 82 | short_image_path = os.path.join(data_dir, short_path) 83 | long_image_path = os.path.join(data_dir, long_path) 84 | 85 | # Load images 86 | try: 87 | short_image = load_raw_image(short_image_path) 88 | long_image = load_raw_image(long_image_path) 89 | except Exception as e: 90 | print(f"Error loading images: {e}, skipping {short_path}") 91 | continue 92 | 93 | # Multiply short_image by ratio and clip 94 | short_image = torch.clamp(short_image * ratio, 0.0, 1.0) 95 | 96 | # Apply center crop and resize 97 | # crop_size = (512, 512) 98 | resize_size = (512, 512) 99 | 100 | # # Synchronized center crop 101 | # center_crop = T.CenterCrop(crop_size) 102 | # short_image = center_crop(short_image) 103 | # long_image = center_crop(long_image) 104 | 105 | # Resize to desired size 106 | short_image = TF.resize(short_image, resize_size) 107 | long_image = TF.resize(long_image, resize_size) 108 | 109 | # Normalize to [-1,1] 110 | short_image = transform(short_image) 111 | long_image = transform(long_image) 112 | 113 | # Move tensors to device 114 | short_tensor = short_image.unsqueeze(0).to(device) 115 | long_tensor = long_image.unsqueeze(0).to(device) 116 | 117 | # Add noise to the long image 118 | max_sigma = sigmas[0] # Highest sigma value 119 | noise = torch.randn_like(long_tensor) * max_sigma 120 | noisy_long_tensor = noise # Start from pure noise 121 | 122 | # Generate the long image starting from the noisy long image 123 | try: 124 | with torch.no_grad(): 125 | generated_long_tensor = consistency_sampling( 126 | model=model, 127 | y=noisy_long_tensor, 128 | v=short_tensor, 129 | sigmas=sigmas, 130 | start_from_y=True, 131 | add_initial_noise=False, 132 | clip_denoised=True, 133 | verbose=False, 134 | ) 135 | except Exception as e: 136 | print(f"Error during model inference: {e}, skipping {short_path}") 137 | continue 138 | 139 | # Denormalize tensors 140 | short_denorm = denormalize(short_tensor.squeeze(0).cpu()) 141 | long_denorm = denormalize(long_tensor.squeeze(0).cpu()) 142 | generated_long_denorm = denormalize(generated_long_tensor.squeeze(0).cpu()) 143 | 144 | # Convert tensors to PIL images 145 | short_image_pil = TF.to_pil_image(short_denorm) 146 | long_image_pil = TF.to_pil_image(long_denorm) 147 | generated_long_image_pil = TF.to_pil_image(generated_long_denorm) 148 | 149 | # Combine images side by side (optional, can be skipped if not needed) 150 | combined_width = short_image_pil.width * 3 151 | combined_height = short_image_pil.height 152 | combined_image = Image.new('RGB', (combined_width, combined_height)) 153 | combined_image.paste(short_image_pil, (0, 0)) 154 | combined_image.paste(long_image_pil, (short_image_pil.width, 0)) 155 | combined_image.paste(generated_long_image_pil, (short_image_pil.width * 2, 0)) 156 | 157 | # Save the combined image 158 | output_filename = f'comparison_sid_{idx}.png' 159 | output_path = os.path.join(output_folder, output_filename) 160 | combined_image.save(output_path) 161 | 162 | # Convert denormalized tensors to NumPy arrays for metric computation 163 | # Ensure the arrays are in the range [0, 1] 164 | long_np = long_denorm.permute(1, 2, 0).numpy() # Shape: (H, W, C) 165 | generated_long_np = generated_long_denorm.permute(1, 2, 0).numpy() 166 | 167 | # Compute PSNR 168 | current_psnr = psnr(long_np, generated_long_np, data_range=1.0) 169 | 170 | # Compute SSIM 171 | # Convert RGB to grayscale for SSIM or compute multi-channel SSIM 172 | # Here, we'll compute multi-channel SSIM 173 | current_ssim = ssim(long_np, generated_long_np, data_range=1.0, multichannel=True, win_size=3) 174 | 175 | # Accumulate the metrics 176 | total_psnr += current_psnr 177 | total_ssim += current_ssim 178 | processed_count += 1 179 | 180 | # Optional: Print metrics for each image 181 | print(f"Image {idx}: PSNR = {current_psnr:.2f} dB, SSIM = {current_ssim:.4f}") 182 | 183 | # Calculate average PSNR and SSIM 184 | if processed_count > 0: 185 | average_psnr = total_psnr / processed_count 186 | average_ssim = total_ssim / processed_count 187 | print(f"\nProcessed {processed_count} images.") 188 | print(f"Average PSNR: {average_psnr:.2f} dB") 189 | print(f"Average SSIM: {average_ssim:.4f}") 190 | else: 191 | print("No images were processed successfully.") 192 | 193 | # Optional: Save the average metrics to a text file 194 | metrics_output_path = os.path.join(output_folder, "metrics.txt") 195 | with open(metrics_output_path, 'w') as f: 196 | f.write(f"Processed {processed_count} images.\n") 197 | f.write(f"Average PSNR: {average_psnr:.2f} dB\n") 198 | f.write(f"Average SSIM: {average_ssim:.4f}\n") 199 | print(f"Metrics saved to {metrics_output_path}") 200 | -------------------------------------------------------------------------------- /sid/script.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import asdict, dataclass 4 | from typing import Any, Callable, List, Optional, Tuple, Union 5 | 6 | import torch 7 | from torchvision.transforms import functional as TF 8 | import math 9 | from einops import rearrange 10 | from einops.layers.torch import Rearrange 11 | from lightning import LightningDataModule, LightningModule, Trainer, seed_everything 12 | from lightning.pytorch.callbacks import LearningRateMonitor 13 | from lightning.pytorch.loggers import TensorBoardLogger 14 | from matplotlib import pyplot as plt 15 | from torch import Tensor, nn 16 | from torch.nn import functional as F 17 | from torch.utils.data import DataLoader 18 | from torchinfo import summary 19 | from torchvision import transforms as T 20 | from torchvision.datasets import ImageFolder 21 | from torchvision.utils import make_grid 22 | import torch 23 | import torch.nn.functional as F 24 | from torch.utils.data import Dataset, DataLoader 25 | from torchmetrics.functional import structural_similarity_index_measure as ssim_metric 26 | from torchmetrics.functional import peak_signal_noise_ratio as psnr_metric 27 | 28 | 29 | import rawpy 30 | import numpy as np 31 | 32 | from improved_consistency_model_conditional import ( 33 | ConsistencySamplingAndEditing, 34 | ImprovedConsistencyTraining, 35 | pseudo_huber_loss, 36 | update_ema_model_, 37 | ) 38 | 39 | 40 | 41 | from torch.utils.data import Dataset 42 | from torchvision import transforms as T 43 | import os 44 | from PIL import Image 45 | import torch 46 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint 47 | 48 | def pack_raw(raw): 49 | im = raw.raw_image_visible.astype(np.float32) 50 | im = np.maximum(im - 512, 0) / (16383 - 512) # Normalize to [0,1] 51 | im = np.expand_dims(im, axis=2) 52 | H, W = im.shape[0], im.shape[1] 53 | out = np.concatenate((im[0:H:2,0:W:2,:], 54 | im[0:H:2,1:W:2,:], 55 | im[1:H:2,1:W:2,:], 56 | im[1:H:2,0:W:2,:]), axis=2) 57 | return out 58 | 59 | class SIDDataset(Dataset): 60 | def __init__( 61 | self, 62 | txt_file: str, 63 | root_dir: str, 64 | crop_size: Tuple[int, int] = (128, 128), 65 | resize_size: Tuple[int, int] = (128, 128) 66 | ): 67 | self.txt_file = txt_file 68 | self.root_dir = root_dir 69 | self.crop_size = crop_size 70 | self.resize_size = resize_size 71 | self.image_pairs: List[Tuple[str, str]] = [] 72 | 73 | with open(self.txt_file, 'r') as f: 74 | lines = f.readlines() 75 | for line in lines: 76 | short_path, long_path, iso, f_number = line.strip().split() 77 | self.image_pairs.append((short_path, long_path)) 78 | 79 | def __len__(self) -> int: 80 | return len(self.image_pairs) 81 | 82 | def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]: 83 | short_path, long_path = self.image_pairs[idx] 84 | short_full_path = os.path.join(self.root_dir, short_path) 85 | long_full_path = os.path.join(self.root_dir, long_path) 86 | 87 | # Extract exposure times from filenames 88 | short_exposure = float(os.path.basename(short_path).split('_')[-1].replace('s.ARW', '')) 89 | long_exposure = float(os.path.basename(long_path).split('_')[-1].replace('s.ARW', '')) 90 | 91 | # Compute exposure ratio, capped at 300 92 | ratio = min(long_exposure / short_exposure, 300) 93 | 94 | # Load and postprocess the short RAW image 95 | with rawpy.imread(short_full_path) as raw: 96 | short_rgb = raw.postprocess( 97 | use_camera_wb=True, 98 | half_size=False, 99 | no_auto_bright=True, 100 | output_bps=16 101 | ) 102 | short_rgb = np.float32(short_rgb) / 65535.0 # Normalize to [0,1] 103 | short_rgb = short_rgb * ratio # Apply exposure ratio 104 | short_rgb = np.clip(short_rgb, 0.0, 1.0) # Clip to [0,1] 105 | 106 | # Load and postprocess the long RAW (target) image 107 | with rawpy.imread(long_full_path) as raw: 108 | long_rgb = raw.postprocess( 109 | use_camera_wb=True, 110 | half_size=False, 111 | no_auto_bright=True, 112 | output_bps=16 113 | ) 114 | long_rgb = np.float32(long_rgb) / 65535.0 # Normalize to [0,1] 115 | 116 | # Convert numpy arrays to torch tensors and rearrange dimensions to (C, H, W) 117 | input_image = torch.from_numpy(short_rgb).permute(2, 0, 1) # Shape: (3, H, W) 118 | target_image = torch.from_numpy(long_rgb).permute(2, 0, 1) # Shape: (3, H, W) 119 | 120 | # Random crop 121 | _, H, W = input_image.shape 122 | crop_h, crop_w = self.crop_size 123 | if H > crop_h and W > crop_w: 124 | i = np.random.randint(0, H - crop_h + 1) 125 | j = np.random.randint(0, W - crop_w + 1) 126 | input_image = input_image[:, i:i + crop_h, j:j + crop_w] 127 | target_image = target_image[:, i:i + crop_h, j:j + crop_w] 128 | 129 | # Resize to the desired size 130 | # input_image = TF.resize(input_image, self.resize_size) 131 | # target_image = TF.resize(target_image, self.resize_size) 132 | 133 | # Normalize images to [-1, 1] 134 | input_image = (input_image * 2) - 1 135 | target_image = (target_image * 2) - 1 136 | 137 | return input_image, target_image 138 | 139 | # Data module configuration remains the same 140 | @dataclass 141 | class ImageDataModuleConfig: 142 | data_dir: str = "datasets/sid" # Path to the dataset directory 143 | image_size_crop: Tuple[int, int] = (128, 128) # Size for random cropping 144 | image_size_resize: Tuple[int, int] = (128, 128) # Resize to 128x128 145 | batch_size: int = 34 # Number of images in each batch 146 | num_workers: int = 28 # Number of worker threads for data loading 147 | pin_memory: bool = True # Whether to pin memory in data loader 148 | persistent_workers: bool = True # Keep workers alive between epochs 149 | 150 | # New SIDDataModule class 151 | class SIDDataModule(LightningDataModule): 152 | def __init__(self, config: ImageDataModuleConfig) -> None: 153 | super().__init__() 154 | self.config = config 155 | 156 | def setup(self, stage: str = None) -> None: 157 | # Define transforms 158 | self.transform = T.Lambda(lambda x: (x * 2) - 1) # Normalize to [-1, 1] 159 | 160 | self.dataset = SIDDataset( 161 | txt_file=os.path.join(self.config.data_dir, "Sony_train_list.txt"), 162 | root_dir=self.config.data_dir, 163 | crop_size=self.config.image_size_crop, 164 | resize_size=self.config.image_size_resize 165 | ) 166 | # self.val_dataset = SIDValDataset( 167 | # txt_file=os.path.join(self.config.data_dir, "Sony_val_list.txt"), 168 | # root_dir=self.config.data_dir, 169 | # resize_size=self.config.image_size_resize 170 | # ) 171 | 172 | def train_dataloader(self) -> DataLoader: 173 | return DataLoader( 174 | self.dataset, 175 | batch_size=self.config.batch_size, 176 | shuffle=True, 177 | num_workers=self.config.num_workers, 178 | pin_memory=self.config.pin_memory, 179 | persistent_workers=self.config.persistent_workers, 180 | ) 181 | 182 | def GroupNorm(channels: int) -> nn.GroupNorm: 183 | return nn.GroupNorm(num_groups=min(32, channels // 4), num_channels=channels) 184 | 185 | 186 | class SelfAttention(nn.Module): 187 | def __init__( 188 | self, 189 | in_channels: int, 190 | out_channels: int, 191 | n_heads: int = 8, 192 | dropout: float = 0.3, 193 | ) -> None: 194 | super().__init__() 195 | 196 | self.dropout = dropout 197 | 198 | self.qkv_projection = nn.Sequential( 199 | GroupNorm(in_channels), 200 | nn.Conv2d(in_channels, 3 * in_channels, kernel_size=1, bias=False), 201 | Rearrange("b (i h d) x y -> i b h (x y) d", i=3, h=n_heads), 202 | ) 203 | self.output_projection = nn.Sequential( 204 | Rearrange("b h l d -> b l (h d)"), 205 | nn.Linear(in_channels, out_channels, bias=False), 206 | Rearrange("b l d -> b d l"), 207 | GroupNorm(out_channels), 208 | nn.Dropout1d(dropout), 209 | ) 210 | self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1) 211 | 212 | def forward(self, x: Tensor) -> Tensor: 213 | q, k, v = self.qkv_projection(x).unbind(dim=0) 214 | 215 | output = F.scaled_dot_product_attention( 216 | q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=False 217 | ) 218 | output = self.output_projection(output) 219 | output = rearrange(output, "b c (x y) -> b c x y", x=x.shape[-2], y=x.shape[-1]) 220 | 221 | return output + self.residual_projection(x) 222 | 223 | 224 | class UNetBlock(nn.Module): 225 | def __init__( 226 | self, 227 | in_channels: int, 228 | out_channels: int, 229 | noise_level_channels: int, 230 | dropout: float = 0.3, 231 | ) -> None: 232 | super().__init__() 233 | 234 | self.input_projection = nn.Sequential( 235 | GroupNorm(in_channels), 236 | nn.SiLU(), 237 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same"), 238 | nn.Dropout2d(dropout), 239 | ) 240 | self.noise_level_projection = nn.Sequential( 241 | nn.SiLU(), 242 | nn.Conv2d(noise_level_channels, out_channels, kernel_size=1), 243 | ) 244 | self.output_projection = nn.Sequential( 245 | GroupNorm(out_channels), 246 | nn.SiLU(), 247 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same"), 248 | nn.Dropout2d(dropout), 249 | ) 250 | self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1) 251 | 252 | def forward(self, x: Tensor, noise_level: Tensor) -> Tensor: 253 | h = self.input_projection(x) 254 | h = h + self.noise_level_projection(noise_level) 255 | 256 | return self.output_projection(h) + self.residual_projection(x) 257 | 258 | 259 | class UNetBlockWithSelfAttention(nn.Module): 260 | def __init__( 261 | self, 262 | in_channels: int, 263 | out_channels: int, 264 | noise_level_channels: int, 265 | n_heads: int = 8, 266 | dropout: float = 0.3, 267 | ) -> None: 268 | super().__init__() 269 | 270 | self.unet_block = UNetBlock( 271 | in_channels, out_channels, noise_level_channels, dropout 272 | ) 273 | self.self_attention = SelfAttention( 274 | out_channels, out_channels, n_heads, dropout 275 | ) 276 | 277 | def forward(self, x: Tensor, noise_level: Tensor) -> Tensor: 278 | return self.self_attention(self.unet_block(x, noise_level)) 279 | 280 | 281 | class Downsample(nn.Module): 282 | def __init__(self, channels: int) -> None: 283 | super().__init__() 284 | 285 | self.projection = nn.Sequential( 286 | Rearrange("b c (h ph) (w pw) -> b (c ph pw) h w", ph=2, pw=2), 287 | nn.Conv2d(4 * channels, channels, kernel_size=1), 288 | ) 289 | 290 | def forward(self, x: Tensor) -> Tensor: 291 | return self.projection(x) 292 | 293 | 294 | class Upsample(nn.Module): 295 | def __init__(self, channels: int) -> None: 296 | super().__init__() 297 | 298 | self.projection = nn.Sequential( 299 | nn.Upsample(scale_factor=2.0, mode="nearest"), 300 | nn.Conv2d(channels, channels, kernel_size=3, padding="same"), 301 | ) 302 | 303 | def forward(self, x: Tensor) -> Tensor: 304 | return self.projection(x) 305 | 306 | 307 | class NoiseLevelEmbedding(nn.Module): 308 | def __init__(self, channels: int, scale: float = 0.02) -> None: 309 | super().__init__() 310 | 311 | self.W = nn.Parameter(torch.randn(channels // 2) * scale, requires_grad=False) 312 | 313 | self.projection = nn.Sequential( 314 | nn.Linear(channels, 4 * channels), 315 | nn.SiLU(), 316 | nn.Linear(4 * channels, channels), 317 | Rearrange("b c -> b c () ()"), 318 | ) 319 | 320 | def forward(self, x: Tensor) -> Tensor: 321 | h = x[:, None] * self.W[None, :] * 2 * torch.pi 322 | h = torch.cat([torch.sin(h), torch.cos(h)], dim=-1) 323 | 324 | return self.projection(h) 325 | 326 | 327 | @dataclass 328 | class UNetConfig: 329 | channels: int = 3 330 | noise_level_channels: int = 256 331 | noise_level_scale: float = 0.02 332 | n_heads: int = 8 333 | top_blocks_channels: Tuple[int, ...] = (128, 128) 334 | top_blocks_n_blocks_per_resolution: Tuple[int, ...] = (2, 2) 335 | top_blocks_has_resampling: Tuple[bool, ...] = (True, True) 336 | top_blocks_dropout: Tuple[float, ...] = (0.0, 0.0) 337 | mid_blocks_channels: Tuple[int, ...] = (256, 512) 338 | mid_blocks_n_blocks_per_resolution: Tuple[int, ...] = (4, 4) 339 | mid_blocks_has_resampling: Tuple[bool, ...] = (True, False) 340 | mid_blocks_dropout: Tuple[float, ...] = (0.0, 0.3) 341 | 342 | 343 | class UNet(nn.Module): 344 | def __init__(self, config: UNetConfig) -> None: 345 | super().__init__() 346 | 347 | self.config = config 348 | 349 | self.input_projection = nn.Conv2d( 350 | config.channels * 2, 351 | config.top_blocks_channels[0], 352 | kernel_size=3, 353 | padding="same", 354 | ) 355 | self.noise_level_embedding = NoiseLevelEmbedding( 356 | config.noise_level_channels, config.noise_level_scale 357 | ) 358 | self.top_encoder_blocks = self._make_encoder_blocks( 359 | self.config.top_blocks_channels + self.config.mid_blocks_channels[:1], 360 | self.config.top_blocks_n_blocks_per_resolution, 361 | self.config.top_blocks_has_resampling, 362 | self.config.top_blocks_dropout, 363 | self._make_top_block, 364 | ) 365 | self.mid_encoder_blocks = self._make_encoder_blocks( 366 | self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:], 367 | self.config.mid_blocks_n_blocks_per_resolution, 368 | self.config.mid_blocks_has_resampling, 369 | self.config.mid_blocks_dropout, 370 | self._make_mid_block, 371 | ) 372 | self.mid_decoder_blocks = self._make_decoder_blocks( 373 | self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:], 374 | self.config.mid_blocks_n_blocks_per_resolution, 375 | self.config.mid_blocks_has_resampling, 376 | self.config.mid_blocks_dropout, 377 | self._make_mid_block, 378 | ) 379 | self.top_decoder_blocks = self._make_decoder_blocks( 380 | self.config.top_blocks_channels + self.config.mid_blocks_channels[:1], 381 | self.config.top_blocks_n_blocks_per_resolution, 382 | self.config.top_blocks_has_resampling, 383 | self.config.top_blocks_dropout, 384 | self._make_top_block, 385 | ) 386 | self.output_projection = nn.Conv2d( 387 | config.top_blocks_channels[0], 388 | config.channels, 389 | kernel_size=3, 390 | padding="same", 391 | ) 392 | 393 | def forward(self, x: Tensor, noise_level: Tensor, v: Tensor) -> Tensor: 394 | x = torch.cat([x, v], dim = 1) 395 | h = self.input_projection(x) 396 | noise_level = self.noise_level_embedding(noise_level) 397 | 398 | top_encoder_embeddings = [] 399 | for block in self.top_encoder_blocks: 400 | if isinstance(block, UNetBlock): 401 | h = block(h, noise_level) 402 | top_encoder_embeddings.append(h) 403 | else: 404 | h = block(h) 405 | 406 | mid_encoder_embeddings = [] 407 | for block in self.mid_encoder_blocks: 408 | if isinstance(block, UNetBlockWithSelfAttention): 409 | h = block(h, noise_level) 410 | mid_encoder_embeddings.append(h) 411 | else: 412 | h = block(h) 413 | 414 | for block in self.mid_decoder_blocks: 415 | if isinstance(block, UNetBlockWithSelfAttention): 416 | h = torch.cat((h, mid_encoder_embeddings.pop()), dim=1) 417 | h = block(h, noise_level) 418 | else: 419 | h = block(h) 420 | 421 | for block in self.top_decoder_blocks: 422 | if isinstance(block, UNetBlock): 423 | h = torch.cat((h, top_encoder_embeddings.pop()), dim=1) 424 | h = block(h, noise_level) 425 | else: 426 | h = block(h) 427 | 428 | output = self.output_projection(h) 429 | 430 | # Concatenate the infrared output with a 3-channel tensor of zeros 431 | # zero_channels = torch.zeros_like(output) 432 | 433 | # return torch.cat([output, zero_channels], dim=1) 434 | return output 435 | 436 | def _make_encoder_blocks( 437 | self, 438 | channels: Tuple[int, ...], 439 | n_blocks_per_resolution: Tuple[int, ...], 440 | has_resampling: Tuple[bool, ...], 441 | dropout: Tuple[float, ...], 442 | block_fn: Callable[[], nn.Module], 443 | ) -> nn.ModuleList: 444 | blocks = nn.ModuleList() 445 | 446 | channel_pairs = list(zip(channels[:-1], channels[1:])) 447 | for idx, (in_channels, out_channels) in enumerate(channel_pairs): 448 | for _ in range(n_blocks_per_resolution[idx]): 449 | blocks.append(block_fn(in_channels, out_channels, dropout[idx])) 450 | in_channels = out_channels 451 | 452 | if has_resampling[idx]: 453 | blocks.append(Downsample(out_channels)) 454 | 455 | return blocks 456 | 457 | def _make_decoder_blocks( 458 | self, 459 | channels: Tuple[int, ...], 460 | n_blocks_per_resolution: Tuple[int, ...], 461 | has_resampling: Tuple[bool, ...], 462 | dropout: Tuple[float, ...], 463 | block_fn: Callable[[], nn.Module], 464 | ) -> nn.ModuleList: 465 | blocks = nn.ModuleList() 466 | 467 | channel_pairs = list(zip(channels[:-1], channels[1:]))[::-1] 468 | for idx, (out_channels, in_channels) in enumerate(channel_pairs): 469 | if has_resampling[::-1][idx]: 470 | blocks.append(Upsample(in_channels)) 471 | 472 | inner_blocks = [] 473 | for _ in range(n_blocks_per_resolution[::-1][idx]): 474 | inner_blocks.append( 475 | block_fn(in_channels * 2, out_channels, dropout[::-1][idx]) 476 | ) 477 | out_channels = in_channels 478 | blocks.extend(inner_blocks[::-1]) 479 | 480 | return blocks 481 | 482 | def _make_top_block( 483 | self, in_channels: int, out_channels: int, dropout: float 484 | ) -> UNetBlock: 485 | return UNetBlock( 486 | in_channels, 487 | out_channels, 488 | self.config.noise_level_channels, 489 | dropout, 490 | ) 491 | 492 | def _make_mid_block( 493 | self, 494 | in_channels: int, 495 | out_channels: int, 496 | dropout: float, 497 | ) -> UNetBlockWithSelfAttention: 498 | return UNetBlockWithSelfAttention( 499 | in_channels, 500 | out_channels, 501 | self.config.noise_level_channels, 502 | self.config.n_heads, 503 | dropout, 504 | ) 505 | 506 | def save_pretrained(self, pretrained_path: str) -> None: 507 | os.makedirs(pretrained_path, exist_ok=True) 508 | 509 | with open(os.path.join(pretrained_path, "config.json"), mode="w") as f: 510 | json.dump(asdict(self.config), f) 511 | 512 | torch.save(self.state_dict(), os.path.join(pretrained_path, "model.pt")) 513 | 514 | @classmethod 515 | def from_pretrained(cls, pretrained_path: str) -> "UNet": 516 | with open(os.path.join(pretrained_path, "config.json"), mode="r") as f: 517 | config_dict = json.load(f) 518 | config = UNetConfig(**config_dict) 519 | 520 | model = cls(config) 521 | 522 | state_dict = torch.load( 523 | os.path.join(pretrained_path, "model.pt"), map_location=torch.device("cpu") 524 | ) 525 | model.load_state_dict(state_dict) 526 | 527 | return model 528 | 529 | 530 | # summary(UNet(UNetConfig()), input_size=((1, 6, 32, 32), (1,))) 531 | 532 | 533 | @dataclass 534 | class LitImprovedConsistencyModelConfig: 535 | ema_decay_rate: float = 0.99993 536 | lr: float = 1e-4 537 | betas: Tuple[float, float] = (0.9, 0.995) 538 | lr_scheduler_start_factor: float = 1e-5 539 | lr_scheduler_iters: int = 10_000 540 | sample_every_n_steps: int = 10_000 541 | num_samples: int = 8 542 | sampling_sigmas: Tuple[Tuple[int, ...], ...] = ( 543 | (80,), 544 | (80.0, 0.661), 545 | (80.0, 24.4, 5.84, 0.9, 0.661), 546 | ) 547 | 548 | 549 | class LitImprovedConsistencyModel(LightningModule): 550 | def __init__( 551 | self, 552 | consistency_training: ImprovedConsistencyTraining, 553 | consistency_sampling: ConsistencySamplingAndEditing, 554 | model: UNet, 555 | ema_model: UNet, 556 | config: LitImprovedConsistencyModelConfig, 557 | ) -> None: 558 | super().__init__() 559 | 560 | self.consistency_training = consistency_training 561 | self.consistency_sampling = consistency_sampling 562 | self.model = model 563 | self.ema_model = ema_model 564 | self.config = config 565 | 566 | # Freeze the EMA model and set it to eval mode 567 | for param in self.ema_model.parameters(): 568 | param.requires_grad = False 569 | self.ema_model = self.ema_model.eval() 570 | 571 | def training_step(self, batch: Union[Tensor, List[Tensor]], batch_idx: int) -> None: 572 | # if isinstance(batch, list): 573 | # batch = batch[0] 574 | 575 | visible_images, infrared_images = batch # Unpack the batch 576 | 577 | output = self.consistency_training( 578 | self.model, 579 | infrared_images, 580 | visible_images, # Pass visible images to the training function 581 | self.global_step, 582 | self.trainer.max_steps 583 | ) 584 | 585 | loss = ( 586 | pseudo_huber_loss(output.predicted, output.target) * output.loss_weights 587 | ).mean() 588 | 589 | self.log_dict({"train_loss": loss, "num_timesteps": output.num_timesteps}) 590 | 591 | return loss 592 | 593 | def on_train_batch_end( 594 | self, outputs: Any, batch: Union[Tensor, List[Tensor]], batch_idx: int 595 | ) -> None: 596 | update_ema_model_(self.model, self.ema_model, self.config.ema_decay_rate) 597 | 598 | if ( 599 | (self.global_step + 1) % self.config.sample_every_n_steps == 0 600 | ) or self.global_step == 0: 601 | self.__sample_and_log_samples(batch) 602 | 603 | def configure_optimizers(self): 604 | opt = torch.optim.Adam( 605 | self.model.parameters(), lr=self.config.lr, betas=self.config.betas 606 | ) 607 | sched = torch.optim.lr_scheduler.LinearLR( 608 | opt, 609 | start_factor=self.config.lr_scheduler_start_factor, 610 | total_iters=self.config.lr_scheduler_iters, 611 | ) 612 | sched = {"scheduler": sched, "interval": "step", "frequency": 1} 613 | 614 | return [opt], [sched] 615 | 616 | @torch.no_grad() 617 | def __sample_and_log_samples(self, batch: Union[Tensor, List[Tensor]]) -> None: 618 | if isinstance(batch, list): 619 | batch = batch[0] 620 | 621 | # Ensure the number of samples does not exceed the batch size 622 | num_samples = min(self.config.num_samples, batch.shape[0]) 623 | noise = torch.randn_like(batch[:num_samples]) 624 | 625 | # Log ground truth samples 626 | self.__log_images( 627 | batch[:num_samples].detach().clone(), f"ground_truth", self.global_step 628 | ) 629 | 630 | for sigmas in self.config.sampling_sigmas: 631 | samples = self.consistency_sampling( 632 | self.ema_model, noise, sigmas, clip_denoised=True, verbose=True 633 | ) 634 | samples = samples.clamp(min=-1.0, max=1.0) 635 | 636 | # Generated samples 637 | self.__log_images( 638 | samples, 639 | f"generated_samples-sigmas={sigmas}", 640 | self.global_step, 641 | ) 642 | 643 | @torch.no_grad() 644 | def __log_images(self, images: Tensor, title: str, global_step: int) -> None: 645 | images = images.detach().float() 646 | 647 | grid = make_grid( 648 | images.clamp(-1.0, 1.0), value_range=(-1.0, 1.0), normalize=True 649 | ) 650 | self.logger.experiment.add_image(title, grid, global_step) 651 | 652 | def on_train_epoch_end(self) -> None: 653 | # Retrieve the loss logged during the epoch 654 | train_loss = self.trainer.callback_metrics.get("train_loss", None) 655 | 656 | if train_loss is not None: 657 | print(f"Epoch {self.current_epoch} - Train Loss: {train_loss:.4f}") 658 | 659 | 660 | @dataclass 661 | class TrainingConfig: 662 | image_dm_config: ImageDataModuleConfig 663 | unet_config: UNetConfig 664 | consistency_training: ImprovedConsistencyTraining 665 | consistency_sampling: ConsistencySamplingAndEditing 666 | lit_icm_config: LitImprovedConsistencyModelConfig 667 | trainer: Trainer 668 | model_ckpt_path: str = "checkpoints/sid" 669 | seed: int = 42 670 | 671 | def run_training(config: TrainingConfig) -> None: 672 | # Set seed 673 | seed_everything(config.seed) 674 | 675 | # Create data module 676 | dm = SIDDataModule(config.image_dm_config) 677 | dm.setup() 678 | print("DataModule setup complete.") 679 | 680 | # Create model and its EMA 681 | model = UNet(config.unet_config) 682 | ema_model = UNet(config.unet_config) 683 | ema_model.load_state_dict(model.state_dict()) 684 | 685 | # Create lightning module 686 | lit_icm = LitImprovedConsistencyModel( 687 | config.consistency_training, 688 | config.consistency_sampling, 689 | model, 690 | ema_model, 691 | config.lit_icm_config, 692 | ) 693 | 694 | print("Lightning module created.") 695 | 696 | # Run training 697 | print("Starting training...") 698 | config.trainer.fit(lit_icm, datamodule=dm) 699 | print("Training completed.") 700 | 701 | # Save model 702 | lit_icm.model.save_pretrained(config.model_ckpt_path) 703 | print("Model saved.") 704 | 705 | def main(): 706 | # Define the checkpoint callback 707 | checkpoint_callback = ModelCheckpoint( 708 | dirpath="checkpoints_sid", 709 | filename="{epoch}-{step}", 710 | save_top_k=-1, # Save all checkpoints 711 | every_n_epochs=20, # Adjust as needed 712 | ) 713 | 714 | # Set up the logger 715 | logger = TensorBoardLogger("logs", name="sid") 716 | 717 | training_config = TrainingConfig( 718 | image_dm_config=ImageDataModuleConfig(data_dir="../datasets/sid"), 719 | unet_config=UNetConfig(), 720 | consistency_training=ImprovedConsistencyTraining(final_timesteps=11), 721 | consistency_sampling=ConsistencySamplingAndEditing(), 722 | lit_icm_config=LitImprovedConsistencyModelConfig( 723 | sample_every_n_steps=2100000, lr_scheduler_iters=1000 724 | ), 725 | trainer=Trainer( 726 | max_steps=100, 727 | precision="16", 728 | log_every_n_steps=10, 729 | logger=logger, 730 | callbacks=[ 731 | LearningRateMonitor(logging_interval="step"), 732 | checkpoint_callback, # Add the checkpoint callback here 733 | ], 734 | ), 735 | ) 736 | run_training(training_config) 737 | 738 | if __name__ == "__main__": 739 | main() --------------------------------------------------------------------------------