├── DDColor_colab.ipynb ├── DDColor_gradio_colab.ipynb └── README.md /DDColor_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/DDColor-colab/blob/main/DDColor_colab.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%cd /content\n", 19 | "!git clone -b dev https://github.com/camenduru/DDColor\n", 20 | "\n", 21 | "!apt -y install -qq aria2\n", 22 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/cv_ddcolor_image-colorization/resolve/main/pytorch_model.pt -d /content/DDColor/models -o pytorch_model.pt\n", 23 | "\n", 24 | "!wget https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/audrey_hepburn.jpg -O /content/DDColor/in.jpg\n", 25 | "!pip install -q timm\n", 26 | "\n", 27 | "%cd /content/DDColor\n", 28 | "\n", 29 | "!sed -i 's/from \\.version import __gitsha__, __version__/# from \\.version import __gitsha__, __version__/' /content/DDColor/basicsr/__init__.py\n", 30 | "\n", 31 | "import argparse\n", 32 | "import cv2\n", 33 | "import numpy as np\n", 34 | "import os\n", 35 | "from tqdm import tqdm\n", 36 | "import torch\n", 37 | "from basicsr.archs.ddcolor_arch import DDColor\n", 38 | "import torch.nn.functional as F\n", 39 | "\n", 40 | "class ImageColorizationPipeline(object):\n", 41 | "\n", 42 | " def __init__(self, model_path, input_size=256, model_size='large'):\n", 43 | " \n", 44 | " self.input_size = input_size\n", 45 | " if torch.cuda.is_available():\n", 46 | " self.device = torch.device('cuda')\n", 47 | " else:\n", 48 | " self.device = torch.device('cpu')\n", 49 | "\n", 50 | " if model_size == 'tiny':\n", 51 | " self.encoder_name = 'convnext-t'\n", 52 | " else:\n", 53 | " self.encoder_name = 'convnext-l'\n", 54 | "\n", 55 | " self.decoder_type = \"MultiScaleColorDecoder\"\n", 56 | "\n", 57 | " if self.decoder_type == 'MultiScaleColorDecoder':\n", 58 | " self.model = DDColor(\n", 59 | " encoder_name=self.encoder_name,\n", 60 | " decoder_name='MultiScaleColorDecoder',\n", 61 | " input_size=[self.input_size, self.input_size],\n", 62 | " num_output_channels=2,\n", 63 | " last_norm='Spectral',\n", 64 | " do_normalize=False,\n", 65 | " num_queries=100,\n", 66 | " num_scales=3,\n", 67 | " dec_layers=9,\n", 68 | " ).to(self.device)\n", 69 | " else:\n", 70 | " self.model = DDColor(\n", 71 | " encoder_name=self.encoder_name,\n", 72 | " decoder_name='SingleColorDecoder',\n", 73 | " input_size=[self.input_size, self.input_size],\n", 74 | " num_output_channels=2,\n", 75 | " last_norm='Spectral',\n", 76 | " do_normalize=False,\n", 77 | " num_queries=256,\n", 78 | " ).to(self.device)\n", 79 | "\n", 80 | " self.model.load_state_dict(\n", 81 | " torch.load(model_path, map_location=torch.device('cpu'))['params'],\n", 82 | " strict=False)\n", 83 | " self.model.eval()\n", 84 | "\n", 85 | " @torch.no_grad()\n", 86 | " def process(self, img):\n", 87 | " self.height, self.width = img.shape[:2]\n", 88 | " # print(self.width, self.height)\n", 89 | " # if self.width * self.height < 100000:\n", 90 | " # self.input_size = 256\n", 91 | "\n", 92 | " img = (img / 255.0).astype(np.float32)\n", 93 | " orig_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1] # (h, w, 1)\n", 94 | "\n", 95 | " # resize rgb image -> lab -> get grey -> rgb\n", 96 | " img = cv2.resize(img, (self.input_size, self.input_size))\n", 97 | " img_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]\n", 98 | " img_gray_lab = np.concatenate((img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1)\n", 99 | " img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB)\n", 100 | "\n", 101 | " tensor_gray_rgb = torch.from_numpy(img_gray_rgb.transpose((2, 0, 1))).float().unsqueeze(0).to(self.device)\n", 102 | " output_ab = self.model(tensor_gray_rgb).cpu() # (1, 2, self.height, self.width)\n", 103 | "\n", 104 | " # resize ab -> concat original l -> rgb\n", 105 | " output_ab_resize = F.interpolate(output_ab, size=(self.height, self.width))[0].float().numpy().transpose(1, 2, 0)\n", 106 | " output_lab = np.concatenate((orig_l, output_ab_resize), axis=-1)\n", 107 | " output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR)\n", 108 | "\n", 109 | " output_img = (output_bgr * 255.0).round().astype(np.uint8) \n", 110 | "\n", 111 | " return output_img\n", 112 | "\n", 113 | "colorizer = ImageColorizationPipeline(model_path='/content/DDColor/models/pytorch_model.pt', input_size=512)\n", 114 | "\n", 115 | "# helper function taken from: https://huggingface.co/blog/stable_diffusion\n", 116 | "from PIL import Image\n", 117 | "def image_grid(imgs, rows, cols):\n", 118 | " assert len(imgs) == rows*cols\n", 119 | "\n", 120 | " w, h = imgs[0].size\n", 121 | " grid = Image.new('RGB', size=(cols*w, rows*h))\n", 122 | " grid_w, grid_h = grid.size\n", 123 | "\n", 124 | " for i, img in enumerate(imgs):\n", 125 | " grid.paste(img, box=(i%cols*w, i//cols*h))\n", 126 | " return grid" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "image_in = cv2.imread('/content/DDColor/in.jpg')\n", 136 | "image_out = colorizer.process(image_in)\n", 137 | "cv2.imwrite('/content/DDColor/out.jpg', image_out)\n", 138 | "image_in_pil = Image.fromarray(cv2.cvtColor(image_in, cv2.COLOR_BGR2RGB))\n", 139 | "image_out_pil = Image.fromarray(cv2.cvtColor(image_out, cv2.COLOR_BGR2RGB))\n", 140 | "images = [image_in_pil, image_out_pil]\n", 141 | "grid = image_grid(images, rows=1, cols=2)\n", 142 | "grid" 143 | ] 144 | } 145 | ], 146 | "metadata": { 147 | "accelerator": "GPU", 148 | "colab": { 149 | "gpuType": "T4", 150 | "provenance": [] 151 | }, 152 | "kernelspec": { 153 | "display_name": "Python 3", 154 | "name": "python3" 155 | }, 156 | "language_info": { 157 | "name": "python" 158 | } 159 | }, 160 | "nbformat": 4, 161 | "nbformat_minor": 0 162 | } 163 | -------------------------------------------------------------------------------- /DDColor_gradio_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/DDColor-colab/blob/main/DDColor_gradio_colab.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%cd /content\n", 19 | "!git clone -b dev https://github.com/camenduru/DDColor\n", 20 | "\n", 21 | "!apt -y install -qq aria2\n", 22 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/cv_ddcolor_image-colorization/resolve/main/pytorch_model.pt -d /content/DDColor/models -o pytorch_model.pt\n", 23 | "\n", 24 | "!wget https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/audrey_hepburn.jpg -O /content/DDColor/in.jpg\n", 25 | "!pip install -q timm gradio gradio_imageslider\n", 26 | "\n", 27 | "%cd /content/DDColor\n", 28 | "\n", 29 | "!sed -i 's/from \\.version import __gitsha__, __version__/# from \\.version import __gitsha__, __version__/' /content/DDColor/basicsr/__init__.py\n", 30 | "\n", 31 | "import argparse\n", 32 | "import cv2\n", 33 | "import numpy as np\n", 34 | "import os\n", 35 | "from tqdm import tqdm\n", 36 | "import torch\n", 37 | "from basicsr.archs.ddcolor_arch import DDColor\n", 38 | "import torch.nn.functional as F\n", 39 | "\n", 40 | "class ImageColorizationPipeline(object):\n", 41 | "\n", 42 | " def __init__(self, model_path, input_size=256, model_size='large'):\n", 43 | " \n", 44 | " self.input_size = input_size\n", 45 | " if torch.cuda.is_available():\n", 46 | " self.device = torch.device('cuda')\n", 47 | " else:\n", 48 | " self.device = torch.device('cpu')\n", 49 | "\n", 50 | " if model_size == 'tiny':\n", 51 | " self.encoder_name = 'convnext-t'\n", 52 | " else:\n", 53 | " self.encoder_name = 'convnext-l'\n", 54 | "\n", 55 | " self.decoder_type = \"MultiScaleColorDecoder\"\n", 56 | "\n", 57 | " if self.decoder_type == 'MultiScaleColorDecoder':\n", 58 | " self.model = DDColor(\n", 59 | " encoder_name=self.encoder_name,\n", 60 | " decoder_name='MultiScaleColorDecoder',\n", 61 | " input_size=[self.input_size, self.input_size],\n", 62 | " num_output_channels=2,\n", 63 | " last_norm='Spectral',\n", 64 | " do_normalize=False,\n", 65 | " num_queries=100,\n", 66 | " num_scales=3,\n", 67 | " dec_layers=9,\n", 68 | " ).to(self.device)\n", 69 | " else:\n", 70 | " self.model = DDColor(\n", 71 | " encoder_name=self.encoder_name,\n", 72 | " decoder_name='SingleColorDecoder',\n", 73 | " input_size=[self.input_size, self.input_size],\n", 74 | " num_output_channels=2,\n", 75 | " last_norm='Spectral',\n", 76 | " do_normalize=False,\n", 77 | " num_queries=256,\n", 78 | " ).to(self.device)\n", 79 | "\n", 80 | " self.model.load_state_dict(\n", 81 | " torch.load(model_path, map_location=torch.device('cpu'))['params'],\n", 82 | " strict=False)\n", 83 | " self.model.eval()\n", 84 | "\n", 85 | " @torch.no_grad()\n", 86 | " def process(self, img):\n", 87 | " self.height, self.width = img.shape[:2]\n", 88 | " # print(self.width, self.height)\n", 89 | " # if self.width * self.height < 100000:\n", 90 | " # self.input_size = 256\n", 91 | "\n", 92 | " img = (img / 255.0).astype(np.float32)\n", 93 | " orig_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1] # (h, w, 1)\n", 94 | "\n", 95 | " # resize rgb image -> lab -> get grey -> rgb\n", 96 | " img = cv2.resize(img, (self.input_size, self.input_size))\n", 97 | " img_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]\n", 98 | " img_gray_lab = np.concatenate((img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1)\n", 99 | " img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB)\n", 100 | "\n", 101 | " tensor_gray_rgb = torch.from_numpy(img_gray_rgb.transpose((2, 0, 1))).float().unsqueeze(0).to(self.device)\n", 102 | " output_ab = self.model(tensor_gray_rgb).cpu() # (1, 2, self.height, self.width)\n", 103 | "\n", 104 | " # resize ab -> concat original l -> rgb\n", 105 | " output_ab_resize = F.interpolate(output_ab, size=(self.height, self.width))[0].float().numpy().transpose(1, 2, 0)\n", 106 | " output_lab = np.concatenate((orig_l, output_ab_resize), axis=-1)\n", 107 | " output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR)\n", 108 | "\n", 109 | " output_img = (output_bgr * 255.0).round().astype(np.uint8) \n", 110 | "\n", 111 | " return output_img\n", 112 | "\n", 113 | "colorizer = ImageColorizationPipeline(model_path='/content/DDColor/models/pytorch_model.pt', input_size=512)\n", 114 | "\n", 115 | "from PIL import Image\n", 116 | "import gradio as gr\n", 117 | "import subprocess\n", 118 | "import shutil, os\n", 119 | "from gradio_imageslider import ImageSlider\n", 120 | "\n", 121 | "def generate(image):\n", 122 | " image_in = cv2.imread(image)\n", 123 | " image_out = colorizer.process(image_in)\n", 124 | " cv2.imwrite('/content/DDColor/out.jpg', image_out)\n", 125 | " image_in_pil = Image.fromarray(cv2.cvtColor(image_in, cv2.COLOR_BGR2RGB))\n", 126 | " image_out_pil = Image.fromarray(cv2.cvtColor(image_out, cv2.COLOR_BGR2RGB))\n", 127 | " return (image_in_pil, image_out_pil)\n", 128 | "\n", 129 | "with gr.Blocks() as demo:\n", 130 | " with gr.Row():\n", 131 | " with gr.Column():\n", 132 | " image = gr.Image(type='filepath')\n", 133 | " button = gr.Button()\n", 134 | " output_image = ImageSlider(show_label=False, type=\"filepath\", interactive=False)\n", 135 | " button.click(fn=generate, inputs=[image], outputs=[output_image])\n", 136 | "\n", 137 | "demo.queue().launch(inline=False, share=True, debug=True)" 138 | ] 139 | } 140 | ], 141 | "metadata": { 142 | "accelerator": "GPU", 143 | "colab": { 144 | "gpuType": "T4", 145 | "provenance": [] 146 | }, 147 | "kernelspec": { 148 | "display_name": "Python 3", 149 | "name": "python3" 150 | }, 151 | "language_info": { 152 | "name": "python" 153 | } 154 | }, 155 | "nbformat": 4, 156 | "nbformat_minor": 0 157 | } 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 🐣 Please follow me for new updates https://twitter.com/camenduru
2 | 🔥 Please join our discord server https://discord.gg/k5BwmmvJJU
3 | 🥳 Please join my patreon community https://patreon.com/camenduru
4 | 5 | ### 🦒 Colab 6 | 7 | | Colab | Info 8 | | --- | --- | 9 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/DDColor-colab/blob/main/DDColor_colab.ipynb) | DDColor_colab 10 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/DDColor-colab/blob/main/DDColor_gradio_colab.ipynb) | DDColor_gradio_colab 11 | 12 | ### 🧬 Code 13 | https://github.com/piddnad/DDColor 14 | 15 | ### 📄 Paper 16 | https://arxiv.org/abs/2212.11613 17 | 18 | ### 🌐 Page 19 | https://www.modelscope.cn/models/damo/cv_ddcolor_image-colorization/summary 20 | 21 | ### 🖼 Output 22 | 23 | https://github.com/camenduru/DDColor-colab/assets/54370274/6f75b3ae-a1ed-48a8-882b-0c00ae332c3c 24 | 25 | 26 | ### 🏢 Sponsor 27 | https://modelslab.com 28 | --------------------------------------------------------------------------------