├── README.md ├── evaluation ├── FID_KID.ipynb ├── LPIPS.ipynb ├── README.md └── segmentation │ ├── run_segmentation.ipynb │ └── transforms.py ├── inference ├── README.md ├── api_inference_controlnet11_tissues_01_49.ipynb ├── api_inference_controlnet11_tissues_25_66.ipynb └── api_inference_controlnet11_tissues_52_56.ipynb ├── preprocessing ├── README.md ├── Texture-Synthesis │ ├── README.md │ ├── tex_01_49 │ │ ├── 01_a.png │ │ ├── 01_b.png │ │ ├── 02_a.png │ │ ├── 02_b.png │ │ ├── 04_a.png │ │ ├── 04_b.png │ │ ├── 10_a.png │ │ ├── 10_b.png │ │ ├── 12_a.png │ │ ├── 12_b.png │ │ ├── tex01_a.png │ │ ├── tex01_a_rand.png │ │ ├── tex01_b.png │ │ ├── tex01_b_rand.png │ │ ├── tex02_a.png │ │ ├── tex02_a_rand.png │ │ ├── tex02_b.png │ │ ├── tex02_b_rand.png │ │ ├── tex04_a.png │ │ ├── tex04_a_rand.png │ │ ├── tex04_b.png │ │ ├── tex04_b_rand.png │ │ ├── tex10_a.png │ │ ├── tex10_a_rand.png │ │ ├── tex10_b.png │ │ ├── tex10_b_rand.png │ │ ├── tex12_a.png │ │ ├── tex12_a_rand.png │ │ ├── tex12_b.png │ │ └── tex12_b_rand.png │ ├── tex_25_66 │ │ ├── 01_a.png │ │ ├── 01_b.png │ │ ├── 02_a.png │ │ ├── 02_b.png │ │ ├── 04_a.png │ │ ├── 04_b.png │ │ ├── 10_a.png │ │ ├── 10_b.png │ │ ├── 12_a.png │ │ ├── 12_b.png │ │ ├── tex01_a.png │ │ ├── tex01_a_rand.png │ │ ├── tex01_b.png │ │ ├── tex01_b_rand.png │ │ ├── tex02_a.png │ │ ├── tex02_a_rand.png │ │ ├── tex02_b.png │ │ ├── tex02_b_rand.png │ │ ├── tex04_a.png │ │ ├── tex04_a_rand.png │ │ ├── tex04_b.png │ │ ├── tex04_b_rand.png │ │ ├── tex10_a.png │ │ ├── tex10_a_rand.png │ │ ├── tex10_b.png │ │ ├── tex10_b_rand.png │ │ ├── tex12_a.png │ │ ├── tex12_a_rand.png │ │ ├── tex12_b.png │ │ └── tex12_b_rand.png │ ├── tex_52_56 │ │ ├── 01_a.png │ │ ├── 01_b.png │ │ ├── 02_a.png │ │ ├── 02_b.png │ │ ├── 04_a.png │ │ ├── 04_b.png │ │ ├── 10_a.png │ │ ├── 10_b.png │ │ ├── 12_a.png │ │ ├── 12_b.png │ │ ├── tex01_a.png │ │ ├── tex01_a_rand.png │ │ ├── tex01_b.png │ │ ├── tex01_b_rand.png │ │ ├── tex02_a.png │ │ ├── tex02_a_rand.png │ │ ├── tex02_b.png │ │ ├── tex02_b_old.png │ │ ├── tex02_b_rand.png │ │ ├── tex02_b_rand_old.png │ │ ├── tex04_a.png │ │ ├── tex04_a_rand.png │ │ ├── tex04_b.png │ │ ├── tex04_b_rand.png │ │ ├── tex10_a.png │ │ ├── tex10_a_rand.png │ │ ├── tex10_b.png │ │ ├── tex10_b_rand.png │ │ ├── tex12_a.png │ │ ├── tex12_a_rand.png │ │ ├── tex12_b.png │ │ └── tex12_b_rand.png │ └── texture_synthesis.py ├── add_textures.ipynb └── prepare_segmentation_maps.ipynb ├── requirements.txt ├── sim2real_asset.png └── training ├── README.md ├── convert_diffusers_to_sd.py ├── train_dreambooth.py └── train_dreambooth_commands.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Minimal Data Requirement for Realistic Endoscopic Image Generation with Stable Diffusion 2 | 3 | [![Paper](https://img.shields.io/badge/Read_Paper-Available-green)](https://link.springer.com/article/10.1007/s11548-023-03030-w) [![Dataset](https://img.shields.io/badge/Dataset-Download-blue)](https://sanoscience-my.sharepoint.com/:f:/g/personal/j_kaleta_sanoscience_org/Eu61v4XSSvZJrAvnkpVBNukBG6CTcGmqLoySF-lT6_sAcQ?e=kIMkfh) 4 | 5 | This is an official repository for **Minimal Data Requirement for Realistic Endoscopic Image Generation with Stable Diffusion**. 6 | ![Sim2Real results](sim2real_asset.png) 7 | ## Overview 8 | This repository provides resources and workflows for generating realistic endoscopic images using Stable Diffusion. The methodology is divided into four key steps: 9 | 10 | 1. **Training** 11 | Train Stable Diffusion for endoscopic image generation. 12 | 2. **Preprocessing** 13 | Prepare raw simulation images for inference. 14 | 3. **Inference** 15 | Generate realistic endoscopic images using the trained model. 16 | 4. **Evaluation** 17 | Evaluate the quality and realism of the generated images. 18 | 19 | Each step has a detailed README in its respective folder. 20 | 21 | ## Resources 22 | Access all resources, including datasets, pretrained models, and scripts: [Dataset and Resources](https://sanoscience-my.sharepoint.com/:f:/g/personal/j_kaleta_sanoscience_org/Eu61v4XSSvZJrAvnkpVBNukBG6CTcGmqLoySF-lT6_sAcQ?e=kIMkfh) 23 | 24 | ## Citation 25 | 26 | If you use this work in your research, please cite it as follows: 27 | 28 | ```bibtex 29 | @article{Kaleta2024, 30 | title={Minimal data requirement for realistic endoscopic image generation with Stable Diffusion}, 31 | author={Kaleta, Joanna and Dall'Alba, Diego and Płotka, Szymon and Korzeniowski, Przemysław}, 32 | journal={International Journal of Computer Assisted Radiology and Surgery}, 33 | volume={19}, 34 | number={3}, 35 | pages={531--539}, 36 | year={2024}, 37 | doi={10.1007/s11548-023-03030-w}, 38 | url={https://doi.org/10.1007/s11548-023-03030-w} 39 | } 40 | -------------------------------------------------------------------------------- /evaluation/FID_KID.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# FID & KID computation" 9 | ] 10 | }, 11 | { 12 | "attachments": {}, 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Copy 10k CHOLECT45 data from https://github.com/CAMMA-public/cholect45 to dest_path. \n", 17 | "### Comment out cells below if you need to prepare a folder with real data.\n", 18 | "### Copy with margin becasue some frames are all black. Exclude videos 01, 49, 25, 66, 52, 56. " 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# import os, shutil\n", 28 | "# import random \n", 29 | "# import glob\n", 30 | "# from tqdm import tqdm\n", 31 | "# import random\n", 32 | "# import os\n", 33 | "# from PIL import Image\n", 34 | "# from torchvision.transforms import functional as F\n", 35 | "# from torchvision.transforms import InterpolationMode\n", 36 | "# import torchvision.transforms as transforms" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 1, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stderr", 46 | "output_type": "stream", 47 | "text": [ 48 | "100%|██████████| 39/39 [01:28<00:00, 2.26s/it]\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "# n=10500 #draw more to exclude all black frames\n", 54 | "\n", 55 | "# cholect_path = '/path/to/real_data'\n", 56 | "# dest_path = 'dest/path/to/real_data'\n", 57 | "# os.makedirs(dest_path, exist_ok=True)\n", 58 | "\n", 59 | "# items=[vid for vid in os.listdir(cholect_path) if all(t not in vid for t in ['01', '49', '25', '66', '52', '56']) ]\n", 60 | "\n", 61 | "# for i in tqdm(items):\n", 62 | "# imgs = [img for img in os.listdir(os.path.join(cholect_path, i)) if \".png\" in img]\n", 63 | "# frames_to_sample = n//len(items)\n", 64 | "# imgs = random.sample(imgs, frames_to_sample+1)\n", 65 | "# for img in imgs:\n", 66 | "# src_path = os.path.join(cholect_path, i, img)\n", 67 | "# out_path = os.path.join(dest_path, str(i +\"_\" + img))\n", 68 | "# shutil.copy(src_path, out_path)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "### Remove black frames" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stderr", 85 | "output_type": "stream", 86 | "text": [ 87 | "100%|██████████| 10530/10530 [01:37<00:00, 107.88it/s]\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "\n", 93 | "# def is_almost_black(image_path, threshold=10):\n", 94 | "# \"\"\"\n", 95 | "# Check if an image is almost black by comparing pixel values with a threshold.\n", 96 | "# \"\"\"\n", 97 | "# image = Image.open(image_path).convert('L') # Open the image and convert it to grayscale\n", 98 | "# pixels = image.getdata()\n", 99 | "# for pixel in pixels:\n", 100 | "# if pixel > threshold:\n", 101 | "# return False\n", 102 | "# return True\n", 103 | "\n", 104 | "# def remove_black_images(folder_path):\n", 105 | "# \"\"\"\n", 106 | "# Remove black and almost black images from a folder.\n", 107 | "# \"\"\"\n", 108 | "# for filename in tqdm(os.listdir(folder_path)):\n", 109 | "# if filename.endswith('.jpg') or filename.endswith('.png'): # Adjust the file extensions as needed\n", 110 | "# image_path = os.path.join(folder_path, filename)\n", 111 | "# if is_almost_black(image_path):\n", 112 | "# os.remove(image_path)\n", 113 | "\n", 114 | "# # Example usage\n", 115 | "# remove_black_images(dest_path)\n" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "### Remove excessive images to keep equal 10k needed for FID, KID" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 4, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "# file_list = os.listdir(dest_path)\n", 132 | "# files_to_delete = random.sample(file_list, len(file_list) - 10000)\n", 133 | "\n", 134 | "# for file_name in files_to_delete:\n", 135 | "# file_path = os.path.join(dest_path, file_name)\n", 136 | "# os.remove(file_path)\n" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "### Resize and crop images, save to separate folder" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 8, 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "name": "stderr", 153 | "output_type": "stream", 154 | "text": [ 155 | "100%|██████████| 10000/10000 [06:53<00:00, 24.20it/s]\n" 156 | ] 157 | } 158 | ], 159 | "source": [ 160 | "# dest_path_cropped = dest_path +\"_cropped\"\n", 161 | "\n", 162 | "# os.makedirs(folder_cropped, exist_ok=True)\n", 163 | "\n", 164 | "# for filename in tqdm.tqdm(os.listdir(dest_path)):\n", 165 | "# if filename.endswith('.jpeg') or filename.endswith('.png') or filename.endswith('.jpg'):\n", 166 | "# # Open the image file\n", 167 | "# image = Image.open(os.path.join(dest_path, filename)).convert(\"RGB\")\n", 168 | "# image = transforms.Resize(299)(image)\n", 169 | "# image = transforms.CenterCrop(299)(image)\n", 170 | "# # Save the cropped image with the same name as the original\n", 171 | "# image.save(os.path.join(dest_path_cropped, filename))" 172 | ] 173 | }, 174 | { 175 | "attachments": {}, 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "## Calculate FID and KID \n", 180 | "### We first resize synthetic images to 299 and copy them to tmp folder - we always do it that way to have unified pipeline (resizing impacts fid & kid values, we want to have full control)." 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "# !pip install torch-fidelity" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "import shutil\n", 199 | "import os\n", 200 | "import tqdm\n", 201 | "from PIL import Image\n", 202 | "from torchvision import transforms\n", 203 | "import random\n", 204 | "import glob\n", 205 | "import torch_fidelity" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "real_data_path = 'path/to/real/data/cropped' # dest_path_cropped from previous cells\n", 215 | "data_path_25_66 = '/path/to/data' #output_vid25_66\n", 216 | "data_path_01_49 = '/path/to/data' #output_vid01_49\n", 217 | "data_path_52_56 = '/path/to/data' #output_vid52_56\n", 218 | "\n", 219 | "# calculate FID & KID for mixed style\n", 220 | "rand_dict = {1: data_path_25_66,\n", 221 | " 2: data_path_01_49,\n", 222 | " 3: data_path_52_56}\n", 223 | "\n", 224 | "filenames_common =set(os.listdir(rand_dict[1]))& set(os.listdir(rand_dict[2])) & set(os.listdir(rand_dict[3]))\n", 225 | "filenames = [os.path.join(rand_dict[random.randint(1, 3)], f) for f in filenames_common]\n", 226 | "\n", 227 | "random.seed(420420)\n", 228 | "filenames = random.sample(filenames,10000)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "# save images in a tmp folder. It is important to keep exactly the same resizing method.\n", 238 | "tmp_folder = \"/path/to/tmp/folder\"\n", 239 | "\n", 240 | "if os.path.exists(tmp_folder):\n", 241 | " shutil.rmtree(tmp_folder)\n", 242 | "os.makedirs(tmp_folder, exist_ok=True)\n", 243 | "\n", 244 | "for filename in tqdm.tqdm(filenames):\n", 245 | " if filename.endswith('.jpeg') or filename.endswith('.png') or filename.endswith('.jpg'):\n", 246 | " # Open the image file\n", 247 | " image = Image.open(os.path.join(filename)).convert(\"RGB\")\n", 248 | " image = transforms.Resize(299)(image)\n", 249 | " # Save the cropped image with the same name as the original\n", 250 | " image.save(os.path.join(tmp_folder, os.path.basename(filename)))" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "# calculate FID and KID\n", 260 | "metrics_dict = torch_fidelity.calculate_metrics(\n", 261 | " input1=real_data_path,\n", 262 | " input2=tmp_folder, \n", 263 | " cuda=True, \n", 264 | " isc=False, \n", 265 | " fid=True, \n", 266 | " kid=True, \n", 267 | " prc=False, \n", 268 | " verbose=True,\n", 269 | " kid_subset_size =1000,\n", 270 | " cache = False\n", 271 | ")\n", 272 | "\n" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [] 281 | } 282 | ], 283 | "metadata": { 284 | "kernelspec": { 285 | "display_name": "fid", 286 | "language": "python", 287 | "name": "python3" 288 | }, 289 | "language_info": { 290 | "codemirror_mode": { 291 | "name": "ipython", 292 | "version": 3 293 | }, 294 | "file_extension": ".py", 295 | "mimetype": "text/x-python", 296 | "name": "python", 297 | "nbconvert_exporter": "python", 298 | "pygments_lexer": "ipython3", 299 | "version": "3.11.3" 300 | }, 301 | "orig_nbformat": 4 302 | }, 303 | "nbformat": 4, 304 | "nbformat_minor": 2 305 | } 306 | -------------------------------------------------------------------------------- /evaluation/LPIPS.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Calculate LPIPS for 10k image pairs from 1 folder \n", 8 | "### Pair: image N & image (N + 1)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "# !pip install lpips" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import os\n", 27 | "import lpips\n", 28 | "import numpy as np\n", 29 | "from tqdm import tqdm\n", 30 | "import random\n", 31 | "\n", 32 | "# select IRCAD images \n", 33 | "# Since we calculate average distance between 10k pairs of current and consecutive images, it makes sense to have one list specifying which IRCAD frames in which order appear in calculation.\n", 34 | "# For that purpose in our paper we choose images generated from model vid25_66 part 1\n", 35 | "list_ = os.listdir('/path/to/output_vid25_66_part1')\n", 36 | "\n", 37 | "# Now we can calculate LPIPS for any data. Below example for mixed style, paths can be for inputs or baseline output data.\n", 38 | "data_path_25_66 = '/path/to/data' #output_vid25_66\n", 39 | "data_path_01_49 = '/path/to/data' #output_vid01_49\n", 40 | "data_path_52_56 = '/path/to/data' #output_vid52_56\n", 41 | "\n", 42 | "# calculate FID & KID for mixed style\n", 43 | "rand_dict = {1: data_path_25_66,\n", 44 | " 2: data_path_01_49,\n", 45 | " 3: data_path_52_56}\n", 46 | "\n", 47 | "files_common =set(os.listdir(rand_dict[1]))& set(os.listdir(rand_dict[2])) & set(os.listdir(rand_dict[3]))\n", 48 | "files = [os.path.join(rand_dict[random.randint(1, 3)], f) for f in files_common]\n", 49 | "\n", 50 | "random.seed(420420)\n", 51 | "files = random.sample(files,10000)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "# Initializing the model\n", 61 | "loss_fn = lpips.LPIPS(net='vgg',version='0.1')\n", 62 | "loss_fn.to('cuda:1')\n", 63 | "\n", 64 | "dists = []\n", 65 | "for (ff,file) in tqdm(enumerate(files[:-1])):\n", 66 | "\n", 67 | "\timg0 = lpips.im2tensor(lpips.load_image(file)) # RGB image from [-1,1]\n", 68 | "\timg0 = img0.to('cuda:1')\n", 69 | "\n", 70 | "\tfiles1 = [files[ff+1],]\n", 71 | "\t\n", 72 | "\tfor file1 in files1:\n", 73 | "\t\timg1 = lpips.im2tensor(lpips.load_image(file1))\n", 74 | "\t\timg1 = img1.to('cuda:1')\n", 75 | "\n", 76 | "\t\t# Compute distance\n", 77 | "\t\tdist01 = loss_fn.forward(img0,img1)\n", 78 | "\t\tdists.append(dist01.item())\n", 79 | "\n", 80 | "avg_dist = np.mean(np.array(dists))\n", 81 | "stderr_dist = np.std(np.array(dists))/np.sqrt(len(dists))\n", 82 | "\n", 83 | "print('Avg: %.5f +/- %.5f'%(avg_dist,stderr_dist))\n", 84 | "\n" 85 | ] 86 | } 87 | ], 88 | "metadata": { 89 | "kernelspec": { 90 | "display_name": "base", 91 | "language": "python", 92 | "name": "python3" 93 | }, 94 | "language_info": { 95 | "codemirror_mode": { 96 | "name": "ipython", 97 | "version": 3 98 | }, 99 | "file_extension": ".py", 100 | "mimetype": "text/x-python", 101 | "name": "python", 102 | "nbconvert_exporter": "python", 103 | "pygments_lexer": "ipython3", 104 | "version": "3.9.13" 105 | }, 106 | "orig_nbformat": 4 107 | }, 108 | "nbformat": 4, 109 | "nbformat_minor": 2 110 | } 111 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | 1. Segmentation: 2 | Train Unet for segmentation on real CholecSeg8k data, evaluate on synthetic data. Follow `segmentation / run_segmentation.ipynb` notebook. CholecSeg8k dataset together with our train/val split is available on OneDrive. 3 | Credits: https://github.com/nezih-niegu/Courses/blob/main/Surgical%20Data%20Science/EDU4SDS_Lecture7_SemSeg.ipynb 4 | 5 | 2. FID & KID: 6 | Follow `FID_KID.ipynb` notebook. For these metrics having at least 10k real images downloaded is required. 7 | 8 | 3. LPIPS: 9 | Follow `FID_KID.ipynb` notebook, distances within only one dataset are computed. 10 | -------------------------------------------------------------------------------- /evaluation/segmentation/run_segmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Notebook for UNet training and testing.\n", 8 | "\n", 9 | "- If you want to test synthetic data with our trained model, download `revised_model_epoch_29.pth` from OneDrive to `/evaluation/segmentation/` folder. \n", 10 | "- If you want to train UNet on CholecSeg8k data, download the CholecSeg8k dataset from OneDrive to `evaulation/segmentation/cholecseg8k` folder." 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "id": "10f9d2a8", 17 | "metadata": {}, 18 | "source": [ 19 | "## Setup" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "id": "a4d67fb5", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# install dependencies\n", 30 | "# !pip install numpy\n", 31 | "# !pip install matplotlib\n", 32 | "# !pip install torch\n", 33 | "# !pip install torchvision\n", 34 | "# !pip install tqdm\n", 35 | "# !pip install ipywidgets" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "id": "a59cba21", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "import torch\n", 46 | "import torch.nn as nn\n", 47 | "import transforms as T\n", 48 | "import cv2\n", 49 | "import random\n", 50 | "import torchvision\n", 51 | "from tqdm.notebook import tqdm\n", 52 | "import matplotlib.pyplot as plt\n", 53 | "import glob\n", 54 | "import os\n", 55 | "from tqdm.notebook import tqdm\n", 56 | "import json\n", 57 | "import numpy as np\n", 58 | "import shutil\n", 59 | "from PIL import Image, ImageColor\n", 60 | "import io\n", 61 | "import torchvision.transforms as transforms\n", 62 | "\n", 63 | "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 64 | "\n", 65 | "NUM_EPOCHS = 30\n", 66 | "DO_TRAINING = False\n", 67 | "FINAL_MODEL_PATH = \"revised_model_epoch_29.pth\"\n", 68 | "MODEL_NAME = \"fcn_resnet50\" \n", 69 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n", 70 | "\n", 71 | "# Defining a color used to depict each semantic class being segmented\n", 72 | "META_DATA_ORIGINAL = [\n", 73 | " (\"black_background\", (0,0,0)),\n", 74 | " (\"abdominal_wall\", (33, 191, 197)),\n", 75 | " (\"liver\", (231, 126, 9)),\n", 76 | " (\"gastrointestinal_tract\", (209, 53, 84)),\n", 77 | " (\"fat\", (80, 155, 4)),\n", 78 | " (\"grasper\", (255, 207, 210)),\n", 79 | " (\"connective_tissue\", (169, 52, 199)),\n", 80 | " (\"blood\", (229, 18, 18)),\n", 81 | " (\"cystic_duct\", (149, 50, 18)),\n", 82 | " (\"l-hook_electrocautery\", (46, 43, 180)),\n", 83 | " (\"gallbladder\", (148, 55, 66)),\n", 84 | " (\"hepatic_vein\", (214, 51, 149)),\n", 85 | " (\"liver_ligament\", (240, 79, 10)),\n", 86 | "]\n", 87 | "\n", 88 | "#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n", 89 | "# THESE 2 LINES REDUCE / MERGE CLASSES. TO REDUCE PUT 255 VALUE. TO MERGE PUT DESIRED CLASS AS VALUE. REMAINING CLASSES HAVE TO BE FROM 0 to N.\n", 90 | "\n", 91 | "CLASSES_TO_IGNORE = [\"black_background\",\"gastrointestinal_tract\", \"connective_tissue\", \"blood\", \"cystic_duct\", \"l-hook_electrocautery\",\"hepatic_vein\", \"liver_ligament\"]\n", 92 | "REPLACE_CLASS = {0:255, 1:0, 2:1,3:255,4:2, 5:3,6:255,7:255, 8:255, 9:3, 10:4,11:255, 12:255, 13:3, 14:3, 15:3, 16:3}\n", 93 | "#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n", 94 | "\n", 95 | "META_DATA = [x for x in META_DATA_ORIGINAL if x[0] not in CLASSES_TO_IGNORE]\n", 96 | "\n", 97 | "\n", 98 | "# Optimizer parameters\n", 99 | "learning_rate = 0.00125\n", 100 | "momentum = 0.9\n", 101 | "power = 0.9\n", 102 | "weight_decay = 1e-4" 103 | ] 104 | }, 105 | { 106 | "attachments": {}, 107 | "cell_type": "markdown", 108 | "id": "3eab8ca5", 109 | "metadata": {}, 110 | "source": [ 111 | "## Helper functions and classes" 112 | ] 113 | }, 114 | { 115 | "attachments": {}, 116 | "cell_type": "markdown", 117 | "id": "8c9c5420", 118 | "metadata": {}, 119 | "source": [ 120 | "Defining some reusable function that we will use throughout this notebook" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 3, 126 | "id": "32d81b90", 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "def cat_list(images, fill_value=0):\n", 131 | " max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))\n", 132 | " batch_shape = (len(images),) + max_size\n", 133 | " batched_imgs = images[0].new(*batch_shape).fill_(fill_value)\n", 134 | " for img, pad_img in zip(images, batched_imgs):\n", 135 | " pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)\n", 136 | " return batched_imgs\n", 137 | "\n", 138 | "def collate_fn(batch):\n", 139 | " images, targets = list(zip(*batch))\n", 140 | " batched_imgs = cat_list(images, fill_value=0)\n", 141 | " batched_targets = cat_list(targets, fill_value=255)\n", 142 | " return batched_imgs, batched_targets\n", 143 | "\n", 144 | "# Helper function to do a cross entropy loss between the ground truth and predicted values\n", 145 | "def criterion(inputs, target):\n", 146 | " losses = {}\n", 147 | " for name, x in inputs.items():\n", 148 | " losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)\n", 149 | " if len(losses) == 1:\n", 150 | " return losses[\"out\"]\n", 151 | " return losses[\"out\"] + 0.5 * losses[\"aux\"]\n", 152 | "\n", 153 | "\n", 154 | "# Helper function to compute relevant metrics using a confusion matrix\n", 155 | "# see: https://en.wikipedia.org/wiki/Confusion_matrix\n", 156 | "class ConfusionMatrix:\n", 157 | " def __init__(self, num_classes):\n", 158 | " self.num_classes = num_classes\n", 159 | " self.mat = None\n", 160 | "\n", 161 | " def update(self, a, b):\n", 162 | " n = self.num_classes\n", 163 | " if self.mat is None:\n", 164 | " self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)\n", 165 | " with torch.no_grad():\n", 166 | " k = (a >= 0) & (a < n)\n", 167 | " inds = n * a[k].to(torch.int64) + b[k]\n", 168 | " self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)\n", 169 | "\n", 170 | " def reset(self):\n", 171 | " self.mat.zero_()\n", 172 | "\n", 173 | " def compute(self):\n", 174 | " h = self.mat.float()\n", 175 | " acc_global = torch.diag(h).sum() / h.sum()\n", 176 | " acc = torch.diag(h) / h.sum(1)\n", 177 | " iou = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))\n", 178 | " return acc_global, acc, iou\n", 179 | " \n", 180 | " # Return overall accuracy, per-class accuracy, per-class Intersection over Union (IoU) and mean IoU\n", 181 | " def __str__(self):\n", 182 | " acc_global, acc, iou = self.compute()\n", 183 | " return (\"global correct: {:.2f}\\naverage row correct: {}\\nIoU: {}\\nmean IoU: {:.2f}\").format(\n", 184 | " acc_global.item() * 100,\n", 185 | " [f\"{i:.1f}\" for i in (acc * 100).tolist()],\n", 186 | " [f\"{i:.1f}\" for i in (iou * 100).tolist()],\n", 187 | " iou.mean().item() * 100,\n", 188 | " )\n", 189 | " " 190 | ] 191 | }, 192 | { 193 | "attachments": {}, 194 | "cell_type": "markdown", 195 | "id": "5a7c96b6", 196 | "metadata": {}, 197 | "source": [ 198 | "## CholeSeg8k\n", 199 | "The CholecSeg8k dataset [1] consists of subset of Cholec80 [2] annotated with semantic segmentation labels with 13 semantic classes for 17 video clips.\n", 200 | "\n", 201 | "\n", 202 | "1. _Hong, W-Y., C-L. Kao, Y-H. Kuo, J-R. Wang, W-L. Chang, and C-S. Shih. \"CholecSeg8k: A Semantic Segmentation Dataset for Laparoscopic Cholecystectomy Based on Cholec80.\" arXiv preprint arXiv:2012.12453 (2020)._\n", 203 | "\n", 204 | "2. _Twinanda, Andru P., Sherif Shehata, Didier Mutter, Jacques Marescaux, Michel De Mathelin, and Nicolas Padoy. \"Endonet: a deep architecture for recognition tasks on laparoscopic videos.\" IEEE transactions on medical imaging 36, no. 1 (2016): 86-97._" 205 | ] 206 | }, 207 | { 208 | "attachments": {}, 209 | "cell_type": "markdown", 210 | "id": "0af3530c", 211 | "metadata": {}, 212 | "source": [ 213 | "## Dataset class" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 4, 219 | "id": "1d0e8beb", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "# We define a dataset class that delivers images and correponding ground truth segmentation masks\n", 224 | "# from the CholecSeg8k. Please refer to Lecture 6 for more info on torch Datasets.\n", 225 | "def map_values(x):\n", 226 | " return REPLACE_CLASS.get(x, x)\n", 227 | "\n", 228 | "class CholecDatasetSegm(torch.utils.data.Dataset):\n", 229 | " def __init__(self, gt_json, meta_data, root_dir = \"./cholecseg8k\", data_split = \"train\", transforms = None):\n", 230 | " self.gt_json = gt_json\n", 231 | " self.root_dir = root_dir\n", 232 | " self.data_split = data_split\n", 233 | " self.transforms = transforms\n", 234 | " gt_data = json.load(open(gt_json))\n", 235 | " self.images = [os.path.join(self.root_dir, g[\"file_name\"]) for g in gt_data]\n", 236 | " self.targets = [os.path.join(self.root_dir, g[\"mask_name\"]) for g in gt_data]\n", 237 | " self.metadata = meta_data\n", 238 | " \n", 239 | " def __len__(self):\n", 240 | " return len(self.images)\n", 241 | " \n", 242 | " def __getitem__(self, index: int):\n", 243 | " img = Image.open(self.images[index]).convert(\"RGB\")\n", 244 | " target = Image.open(self.targets[index]).convert(\"L\")\n", 245 | " target = target.resize(img.size, resample=Image.NEAREST)\n", 246 | " \n", 247 | " if CLASSES_TO_IGNORE:\n", 248 | " target = np.array(target)\n", 249 | " \n", 250 | " # use numpy's vectorize function to apply the mapping to the whole array\n", 251 | " map_func = np.vectorize(map_values)\n", 252 | " target = map_func(target)\n", 253 | "\n", 254 | " target = Image.fromarray(target.astype(np.uint8), mode='L')\n", 255 | " if self.transforms is not None:\n", 256 | " img, target = self.transforms(img, target) \n", 257 | " return img, target" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 5, 263 | "id": "a3e7df1e", 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "class SegmentationPresetTrain:\n", 268 | " def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):\n", 269 | " min_size = int(0.5 * base_size)\n", 270 | " max_size = int(2.0 * base_size)\n", 271 | "\n", 272 | " trans = [T.RandomResize(min_size, max_size)]\n", 273 | " if hflip_prob > 0:\n", 274 | " trans.append(T.RandomHorizontalFlip(hflip_prob))\n", 275 | " trans.extend(\n", 276 | " [\n", 277 | " T.RandomCrop(crop_size),\n", 278 | " T.PILToTensor(),\n", 279 | " T.ConvertImageDtype(torch.float),\n", 280 | " T.Normalize(mean=mean, std=std),\n", 281 | " ]\n", 282 | " )\n", 283 | " self.transforms = T.Compose(trans)\n", 284 | "\n", 285 | " def __call__(self, img, target):\n", 286 | " return self.transforms(img, target)\n", 287 | "\n", 288 | "\n", 289 | "class SegmentationPresetEval:\n", 290 | " def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):\n", 291 | " self.transforms = T.Compose(\n", 292 | " [\n", 293 | " T.RandomResize(base_size, base_size),\n", 294 | " T.PILToTensor(),\n", 295 | " T.ConvertImageDtype(torch.float),\n", 296 | " T.Normalize(mean=mean, std=std),\n", 297 | " ]\n", 298 | " )\n", 299 | "\n", 300 | " def __call__(self, img, target):\n", 301 | " return self.transforms(img, target)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 6, 307 | "id": "b0fd7802", 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "# Defining Data Loaders for the training and testing splits.\n", 312 | "# Please refer to Lecture 6 for more info on torch Data Loaders.\n", 313 | "\n", 314 | "\n", 315 | "def get_transform(train=True):\n", 316 | " if train:\n", 317 | " return SegmentationPresetTrain(base_size=512, crop_size=400)\n", 318 | " else:\n", 319 | " return SegmentationPresetEval(base_size=400)\n", 320 | " \n", 321 | "# Train loader\n", 322 | "dataset = CholecDatasetSegm(\"./cholecseg8k/train_final.json\", META_DATA, data_split=\"train\", transforms=get_transform())\n", 323 | "num_classes = len(META_DATA)\n", 324 | "train_sampler = torch.utils.data.RandomSampler(dataset)\n", 325 | "data_loader = torch.utils.data.DataLoader(\n", 326 | " dataset,\n", 327 | " batch_size=2,\n", 328 | " sampler=train_sampler,\n", 329 | " collate_fn=collate_fn,\n", 330 | " drop_last=True,\n", 331 | ")\n", 332 | "\n", 333 | "# Test loader\n", 334 | "dataset_test = CholecDatasetSegm(\"./cholecseg8k/val_final.json\", META_DATA, data_split=\"val\", transforms=get_transform(False))\n", 335 | "test_sampler = torch.utils.data.SequentialSampler(dataset_test)\n", 336 | "data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, sampler=test_sampler, collate_fn=collate_fn)" 337 | ] 338 | }, 339 | { 340 | "attachments": {}, 341 | "cell_type": "markdown", 342 | "id": "60e8373e", 343 | "metadata": {}, 344 | "source": [ 345 | "## CREATE IRCAD TEST DATASET" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 20, 351 | "id": "bd86f757", 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "# mIOU for mixed style\n", 356 | "\n", 357 | "data_path_25_66 = '/path/to/data' #output_vid25_66\n", 358 | "data_path_01_49 = '/path/to/data' #output_vid01_49\n", 359 | "data_path_52_56 = '/path/to/data' #output_vid52_56\n", 360 | "\n", 361 | "data_path_seg = '/path/to/segmentation_maps' # Segmentation maps need to have CholecSeg classes. They were generated in the preprocessing step\n", 362 | "\n", 363 | "# get list of files\n", 364 | "rand_dict = {1: data_path_25_66,\n", 365 | " 2: data_path_01_49,\n", 366 | " 3: data_path_52_56}\n", 367 | "\n", 368 | "filenames_common =set(os.listdir(rand_dict[1]))& set(os.listdir(rand_dict[2])) & set(os.listdir(rand_dict[3]))\n", 369 | "filenames = [os.path.join(rand_dict[random.randint(1, 1)], f) for f in filenames_common]\n", 370 | "\n", 371 | "random.seed(420420)\n", 372 | "filenames = random.sample(filenames,10)\n", 373 | "\n", 374 | "# get list of corresponding segmentation masks\n", 375 | "masks = [os.path.join(data_path_seg, os.path.basename(f)) for f in filenames]\n", 376 | "\n", 377 | "# Save json with test files + masks\n", 378 | "data = []\n", 379 | "for f, m in zip(filenames, masks):\n", 380 | " data.append({'file_name': f, 'mask_name': m})\n", 381 | "with open('test_synthetic.json', 'w') as f:\n", 382 | " json.dump(data, f)\n", 383 | "\n", 384 | "#Test IRCAD\n", 385 | "transforms_ircad = SegmentationPresetEval(base_size=512)\n", 386 | "dataset_test_ircad = CholecDatasetSegm(\"test_synthetic.json\", META_DATA, root_dir='',data_split=\"val\", transforms=transforms_ircad)\n", 387 | "test_sampler_ircad = torch.utils.data.SequentialSampler(dataset_test_ircad)\n", 388 | "data_loader_test_ircad = torch.utils.data.DataLoader(dataset_test_ircad, batch_size=1, sampler=test_sampler_ircad, collate_fn=collate_fn)" 389 | ] 390 | }, 391 | { 392 | "attachments": {}, 393 | "cell_type": "markdown", 394 | "id": "773fc52d", 395 | "metadata": {}, 396 | "source": [ 397 | "## Segmentation model" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": 22, 403 | "id": "234be465", 404 | "metadata": {}, 405 | "outputs": [ 406 | { 407 | "name": "stderr", 408 | "output_type": "stream", 409 | "text": [ 410 | "/home/jk/anaconda3/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", 411 | " warnings.warn(\n", 412 | "/home/jk/anaconda3/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1`. You can also use `weights=FCN_ResNet50_Weights.DEFAULT` to get the most up-to-date weights.\n", 413 | " warnings.warn(msg)\n" 414 | ] 415 | } 416 | ], 417 | "source": [ 418 | "model = torchvision.models.segmentation.__dict__[MODEL_NAME](pretrained=True)\n", 419 | "model.classifier[4] = nn.Conv2d(512, num_classes, 1)\n", 420 | "model.aux_classifier [4] = nn.Conv2d(256, num_classes, 1)\n", 421 | "model = model.to(DEVICE)" 422 | ] 423 | }, 424 | { 425 | "attachments": {}, 426 | "cell_type": "markdown", 427 | "id": "88470175", 428 | "metadata": {}, 429 | "source": [ 430 | "## Optimizer and learning rate scheduler" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 23, 436 | "id": "c3a6ab41", 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "params_to_optimize = [\n", 441 | " {\"params\": [p for p in model.backbone.parameters() if p.requires_grad]},\n", 442 | " {\"params\": [p for p in model.classifier.parameters() if p.requires_grad]},\n", 443 | "]\n", 444 | "params = [p for p in model.aux_classifier.parameters() if p.requires_grad]\n", 445 | "params_to_optimize.append({\"params\": params, \"lr\": learning_rate * 10})\n", 446 | "\n", 447 | "iters_per_epoch = len(data_loader)\n", 448 | "optimizer = torch.optim.SGD(params_to_optimize, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)\n", 449 | "lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (iters_per_epoch * NUM_EPOCHS)) ** power)" 450 | ] 451 | }, 452 | { 453 | "attachments": {}, 454 | "cell_type": "markdown", 455 | "id": "c52c47d6", 456 | "metadata": {}, 457 | "source": [ 458 | "## Helper function for training and validation for one epoch" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 24, 464 | "id": "91e82f15", 465 | "metadata": {}, 466 | "outputs": [], 467 | "source": [ 468 | "# Helper function to train\n", 469 | "def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device):\n", 470 | " model.train()\n", 471 | " train_loss = 0.0\n", 472 | " pbar = tqdm(data_loader)\n", 473 | " for image, target in pbar:\n", 474 | " image, target = image.to(device), target.to(device)\n", 475 | " output = model(image)\n", 476 | " loss = criterion(output, target)\n", 477 | " optimizer.zero_grad()\n", 478 | " loss.backward()\n", 479 | " optimizer.step()\n", 480 | " lr_scheduler.step()\n", 481 | " train_loss += loss.item()\n", 482 | " pbar.set_description(\"train_loss: {:.3f} lr: {:.3f}\".format(loss.item(), \n", 483 | " optimizer.param_groups[0][\"lr\"]))\n", 484 | " train_loss /= len(data_loader)\n", 485 | " return train_loss, optimizer.param_groups[0][\"lr\"]\n", 486 | "\n", 487 | "# Helper function to evaluate\n", 488 | "def evaluate(model, data_loader, device, num_classes):\n", 489 | " model.eval()\n", 490 | " confmat = ConfusionMatrix(num_classes)\n", 491 | " pbar = tqdm(data_loader)\n", 492 | " with torch.no_grad():\n", 493 | " for image, target in pbar:\n", 494 | " image, target = image.to(device), target.to(device)\n", 495 | " output = model(image)\n", 496 | " output = output[\"out\"]\n", 497 | " confmat.update(target.flatten(), output.argmax(1).flatten())\n", 498 | " pbar.set_description(\"eval\")\n", 499 | " return confmat" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": 25, 505 | "id": "19d306f1", 506 | "metadata": {}, 507 | "outputs": [ 508 | { 509 | "name": "stdout", 510 | "output_type": "stream", 511 | "text": [ 512 | "=> loaded model weights from revised_model_epoch_29.pth \n", 513 | "missing keys = [] invalid keys []\n" 514 | ] 515 | }, 516 | { 517 | "data": { 518 | "application/vnd.jupyter.widget-view+json": { 519 | "model_id": "0cb3d4bbbd104087a019c397b6abb798", 520 | "version_major": 2, 521 | "version_minor": 0 522 | }, 523 | "text/plain": [ 524 | " 0%| | 0/10 [00:00 loaded model weights from {} \\nmissing keys = {} invalid keys {}\".format(FINAL_MODEL_PATH, m, v))\n", 560 | "\n", 561 | "\n", 562 | " confmat = evaluate(model, data_loader_test_ircad, device=DEVICE, num_classes=num_classes)\n", 563 | " acc_global, acc, iu = confmat.compute() \n", 564 | " print(\n", 565 | " \"acc_global: {:.3f} iou: {:.3f}\".format(\n", 566 | " acc_global.item() * 100, iu.mean().item() * 100\n", 567 | " )\n", 568 | " )\n", 569 | " print(\"confmat:\", confmat)\n" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": null, 575 | "metadata": {}, 576 | "outputs": [], 577 | "source": [] 578 | } 579 | ], 580 | "metadata": { 581 | "kernelspec": { 582 | "display_name": "fid", 583 | "language": "python", 584 | "name": "python3" 585 | }, 586 | "language_info": { 587 | "codemirror_mode": { 588 | "name": "ipython", 589 | "version": 3 590 | }, 591 | "file_extension": ".py", 592 | "mimetype": "text/x-python", 593 | "name": "python", 594 | "nbconvert_exporter": "python", 595 | "pygments_lexer": "ipython3", 596 | "version": "3.9.13" 597 | }, 598 | "orig_nbformat": 4 599 | }, 600 | "nbformat": 4, 601 | "nbformat_minor": 2 602 | } 603 | -------------------------------------------------------------------------------- /evaluation/segmentation/transforms.py: -------------------------------------------------------------------------------- 1 | # taken from https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py 2 | import random 3 | import numpy as np 4 | import torch 5 | from torchvision import transforms as T 6 | from torchvision.transforms import functional as F 7 | 8 | 9 | def pad_if_smaller(img, size, fill=0): 10 | min_size = min(img.size) 11 | if min_size < size: 12 | ow, oh = img.size 13 | padh = size - oh if oh < size else 0 14 | padw = size - ow if ow < size else 0 15 | img = F.pad(img, (0, 0, padw, padh), fill=fill) 16 | return img 17 | 18 | 19 | class Compose: 20 | def __init__(self, transforms): 21 | self.transforms = transforms 22 | 23 | def __call__(self, image, target): 24 | for t in self.transforms: 25 | image, target = t(image, target) 26 | return image, target 27 | 28 | 29 | class RandomResize: 30 | def __init__(self, min_size, max_size=None): 31 | self.min_size = min_size 32 | if max_size is None: 33 | max_size = min_size 34 | self.max_size = max_size 35 | 36 | def __call__(self, image, target): 37 | size = random.randint(self.min_size, self.max_size) 38 | image = F.resize(image, size) 39 | target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST) 40 | return image, target 41 | 42 | 43 | class RandomHorizontalFlip: 44 | def __init__(self, flip_prob): 45 | self.flip_prob = flip_prob 46 | 47 | def __call__(self, image, target): 48 | if random.random() < self.flip_prob: 49 | image = F.hflip(image) 50 | target = F.hflip(target) 51 | return image, target 52 | 53 | 54 | class RandomCrop: 55 | def __init__(self, size): 56 | self.size = size 57 | 58 | def __call__(self, image, target): 59 | image = pad_if_smaller(image, self.size) 60 | target = pad_if_smaller(target, self.size, fill=255) 61 | crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) 62 | image = F.crop(image, *crop_params) 63 | target = F.crop(target, *crop_params) 64 | return image, target 65 | 66 | 67 | class CenterCrop: 68 | def __init__(self, size): 69 | self.size = size 70 | 71 | def __call__(self, image, target): 72 | image = F.center_crop(image, self.size) 73 | target = F.center_crop(target, self.size) 74 | return image, target 75 | 76 | 77 | class PILToTensor: 78 | def __call__(self, image, target): 79 | image = F.pil_to_tensor(image) 80 | target = torch.as_tensor(np.array(target), dtype=torch.int64) 81 | return image, target 82 | 83 | 84 | class ConvertImageDtype: 85 | def __init__(self, dtype): 86 | self.dtype = dtype 87 | 88 | def __call__(self, image, target): 89 | image = F.convert_image_dtype(image, self.dtype) 90 | return image, target 91 | 92 | 93 | class Normalize: 94 | def __init__(self, mean, std): 95 | self.mean = mean 96 | self.std = std 97 | 98 | def __call__(self, image, target): 99 | image = F.normalize(image, mean=self.mean, std=self.std) 100 | return image, target 101 | -------------------------------------------------------------------------------- /inference/README.md: -------------------------------------------------------------------------------- 1 | UPDATE: ControlNet pipelines are out 2 | TODO: move code to diffusers 3 | 4 | --- 5 | 6 | We use webui for inference. Currently, according to the official ControlNet repository, it is the recommended way to merge ControlNet weights with Custom SD weights: 7 | https://github.com/lllyasviel/ControlNet/discussions/12 8 | 9 | To run inference we first have to run WebUI + ControlNet extension 10 | - https://github.com/AUTOMATIC1111/stable-diffusion-webui 11 | - https://github.com/Mikubill/sd-webui-controlnet 12 | Please follow the guidance on official repository explaining how to run webui, extensions and where to put ControlNet models. 13 | 14 | To use WebUI API: 15 | - https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API 16 | 17 | We use following ControlNet 1.1 checkpoints: 18 | - https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_softedge.pth 19 | - https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth 20 | 21 | For inference we use Dreambooth model trained with Diffusers package. 22 | **Please remember, you have to first convert diffusers_checkpoint to original_stable_diffusion_checkpoint. Then you have to place your checkpoint in `stable-diffusion-webui/models/Stable-diffusion/` folder. Both these operations can be done as a last step in `training/train_dreambooth.ipynb`.** 23 | Next, select proper checkpoint in WebUI (upper left corner). Having the server running, you can run inference notebook. 24 | 25 | -------------------------------------------------------------------------------- /preprocessing/README.md: -------------------------------------------------------------------------------- 1 | To perform preprocessing: 2 | 3 | 1. run `prepare_segmentation_maps.ipynb` to get consistent classes. These segmentation maps will be used both for Texture enrichment (step 2.) and in evaluation phase. 4 | 2. run `texture_masks.ipynb` 5 | 6 | Credits: 7 | https://github.com/Devashi-Choudhary/Texture-Synthesis/tree/master -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/README.md: -------------------------------------------------------------------------------- 1 | # Texture-Synthesis 2 | Texture synthesis is a technique of generating new images by stitching together patches of existing images. The purpose of study is to understand 3 | patched based method for generating arbitrarily large textures from small real-world samples. The study is implementation of the paper 4 | [Image Quilting for Texture Synthesis and Transfer](https://people.eecs.berkeley.edu/~efros/research/quilting.html), 5 | based on the work of [Alexei Efros and William Freeman](https://github.com/lschlessinger1/image-quilting). 6 | 7 | # Dependencies 8 | 1. numpy 9 | 2. skimage 10 | 3. PIL 11 | 4. argparse 12 | 5. pyqt5 (For GUI Implementation) 13 | 14 | # How to execute the code : 15 | **Using Command Line :** 16 | 1. You will first have to download the repository and then extract the contents into a folder. 17 | 2. Make sure you have the correct version of Python installed on your machine. This code runs on Python 3.6 above. 18 | 3. Install all dependencies mentioned above. 19 | 4. You can open the folder and run texture_synthesis.py on command prompt. 20 | > `python texture_synthesis.py --image_path --block_size --num_block --mode ` 21 | 22 | For example : 23 | 24 | > `python texture_synthesis.py --image_path data/input1.jpg --block_size 60 --num_block 8 --mode Best` 25 | 26 | The default values of block_size is 50, num_block is 6 and mode is "Cut". 27 | 28 | **Using GUI :** 29 | 1. Go to Texture_Synthesis_GUI folder, and run 30 | > `Texture_Synthesis_GUI.py` 31 | 2. The above file will open GUI then you can run it. 32 | 33 | **The default values of block_size is 50, num_block is 6 and mode is "Cut" and can not be changed.** 34 | # Results 35 | ![Outpu1](https://github.com/Devashi-Choudhary/Texture-Synthesis/blob/master/Results/output1.jpg) 36 | ![Outpu2](https://github.com/Devashi-Choudhary/Texture-Synthesis/blob/master/Results/output2.jpg) 37 | **Note :** For more details of texture synthesis using patch based method, go through [Texture Synthesis : Generating arbitrarily large textures from image patches.](https://medium.com/@Devashi_Choudhary/texture-synthesis-generating-arbitrarily-large-textures-from-image-patches-32dd49e2d637) 38 | 39 | # References 40 | 1. [Image Quilting for Texture Synthesis and Transfer](https://people.eecs.berkeley.edu/~efros/research/quilting/quilting.pdf) 41 | 2. [Image-Quilting-for-Texture-Synthesis](https://github.com/rohitrango/Image-Quilting-for-Texture-Synthesis) 42 | 3. [GUI](https://github.com/Baticsute/Panorama-Simple-App-Demo#using-opencv-library--pyqt5-for-ui--imutils--resize-image) and [Style](https://github.com/mayste/StyleTransfer) 43 | -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/01_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/01_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/01_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/01_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/02_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/02_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/02_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/02_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/04_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/04_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/04_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/04_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/10_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/10_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/10_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/10_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/12_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/12_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/12_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/12_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex01_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex01_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex01_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex01_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex01_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex01_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex01_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex01_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex02_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex02_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex02_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex02_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex02_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex02_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex02_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex02_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex04_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex04_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex04_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex04_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex04_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex04_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex04_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex04_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex10_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex10_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex10_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex10_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex10_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex10_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex10_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex10_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex12_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex12_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex12_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex12_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex12_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex12_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_01_49/tex12_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_01_49/tex12_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/01_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/01_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/01_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/01_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/02_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/02_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/02_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/02_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/04_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/04_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/04_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/04_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/10_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/10_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/10_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/10_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/12_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/12_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/12_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/12_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex01_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex01_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex01_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex01_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex01_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex01_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex01_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex01_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex02_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex02_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex02_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex02_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex02_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex02_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex02_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex02_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex04_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex04_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex04_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex04_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex04_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex04_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex04_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex04_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex10_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex10_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex10_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex10_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex10_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex10_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex10_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex10_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex12_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex12_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex12_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex12_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex12_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex12_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_25_66/tex12_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_25_66/tex12_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/01_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/01_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/01_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/01_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/02_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/02_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/02_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/02_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/04_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/04_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/04_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/04_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/10_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/10_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/10_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/10_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/12_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/12_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/12_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/12_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex01_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex01_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex01_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex01_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex01_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex01_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex01_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex01_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex02_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex02_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex02_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex02_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex02_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex02_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex02_b_old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex02_b_old.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex02_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex02_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex02_b_rand_old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex02_b_rand_old.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex04_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex04_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex04_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex04_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex04_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex04_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex04_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex04_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex10_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex10_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex10_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex10_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex10_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex10_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex10_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex10_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex12_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex12_a.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex12_a_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex12_a_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex12_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex12_b.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/tex_52_56/tex12_b_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/preprocessing/Texture-Synthesis/tex_52_56/tex12_b_rand.png -------------------------------------------------------------------------------- /preprocessing/Texture-Synthesis/texture_synthesis.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import math 4 | from skimage import io, util 5 | import heapq 6 | from skimage.transform import rescale, resize 7 | from PIL import Image 8 | import argparse 9 | import os 10 | import matplotlib.pyplot as plt 11 | import cv2 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("-i", "--image_path", required=True, type=str, help="path of image you want to quilt") 15 | parser.add_argument("-o", "--output", required=True, type=str, help="output_path") 16 | parser.add_argument("-n", "--num_block", type=int, default=6, help="number of blocks you want") 17 | parser.add_argument("-m", "--mode", type=str, default='Cut', help="which mode --random placement of block(Random)/Neighbouring blocks constrained by overlap(Best)/Minimum error boundary cut(Cut)") 18 | args = parser.parse_args() 19 | 20 | 21 | def randomPatch(texture, block_size): 22 | h, w, _ = texture.shape 23 | i = np.random.randint(h - block_size) 24 | j = np.random.randint(w - block_size) 25 | 26 | return texture[i:i+block_size, j:j+block_size] 27 | 28 | def L2OverlapDiff(patch, block_size, overlap, res, y, x): 29 | error = 0 30 | if x > 0: 31 | left = patch[:, :overlap] - res[y:y+block_size, x:x+overlap] 32 | error += np.sum(left**2) 33 | 34 | if y > 0: 35 | up = patch[:overlap, :] - res[y:y+overlap, x:x+block_size] 36 | error += np.sum(up**2) 37 | 38 | if x > 0 and y > 0: 39 | corner = patch[:overlap, :overlap] - res[y:y+overlap, x:x+overlap] 40 | error -= np.sum(corner**2) 41 | 42 | return error 43 | 44 | 45 | def randomBestPatch(texture, block_size, overlap, res, y, x): 46 | h, w, _ = texture.shape 47 | errors = np.zeros((h - block_size, w - block_size)) 48 | 49 | for i in range(h - block_size): 50 | for j in range(w - block_size): 51 | patch = texture[i:i+block_size, j:j+block_size] 52 | e = L2OverlapDiff(patch, block_size, overlap, res, y, x) 53 | errors[i, j] = e 54 | 55 | i, j = np.unravel_index(np.argmin(errors), errors.shape) 56 | return texture[i:i+block_size, j:j+block_size] 57 | 58 | 59 | 60 | def minCutPath(errors): 61 | # dijkstra's algorithm vertical 62 | pq = [(error, [i]) for i, error in enumerate(errors[0])] 63 | heapq.heapify(pq) 64 | 65 | h, w = errors.shape 66 | seen = set() 67 | 68 | while pq: 69 | error, path = heapq.heappop(pq) 70 | curDepth = len(path) 71 | curIndex = path[-1] 72 | 73 | if curDepth == h: 74 | return path 75 | 76 | for delta in -1, 0, 1: 77 | nextIndex = curIndex + delta 78 | 79 | if 0 <= nextIndex < w: 80 | if (curDepth, nextIndex) not in seen: 81 | cumError = error + errors[curDepth, nextIndex] 82 | heapq.heappush(pq, (cumError, path + [nextIndex])) 83 | seen.add((curDepth, nextIndex)) 84 | 85 | 86 | def minCutPatch(patch, block_size, overlap, res, y, x): 87 | patch = patch.copy() 88 | dy, dx, _ = patch.shape 89 | minCut = np.zeros_like(patch, dtype=bool) 90 | 91 | if x > 0: 92 | left = patch[:, :overlap] - res[y:y+dy, x:x+overlap] 93 | leftL2 = np.sum(left**2, axis=2) 94 | for i, j in enumerate(minCutPath(leftL2)): 95 | minCut[i, :j] = True 96 | 97 | if y > 0: 98 | up = patch[:overlap, :] - res[y:y+overlap, x:x+dx] 99 | upL2 = np.sum(up**2, axis=2) 100 | for j, i in enumerate(minCutPath(upL2.T)): 101 | minCut[:i, j] = True 102 | 103 | np.copyto(patch, res[y:y+dy, x:x+dx], where=minCut) 104 | 105 | return patch 106 | 107 | 108 | def quilt(image_path, block_size, num_block, mode, sequence=False): 109 | texture = Image.open(image_path) 110 | texture = util.img_as_float(texture) 111 | 112 | overlap = block_size // 6 113 | num_blockHigh, num_blockWide = num_block 114 | 115 | h = (num_blockHigh * block_size) - (num_blockHigh - 1) * overlap 116 | w = (num_blockWide * block_size) - (num_blockWide - 1) * overlap 117 | 118 | res = np.zeros((h, w, texture.shape[2])) 119 | 120 | for i in range(num_blockHigh): 121 | for j in range(num_blockWide): 122 | y = i * (block_size - overlap) 123 | x = j * (block_size - overlap) 124 | 125 | if i == 0 and j == 0 or mode == "Random": 126 | patch = randomPatch(texture, block_size) 127 | elif mode == "Best": 128 | patch = randomBestPatch(texture, block_size, overlap, res, y, x) 129 | elif mode == "Cut": 130 | patch = randomBestPatch(texture, block_size, overlap, res, y, x) 131 | patch = minCutPatch(patch, block_size, overlap, res, y, x) 132 | 133 | res[y:y+block_size, x:x+block_size] = patch 134 | 135 | image = Image.fromarray((res * 255).astype(np.uint8)) 136 | return image 137 | 138 | if __name__ == "__main__": 139 | image_path = args.image_path 140 | img = cv2.imread(image_path) 141 | # block size calculated instead passed in args 142 | block_size = min(img.shape[0]//10*10 - 5, img.shape[1]//10*10 -5) 143 | num_block = args.num_block 144 | mode = args.mode 145 | img = quilt(image_path, block_size, (num_block, num_block), mode) 146 | img.save(args.output) 147 | 148 | -------------------------------------------------------------------------------- /preprocessing/prepare_segmentation_maps.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Notebook for adjusting IRCAD segmentation maps (labels) to CholecSeg8k." 8 | ] 9 | }, 10 | { 11 | "attachments": {}, 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "## Cholect8kSeg dataset - classes and corresponding values" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 8, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "# Our official color mappig, matching with Cholect -> class : rgb\n", 25 | "\n", 26 | "# Class 0 Black Background #7F7F7F 127 127 127\n", 27 | "# Class 1 Abdominal Wall #D28C8C 210 140 140\n", 28 | "# Class 2 Liver #FF7272 255 114 114\n", 29 | "# Class 3 Gastrointestinal Tract #E7469C 231 70 156\n", 30 | "# Class 4 Fat #BAB74B 186 183 75\n", 31 | "# Class 5 Grasper #AAFF00 170 255 0\n", 32 | "# Class 6 Connective Tissue #FF5500 255 85 0\n", 33 | "# Class 7 Blood #FF0000 255 0 0\n", 34 | "# Class 8 Cystic Duct #FFFF00 255 255 0\n", 35 | "# Class 9 L-hook Electrocautery #A9FFB8 169 255 184\n", 36 | "# Class 10 Gallbladder #FFA0A5255 160 165\n", 37 | "# Class 11 Hepatic Vein #003280 0 50 128\n", 38 | "# Class 12 Liver Ligament #6F4A00 111 74 0\n" 39 | ] 40 | }, 41 | { 42 | "attachments": {}, 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## IRCAD dataset - classes and corresponding values" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# IRCAD classes and corresponding values\n", 56 | "# Liver 26 26 26 \n", 57 | "# Abdominal Wall 77 77 77 \n", 58 | "# Fat 102, 102, 102 \n", 59 | "# Gallblader 51 51 51 \n", 60 | "# tool shaft 179 179 179 \n", 61 | "# tool tip 153 153 153\n", 62 | "# liver ligament 128 128 128" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "## Assign Cholect8kSeg classes to IRCAD classes. Segmentation maps will be saved as greyscale." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "import numpy as np\n", 79 | "from tqdm import tqdm \n", 80 | "import os\n", 81 | "import torch\n", 82 | "import glob\n", 83 | "import cv2\n", 84 | "from PIL import Image\n", 85 | "import matplotlib.pyplot as plt" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 1, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "\n", 95 | "speidel_mapping = {\n", 96 | "2: (26, 26, 26),\n", 97 | "1: (77,77,77),\n", 98 | "4:(102, 102, 102),\n", 99 | "10:(51, 51, 51),\n", 100 | "5:(179, 179, 179),\n", 101 | "12:(128,128,128)\n", 102 | "}\n", 103 | "\n", 104 | "speidel_mapping_r = {v: k for k, v in speidel_mapping.items()}\n", 105 | "speidel_mapping_r[(153,153,153)] = 5 # tool tip goes as tool\n" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 3, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stderr", 115 | "output_type": "stream", 116 | "text": [ 117 | "100%|██████████| 20000/20000 [01:21<00:00, 244.19it/s]\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "ircad_labels_path = 'path/to/ircad/labels' # original IRCAD segmentation maps are in folder `labels`\n", 123 | "dest_path = \"path/to/new/segmentation/maps\"\n", 124 | "\n", 125 | "images = [v for v in glob.glob(f\"{ircad_labels_path}/*.png\")]\n", 126 | "\n", 127 | "if not os.path.exists(dest_path):\n", 128 | " os.makedirs(dest_path)\n", 129 | "\n", 130 | "for im in tqdm(images):\n", 131 | " img = cv2.imread(im)\n", 132 | " img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", 133 | " \n", 134 | " seg_map_tensor = torch.from_numpy(np.array(img)).float().cuda()\n", 135 | "\n", 136 | " # Convert the RGB segmentation map to grayscale using the mapping dictionary\n", 137 | " gray_seg_map = torch.zeros(seg_map_tensor.shape[:2]).cuda()\n", 138 | " for color, class_id in speidel_mapping_r.items():\n", 139 | " mask = (seg_map_tensor == torch.tensor(color).cuda().view(1, 1, -1)).all(dim=2)\n", 140 | " gray_seg_map[mask] = class_id\n", 141 | "\n", 142 | " # save greyscale img\n", 143 | " output_path = os.path.join(dest_path,os.path.basename(im))\n", 144 | " cv2.imwrite(output_path, gray_seg_map.cpu().numpy())\n" 145 | ] 146 | }, 147 | { 148 | "attachments": {}, 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "### and sanity check" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 4, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "name": "stdout", 162 | "output_type": "stream", 163 | "text": [ 164 | "(256, 452)\n", 165 | "(3, 256, 452)\n", 166 | "(256, 452, 3)\n" 167 | ] 168 | }, 169 | { 170 | "data": { 171 | "text/plain": [ 172 | "Text(0.5, 1.0, 'restored')" 173 | ] 174 | }, 175 | "execution_count": 4, 176 | "metadata": {}, 177 | "output_type": "execute_result" 178 | }, 179 | { 180 | "data": { 181 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAz8AAAEHCAYAAACA4/iqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAtGElEQVR4nO3deXRUZZ7/8U9VUlVUQqXIXikSIGxuCdAGWYItSyDIADaiR1pnumF0HDc45oBji8xp6Y3YzLhMt42e6WEI4oLdI9AqTg+hkSgHnGFtAWcYnUYJ3YlRliTEmLA8vz/6cH8WCZAKSW4t79c59xxz71OV730S68O3cuu5DmOMEQAAAADEOKfdBQAAAABAT6D5AQAAABAXaH4AAAAAxAWaHwAAAABxgeYHAAAAQFyg+QEAAAAQF2h+AAAAAMQFmh8AAAAAcYHmBwAAAEBcoPkBbFRRUSGHw6FPPvnE7lIAADFu2bJl2rBhg91ltGvp0qVyOBx2l4E4QPMD2Gj69OnasWOHcnJy7C4FABDjIrn5AXpKot0FAPGoublZvXr1UmZmpjIzM+0uBwAQgb788kslJSXZXcYlnT59Wg6HQ4mJ/JMS0YG//ABXaNu2bSopKZHP51NSUpKKi4u1ceNG6/j5S9s2bdqku+++W5mZmUpKSlJLS0u7l70ZY7Rs2TL1799fvXr10siRI1VZWakJEyZowoQJPX+CAIBud/6yrz179uj2229XamqqBg0aJGOMVqxYoREjRsjr9So1NVW33367/vCHP4Q8fu/evZoxY4aysrLk8XgUDAY1ffp0HT16VJLkcDjU1NSk1atXy+FwyOFwhGTKgQMH9K1vfUupqanq1auXRowYodWrV4d8j61bt8rhcGjNmjVatGiR+vbtK4/Ho48//liStHnzZpWUlCglJUVJSUkaN26cfve737U5140bN2rEiBHyeDzKz8/XP/7jP3bxbAIXR/MDXIGqqipNmjRJ9fX1WrlypV599VX5fD7NnDlTr732WsjYu+++Wy6XS2vWrNG//du/yeVytfucS5Ys0ZIlS3TzzTfrN7/5je6//379zd/8jf73f/+3J04JAGCj2bNna/Dgwfr1r3+tF154Qffdd5/Kyso0efJkbdiwQStWrNDBgwdVXFyszz77TJLU1NSkKVOm6LPPPtMvfvELVVZW6tlnn1W/fv3U2NgoSdqxY4e8Xq/+4i/+Qjt27NCOHTu0YsUKSdKhQ4dUXFysgwcP6mc/+5nWrVuna6+9VvPmzdPy5cvb1Lh48WIdOXJEL7zwgt58801lZWXppZdeUmlpqVJSUrR69Wr96le/UlpamqZOnRrSAP3ud7/Tt771Lfl8Pq1du1b/8A//oF/96ldatWpVD8wuIMkA6LQxY8aYrKws09jYaO07c+aMKSgoMLm5uebcuXNm1apVRpL57ne/2+bx548dPnzYGGPM8ePHjcfjMXPmzAkZt2PHDiPJjB8/vjtPBwBgkyeeeMJIMt///vetfedf+5966qmQsdXV1cbr9ZpHH33UGGPMrl27jCSzYcOGS36P5ORkM3fu3Db7v/3tbxuPx2OOHDkSsn/atGkmKSnJnDx50hhjzDvvvGMkmZtuuilkXFNTk0lLSzMzZ84M2X/27FkzfPhwM2rUKGvf6NGjTTAYNM3Nzda+hoYGk5aWZvhnKXoCf/kBOqmpqUn/+Z//qdtvv129e/e29ickJOg73/mOjh49qkOHDln7b7vttss+5/vvv6+WlhbdcccdIfvHjBmjAQMGdFntAIDI9PWseOutt+RwOPRXf/VXOnPmjLUFAgENHz5cW7dulSQNHjxYqamp+t73vqcXXnhBH374YVjfc8uWLSopKVFeXl7I/nnz5unLL7/Ujh07LlqjJG3fvl3Hjx/X3LlzQ+o8d+6cbr75Zu3cuVNNTU1qamrSzp07NXv2bPXq1ct6/PkrJoCewKfTgE46ceKEjDHtrtQWDAYlSceOHbP2dWRFt/Pjs7Oz2xxrbx8AILZ8PSs+++wzGWMu+vo/cOBASZLf71dVVZV+8pOf6PHHH9eJEyeUk5Oje++9V3//939/0cuszzt27FiHs+zCGs/XKUm33377Rb/H8ePH5XA4dO7cOQUCgTbH29sHdAeaH6CTUlNT5XQ6VVNT0+bYn/70J0lSRkaGPvroI0nq0P0L0tPTJf3/IPm62tpa/voDADHu61mRkZEhh8Oh9957Tx6Pp83Yr+8rLCzU2rVrZYzRBx98oIqKCv3whz+U1+vVY489dsnvmZ6eftksu1iNXz/+85//XGPGjGn3e2RnZ1srw9XW1rY53t4+oDtw2RvQScnJyRo9erTWrVun5uZma/+5c+f00ksvKTc3V0OHDg3rOUePHi2Px9NmsYT3339fn376aZfUDQCIDjNmzJAxRn/84x81cuTINlthYWGbxzgcDg0fPlzPPPOM+vTpoz179ljHPB5PSF6dV1JSoi1btljNznkvvviikpKSLtrQnDdu3Dj16dNHH374Ybt1jhw5Um63W8nJyRo1apTWrVunr776ynp8Y2Oj3nzzzXCnB+gU/vIDXIHy8nJNmTJFEydO1COPPCK3260VK1bowIEDevXVV8O+W3VaWpoWLlyo8vJypaam6tZbb9XRo0f1gx/8QDk5OXI6eb8CAOLFuHHj9Ld/+7f667/+a+3atUs33XSTkpOTVVNTo23btqmwsFAPPPCA3nrrLa1YsUKzZs3SwIEDZYzRunXrdPLkSU2ZMsV6vsLCQm3dulVvvvmmcnJy5PP5dNVVV+mJJ57QW2+9pYkTJ+r73/++0tLS9PLLL2vjxo1avny5/H7/Jevs3bu3fv7zn2vu3Lk6fvy4br/9dmVlZenzzz/X73//e33++ed6/vnnJUk/+tGPdPPNN2vKlClatGiRzp49q5/+9KdKTk7W8ePHu3U+AUksqwFcqffee89MmjTJJCcnG6/Xa8aMGWPefPNN6/j5Fd127tzZ5rEXrvZmjDHnzp0zP/7xj01ubq5xu91m2LBh5q233jLDhw83t956a0+cEgCgh51f7e3zzz9vc+xf//VfzejRo62cGTRokPnud79rdu3aZYwx5n/+53/MnXfeaQYNGmS8Xq/x+/1m1KhRpqKiIuR59u3bZ8aNG2eSkpLarCC6f/9+M3PmTOP3+43b7TbDhw83q1atCnn8+dXefv3rX7d7DlVVVWb69OkmLS3NuFwu07dvXzN9+vQ249944w0zbNgw43a7Tb9+/cyTTz5pnT/Q3RzGGGNn8wXg8g4fPqyrr75aTzzxhB5//HG7ywEAAIhKND9AhPn973+vV199VcXFxUpJSdGhQ4e0fPlyNTQ06MCBA6z6BgAA0El85geIMMnJydq1a5dWrlypkydPyu/3a8KECfrJT35C4wMAAHAF+MsPAAAAgLhg69JRK1asUH5+vnr16qWioiK99957dpYDAIhz5BIAxDbbmp/XXntNZWVlWrJkifbu3atvfvObmjZtmo4cOWJXSQCAOEYuAUDss+2yt9GjR+v666+31n2XpGuuuUazZs1SeXm5HSUBAOIYuQQAsc+WBQ9aW1u1e/duPfbYYyH7S0tLtX379jbjW1pa1NLSYn197tw5HT9+XOnp6WHfRBIAcGWMMWpsbFQwGIyZG++Gm0sS2QQAkSKcXLKl+fniiy909uzZNitXZWdnq7a2ts348vJy/eAHP+ip8gAAHVBdXa3c3Fy7y+gS4eaSRDYBQKTpSC7ZutT1he+MGWPafbds8eLFWrhwofV1fX29+vXrp6ysrJh51xGA/a6//nplZWXZXUbEa21t1SuvvCKfz2d3KV2uo7kkkU0AegbZdHnh5JItzU9GRoYSEhLavJtWV1fX7n1MPB6PPB5Pm/1Op5OAAdAlHA6HcnJylJCQYHcpUSOWLu0KN5cksglA9yObwtORXLLl1dntdquoqEiVlZUh+ysrK1VcXGxHSQDiXGpqakz9Yx7hIZcARCKyqevZdtnbwoUL9Z3vfEcjR47U2LFj9c///M86cuSI7r//frtKAhDHrrnmGt6tj3PkEoBIQzZ1Pduanzlz5ujYsWP64Q9/qJqaGhUUFOjtt99W//797SoJABDHyCUAiH223efnSjQ0NMjv9ysQCNANA7hiV111lQYPHsylBR3U2tqqiooK1dfXKyUlxe5yIgbZBKArkU0dF04u8eoMIK55vV716dOHcAEARAyyqfvQ/ACIaykpKcrMzLS7DAAALGRT96H5ARDXWD4UABBpyKbuQ/MDIG653W594xvfsLsMAAAsZFP3ovkBELdYxQsAEGnIpu5F8wMgbvXv358PkwIAIgrZ1L1su88PANglMTFRI0aMkMfjsbsUAAAkkU09heYHQNzJy8tTIBCwuwwAACxkU8/gsjcAAAAAcYHmB0Bc8fl8Gjx4sN1lAABgIZt6Ds0PgLjidDq5nhoAEFHIpp5D8wMAAAAgLtD8AAAAAIgLND8A4obD4dCoUaPsLgMAAAvZ1LNofgDEDYfDIZfLZXcZAABYyKaeRfMDIG4MHDiQu2YDACIK2dSzaH4AxI2cnBwCBgAQUcimnkXzAyAu9OrVSwkJCXaXAQCAhWzqeTQ/AOLCoEGD1Lt3b7vLAADAQjb1PJofAAAAAHGB5gdAzMvOzlZeXp7dZQAAYCGb7EHzAyDmOZ1OJSYm2l0GAAAWsskeND8AYlpiYqL69u1rdxkAAFjIJvvQ/ACIaYmJiQoEAnaXAQCAhWyyD80PgJiWlJRkdwkAAIQgm+xD8wMgpo0cOdLuEgAACEE22YfmBwAAAEBcoPkBELO4azYAINKQTfbq8uZn6dKlcjgcIdvXP9BljNHSpUsVDAbl9Xo1YcIEHTx4sKvLAABdf/31crlcdpeBCEA2AYgUZJO9uuUvP9ddd51qamqsbf/+/dax5cuX6+mnn9Zzzz2nnTt3KhAIaMqUKWpsbOyOUgDEqbS0NPXu3VsOh8PuUhAhyCYAdiOb7Nctzc/55fvOb5mZmZL+/M7as88+qyVLlmj27NkqKCjQ6tWr9eWXX+qVV17pjlIAxKnU1FQlJyfbXQYiCNkEwG5kk/26pfn56KOPFAwGlZ+fr29/+9v6wx/+IEk6fPiwamtrVVpaao31eDwaP368tm/fftHna2lpUUNDQ8gGAEA4yCYAQJc3P6NHj9aLL76o//iP/9Avf/lL1dbWqri4WMeOHVNtba0kKTs7O+Qx2dnZ1rH2lJeXy+/3W1teXl5Xlw0ghrjdbvn9frvLQAQhmwDYjWyKDF3e/EybNk233XabCgsLNXnyZG3cuFGStHr1amvMhdc5GmMuee3j4sWLVV9fb23V1dVdXTaAGJKcnKxgMGh3GYggZBMAu5FNkaHbl7pOTk5WYWGhPvroI2tlnQvfSaurq2vzjtvXeTwepaSkhGwAAHQW2QQA8anbm5+Wlhb993//t3JycpSfn69AIKDKykrreGtrq6qqqlRcXNzdpQCIE/wjFJdDNgHoaWRTZEjs6id85JFHNHPmTPXr1091dXX68Y9/rIaGBs2dO1cOh0NlZWVatmyZhgwZoiFDhmjZsmVKSkrSXXfd1dWlAIhDCQkJKigosLsMRBiyCYCdyKbI0eXNz9GjR3XnnXfqiy++UGZmpsaMGaP3339f/fv3lyQ9+uijam5u1oMPPqgTJ05o9OjR2rRpk3w+X1eXAgCAJLIJAPBnDmOMsbuIcDU0NMjv9ysQCMjp7PYr9wBEkXHjxqlPnz7cQK4btba2qqKiQvX19VzG8TVkE4CLIZu6Vzi5xKszgJiSmJhIuAAAIgrZFDlofgDEjGAwKI/HY3cZAABYyKbIQvMDIGZkZGTI7XbbXQYAABayKbLQ/ACICS6Xi3ABAEQUsiny0PwAiAlZWVnWzSoBAIgEZFPkofkBAAAAEBdofgBEPa/Xy83jAAARhWyKTDQ/AKKew+GQy+WyuwwAACxkU2Si+QEAAAAQF2h+AES9ESNG2F0CAAAhyKbIRPMDIOr5fD67SwAAIATZFJlofgBEtaysLDmdvJQBACIH2RS5+KkAiGoDBw5UQkKC3WUAAGAhmyIXzQ+AqOV0OuVwOOwuAwAAC9kU2Wh+AEStwYMHKy0tze4yAACwkE2RjeYHQFTj3TUAQKQhmyIXzQ8AAACAuEDzAyAqOZ1O7pwNAIgoZFPko/kBEJXS0tKUn59vdxkAAFjIpshH8wMg6jidTsIFABBRyKboQPMDIOo4nU5lZWXZXQYAABayKTrQ/AAAAACICzQ/AKJOXl6e3SUAABCCbIoOND8Aok5+fj73UAAARBSyKTrQ/ACIKi6Xi3ABAEQUsil60PwAiCrXXXedvF6v3WUAAGAhm6IHzQ8AAACAuEDzAyBqBAIBlhEFAEQUsim6hN38vPvuu5o5c6aCwaAcDoc2bNgQctwYo6VLlyoYDMrr9WrChAk6ePBgyJiWlhYtWLBAGRkZSk5O1i233KKjR49e0YkAiH1ut1tut9vuMhBhyCUAdiKbokvYzU9TU5OGDx+u5557rt3jy5cv19NPP63nnntOO3fuVCAQ0JQpU9TY2GiNKSsr0/r167V27Vpt27ZNp06d0owZM3T27NnOnwmAmOZyuZSRkWF3GYhA5BIAu5BN0Scx3AdMmzZN06ZNa/eYMUbPPvuslixZotmzZ0uSVq9erezsbL3yyiu67777VF9fr5UrV2rNmjWaPHmyJOmll15SXl6eNm/erKlTp17B6QCIVR6PR8Fg0O4yEIHIJQB2IZuiT5d+5ufw4cOqra1VaWmptc/j8Wj8+PHavn27JGn37t06ffp0yJhgMKiCggJrzIVaWlrU0NAQsgGILy6Xy+4SEIW6K5cksgkA2RSNurT5qa2tlSRlZ2eH7M/OzraO1dbWyu12KzU19aJjLlReXi6/329t3EEXiD9jxoyxuwREoe7KJYlsAkA2RaNuWe3twps8GWMue+OnS41ZvHix6uvrra26urrLagUAxL6uziWJbAKAaNSlzU8gEJCkNu+U1dXVWe+6BQIBtba26sSJExcdcyGPx6OUlJSQDUD8KCwslNPJyvwIX3flkkQ2AfGObIpOXfoTy8/PVyAQUGVlpbWvtbVVVVVVKi4uliQVFRXJ5XKFjKmpqdGBAwesMQDwdT6f77Lv0gPtIZcAdBeyKTqFvdrbqVOn9PHHH1tfHz58WPv27VNaWpr69eunsrIyLVu2TEOGDNGQIUO0bNkyJSUl6a677pIk+f1+3XPPPVq0aJHS09OVlpamRx55RIWFhdYqOwAAdBS5BADoqLCbn127dmnixInW1wsXLpQkzZ07VxUVFXr00UfV3NysBx98UCdOnNDo0aO1adMm+Xw+6zHPPPOMEhMTdccdd6i5uVklJSWqqKhQQkJCF5wSACCekEsAgI5yGGOM3UWEq6GhQX6/X4FAgGstgRg3aNAgDR06lH+ERpDW1lZVVFSovr6ez7l8DdkExA+yKbKEk0u8OgOIaC6Xi3ABAEQUsil60fwAiFjJyclKS0uzuwwAACxkU3Sj+QEQsQgYAECkIZuiG80PgIjl8XjsLgEAgBBkU3Sj+QEQkdxut4YNG2Z3GQAAWMim6EfzAwAAACAu0PwAiEgjR460uwQAAEKQTdGP5gdARPJ6vXI4HHaXAQCAhWyKfjQ/ACJOdna2EhMT7S4DAAAL2RQbaH4ARJzc3Fy5XC67ywAAwEI2xQaaHwARJSEhQU4nL00AgMhBNsUOfooAIkpeXp6ys7PtLgMAAAvZFDtofgAAAADEBZofABGjd+/eGjRokN1lAABgIZtiC80PgIjhdDrl9XrtLgMAAAvZFFtofgBEjGAwaHcJAACEIJtiC80PgIjgcDg0cOBAu8sAAMBCNsUemh8AAAAAcYHmBwAAAEBcoPkBEBEGDBggh8NhdxkAAFjIpthD8wMgIvTt25eAAQBEFLIp9tD8ALCdx+NRQkKC3WUAAGAhm2ITzQ8A2w0ePFg+n8/uMgAAsJBNsYnmBwAAAEBcoPkBYKvMzEzl5ubaXQYAABayKXbR/ACwVUJCglwul91lAABgIZtiF80PANskJCQoGAzaXQYAABayKbbR/ACwTWJiIgEDAIgoZFNsC7v5effddzVz5kwFg0E5HA5t2LAh5Pi8efPkcDhCtjFjxoSMaWlp0YIFC5SRkaHk5GTdcsstOnr06BWdCIDo4/V67S4BMYBcAtCVyKbYFnbz09TUpOHDh+u555676Jibb75ZNTU11vb222+HHC8rK9P69eu1du1abdu2TadOndKMGTN09uzZ8M8AQNS64YYb7C4BMYBcAtCVyKbYlhjuA6ZNm6Zp06ZdcozH41EgEGj3WH19vVauXKk1a9Zo8uTJkqSXXnpJeXl52rx5s6ZOnRpuSQCAOEYuAQA6qls+87N161ZlZWVp6NChuvfee1VXV2cd2717t06fPq3S0lJrXzAYVEFBgbZv397u87W0tKihoSFkAxDdCgoKWEkHPaarc0kim4BYRDbFvi5vfqZNm6aXX35ZW7Zs0VNPPaWdO3dq0qRJamlpkSTV1tbK7XYrNTU15HHZ2dmqra1t9znLy8vl9/utLS8vr6vLBtDDvF6vnE7WXEH3645cksgmIBaRTbEv7MveLmfOnDnWfxcUFGjkyJHq37+/Nm7cqNmzZ1/0ccYYORyOdo8tXrxYCxcutL5uaGggZAAAHdIduSSRTQAQjbq9tc3JyVH//v310UcfSZICgYBaW1t14sSJkHF1dXXKzs5u9zk8Ho9SUlJCNgDRKxAIqE+fPnaXgTjVFbkkkU1ArCGb4kO3Nz/Hjh1TdXW1cnJyJElFRUVyuVyqrKy0xtTU1OjAgQMqLi7u7nIA2MzhcMjv98vj8dhdCuIUuQTgQmRT/Aj7srdTp07p448/tr4+fPiw9u3bp7S0NKWlpWnp0qW67bbblJOTo08++USPP/64MjIydOutt0qS/H6/7rnnHi1atEjp6elKS0vTI488osLCQmuVHQCxq0+fPhoyZIjdZSCGkEsArhTZFD/Cbn527dqliRMnWl+fv9557ty5ev7557V//369+OKLOnnypHJycjRx4kS99tpr8vl81mOeeeYZJSYm6o477lBzc7NKSkpUUVGhhISELjglAEA8IZcAAB3lMMYYu4sIV0NDg/x+vwKBACtyAFEmNTVV48aNs7sMXIHW1lZVVFSovr6ez7l8DdkERC+yKbqFk0u8OgPoUcnJyXaXAABACLIpftD8AOgxTqdTw4YNs7sMAAAsZFN8ofkBAAAAEBdofgD0mNGjR1/yppEAAPQ0sim+0PwA6BE+n09JSUkEDAAgYpBN8YfmB0CPGDBggLxer91lAABgIZviD80PgG6XmZmpQCBgdxkAAFjIpvhE8wOg27ndbnk8HrvLAADAQjbFJ5ofAN0qISFB6enpdpcBAICFbIpfND8AupXb7Va/fv3sLgMAAAvZFL9ofgB0q1GjRtldAgAAIcim+EXzA6Db+P1+VtEBAEQUsim+0fwA6DZXX321EhMT7S4DAAAL2RTfaH4AAAAAxAWaHwDdYtCgQaykAwCIKGQTaH4AdAun0ymnk5cYAEDkIJvATx8AAABAXKD5AdDl+vTpowEDBthdBgAAFrIJEs0PgG7Qp08feTweu8sAAMBCNkGi+QHQxRITE3XdddfZXQYAABayCefR/ADoUoQLACDSkE04j+YHQJdxuVxKT0+Xw+GwuxQAACSRTQhF8wOgy1x77bVKSkqyuwwAACxkE76O5gcAAABAXKD5AdAlsrOzlZ2dbXcZAABYyCZciOYHwBVzOp3y+/1yu912lwIAgCSyCe2j+QFwxbxer4YOHWp3GQAAWMgmtIfmB8AVu+aaa+wuAQCAEGQT2hNW81NeXq4bbrhBPp9PWVlZmjVrlg4dOhQyxhijpUuXKhgMyuv1asKECTp48GDImJaWFi1YsEAZGRlKTk7WLbfcoqNHj1752QCwRUZGht0lII6RTQDaQzahPWE1P1VVVXrooYf0/vvvq7KyUmfOnFFpaamampqsMcuXL9fTTz+t5557Tjt37lQgENCUKVPU2NhojSkrK9P69eu1du1abdu2TadOndKMGTN09uzZrjszAD0iNTWVeyfAVmQTgAuRTbgYhzHGdPbBn3/+ubKyslRVVaWbbrpJxhgFg0GVlZXpe9/7nqQ/v5OWnZ2tn/70p7rvvvtUX1+vzMxMrVmzRnPmzJEk/elPf1JeXp7efvttTZ069bLft6GhQX6/X4FAQE4nV+4BdiouLlZaWprdZaAHtba2qqKiQvX19UpJSbG7nDbIJgBkU3wJJ5eu6NW5vr5ekqxfrsOHD6u2tlalpaXWGI/Ho/Hjx2v79u2SpN27d+v06dMhY4LBoAoKCqwxF2ppaVFDQ0PIBgBAe8gmAMDFdLr5McZo4cKFuvHGG1VQUCBJqq2tlaQ266lnZ2dbx2pra+V2u5WamnrRMRcqLy+X3++3try8vM6WDaALud1uJSQk2F0GYCGbAJBNuJRONz/z58/XBx98oFdffbXNsQuvsTTGXPa6y0uNWbx4serr662turq6s2UD6EL5+fny+/12lwFYyCYAZBMupVPNz4IFC/TGG2/onXfeUW5urrU/EAhIUpt3yerq6qx33AKBgFpbW3XixImLjrmQx+NRSkpKyAbAXj6fz/p/HogEZBMAsgmXE1bzY4zR/PnztW7dOm3ZskX5+fkhx/Pz8xUIBFRZWWnta21tVVVVlYqLiyVJRUVFcrlcIWNqamp04MABawyAyJecnCyfz2d3GQDZBMBCNuFyEsMZ/NBDD+mVV17Rb37zG/l8PutdNL/fL6/XK4fDobKyMi1btkxDhgzRkCFDtGzZMiUlJemuu+6yxt5zzz1atGiR0tPTlZaWpkceeUSFhYWaPHly158hgC7ncrlUVFRkdxmAJLIJwJ+RTeiIsJqf559/XpI0YcKEkP2rVq3SvHnzJEmPPvqompub9eCDD+rEiRMaPXq0Nm3aFNKFP/PMM0pMTNQdd9yh5uZmlZSUqKKigg+nAVGif//+dpcAWMgmABLZhI65ovv82IV7KQD2KikpkdfrtbsM2CTS7/NjF7IJsBfZFL967D4/AOKP2+3mrtkAgIhCNqGjaH4AhKWgoEC9evWyuwwAACxkEzqK5gdAh6Wnp3PvBABARCGbEA6aHwAdlpKSouTkZLvLAADAQjYhHDQ/ADrE5/Pp2muvtbsMAAAsZBPCRfMDoMP4MCkAINKQTQgHzQ+Ay3I4HLruuuvsLgMAAAvZhM6g+QHQIampqXaXAABACLIJ4aL5AXBZGRkZXFYAAIgoZBM6g+YHwGVdddVV3LEeABBRyCZ0Br8xAC6Jd9UAAJGGbEJn0fwAuKSrr76am8cBACIK2YTOovkBcFFJSUnq06cP77ABACIG2YQrQfMD4KJSUlKUnp5udxkAAFjIJlwJmh8A7XK73Ro2bJjdZQAAYCGbcKVofgC0Kz09XW632+4yAACwkE24UjQ/ANpVUFBgdwkAAIQgm3ClaH4AtHHNNdfI5XLZXQYAABayCV0h0e4CAEQOp9OpoUOHauDAgayiAwCICGQTuhLNDwBLSkqKBg8ebHcZAABYyCZ0JS57AyDpz3fLHjJkiN1lAABgIZvQ1fjLDwA5nU7dcMMNysjIsLsUAAAkkU3oHjQ/QJw7f8+EjIwMrqUGAEQEsgndheYHiHMZGRkKBAJ2lwEAgIVsQnfhMz9AHPN4PFxLDQCIKGQTuhN/+QHilMfj0Te/+U316tXL7lIAAJBENqH70fwAccjn8+kb3/gG4QIAiBhkE3oCl70BcSg7O1spKSl2lwEAgIVsQk8Iq/kpLy/XDTfcIJ/Pp6ysLM2aNUuHDh0KGTNv3jw5HI6QbcyYMSFjWlpatGDBAmVkZCg5OVm33HKLjh49euVnA+CyUlNTlZ+fb3cZQJchm4DoRzahp4TV/FRVVemhhx7S+++/r8rKSp05c0alpaVqamoKGXfzzTerpqbG2t5+++2Q42VlZVq/fr3Wrl2rbdu26dSpU5oxY4bOnj175WcE4KJSU1M1duxYeTweu0sBugzZBEQ3sgk9KazP/Pz2t78N+XrVqlXKysrS7t27ddNNN1n7PR7PRZcnrK+v18qVK7VmzRpNnjxZkvTSSy8pLy9Pmzdv1tSpU8M9BwAdkJWVpREjRsjp5GpXxBayCYheZBN62hX9ptXX10uS0tLSQvZv3bpVWVlZGjp0qO69917V1dVZx3bv3q3Tp0+rtLTU2hcMBlVQUKDt27e3+31aWlrU0NAQsgHouIyMDBUWFsrtdttdCtDtyCYgOpBNsEOnmx9jjBYuXKgbb7xRBQUF1v5p06bp5Zdf1pYtW/TUU09p586dmjRpklpaWiRJtbW1crvdSk1NDXm+7Oxs1dbWtvu9ysvL5ff7rS0vL6+zZQNxJykpSSNHjpTX67W7FKDbkU1AdCCbYJdOL3U9f/58ffDBB9q2bVvI/jlz5lj/XVBQoJEjR6p///7auHGjZs+efdHnM8bI4XC0e2zx4sVauHCh9XVDQwMhA3RQZmamEhNZ1R7xgWwCogPZBLt06rduwYIFeuONN/Tuu+8qNzf3kmNzcnLUv39/ffTRR5KkQCCg1tZWnThxIuQdtrq6OhUXF7f7HB6Phw/BAZ2Qn5+va665xu4ygB5BNgHRgWyCncJqfowxWrBggdavX6+tW7d2aEnCY8eOqbq6Wjk5OZKkoqIiuVwuVVZW6o477pAk1dTU6MCBA1q+fHmH65Ckc+fOhVM+EFfy8vKUn5+vM2fO2F0KYkxra6uk//9abDeyCYgeZBO6Q1i5ZMLwwAMPGL/fb7Zu3Wpqamqs7csvvzTGGNPY2GgWLVpktm/fbg4fPmzeeecdM3bsWNO3b1/T0NBgPc/9999vcnNzzebNm82ePXvMpEmTzPDhw82ZM2c6VEd1dbWRxMbGxsZm41ZdXR1OhHSbSMmm//u//7P9Z8LGxsYWz1tHcslhTMffurvYdc+rVq3SvHnz1NzcrFmzZmnv3r06efKkcnJyNHHiRP3oRz8KuQ76q6++0t/93d/plVdeUXNzs0pKSrRixYoOXyt97tw5HTp0SNdee62qq6u5G3AYzl+TzryFh3nrHOYtfNEwZ8YYNTY2KhgMRsTytJGSTSdPnlRqaqqOHDkiv9/fJecWD6Lhdz4SMW+dw7x1TqTPWzi5FFbzE0kaGhrk9/tVX18fkT+ESMW8dQ7z1jnMW/iYs+jFz65zmLfOYd46h3nrnFiaN/vfsgMAAACAHkDzAwAAACAuRG3z4/F49MQTT7DMaJiYt85h3jqHeQsfcxa9+Nl1DvPWOcxb5zBvnRNL8xa1n/kBAAAAgHBE7V9+AAAAACAcND8AAAAA4gLNDwAAAIC4QPMDAAAAIC7Q/AAAAACIC1HZ/KxYsUL5+fnq1auXioqK9N5779ldkq3effddzZw5U8FgUA6HQxs2bAg5bozR0qVLFQwG5fV6NWHCBB08eDBkTEtLixYsWKCMjAwlJyfrlltu0dGjR3vwLHpWeXm5brjhBvl8PmVlZWnWrFk6dOhQyBjmra3nn39ew4YNU0pKilJSUjR27Fj9+7//u3WcOeuY8vJyORwOlZWVWfuYu+hHNoUim8JHNnUO2dQ14iabTJRZu3atcblc5pe//KX58MMPzcMPP2ySk5PNp59+andptnn77bfNkiVLzOuvv24kmfXr14ccf/LJJ43P5zOvv/662b9/v5kzZ47JyckxDQ0N1pj777/f9O3b11RWVpo9e/aYiRMnmuHDh5szZ8708Nn0jKlTp5pVq1aZAwcOmH379pnp06ebfv36mVOnTlljmLe23njjDbNx40Zz6NAhc+jQIfP4448bl8tlDhw4YIxhzjriv/7rv8yAAQPMsGHDzMMPP2ztZ+6iG9nUFtkUPrKpc8imKxdP2RR1zc+oUaPM/fffH7Lv6quvNo899phNFUWWCwPm3LlzJhAImCeffNLa99VXXxm/329eeOEFY4wxJ0+eNC6Xy6xdu9Ya88c//tE4nU7z29/+tsdqt1NdXZ2RZKqqqowxzFs4UlNTzb/8y78wZx3Q2NhohgwZYiorK8348eOtgGHuoh/ZdGlkU+eQTZ1HNnVcvGVTVF321traqt27d6u0tDRkf2lpqbZv325TVZHt8OHDqq2tDZkzj8ej8ePHW3O2e/dunT59OmRMMBhUQUFB3MxrfX29JCktLU0S89YRZ8+e1dq1a9XU1KSxY8cyZx3w0EMPafr06Zo8eXLIfuYuupFN4eN3vmPIpvCRTeGLt2xKtLuAcHzxxRc6e/assrOzQ/ZnZ2ertrbWpqoi2/l5aW/OPv30U2uM2+1WampqmzHxMK/GGC1cuFA33nijCgoKJDFvl7J//36NHTtWX331lXr37q3169fr2muvtV7kmLP2rV27Vnv27NHOnTvbHOP3LbqRTeHjd/7yyKbwkE2dE4/ZFFXNz3kOhyPka2NMm30I1Zk5i5d5nT9/vj744ANt27atzTHmra2rrrpK+/bt08mTJ/X6669r7ty5qqqqso4zZ21VV1fr4Ycf1qZNm9SrV6+LjmPuohvZFD5+5y+ObAoP2RS+eM2mqLrsLSMjQwkJCW06ybq6ujZdKf4sEAhI0iXnLBAIqLW1VSdOnLjomFi1YMECvfHGG3rnnXeUm5tr7WfeLs7tdmvw4MEaOXKkysvLNXz4cP3TP/0Tc3YJu3fvVl1dnYqKipSYmKjExERVVVXpZz/7mRITE61zZ+6iE9kUPl4vLo1sCh/ZFL54zaaoan7cbreKiopUWVkZsr+yslLFxcU2VRXZ8vPzFQgEQuastbVVVVVV1pwVFRXJ5XKFjKmpqdGBAwdidl6NMZo/f77WrVunLVu2KD8/P+Q489Zxxhi1tLQwZ5dQUlKi/fv3a9++fdY2cuRI/eVf/qX27dungQMHMndRjGwKH68X7SObug7ZdHlxm009t7ZC1zi/nOjKlSvNhx9+aMrKykxycrL55JNP7C7NNo2NjWbv3r1m7969RpJ5+umnzd69e60lVp988knj9/vNunXrzP79+82dd97Z7jKFubm5ZvPmzWbPnj1m0qRJEb1M4ZV64IEHjN/vN1u3bjU1NTXW9uWXX1pjmLe2Fi9ebN59911z+PBh88EHH5jHH3/cOJ1Os2nTJmMMcxaOr6+oYwxzF+3IprbIpvCRTZ1DNnWdeMimqGt+jDHmF7/4henfv79xu93m+uuvt5aAjFfvvPOOkdRmmzt3rjHmz0sVPvHEEyYQCBiPx2Nuuukms3///pDnaG5uNvPnzzdpaWnG6/WaGTNmmCNHjthwNj2jvfmSZFatWmWNYd7auvvuu63/9zIzM01JSYkVLsYwZ+G4MGCYu+hHNoUim8JHNnUO2dR14iGbHMYY03N/ZwIAAAAAe0TVZ34AAAAAoLNofgAAAADEBZofAAAAAHGB5gcAAABAXKD5AQAAABAXaH4AAAAAxAWaHwAAAABxgeYHAAAAQFyg+QEAAAAQF2h+AAAAAMQFmh8AAAAAceH/AeK4W3T0haZyAAAAAElFTkSuQmCC", 182 | "text/plain": [ 183 | "
" 184 | ] 185 | }, 186 | "metadata": {}, 187 | "output_type": "display_data" 188 | } 189 | ], 190 | "source": [ 191 | "import matplotlib.pyplot as plt\n", 192 | "\n", 193 | "img_p = 'img00087_ircad18.png'\n", 194 | "\n", 195 | "orig_mask = cv2.imread(os.path.join(ircad_labels_path,img_p))\n", 196 | "orig_mask = cv2.cvtColor(orig_mask, cv2.COLOR_BGR2RGB)\n", 197 | "\n", 198 | "img = cv2.imread(os.path.join(dest_path,img_p), cv2.IMREAD_GRAYSCALE)\n", 199 | "print(img.shape)\n", 200 | "\n", 201 | "speidel_mapping = {\n", 202 | "2: (26, 26, 26),\n", 203 | "1: (77,77,77),\n", 204 | "4:(102, 102, 102),\n", 205 | "10:(51, 51, 51),\n", 206 | "5:(179, 179, 179),\n", 207 | "12:(128,128,128)\n", 208 | "}\n", 209 | "\n", 210 | "def map_function(mapping_dict, i): \n", 211 | " return lambda x: mapping_dict[x][i]\n", 212 | "\n", 213 | "\n", 214 | "# apply the mapping for each channel using np.vectorize and the mapping dictionaries\n", 215 | "mapped_array_channel_1 = np.vectorize(map_function(speidel_mapping, 0))(img[:,:])\n", 216 | "mapped_array_channel_2 = np.vectorize(map_function(speidel_mapping, 1))(img[:,:])\n", 217 | "mapped_array_channel_3 = np.vectorize(map_function(speidel_mapping, 2))(img[:,:])\n", 218 | "\n", 219 | "# combine the mapped channels back into a single array\n", 220 | "mapped_array = np.array([mapped_array_channel_1, mapped_array_channel_2, mapped_array_channel_3])\n", 221 | "print(mapped_array.shape)\n", 222 | "mapped_array = mapped_array.transpose((1, 2, 0)).astype('uint8')\n", 223 | "print(mapped_array.shape)\n", 224 | "\n", 225 | "# Create a figure with two subplots\n", 226 | "fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(10, 10))\n", 227 | "\n", 228 | "# Display the first image in the first subplot\n", 229 | "ax1.imshow(orig_mask)\n", 230 | "ax1.set_title('orig')\n", 231 | "\n", 232 | "# Display the second image in the second subplot\n", 233 | "ax2.imshow(mapped_array)\n", 234 | "ax2.set_title('restored')" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 21, 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "data": { 244 | "text/plain": [ 245 | "" 246 | ] 247 | }, 248 | "execution_count": 21, 249 | "metadata": {}, 250 | "output_type": "execute_result" 251 | }, 252 | { 253 | "data": { 254 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAFJCAYAAACxTdbnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjOElEQVR4nO3df3xU5YHv8e9kJhlCSMaEwExGIsY2tGqA1oBA1vJDMJgVKeK+oGpb2HW9WiFrFllX8O4tu9sllr2CbRGqloKCNNxdAelCLUEwSCkrDVACWooFNGimqRQyCcYJhGf/6O20w+9AkufM5PN+vc7rxZzzZPiOTyTfPHPOHJcxxggAAMBBkmwHAAAAOBsFBQAAOA4FBQAAOA4FBQAAOA4FBQAAOA4FBQAAOA4FBQAAOA4FBQAAOA4FBQAAOA4FBQAAOI7VgrJo0SLl5eWpW7duKiws1FtvvWUzDgAAcAhrBWXVqlUqKyvTU089pd27d+tLX/qSSkpK9MEHH9iKBAAAHMJl62aBQ4YM0S233KLFixdH9914442aMGGCysvLL/q1Z86c0UcffaT09HS5XK6OjgoAANqBMUaNjY0KBoNKSrr4GomnkzLFaGlpUXV1tZ588smY/cXFxdq+ffs54yORiCKRSPTxhx9+qJtuuqnDcwIAgPZXW1urPn36XHSMlYLy8ccfq7W1VX6/P2a/3+9XKBQ6Z3x5ebn++Z//+Zz9t+kv5VFyh+UEAMSZwTfrhZeXqJc7zXYSnEe46Yz63nJE6enplxxrpaD80dlvzxhjzvuWzaxZszRjxozo43A4rNzcXHmULI+LggIA+AOTkqrPXHPpH36w63JOz7BSULKzs+V2u89ZLamvrz9nVUWSvF6vvF5vZ8UDAMShyF8O1l3zNtuOgXZi5SqelJQUFRYWqrKyMmZ/ZWWlioqKbEQCAMS5TzPd+oes39iOgXZi7S2eGTNm6Gtf+5oGDRqkYcOG6YUXXtAHH3ygRx55xFYkAADgENYKyuTJk3Xs2DH9y7/8i+rq6lRQUKANGzaob9++tiIBAOLVrf3129GnbadAO7L2OShXIxwOy+fzaaS+zEmyAADV/lOR3vnGItsxcAnhxjPK7HdIDQ0NysjIuOhY7sUDAAAch4ICAIhrSWlpOp0ad28G4BKsfg4KAABX61ffvVGHSxZfeiDiCisoAADAcVhBAQDErZOv36Bf9V8scduThMMKCgAgbqUmn5KXqzkTEgUFAAA4Dm/xAADi0q9/OEgr8l4Qv2snJmYVABCXHhlcpb/oxo+xRMXMAgDiTujvi/TF1CO2Y6AD8RYPACDuPF/6PQ3t5rYdAx2IFRQAQHwZOkDdk07ZToEOxgoKACBueK4NauV/Pi9fUqrtKOhgFBQAQFxweTxav3ODJMpJV8BbPAAAwHEoKAAAwHEoKAAAx3Nn99S897bZjoFOREEBADifK0kDUrrZToFOREEBAACOQ0EBAACOQ0EBAACOQ0EBADia+xqf6ibn246BTkZBAQA42pm8Pto9e5HtGOhkFBQAAOA4FBQAgHMluXUmlbuydEUUFACAY528Z5A2/udLtmPAAgoKAABwHAoKAABwHN7YAwA40scPD9O8J16wHQOWsIICAHCklnSXRqe22o4BSygoAADAcSgoAADH+eSeIepZ/JHtGLCIc1AAAI7z4RjpUMFa2zFgESsoAADAcSgoAABH8VwblNJP2Y4By3iLBwDgKL97IU2HvvhD2zFgGSsoAADAcSgoAADAcSgoAADH6LOjh976wkrbMeAAFBQAgGNkpZyU15VsOwYcgIICAAAch4ICAAAch4ICAHCEI/82THf69tqOAYegoAAAHGH+5KXcvRhR7V5Q5syZI5fLFbMFAoHocWOM5syZo2AwqNTUVI0cOVL79+9v7xgAgDjy6bhb1cvdaDsGHKRDVlBuvvlm1dXVRbeamprosXnz5mn+/PlauHChdu7cqUAgoDvuuEONjXxjAkBXVbFovm71cvUO/qRDCorH41EgEIhuvXr1kvSH1ZNnn31WTz31lCZOnKiCggK99NJL+uSTT7RyJde9AwCAP+iQgnLw4EEFg0Hl5eXpK1/5ig4dOiRJOnz4sEKhkIqLi6NjvV6vRowYoe3bt1/w+SKRiMLhcMwGAAASV7sXlCFDhujll1/WT3/6U7344osKhUIqKirSsWPHFAqFJEl+vz/ma/x+f/TY+ZSXl8vn80W33Nzc9o4NAAAcpN0LSklJie699171799fY8aM0fr16yVJL730UnSMy+WK+RpjzDn7/tysWbPU0NAQ3Wpra9s7NgDABpdLr9T+TDmeHraTwGE6/DLjtLQ09e/fXwcPHoxezXP2akl9ff05qyp/zuv1KiMjI2YDACSGbHea7QhwoA4vKJFIRO+++65ycnKUl5enQCCgysrK6PGWlhZVVVWpqKioo6MAAIA44WnvJ5w5c6buvvtuXXfddaqvr9e3vvUthcNhTZkyRS6XS2VlZZo7d67y8/OVn5+vuXPnqnv37rr//vvbOwoAAIhT7V5Qjh49qvvuu08ff/yxevXqpaFDh2rHjh3q27evJOmJJ55Qc3OzHn30UR0/flxDhgzRxo0blZ6e3t5RAAAOltStm371vf6SdtuOAgdyGWOM7RBtFQ6H5fP5NFJflofbcgNAXHJf49OGd6psx0AnCjeeUWa/Q2poaLjk+aTciwcAADgOBQUA0OlcXq8it3zWdgw4GAUFANDpkj7TV2+sWGI7BhyMggIAAByHggIAAByHggIA6FSRksH6z40rbMeAw1FQAACdyyV1T0qxnQIOR0EBAACOQ0EBAHSaT+4ZosJ/rbYdA3GAggIA6DRNOW49k7PLdgzEAQoKAABwHAoKAKBTuApv1okvnLIdA3GCggIA6BTvfSVDh8e9aDsG4gQFBQAAOA4FBQAAOA4FBQAAOI7HdgAAQOL74D/660DRc+L3YlwuvlMAAB0uKemM3C5+5ODy8d0CAAAch7d4AAAd6tdLBumlL/7AdgzEGVZQAAAdauTNBzS8m+0UiDcUFAAA4DgUFABAh2n46lANyjhiOwbiEOegAAA6TOk//YceSD9mOwbiECsoAADAcSgoAADAcSgoAADAcSgoAADAcThJFgDQIWb9Zq++1O20+F0YV4LvGgBAh8j1hLn/Dq4Y3zkAgHZ3/dup6uPx2o6BOMZbPACAdrfw2m1KdiXbjoE4xgoKAABwHAoKAABwHAoKAKD9uFxq/vKtSpLLdhLEOc5BAQC0m6Tu3bV18Qvi919cLb6DAACA41BQAADtxuV2246ABEFBAQC0C88N12vDr7bajoEEQUEBAACOQ0EBAACOQ0EBAFw116ACjVxXYzsGEggFBQBw1U73SNE/ZP3GdgwkEAoKAABwnDYXlK1bt+ruu+9WMBiUy+XS2rVrY44bYzRnzhwFg0GlpqZq5MiR2r9/f8yYSCSi0tJSZWdnKy0tTePHj9fRo0ev6oUAAOxw3/w5HforbgyI9tXmgnLy5EkNHDhQCxcuPO/xefPmaf78+Vq4cKF27typQCCgO+64Q42NjdExZWVlWrNmjSoqKrRt2zY1NTVp3Lhxam1tvfJXAgCw4ne3ZunQxOdtx0CCafNH3ZeUlKikpOS8x4wxevbZZ/XUU09p4sSJkqSXXnpJfr9fK1eu1MMPP6yGhgYtWbJEy5cv15gxYyRJK1asUG5urjZt2qSxY8dexcsBAACJoF3PQTl8+LBCoZCKi4uj+7xer0aMGKHt27dLkqqrq3Xq1KmYMcFgUAUFBdExZ4tEIgqHwzEbAMC+pLQ0nerBjQHR/tq1oIRCIUmS3++P2e/3+6PHQqGQUlJSlJmZecExZysvL5fP54tuubm57RkbAHCFDj05QHtmLbIdAwmoQ67icbli27Qx5px9Z7vYmFmzZqmhoSG61dbWtltWAADgPO1aUAKBgCSdsxJSX18fXVUJBAJqaWnR8ePHLzjmbF6vVxkZGTEbAABIXO1aUPLy8hQIBFRZWRnd19LSoqqqKhUVFUmSCgsLlZycHDOmrq5O+/bti44BADjfr58frDem/LvtGEhQbb6Kp6mpSe+991708eHDh7Vnzx5lZWXpuuuuU1lZmebOnav8/Hzl5+dr7ty56t69u+6//35Jks/n04MPPqjHH39cPXv2VFZWlmbOnKn+/ftHr+oBADifu8dp9fH0sB0DCarNBeUXv/iFRo0aFX08Y8YMSdKUKVO0bNkyPfHEE2pubtajjz6q48ePa8iQIdq4caPS09OjX7NgwQJ5PB5NmjRJzc3NGj16tJYtWya3290OLwkAAMQ7lzHG2A7RVuFwWD6fTyP1ZXlcfHohANjwm1e+qPdGLbUdA3Ek3HhGmf0OqaGh4ZLnk3IvHgAA4DgUFABAm7WMHaSb+tTZjoEERkEBALSZb3at1uW/bjsGEhgFBQAAOA4FBQDQJi6vV54k7j6PjtXmy4wBAF3bXbtCKs38b9sxkOBYQQEAAI5DQQEAAI5DQQEAXLYv7Ja+7vuV7RjoAigoAIDLdt81/y1fUqrtGOgCKCgAAMBxKCgAgMtyuHyYgp7TtmOgi+AyYwDAZdn1tQXqkZRmOwa6CFZQAACA41BQAACX5O73GSXxIwOdiLd4AAAX5fJ6teHNVyWl2I6CLoQ6DAAAHIeCAgAAHIeCAgC4IE9OQC8cfMN2DHRBFBQAwIUlJek6Tw/bKdAFUVAAAIDjUFAAAOflzr9BJ5Z0sx0DXRQFBQBwXq09e+hnA1bbjoEuioICAAAch4ICADiHJ+DXbwdzcizsoaAAAM5xfGSe9sxaZDsGujAKCgAAcBwKCgAAcBwKCgAAcBwKCgAAcByP7QAAAGcJlRVpedl8SXxIG+xhBQUAEON0d2lACuUEdlFQAACA41BQAABRDV8dqhtLfm07BsA5KACAP/nt8Fbt+Mwm2zEAVlAAAIDzUFAAAJIkd/4N6pb5qe0YgCTe4gEA/H+fPHdG7xYstx0DkMQKCgAAcCAKCgAAcBwKCgBAA3a59MbNq23HAKIoKAAA9XBH5HbxIwHOwXcjAABwHAoKAHRxdWtv1Fevedt2DCBGmwvK1q1bdffddysYDMrlcmnt2rUxx6dOnSqXyxWzDR06NGZMJBJRaWmpsrOzlZaWpvHjx+vo0aNX9UIAAFfmu/1X6TPJPWzHAGK0uaCcPHlSAwcO1MKFCy845s4771RdXV1027BhQ8zxsrIyrVmzRhUVFdq2bZuampo0btw4tba2tv0VAACAhNPmD2orKSlRSUnJRcd4vV4FAoHzHmtoaNCSJUu0fPlyjRkzRpK0YsUK5ebmatOmTRo7dmxbIwEArlD4/qHq5f65pFTbUYAYHXIOyptvvqnevXurX79+euihh1RfXx89Vl1drVOnTqm4uDi6LxgMqqCgQNu3bz/v80UiEYXD4ZgNAHD1fvztZ3RzCuUEztPuBaWkpESvvPKKNm/erGeeeUY7d+7U7bffrkgkIkkKhUJKSUlRZmZmzNf5/X6FQqHzPmd5ebl8Pl90y83Nbe/YAADAQdr9XjyTJ0+O/rmgoECDBg1S3759tX79ek2cOPGCX2eMkcvlOu+xWbNmacaMGdHH4XCYkgIAQALr8MuMc3Jy1LdvXx08eFCSFAgE1NLSouPHj8eMq6+vl9/vP+9zeL1eZWRkxGwAACBxdXhBOXbsmGpra5WTkyNJKiwsVHJysiorK6Nj6urqtG/fPhUVFXV0HACAJCW5te7Dncp2p9lOApxXm9/iaWpq0nvvvRd9fPjwYe3Zs0dZWVnKysrSnDlzdO+99yonJ0dHjhzR7NmzlZ2drXvuuUeS5PP59OCDD+rxxx9Xz549lZWVpZkzZ6p///7Rq3oAAB3P60q2HQG4oDYXlF/84hcaNWpU9PEfzw2ZMmWKFi9erJqaGr388ss6ceKEcnJyNGrUKK1atUrp6enRr1mwYIE8Ho8mTZqk5uZmjR49WsuWLZPb7W6HlwQAuBh3Zqbu3f6u7RjARbmMMcZ2iLYKh8Py+XwaqS/Lw28AANAm7uye2rD3Ddsx0AWFG88os98hNTQ0XPJ8Uu7FAwAAHIeCAgAAHIeCAgAAHIeCAgBdSFL37mr80mdtxwAuiYICAF1Jfl+99dzztlMAl0RBAQAAjkNBAQAAjkNBAYAuonnCrfqvDa/YjgFcFgoKAHQhbhf/7CM+8J0KAF3Asb8dpm/Nf8F2DOCyUVAAoAs41cOl4d1spwAuHwUFABLcsb8dphFf3Wk7BtAmFBQASHC/H3BG3w1SUBBfKCgAAMBxKCgAkMBcgwqUem2T7RhAm3lsBwAAdJz6/3NK+wetsh0DaDNWUAAgQbkzMuRNPm07BnBFKCgAkKD6bflEPx/4qu0YwBWhoAAAAMfhHBQASED3vluvv86oleS2HQW4IqygAEAC6u0JK9lFOUH8oqAAAADHoaAAQIKpW3ujhnhDtmMAV4WCAgAJZsMtLyrH08N2DOCqUFAAAIDjUFAAAIDjUFAAIIG4Cm9WistlOwZw1fgcFABIIOvXLZfblWY7BnDVWEEBgAThSk6xHQFoNxQUAEgA7owMvf7+23K7+GcdiYHvZAAA4DgUFACIc+7P5mnOL9+wHQNoVxQUAIh3brdu9SbbTgG0KwoKAMQx16ACZfzw97ZjAO2OggIAcSzSs5sq8jbbjgG0OwoKAABwHAoKAMQpzw3Xq+4vOPcEiYmCAgBxKnRHjn71t4ttxwA6BAUFAOKQOyNDLRnccweJi3vxAEAcOjj7Zh38+iLbMYAOwwoKAABwHFZQACDOHK4YoHe+tFCS23YUoMOwggIAcSYpySjZRTlBYmtTQSkvL9fgwYOVnp6u3r17a8KECTpw4EDMGGOM5syZo2AwqNTUVI0cOVL79++PGROJRFRaWqrs7GylpaVp/PjxOnr06NW/GgAAkBDaVFCqqqo0bdo07dixQ5WVlTp9+rSKi4t18uTJ6Jh58+Zp/vz5WrhwoXbu3KlAIKA77rhDjY2N0TFlZWVas2aNKioqtG3bNjU1NWncuHFqbW1tv1cGAAno19+/Vc8VrrQdA+hwLmOMudIv/t3vfqfevXurqqpKw4cPlzFGwWBQZWVl+sd//EdJf1gt8fv9+va3v62HH35YDQ0N6tWrl5YvX67JkydLkj766CPl5uZqw4YNGjt27CX/3nA4LJ/Pp5H6sjwuPqQIQNfR82eZWpm3xXYM4IqEG88os98hNTQ0KCMj46Jjr+oclIaGBklSVlaWJOnw4cMKhUIqLi6OjvF6vRoxYoS2b98uSaqurtapU6dixgSDQRUUFETHnC0SiSgcDsdsAAAgcV1xQTHGaMaMGbrttttUUFAgSQqFQpIkv98fM9bv90ePhUIhpaSkKDMz84JjzlZeXi6fzxfdcnNzrzQ2AACIA1dcUKZPn669e/fqRz/60TnHXK7YTzc0xpyz72wXGzNr1iw1NDREt9ra2iuNDQBxyxQN1PXdj9mOAXSKKyoopaWlWrdunbZs2aI+ffpE9wcCAUk6ZyWkvr4+uqoSCATU0tKi48ePX3DM2bxerzIyMmI2AOhq7v/hTzTXv9d2DKBTtKmgGGM0ffp0rV69Wps3b1ZeXl7M8by8PAUCAVVWVkb3tbS0qKqqSkVFRZKkwsJCJScnx4ypq6vTvn37omMAALGS0tKUpDO2YwCdpk2fJDtt2jStXLlSr732mtLT06MrJT6fT6mpqXK5XCorK9PcuXOVn5+v/Px8zZ07V927d9f9998fHfvggw/q8ccfV8+ePZWVlaWZM2eqf//+GjNmTPu/QgBIAN95Z6P6JafZjgF0mjYVlMWL/3Bb75EjR8bsX7p0qaZOnSpJeuKJJ9Tc3KxHH31Ux48f15AhQ7Rx40alp6dHxy9YsEAej0eTJk1Sc3OzRo8erWXLlsnt5pMRAQDAVX4Oii18DgqAruTZI9t1Y0p32zGAq9Zpn4MCAOh4NyTzixi6Hu5mDAAO5fJ61fetJHlZKUYXxAoKADiUy+XS831+bjsGYAUFBQAAOA4FBQAcKCktTR89covtGIA1FBQAcKCkXj31yycW2Y4BWENBAQCHcXm9+vQzvWzHAKyioACAw7QOvlFvLF9iOwZgFQUFAJzGffG7vwNdAQUFABzkxNeHqfJHS23HAKyjoAAAAMehoAAAAMfho+4BwCF+941h+rvHXrUdA3AEVlAAwCGa/S5Nzai3HQNwBAoKAABwHAoKADhApGSwegz62HYMwDE4BwUAHOCDyWd06Jb/ZzsG4BisoAAAAMehoACAZe5rfHKntNqOATgKb/EAgG1r0nTwc8tspwAchRUUAADgOBQUALBo6C9P6bV+P7YdA3AcCgoAWORzNyvZ5bYdA3AcCgoAWBLcka6v+2psxwAciYICAJZM97+hbHea7RiAI1FQAMCCX78wWLmeU7ZjAI7FZcYAYEHNX35PPZJYPQEuhBUUAOhMLpcaHhiqJP75BS6KFRQA6CQuj0etw/prx79/X1KK7TiAo1HhAaATuDweuQr6qXLVUttRgLhAQQGAzvCFz+snG1baTgHEDQoKAABwHAoKAHSwk381RK+9xls7QFtwkiwAdKBQWZHW/v08eV09bEcB4goFBQA6yJF/HaYXH1ikvGTKCdBWFBQA6AC/+b9D9e3xr2h4N9tJgPjEOSgA0AH8BfW6t0fYdgwgblFQAKCd/fbvijShzy9txwDiGm/xAEA7Ct8/VP8wbZUeSD9mOwoQ1ygoANBOzoz4or7zb9/Trd5k21GAuEdBAYB24Lnher20YqGy3dyhGGgPFBQAuFpJbq3ftlYS5QRoL5wkCwAAHIeCAgBXwZ2Zqf+qfdt2DCDhtKmglJeXa/DgwUpPT1fv3r01YcIEHThwIGbM1KlT5XK5YrahQ4fGjIlEIiotLVV2drbS0tI0fvx4HT169OpfDQB0IvdN/fTcnh8r2eW2HQVIOG0qKFVVVZo2bZp27NihyspKnT59WsXFxTp58mTMuDvvvFN1dXXRbcOGDTHHy8rKtGbNGlVUVGjbtm1qamrSuHHj1NraevWvCAA6wenRhZq2bh0fYw90kDadJPv666/HPF66dKl69+6t6upqDR8+PLrf6/UqEAic9zkaGhq0ZMkSLV++XGPGjJEkrVixQrm5udq0aZPGjh3b1tcAAJ2qadJQ3fW/t+iu7p/ajgIkrKs6B6WhoUGSlJWVFbP/zTffVO/evdWvXz899NBDqq+vjx6rrq7WqVOnVFxcHN0XDAZVUFCg7du3n/fviUQiCofDMRsA2PD7vx6mgpm/1OzsA5ceDOCKXXFBMcZoxowZuu2221RQUBDdX1JSoldeeUWbN2/WM888o507d+r2229XJBKRJIVCIaWkpCgzMzPm+fx+v0Kh0Hn/rvLycvl8vuiWm5t7pbEB4IodnzpMn3voXT3f5+e2owAJ74o/B2X69Onau3evtm3bFrN/8uTJ0T8XFBRo0KBB6tu3r9avX6+JEyde8PmMMXK5XOc9NmvWLM2YMSP6OBwOU1IAdLpPJ5zQiuvftB0D6BKuaAWltLRU69at05YtW9SnT5+Ljs3JyVHfvn118OBBSVIgEFBLS4uOHz8eM66+vl5+v/+8z+H1epWRkRGzAUBnMkUDdUPm723HALqMNhUUY4ymT5+u1atXa/PmzcrLy7vk1xw7dky1tbXKycmRJBUWFio5OVmVlZXRMXV1ddq3b5+KioraGB8AOp47/wYNem631ub/1HYUoMto01s806ZN08qVK/Xaa68pPT09es6Iz+dTamqqmpqaNGfOHN17773KycnRkSNHNHv2bGVnZ+uee+6Jjn3wwQf1+OOPq2fPnsrKytLMmTPVv3//6FU9AOAU7mt8+l8bNmpCWpPtKECX0qaCsnjxYknSyJEjY/YvXbpUU6dOldvtVk1NjV5++WWdOHFCOTk5GjVqlFatWqX09PTo+AULFsjj8WjSpElqbm7W6NGjtWzZMrndl/dhR8YYSdJpnZJMW14BALTN/J/9WLlneijcaDsJEP/CTWck/enn+MW4zOWMcpijR49ykiwAAHGqtrb2kuewxmVBOXPmjA4cOKCbbrpJtbW1nDQbR/54BRbzFj+Ys/jEvMWnRJ83Y4waGxsVDAaVlHTx02Cv+DJjm5KSknTttddKElf1xCnmLf4wZ/GJeYtPiTxvPp/vssZxN2MAAOA4FBQAAOA4cVtQvF6vvvnNb8rr9dqOgjZg3uIPcxafmLf4xLz9SVyeJAsAABJb3K6gAACAxEVBAQAAjkNBAQAAjkNBAQAAjkNBAQAAjhOXBWXRokXKy8tTt27dVFhYqLfeest2pC5t69atuvvuuxUMBuVyubR27dqY48YYzZkzR8FgUKmpqRo5cqT2798fMyYSiai0tFTZ2dlKS0vT+PHjdfTo0U58FV1LeXm5Bg8erPT0dPXu3VsTJkzQgQMHYsYwb86zePFiDRgwIPopo8OGDdNPfvKT6HHmzPnKy8vlcrlUVlYW3ce8XYCJMxUVFSY5Odm8+OKL5p133jGPPfaYSUtLM++//77taF3Whg0bzFNPPWVeffVVI8msWbMm5vjTTz9t0tPTzauvvmpqamrM5MmTTU5OjgmHw9ExjzzyiLn22mtNZWWl2bVrlxk1apQZOHCgOX36dCe/mq5h7NixZunSpWbfvn1mz5495q677jLXXXedaWpqio5h3pxn3bp1Zv369ebAgQPmwIEDZvbs2SY5Odns27fPGMOcOd3bb79trr/+ejNgwADz2GOPRfczb+cXdwXl1ltvNY888kjMvs9//vPmySeftJQIf+7sgnLmzBkTCATM008/Hd336aefGp/PZ77//e8bY4w5ceKESU5ONhUVFdExH374oUlKSjKvv/56p2Xvyurr640kU1VVZYxh3uJJZmam+cEPfsCcOVxjY6PJz883lZWVZsSIEdGCwrxdWFy9xdPS0qLq6moVFxfH7C8uLtb27dstpcLFHD58WKFQKGbOvF6vRowYEZ2z6upqnTp1KmZMMBhUQUEB89pJGhoaJElZWVmSmLd40NraqoqKCp08eVLDhg1jzhxu2rRpuuuuuzRmzJiY/czbhcXV3Yw//vhjtba2yu/3x+z3+/0KhUKWUuFi/jgv55uz999/PzomJSVFmZmZ54xhXjueMUYzZszQbbfdpoKCAknMm5PV1NRo2LBh+vTTT9WjRw+tWbNGN910U/QHFXPmPBUVFdq1a5d27tx5zjH+X7uwuCoof+RyuWIeG2PO2QdnuZI5Y147x/Tp07V3715t27btnGPMm/N87nOf0549e3TixAm9+uqrmjJliqqqqqLHmTNnqa2t1WOPPaaNGzeqW7duFxzHvJ0rrt7iyc7OltvtPqcx1tfXn9M+4QyBQECSLjpngUBALS0tOn78+AXHoGOUlpZq3bp12rJli/r06RPdz7w5V0pKij772c9q0KBBKi8v18CBA/Wd73yHOXOo6upq1dfXq7CwUB6PRx6PR1VVVfrud78rj8cT/e/OvJ0rrgpKSkqKCgsLVVlZGbO/srJSRUVFllLhYvLy8hQIBGLmrKWlRVVVVdE5KywsVHJycsyYuro67du3j3ntIMYYTZ8+XatXr9bmzZuVl5cXc5x5ix/GGEUiEebMoUaPHq2amhrt2bMnug0aNEgPPPCA9uzZoxtuuIF5uxA75+ZeuT9eZrxkyRLzzjvvmLKyMpOWlmaOHDliO1qX1djYaHbv3m12795tJJn58+eb3bt3Ry/9fvrpp43P5zOrV682NTU15r777jvvJXR9+vQxmzZtMrt27TK33357wl9CZ9M3vvEN4/P5zJtvvmnq6uqi2yeffBIdw7w5z6xZs8zWrVvN4cOHzd69e83s2bNNUlKS2bhxozGGOYsXf34VjzHM24XEXUExxpjnnnvO9O3b16SkpJhbbrklemkk7NiyZYuRdM42ZcoUY8wfLqP75je/aQKBgPF6vWb48OGmpqYm5jmam5vN9OnTTVZWlklNTTXjxo0zH3zwgYVX0zWcb74kmaVLl0bHMG/O8zd/8zfRf/t69eplRo8eHS0nxjBn8eLsgsK8nZ/LGGPsrN0AAACcX1ydgwIAALoGCgoAAHAcCgoAAHAcCgoAAHAcCgoAAHAcCgoAAHAcCgoAAHAcCgoAAHAcCgoAAHAcCgoAAHAcCgoAAHCc/wFTY73IppXWjAAAAABJRU5ErkJggg==", 255 | "text/plain": [ 256 | "
" 257 | ] 258 | }, 259 | "metadata": {}, 260 | "output_type": "display_data" 261 | } 262 | ], 263 | "source": [ 264 | "plt.imshow(mapped_array[:,:,0])" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 22, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "data": { 274 | "text/plain": [ 275 | "array([2, 4], dtype=uint8)" 276 | ] 277 | }, 278 | "execution_count": 22, 279 | "metadata": {}, 280 | "output_type": "execute_result" 281 | } 282 | ], 283 | "source": [ 284 | "np.unique(img)" 285 | ] 286 | } 287 | ], 288 | "metadata": { 289 | "kernelspec": { 290 | "display_name": "base", 291 | "language": "python", 292 | "name": "python3" 293 | }, 294 | "language_info": { 295 | "codemirror_mode": { 296 | "name": "ipython", 297 | "version": 3 298 | }, 299 | "file_extension": ".py", 300 | "mimetype": "text/x-python", 301 | "name": "python", 302 | "nbconvert_exporter": "python", 303 | "pygments_lexer": "ipython3", 304 | "version": "3.9.13" 305 | }, 306 | "orig_nbformat": 4, 307 | "vscode": { 308 | "interpreter": { 309 | "hash": "c3ba92e74a4191627d049a556b86b8de7025ad6ea6b9bd6e1b0735ce4be03fa6" 310 | } 311 | } 312 | }, 313 | "nbformat": 4, 314 | "nbformat_minor": 2 315 | } 316 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # to be added 2 | -------------------------------------------------------------------------------- /sim2real_asset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/sim2real_with_Stable_Diffusion/d975733c12e5d9dfaebdf7672b88b00ff10eb229/sim2real_asset.png -------------------------------------------------------------------------------- /training/README.md: -------------------------------------------------------------------------------- 1 | To run Dreambooth training follow `train_dreambooth_commands.ipynb` notebook. -------------------------------------------------------------------------------- /training/convert_diffusers_to_sd.py: -------------------------------------------------------------------------------- 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. 2 | # *Only* converts the UNet, VAE, and Text Encoder. 3 | # Does not convert optimizer state or any other thing. 4 | 5 | import argparse 6 | import os.path as osp 7 | import re 8 | 9 | import torch 10 | from safetensors.torch import load_file, save_file 11 | 12 | 13 | # =================# 14 | # UNet Conversion # 15 | # =================# 16 | 17 | unet_conversion_map = [ 18 | # (stable-diffusion, HF Diffusers) 19 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 20 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 21 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 22 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 23 | ("input_blocks.0.0.weight", "conv_in.weight"), 24 | ("input_blocks.0.0.bias", "conv_in.bias"), 25 | ("out.0.weight", "conv_norm_out.weight"), 26 | ("out.0.bias", "conv_norm_out.bias"), 27 | ("out.2.weight", "conv_out.weight"), 28 | ("out.2.bias", "conv_out.bias"), 29 | ] 30 | 31 | unet_conversion_map_resnet = [ 32 | # (stable-diffusion, HF Diffusers) 33 | ("in_layers.0", "norm1"), 34 | ("in_layers.2", "conv1"), 35 | ("out_layers.0", "norm2"), 36 | ("out_layers.3", "conv2"), 37 | ("emb_layers.1", "time_emb_proj"), 38 | ("skip_connection", "conv_shortcut"), 39 | ] 40 | 41 | unet_conversion_map_layer = [] 42 | # hardcoded number of downblocks and resnets/attentions... 43 | # would need smarter logic for other networks. 44 | for i in range(4): 45 | # loop over downblocks/upblocks 46 | 47 | for j in range(2): 48 | # loop over resnets/attentions for downblocks 49 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 50 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 51 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 52 | 53 | if i < 3: 54 | # no attention layers in down_blocks.3 55 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 56 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 57 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 58 | 59 | for j in range(3): 60 | # loop over resnets/attentions for upblocks 61 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 62 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 63 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 64 | 65 | if i > 0: 66 | # no attention layers in up_blocks.0 67 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 68 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 69 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 70 | 71 | if i < 3: 72 | # no downsample in down_blocks.3 73 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 74 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 75 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 76 | 77 | # no upsample in up_blocks.3 78 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 79 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 80 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 81 | 82 | hf_mid_atn_prefix = "mid_block.attentions.0." 83 | sd_mid_atn_prefix = "middle_block.1." 84 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 85 | 86 | for j in range(2): 87 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 88 | sd_mid_res_prefix = f"middle_block.{2*j}." 89 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 90 | 91 | 92 | def convert_unet_state_dict(unet_state_dict): 93 | # buyer beware: this is a *brittle* function, 94 | # and correct output requires that all of these pieces interact in 95 | # the exact order in which I have arranged them. 96 | mapping = {k: k for k in unet_state_dict.keys()} 97 | for sd_name, hf_name in unet_conversion_map: 98 | mapping[hf_name] = sd_name 99 | for k, v in mapping.items(): 100 | if "resnets" in k: 101 | for sd_part, hf_part in unet_conversion_map_resnet: 102 | v = v.replace(hf_part, sd_part) 103 | mapping[k] = v 104 | for k, v in mapping.items(): 105 | for sd_part, hf_part in unet_conversion_map_layer: 106 | v = v.replace(hf_part, sd_part) 107 | mapping[k] = v 108 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 109 | return new_state_dict 110 | 111 | 112 | # ================# 113 | # VAE Conversion # 114 | # ================# 115 | 116 | vae_conversion_map = [ 117 | # (stable-diffusion, HF Diffusers) 118 | ("nin_shortcut", "conv_shortcut"), 119 | ("norm_out", "conv_norm_out"), 120 | ("mid.attn_1.", "mid_block.attentions.0."), 121 | ] 122 | 123 | for i in range(4): 124 | # down_blocks have two resnets 125 | for j in range(2): 126 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 127 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 128 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 129 | 130 | if i < 3: 131 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 132 | sd_downsample_prefix = f"down.{i}.downsample." 133 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 134 | 135 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 136 | sd_upsample_prefix = f"up.{3-i}.upsample." 137 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 138 | 139 | # up_blocks have three resnets 140 | # also, up blocks in hf are numbered in reverse from sd 141 | for j in range(3): 142 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 143 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 144 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 145 | 146 | # this part accounts for mid blocks in both the encoder and the decoder 147 | for i in range(2): 148 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 149 | sd_mid_res_prefix = f"mid.block_{i+1}." 150 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 151 | 152 | 153 | vae_conversion_map_attn = [ 154 | # (stable-diffusion, HF Diffusers) 155 | ("norm.", "group_norm."), 156 | ("q.", "query."), 157 | ("k.", "key."), 158 | ("v.", "value."), 159 | ("proj_out.", "proj_attn."), 160 | ] 161 | 162 | 163 | def reshape_weight_for_sd(w): 164 | # convert HF linear weights to SD conv2d weights 165 | return w.reshape(*w.shape, 1, 1) 166 | 167 | 168 | def convert_vae_state_dict(vae_state_dict): 169 | mapping = {k: k for k in vae_state_dict.keys()} 170 | for k, v in mapping.items(): 171 | for sd_part, hf_part in vae_conversion_map: 172 | v = v.replace(hf_part, sd_part) 173 | mapping[k] = v 174 | for k, v in mapping.items(): 175 | if "attentions" in k: 176 | for sd_part, hf_part in vae_conversion_map_attn: 177 | v = v.replace(hf_part, sd_part) 178 | mapping[k] = v 179 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 180 | weights_to_convert = ["q", "k", "v", "proj_out"] 181 | for k, v in new_state_dict.items(): 182 | for weight_name in weights_to_convert: 183 | if f"mid.attn_1.{weight_name}.weight" in k: 184 | print(f"Reshaping {k} for SD format") 185 | new_state_dict[k] = reshape_weight_for_sd(v) 186 | return new_state_dict 187 | 188 | 189 | # =========================# 190 | # Text Encoder Conversion # 191 | # =========================# 192 | 193 | 194 | textenc_conversion_lst = [ 195 | # (stable-diffusion, HF Diffusers) 196 | ("resblocks.", "text_model.encoder.layers."), 197 | ("ln_1", "layer_norm1"), 198 | ("ln_2", "layer_norm2"), 199 | (".c_fc.", ".fc1."), 200 | (".c_proj.", ".fc2."), 201 | (".attn", ".self_attn"), 202 | ("ln_final.", "transformer.text_model.final_layer_norm."), 203 | ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), 204 | ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), 205 | ] 206 | protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} 207 | textenc_pattern = re.compile("|".join(protected.keys())) 208 | 209 | # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp 210 | code2idx = {"q": 0, "k": 1, "v": 2} 211 | 212 | 213 | def convert_text_enc_state_dict_v20(text_enc_dict): 214 | new_state_dict = {} 215 | capture_qkv_weight = {} 216 | capture_qkv_bias = {} 217 | for k, v in text_enc_dict.items(): 218 | if ( 219 | k.endswith(".self_attn.q_proj.weight") 220 | or k.endswith(".self_attn.k_proj.weight") 221 | or k.endswith(".self_attn.v_proj.weight") 222 | ): 223 | k_pre = k[: -len(".q_proj.weight")] 224 | k_code = k[-len("q_proj.weight")] 225 | if k_pre not in capture_qkv_weight: 226 | capture_qkv_weight[k_pre] = [None, None, None] 227 | capture_qkv_weight[k_pre][code2idx[k_code]] = v 228 | continue 229 | 230 | if ( 231 | k.endswith(".self_attn.q_proj.bias") 232 | or k.endswith(".self_attn.k_proj.bias") 233 | or k.endswith(".self_attn.v_proj.bias") 234 | ): 235 | k_pre = k[: -len(".q_proj.bias")] 236 | k_code = k[-len("q_proj.bias")] 237 | if k_pre not in capture_qkv_bias: 238 | capture_qkv_bias[k_pre] = [None, None, None] 239 | capture_qkv_bias[k_pre][code2idx[k_code]] = v 240 | continue 241 | 242 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) 243 | new_state_dict[relabelled_key] = v 244 | 245 | for k_pre, tensors in capture_qkv_weight.items(): 246 | if None in tensors: 247 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") 248 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) 249 | new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) 250 | 251 | for k_pre, tensors in capture_qkv_bias.items(): 252 | if None in tensors: 253 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") 254 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) 255 | new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) 256 | 257 | return new_state_dict 258 | 259 | 260 | def convert_text_enc_state_dict(text_enc_dict): 261 | return text_enc_dict 262 | 263 | 264 | if __name__ == "__main__": 265 | parser = argparse.ArgumentParser() 266 | 267 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") 268 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") 269 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") 270 | parser.add_argument( 271 | "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt." 272 | ) 273 | 274 | args = parser.parse_args() 275 | 276 | assert args.model_path is not None, "Must provide a model path!" 277 | 278 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!" 279 | 280 | # Path for safetensors 281 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors") 282 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors") 283 | text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors") 284 | 285 | # Load models from safetensors if it exists, if it doesn't pytorch 286 | if osp.exists(unet_path): 287 | unet_state_dict = load_file(unet_path, device="cpu") 288 | else: 289 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") 290 | unet_state_dict = torch.load(unet_path, map_location="cpu") 291 | 292 | if osp.exists(vae_path): 293 | vae_state_dict = load_file(vae_path, device="cpu") 294 | else: 295 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") 296 | vae_state_dict = torch.load(vae_path, map_location="cpu") 297 | 298 | if osp.exists(text_enc_path): 299 | text_enc_dict = load_file(text_enc_path, device="cpu") 300 | else: 301 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") 302 | text_enc_dict = torch.load(text_enc_path, map_location="cpu") 303 | 304 | # Convert the UNet model 305 | unet_state_dict = convert_unet_state_dict(unet_state_dict) 306 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} 307 | 308 | # Convert the VAE model 309 | vae_state_dict = convert_vae_state_dict(vae_state_dict) 310 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} 311 | 312 | # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper 313 | is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict 314 | 315 | if is_v20_model: 316 | # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm 317 | text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} 318 | text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) 319 | text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} 320 | else: 321 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) 322 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} 323 | 324 | # Put together new checkpoint 325 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} 326 | if args.half: 327 | state_dict = {k: v.half() for k, v in state_dict.items()} 328 | 329 | if args.use_safetensors: 330 | save_file(state_dict, args.checkpoint_path) 331 | else: 332 | state_dict = {"state_dict": state_dict} 333 | torch.save(state_dict, args.checkpoint_path) -------------------------------------------------------------------------------- /training/train_dreambooth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import itertools 4 | import random 5 | import json 6 | import logging 7 | import math 8 | import os 9 | from contextlib import nullcontext 10 | from pathlib import Path 11 | from typing import Optional 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint 16 | from torch.utils.data import Dataset 17 | 18 | from accelerate import Accelerator 19 | from accelerate.logging import get_logger 20 | from accelerate.utils import set_seed 21 | from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel 22 | from diffusers.optimization import get_scheduler 23 | from diffusers.utils.import_utils import is_xformers_available 24 | from huggingface_hub import HfFolder, Repository, whoami 25 | from PIL import Image 26 | from torchvision import transforms 27 | from tqdm.auto import tqdm 28 | from transformers import CLIPTextModel, CLIPTokenizer 29 | 30 | 31 | torch.backends.cudnn.benchmark = True 32 | 33 | 34 | logger = get_logger(__name__) 35 | 36 | 37 | def parse_args(input_args=None): 38 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 39 | parser.add_argument( 40 | "--pretrained_model_name_or_path", 41 | type=str, 42 | default=None, 43 | required=True, 44 | help="Path to pretrained model or model identifier from huggingface.co/models.", 45 | ) 46 | parser.add_argument( 47 | "--pretrained_vae_name_or_path", 48 | type=str, 49 | default=None, 50 | help="Path to pretrained vae or vae identifier from huggingface.co/models.", 51 | ) 52 | parser.add_argument( 53 | "--revision", 54 | type=str, 55 | default=None, 56 | required=False, 57 | help="Revision of pretrained model identifier from huggingface.co/models.", 58 | ) 59 | parser.add_argument( 60 | "--tokenizer_name", 61 | type=str, 62 | default=None, 63 | help="Pretrained tokenizer name or path if not the same as model_name", 64 | ) 65 | parser.add_argument( 66 | "--instance_data_dir", 67 | type=str, 68 | default=None, 69 | help="A folder containing the training data of instance images.", 70 | ) 71 | parser.add_argument( 72 | "--class_data_dir", 73 | type=str, 74 | default=None, 75 | help="A folder containing the training data of class images.", 76 | ) 77 | parser.add_argument( 78 | "--instance_prompt", 79 | type=str, 80 | default=None, 81 | help="The prompt with identifier specifying the instance", 82 | ) 83 | parser.add_argument( 84 | "--class_prompt", 85 | type=str, 86 | default=None, 87 | help="The prompt to specify images in the same class as provided instance images.", 88 | ) 89 | parser.add_argument( 90 | "--save_sample_prompt", 91 | type=str, 92 | default=None, 93 | help="The prompt used to generate sample outputs to save.", 94 | ) 95 | parser.add_argument( 96 | "--save_sample_negative_prompt", 97 | type=str, 98 | default=None, 99 | help="The negative prompt used to generate sample outputs to save.", 100 | ) 101 | parser.add_argument( 102 | "--n_save_sample", 103 | type=int, 104 | default=4, 105 | help="The number of samples to save.", 106 | ) 107 | parser.add_argument( 108 | "--save_guidance_scale", 109 | type=float, 110 | default=7.5, 111 | help="CFG for save sample.", 112 | ) 113 | parser.add_argument( 114 | "--save_infer_steps", 115 | type=int, 116 | default=20, 117 | help="The number of inference steps for save sample.", 118 | ) 119 | parser.add_argument( 120 | "--pad_tokens", 121 | default=False, 122 | action="store_true", 123 | help="Flag to pad tokens to length 77.", 124 | ) 125 | parser.add_argument( 126 | "--with_prior_preservation", 127 | default=False, 128 | action="store_true", 129 | help="Flag to add prior preservation loss.", 130 | ) 131 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 132 | parser.add_argument( 133 | "--num_class_images", 134 | type=int, 135 | default=100, 136 | help=( 137 | "Minimal class images for prior preservation loss. If not have enough images, additional images will be" 138 | " sampled with class_prompt." 139 | ), 140 | ) 141 | parser.add_argument( 142 | "--output_dir", 143 | type=str, 144 | default="text-inversion-model", 145 | help="The output directory where the model predictions and checkpoints will be written.", 146 | ) 147 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 148 | parser.add_argument( 149 | "--resolution", 150 | type=int, 151 | default=512, 152 | help=( 153 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 154 | " resolution" 155 | ), 156 | ) 157 | parser.add_argument( 158 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 159 | ) 160 | parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") 161 | parser.add_argument( 162 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 163 | ) 164 | parser.add_argument( 165 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 166 | ) 167 | parser.add_argument("--num_train_epochs", type=int, default=1) 168 | parser.add_argument( 169 | "--max_train_steps", 170 | type=int, 171 | default=None, 172 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 173 | ) 174 | parser.add_argument( 175 | "--gradient_accumulation_steps", 176 | type=int, 177 | default=1, 178 | help="Number of updates steps to accumulate before performing a backward/update pass.", 179 | ) 180 | parser.add_argument( 181 | "--gradient_checkpointing", 182 | action="store_true", 183 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 184 | ) 185 | parser.add_argument( 186 | "--learning_rate", 187 | type=float, 188 | default=5e-6, 189 | help="Initial learning rate (after the potential warmup period) to use.", 190 | ) 191 | parser.add_argument( 192 | "--scale_lr", 193 | action="store_true", 194 | default=False, 195 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 196 | ) 197 | parser.add_argument( 198 | "--lr_scheduler", 199 | type=str, 200 | default="constant", 201 | help=( 202 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 203 | ' "constant", "constant_with_warmup"]' 204 | ), 205 | ) 206 | parser.add_argument( 207 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 208 | ) 209 | parser.add_argument( 210 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 211 | ) 212 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 213 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 214 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 215 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 216 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 217 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 218 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 219 | parser.add_argument( 220 | "--hub_model_id", 221 | type=str, 222 | default=None, 223 | help="The name of the repository to keep in sync with the local `output_dir`.", 224 | ) 225 | parser.add_argument( 226 | "--logging_dir", 227 | type=str, 228 | default="logs", 229 | help=( 230 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 231 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 232 | ), 233 | ) 234 | parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.") 235 | parser.add_argument("--save_interval", type=int, default=10_000, help="Save weights every N steps.") 236 | parser.add_argument("--save_min_steps", type=int, default=0, help="Start saving weights after N steps.") 237 | parser.add_argument( 238 | "--mixed_precision", 239 | type=str, 240 | default=None, 241 | choices=["no", "fp16", "bf16"], 242 | help=( 243 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 244 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 245 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 246 | ), 247 | ) 248 | parser.add_argument("--not_cache_latents", action="store_true", help="Do not precompute and cache latents from VAE.") 249 | parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.") 250 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 251 | parser.add_argument( 252 | "--concepts_list", 253 | type=str, 254 | default=None, 255 | help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.", 256 | ) 257 | parser.add_argument( 258 | "--read_prompts_from_txts", 259 | action="store_true", 260 | help="Use prompt per image. Put prompts in the same directory as images, e.g. for image.png create image.png.txt.", 261 | ) 262 | 263 | if input_args is not None: 264 | args = parser.parse_args(input_args) 265 | else: 266 | args = parser.parse_args() 267 | 268 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 269 | if env_local_rank != -1 and env_local_rank != args.local_rank: 270 | args.local_rank = env_local_rank 271 | 272 | return args 273 | 274 | 275 | class DreamBoothDataset(Dataset): 276 | """ 277 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 278 | It pre-processes the images and the tokenizes prompts. 279 | """ 280 | 281 | def __init__( 282 | self, 283 | concepts_list, 284 | tokenizer, 285 | with_prior_preservation=True, 286 | size=512, 287 | center_crop=False, 288 | num_class_images=None, 289 | pad_tokens=False, 290 | hflip=False, 291 | read_prompts_from_txts=False, 292 | ): 293 | self.size = size 294 | self.center_crop = center_crop 295 | self.tokenizer = tokenizer 296 | self.with_prior_preservation = with_prior_preservation 297 | self.pad_tokens = pad_tokens 298 | self.read_prompts_from_txts = read_prompts_from_txts 299 | 300 | self.instance_images_path = [] 301 | self.class_images_path = [] 302 | 303 | for concept in concepts_list: 304 | inst_img_path = [ 305 | (x, concept["instance_prompt"]) 306 | for x in Path(concept["instance_data_dir"]).iterdir() 307 | if x.is_file() and not str(x).endswith(".txt") 308 | ] 309 | self.instance_images_path.extend(inst_img_path) 310 | 311 | if with_prior_preservation: 312 | class_img_path = [(x, concept["class_prompt"]) for x in Path(concept["class_data_dir"]).iterdir() if x.is_file()] 313 | self.class_images_path.extend(class_img_path[:num_class_images]) 314 | 315 | random.shuffle(self.instance_images_path) 316 | self.num_instance_images = len(self.instance_images_path) 317 | self.num_class_images = len(self.class_images_path) 318 | self._length = max(self.num_class_images, self.num_instance_images) 319 | 320 | self.image_transforms = transforms.Compose( 321 | [ 322 | transforms.RandomHorizontalFlip(0.5 * hflip), 323 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 324 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 325 | transforms.ToTensor(), 326 | transforms.Normalize([0.5], [0.5]), 327 | ] 328 | ) 329 | 330 | def __len__(self): 331 | return self._length 332 | 333 | def __getitem__(self, index): 334 | example = {} 335 | instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images] 336 | 337 | if self.read_prompts_from_txts: 338 | with open(str(instance_path) + ".txt") as f: 339 | instance_prompt = f.read().strip() 340 | 341 | instance_image = Image.open(instance_path) 342 | if not instance_image.mode == "RGB": 343 | instance_image = instance_image.convert("RGB") 344 | 345 | example["instance_images"] = self.image_transforms(instance_image) 346 | example["instance_prompt_ids"] = self.tokenizer( 347 | instance_prompt, 348 | padding="max_length" if self.pad_tokens else "do_not_pad", 349 | truncation=True, 350 | max_length=self.tokenizer.model_max_length, 351 | ).input_ids 352 | 353 | if self.with_prior_preservation: 354 | class_path, class_prompt = self.class_images_path[index % self.num_class_images] 355 | class_image = Image.open(class_path) 356 | if not class_image.mode == "RGB": 357 | class_image = class_image.convert("RGB") 358 | example["class_images"] = self.image_transforms(class_image) 359 | example["class_prompt_ids"] = self.tokenizer( 360 | class_prompt, 361 | padding="max_length" if self.pad_tokens else "do_not_pad", 362 | truncation=True, 363 | max_length=self.tokenizer.model_max_length, 364 | ).input_ids 365 | 366 | return example 367 | 368 | 369 | class PromptDataset(Dataset): 370 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 371 | 372 | def __init__(self, prompt, num_samples): 373 | self.prompt = prompt 374 | self.num_samples = num_samples 375 | 376 | def __len__(self): 377 | return self.num_samples 378 | 379 | def __getitem__(self, index): 380 | example = {} 381 | example["prompt"] = self.prompt 382 | example["index"] = index 383 | return example 384 | 385 | 386 | class LatentsDataset(Dataset): 387 | def __init__(self, latents_cache, text_encoder_cache): 388 | self.latents_cache = latents_cache 389 | self.text_encoder_cache = text_encoder_cache 390 | 391 | def __len__(self): 392 | return len(self.latents_cache) 393 | 394 | def __getitem__(self, index): 395 | return self.latents_cache[index], self.text_encoder_cache[index] 396 | 397 | 398 | class AverageMeter: 399 | def __init__(self, name=None): 400 | self.name = name 401 | self.reset() 402 | 403 | def reset(self): 404 | self.sum = self.count = self.avg = 0 405 | 406 | def update(self, val, n=1): 407 | self.sum += val * n 408 | self.count += n 409 | self.avg = self.sum / self.count 410 | 411 | 412 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 413 | if token is None: 414 | token = HfFolder.get_token() 415 | if organization is None: 416 | username = whoami(token)["name"] 417 | return f"{username}/{model_id}" 418 | else: 419 | return f"{organization}/{model_id}" 420 | 421 | 422 | def main(args): 423 | logging_dir = Path(args.output_dir, "0", args.logging_dir) 424 | 425 | accelerator = Accelerator( 426 | gradient_accumulation_steps=args.gradient_accumulation_steps, 427 | mixed_precision=args.mixed_precision, 428 | log_with="tensorboard", 429 | logging_dir=logging_dir, 430 | ) 431 | 432 | logging.basicConfig( 433 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 434 | datefmt="%m/%d/%Y %H:%M:%S", 435 | level=logging.INFO, 436 | ) 437 | 438 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 439 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 440 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 441 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 442 | raise ValueError( 443 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 444 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 445 | ) 446 | 447 | if args.seed is not None: 448 | set_seed(args.seed) 449 | 450 | if args.concepts_list is None: 451 | args.concepts_list = [ 452 | { 453 | "instance_prompt": args.instance_prompt, 454 | "class_prompt": args.class_prompt, 455 | "instance_data_dir": args.instance_data_dir, 456 | "class_data_dir": args.class_data_dir 457 | } 458 | ] 459 | else: 460 | with open(args.concepts_list, "r") as f: 461 | args.concepts_list = json.load(f) 462 | 463 | if args.with_prior_preservation: 464 | pipeline = None 465 | for concept in args.concepts_list: 466 | class_images_dir = Path(concept["class_data_dir"]) 467 | class_images_dir.mkdir(parents=True, exist_ok=True) 468 | cur_class_images = len(list(class_images_dir.iterdir())) 469 | 470 | if cur_class_images < args.num_class_images: 471 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 472 | if pipeline is None: 473 | pipeline = StableDiffusionPipeline.from_pretrained( 474 | args.pretrained_model_name_or_path, 475 | vae=AutoencoderKL.from_pretrained( 476 | args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path, 477 | subfolder=None if args.pretrained_vae_name_or_path else "vae", 478 | revision=None if args.pretrained_vae_name_or_path else args.revision, 479 | torch_dtype=torch_dtype 480 | ), 481 | torch_dtype=torch_dtype, 482 | safety_checker=None, 483 | revision=args.revision 484 | ) 485 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) 486 | if is_xformers_available(): 487 | pipeline.enable_xformers_memory_efficient_attention() 488 | pipeline.set_progress_bar_config(disable=True) 489 | pipeline.to(accelerator.device) 490 | 491 | num_new_images = args.num_class_images - cur_class_images 492 | logger.info(f"Number of class images to sample: {num_new_images}.") 493 | 494 | sample_dataset = PromptDataset(concept["class_prompt"], num_new_images) 495 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 496 | 497 | sample_dataloader = accelerator.prepare(sample_dataloader) 498 | 499 | with torch.autocast("cuda"), torch.inference_mode(): 500 | for example in tqdm( 501 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 502 | ): 503 | images = pipeline( 504 | example["prompt"], 505 | num_inference_steps=args.save_infer_steps 506 | ).images 507 | 508 | for i, image in enumerate(images): 509 | hash_image = hashlib.sha1(image.tobytes()).hexdigest() 510 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 511 | image.save(image_filename) 512 | 513 | del pipeline 514 | if torch.cuda.is_available(): 515 | torch.cuda.empty_cache() 516 | 517 | # Load the tokenizer 518 | if args.tokenizer_name: 519 | tokenizer = CLIPTokenizer.from_pretrained( 520 | args.tokenizer_name, 521 | revision=args.revision, 522 | ) 523 | elif args.pretrained_model_name_or_path: 524 | tokenizer = CLIPTokenizer.from_pretrained( 525 | args.pretrained_model_name_or_path, 526 | subfolder="tokenizer", 527 | revision=args.revision, 528 | ) 529 | 530 | # Load models and create wrapper for stable diffusion 531 | text_encoder = CLIPTextModel.from_pretrained( 532 | args.pretrained_model_name_or_path, 533 | subfolder="text_encoder", 534 | revision=args.revision, 535 | ) 536 | vae = AutoencoderKL.from_pretrained( 537 | args.pretrained_model_name_or_path, 538 | subfolder="vae", 539 | revision=args.revision, 540 | ) 541 | unet = UNet2DConditionModel.from_pretrained( 542 | args.pretrained_model_name_or_path, 543 | subfolder="unet", 544 | revision=args.revision, 545 | torch_dtype=torch.float32 546 | ) 547 | 548 | vae.requires_grad_(False) 549 | if not args.train_text_encoder: 550 | text_encoder.requires_grad_(False) 551 | 552 | if is_xformers_available(): 553 | vae.enable_xformers_memory_efficient_attention() 554 | unet.enable_xformers_memory_efficient_attention() 555 | else: 556 | logger.warning("xformers is not available. Make sure it is installed correctly") 557 | 558 | if args.gradient_checkpointing: 559 | unet.enable_gradient_checkpointing() 560 | if args.train_text_encoder: 561 | text_encoder.gradient_checkpointing_enable() 562 | 563 | if args.scale_lr: 564 | args.learning_rate = ( 565 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 566 | ) 567 | 568 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 569 | if args.use_8bit_adam: 570 | try: 571 | import bitsandbytes as bnb 572 | except ImportError: 573 | raise ImportError( 574 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 575 | ) 576 | 577 | optimizer_class = bnb.optim.AdamW8bit 578 | else: 579 | optimizer_class = torch.optim.AdamW 580 | 581 | params_to_optimize = ( 582 | itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() 583 | ) 584 | optimizer = optimizer_class( 585 | params_to_optimize, 586 | lr=args.learning_rate, 587 | betas=(args.adam_beta1, args.adam_beta2), 588 | weight_decay=args.adam_weight_decay, 589 | eps=args.adam_epsilon, 590 | ) 591 | 592 | noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") 593 | 594 | train_dataset = DreamBoothDataset( 595 | concepts_list=args.concepts_list, 596 | tokenizer=tokenizer, 597 | with_prior_preservation=args.with_prior_preservation, 598 | size=args.resolution, 599 | center_crop=args.center_crop, 600 | num_class_images=args.num_class_images, 601 | pad_tokens=args.pad_tokens, 602 | hflip=args.hflip, 603 | read_prompts_from_txts=args.read_prompts_from_txts, 604 | ) 605 | 606 | def collate_fn(examples): 607 | input_ids = [example["instance_prompt_ids"] for example in examples] 608 | pixel_values = [example["instance_images"] for example in examples] 609 | 610 | # Concat class and instance examples for prior preservation. 611 | # We do this to avoid doing two forward passes. 612 | if args.with_prior_preservation: 613 | input_ids += [example["class_prompt_ids"] for example in examples] 614 | pixel_values += [example["class_images"] for example in examples] 615 | 616 | pixel_values = torch.stack(pixel_values) 617 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 618 | 619 | input_ids = tokenizer.pad( 620 | {"input_ids": input_ids}, 621 | padding=True, 622 | return_tensors="pt", 623 | ).input_ids 624 | 625 | batch = { 626 | "input_ids": input_ids, 627 | "pixel_values": pixel_values, 628 | } 629 | return batch 630 | 631 | train_dataloader = torch.utils.data.DataLoader( 632 | train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True 633 | ) 634 | 635 | weight_dtype = torch.float32 636 | if args.mixed_precision == "fp16": 637 | weight_dtype = torch.float16 638 | elif args.mixed_precision == "bf16": 639 | weight_dtype = torch.bfloat16 640 | 641 | # Move text_encode and vae to gpu. 642 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 643 | # as these models are only used for inference, keeping weights in full precision is not required. 644 | vae.to(accelerator.device, dtype=weight_dtype) 645 | if not args.train_text_encoder: 646 | text_encoder.to(accelerator.device, dtype=weight_dtype) 647 | 648 | if not args.not_cache_latents: 649 | latents_cache = [] 650 | text_encoder_cache = [] 651 | for batch in tqdm(train_dataloader, desc="Caching latents"): 652 | with torch.no_grad(): 653 | batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype) 654 | batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True) 655 | latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) 656 | if args.train_text_encoder: 657 | text_encoder_cache.append(batch["input_ids"]) 658 | else: 659 | text_encoder_cache.append(text_encoder(batch["input_ids"])[0]) 660 | train_dataset = LatentsDataset(latents_cache, text_encoder_cache) 661 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True) 662 | 663 | del vae 664 | if not args.train_text_encoder: 665 | del text_encoder 666 | if torch.cuda.is_available(): 667 | torch.cuda.empty_cache() 668 | 669 | # Scheduler and math around the number of training steps. 670 | overrode_max_train_steps = False 671 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 672 | if args.max_train_steps is None: 673 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 674 | overrode_max_train_steps = True 675 | 676 | lr_scheduler = get_scheduler( 677 | args.lr_scheduler, 678 | optimizer=optimizer, 679 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 680 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 681 | ) 682 | 683 | if args.train_text_encoder: 684 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 685 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 686 | ) 687 | else: 688 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 689 | unet, optimizer, train_dataloader, lr_scheduler 690 | ) 691 | 692 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 693 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 694 | if overrode_max_train_steps: 695 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 696 | # Afterwards we recalculate our number of training epochs 697 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 698 | 699 | # We need to initialize the trackers we use, and also store our configuration. 700 | # The trackers initializes automatically on the main process. 701 | if accelerator.is_main_process: 702 | accelerator.init_trackers("dreambooth") 703 | 704 | # Train! 705 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 706 | 707 | logger.info("***** Running training *****") 708 | logger.info(f" Num examples = {len(train_dataset)}") 709 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 710 | logger.info(f" Num Epochs = {args.num_train_epochs}") 711 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 712 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 713 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 714 | logger.info(f" Total optimization steps = {args.max_train_steps}") 715 | 716 | def save_weights(step): 717 | # Create the pipeline using using the trained modules and save it. 718 | if accelerator.is_main_process: 719 | if args.train_text_encoder: 720 | text_enc_model = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 721 | else: 722 | text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision) 723 | pipeline = StableDiffusionPipeline.from_pretrained( 724 | args.pretrained_model_name_or_path, 725 | unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True), 726 | text_encoder=text_enc_model, 727 | vae=AutoencoderKL.from_pretrained( 728 | args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path, 729 | subfolder=None if args.pretrained_vae_name_or_path else "vae", 730 | revision=None if args.pretrained_vae_name_or_path else args.revision, 731 | ), 732 | safety_checker=None, 733 | torch_dtype=torch.float16, 734 | revision=args.revision, 735 | ) 736 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) 737 | if is_xformers_available(): 738 | pipeline.enable_xformers_memory_efficient_attention() 739 | save_dir = os.path.join(args.output_dir, f"{step}") 740 | pipeline.save_pretrained(save_dir) 741 | with open(os.path.join(save_dir, "args.json"), "w") as f: 742 | json.dump(args.__dict__, f, indent=2) 743 | 744 | if args.save_sample_prompt is not None: 745 | pipeline = pipeline.to(accelerator.device) 746 | g_cuda = torch.Generator(device=accelerator.device).manual_seed(args.seed) 747 | pipeline.set_progress_bar_config(disable=True) 748 | sample_dir = os.path.join(save_dir, "samples") 749 | os.makedirs(sample_dir, exist_ok=True) 750 | with torch.autocast("cuda"), torch.inference_mode(): 751 | for i in tqdm(range(args.n_save_sample), desc="Generating samples"): 752 | images = pipeline( 753 | args.save_sample_prompt, 754 | negative_prompt=args.save_sample_negative_prompt, 755 | guidance_scale=args.save_guidance_scale, 756 | num_inference_steps=args.save_infer_steps, 757 | generator=g_cuda 758 | ).images 759 | images[0].save(os.path.join(sample_dir, f"{i}.png")) 760 | del pipeline 761 | if torch.cuda.is_available(): 762 | torch.cuda.empty_cache() 763 | print(f"[*] Weights saved at {save_dir}") 764 | 765 | # Only show the progress bar once on each machine. 766 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 767 | progress_bar.set_description("Steps") 768 | global_step = 0 769 | loss_avg = AverageMeter() 770 | text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad() 771 | for epoch in range(args.num_train_epochs): 772 | unet.train() 773 | if args.train_text_encoder: 774 | text_encoder.train() 775 | for step, batch in enumerate(train_dataloader): 776 | with accelerator.accumulate(unet): 777 | # Convert images to latent space 778 | with torch.no_grad(): 779 | if not args.not_cache_latents: 780 | latent_dist = batch[0][0] 781 | else: 782 | latent_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist 783 | latents = latent_dist.sample() * 0.18215 784 | 785 | # Sample noise that we'll add to the latents 786 | noise = torch.randn_like(latents) 787 | bsz = latents.shape[0] 788 | # Sample a random timestep for each image 789 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 790 | timesteps = timesteps.long() 791 | 792 | # Add noise to the latents according to the noise magnitude at each timestep 793 | # (this is the forward diffusion process) 794 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 795 | 796 | # Get the text embedding for conditioning 797 | with text_enc_context: 798 | if not args.not_cache_latents: 799 | if args.train_text_encoder: 800 | encoder_hidden_states = text_encoder(batch[0][1])[0] 801 | else: 802 | encoder_hidden_states = batch[0][1] 803 | else: 804 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 805 | 806 | # Predict the noise residual 807 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 808 | 809 | # Get the target for loss depending on the prediction type 810 | if noise_scheduler.config.prediction_type == "epsilon": 811 | target = noise 812 | elif noise_scheduler.config.prediction_type == "v_prediction": 813 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 814 | else: 815 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 816 | 817 | if args.with_prior_preservation: 818 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 819 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 820 | target, target_prior = torch.chunk(target, 2, dim=0) 821 | 822 | # Compute instance loss 823 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 824 | 825 | # Compute prior loss 826 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 827 | 828 | # Add the prior loss to the instance loss. 829 | loss = loss + args.prior_loss_weight * prior_loss 830 | else: 831 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 832 | 833 | accelerator.backward(loss) 834 | # if accelerator.sync_gradients: 835 | # params_to_clip = ( 836 | # itertools.chain(unet.parameters(), text_encoder.parameters()) 837 | # if args.train_text_encoder 838 | # else unet.parameters() 839 | # ) 840 | # accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 841 | optimizer.step() 842 | lr_scheduler.step() 843 | optimizer.zero_grad(set_to_none=True) 844 | loss_avg.update(loss.detach_(), bsz) 845 | 846 | if not global_step % args.log_interval: 847 | logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]} 848 | progress_bar.set_postfix(**logs) 849 | accelerator.log(logs, step=global_step) 850 | 851 | if global_step > 0 and not global_step % args.save_interval and global_step >= args.save_min_steps: 852 | save_weights(global_step) 853 | 854 | progress_bar.update(1) 855 | global_step += 1 856 | 857 | if global_step >= args.max_train_steps: 858 | break 859 | 860 | accelerator.wait_for_everyone() 861 | 862 | save_weights(global_step) 863 | 864 | accelerator.end_training() 865 | 866 | 867 | if __name__ == "__main__": 868 | args = parse_args() 869 | main(args) -------------------------------------------------------------------------------- /training/train_dreambooth_commands.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Dreambooth training \n", 9 | "### We recommend [official Diffusers tutorial](https://huggingface.co/docs/diffusers/training/dreambooth) which helps to get familiar with this type of fine-tuning, explains the environment setup and used parameters in detail. \n", 10 | "### Our fine-tuned models and training data are available on OneDrive." 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## Crop and resize selected images to 512x512." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from PIL import Image\n", 28 | "import os\n", 29 | "from torchvision.transforms import functional as F\n", 30 | "from torchvision.transforms import InterpolationMode\n", 31 | "import random\n", 32 | "import json" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "folder = \"/path/to/training/images\"\n", 42 | "folder_cropped = \"/path/to/training/images_croped\"\n", 43 | "os.makedirs(folder_cropped, exist_ok=True)\n", 44 | "\n", 45 | "for filename in os.listdir(folder):\n", 46 | " if filename.endswith('.jpeg') or filename.endswith('.png') or filename.endswith('.jpg'):\n", 47 | " image = Image.open(os.path.join(folder, filename)).convert(\"RGB\")\n", 48 | " image = F.center_crop(image, 480) # this size (eg. 480) depends on cholect video\n", 49 | " image = F.resize(image, 512, interpolation = InterpolationMode.BILINEAR)\n", 50 | " image.save(os.path.join(folder_cropped, filename))" 51 | ] 52 | }, 53 | { 54 | "attachments": {}, 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "## Creating json concept list" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "concepts_list = [\n", 68 | " {\n", 69 | " \"instance_prompt\": \"cholect45\",\n", 70 | " \"class_prompt\": \"\",\n", 71 | " \"instance_data_dir\": f\"{folder_cropped}\"\n", 72 | " },\n", 73 | "]\n", 74 | "\n", 75 | "with open(\"./concepts_list.json\", \"w\") as f:\n", 76 | " json.dump(concepts_list, f, indent=4)" 77 | ] 78 | }, 79 | { 80 | "attachments": {}, 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## DREAMBOOTH TRAINING\n", 85 | "### Following commands run in Terminal. Also remember to have `diffusers venv` activated!" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "'''\n", 95 | "cd training\n", 96 | "source diffusers_venv/bin/activate\n", 97 | "export MODEL_NAME=\"runwayml/stable-diffusion-v1-5\"\n", 98 | "export OUTPUT_DIR=\"/path/to/save/checkpoints\"\n", 99 | "'''" 100 | ] 101 | }, 102 | { 103 | "attachments": {}, 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "### Template for Dreabooth training command" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "#template\n", 117 | "'''\n", 118 | "!accelerate launch train_dreambooth.py \\\n", 119 | " --pretrained_model_name_or_path=$MODEL_NAME \\\n", 120 | " --pretrained_vae_name_or_path=\"stabilityai/sd-vae-ft-mse\" \\\n", 121 | " --output_dir=$OUTPUT_DIR \\\n", 122 | " --revision=\"fp16\" \\\n", 123 | " --with_prior_preservation --prior_loss_weight=1.0 \\\n", 124 | " --seed=1337 \\\n", 125 | " --resolution=512 \\\n", 126 | " --train_batch_size=1 \\\n", 127 | " --train_text_encoder \\\n", 128 | " --mixed_precision=\"fp16\" \\\n", 129 | " --use_8bit_adam \\\n", 130 | " --gradient_accumulation_steps=1 \\\n", 131 | " --learning_rate=1e-6 \\\n", 132 | " --lr_scheduler=\"constant\" \\\n", 133 | " --lr_warmup_steps=0 \\\n", 134 | " --num_class_images=50 \\\n", 135 | " --sample_batch_size=4 \\\n", 136 | " --max_train_steps=800 \\\n", 137 | " --save_interval=10000 \\\n", 138 | " --save_sample_prompt=\"photo of zwx dog\" \\\n", 139 | " --concepts_list=\"concepts_list.json\"\n", 140 | "'''" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "### The actual command (with proper parameters) we used for training all styles - run in Terminal" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "'''\n", 157 | "CUDA_VISIBLE_DEVICES=1 accelerate launch train_dreambooth.py \\\n", 158 | "--pretrained_model_name_or_path=$MODEL_NAME \\\n", 159 | "--output_dir=$OUTPUT_DIR \\\n", 160 | "--concepts_list=\"concepts_list.json\" \\\n", 161 | "--revision=\"fp16\" \\\n", 162 | "--train_text_encoder \\\n", 163 | "--seed=1337 \\\n", 164 | "--resolution=512 \\\n", 165 | "--train_batch_size=4 \\\n", 166 | "--mixed_precision=\"fp16\" \\\n", 167 | "--gradient_accumulation_steps=1 \\\n", 168 | "--learning_rate=1e-6 \\\n", 169 | "--lr_warmup_steps=0 \\\n", 170 | "--num_class_images=50 \\\n", 171 | "--save_interval=500 \\\n", 172 | "--max_train_steps=3000\n", 173 | "'''" 174 | ] 175 | }, 176 | { 177 | "attachments": {}, 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "### Quick inference for sanity check using diffusers pipeline" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "# #Quick inference for sanity check\n", 191 | "# from diffusers import StableDiffusionPipeline\n", 192 | "# import torch\n", 193 | "\n", 194 | "# model_id = \"/path/to/save/checkpoints\"\n", 195 | "# pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\n", 196 | "\n", 197 | "# prompt = \"cholect45\"\n", 198 | "# image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]\n", 199 | "\n", 200 | "# image" 201 | ] 202 | }, 203 | { 204 | "attachments": {}, 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "### If you want to upload the fine-tuned model to WebUI, convert diffusers format to original SD format and save it to proper WebUI models folder" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 4, 214 | "metadata": {}, 215 | "outputs": [ 216 | { 217 | "name": "stdout", 218 | "output_type": "stream", 219 | "text": [ 220 | "Reshaping encoder.mid.attn_1.q.weight for SD format\n", 221 | "Reshaping encoder.mid.attn_1.k.weight for SD format\n", 222 | "Reshaping encoder.mid.attn_1.v.weight for SD format\n", 223 | "Reshaping encoder.mid.attn_1.proj_out.weight for SD format\n", 224 | "Reshaping decoder.mid.attn_1.q.weight for SD format\n", 225 | "Reshaping decoder.mid.attn_1.k.weight for SD format\n", 226 | "Reshaping decoder.mid.attn_1.v.weight for SD format\n", 227 | "Reshaping decoder.mid.attn_1.proj_out.weight for SD format\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "!python convert_diffusers_to_sd.py --model_path ./cholect_vid52_56_v2_ckpts/2000 \\\n", 233 | " --checkpoint_path ../stable-diffusion-webui/models/Stable-diffusion/cholect_vid52_56_v2_2000.safetensors --half --use_safetensors\n" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [] 242 | } 243 | ], 244 | "metadata": { 245 | "kernelspec": { 246 | "display_name": "base", 247 | "language": "python", 248 | "name": "python3" 249 | }, 250 | "language_info": { 251 | "codemirror_mode": { 252 | "name": "ipython", 253 | "version": 3 254 | }, 255 | "file_extension": ".py", 256 | "mimetype": "text/x-python", 257 | "name": "python", 258 | "nbconvert_exporter": "python", 259 | "pygments_lexer": "ipython3", 260 | "version": "3.9.16" 261 | }, 262 | "orig_nbformat": 4 263 | }, 264 | "nbformat": 4, 265 | "nbformat_minor": 2 266 | } 267 | --------------------------------------------------------------------------------