├── DDColor_colab.ipynb
├── DDColor_gradio_colab.ipynb
└── README.md
/DDColor_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github"
7 | },
8 | "source": [
9 | "[](https://colab.research.google.com/github/camenduru/DDColor-colab/blob/main/DDColor_colab.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "%cd /content\n",
19 | "!git clone -b dev https://github.com/camenduru/DDColor\n",
20 | "\n",
21 | "!apt -y install -qq aria2\n",
22 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/cv_ddcolor_image-colorization/resolve/main/pytorch_model.pt -d /content/DDColor/models -o pytorch_model.pt\n",
23 | "\n",
24 | "!wget https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/audrey_hepburn.jpg -O /content/DDColor/in.jpg\n",
25 | "!pip install -q timm\n",
26 | "\n",
27 | "%cd /content/DDColor\n",
28 | "\n",
29 | "!sed -i 's/from \\.version import __gitsha__, __version__/# from \\.version import __gitsha__, __version__/' /content/DDColor/basicsr/__init__.py\n",
30 | "\n",
31 | "import argparse\n",
32 | "import cv2\n",
33 | "import numpy as np\n",
34 | "import os\n",
35 | "from tqdm import tqdm\n",
36 | "import torch\n",
37 | "from basicsr.archs.ddcolor_arch import DDColor\n",
38 | "import torch.nn.functional as F\n",
39 | "\n",
40 | "class ImageColorizationPipeline(object):\n",
41 | "\n",
42 | " def __init__(self, model_path, input_size=256, model_size='large'):\n",
43 | " \n",
44 | " self.input_size = input_size\n",
45 | " if torch.cuda.is_available():\n",
46 | " self.device = torch.device('cuda')\n",
47 | " else:\n",
48 | " self.device = torch.device('cpu')\n",
49 | "\n",
50 | " if model_size == 'tiny':\n",
51 | " self.encoder_name = 'convnext-t'\n",
52 | " else:\n",
53 | " self.encoder_name = 'convnext-l'\n",
54 | "\n",
55 | " self.decoder_type = \"MultiScaleColorDecoder\"\n",
56 | "\n",
57 | " if self.decoder_type == 'MultiScaleColorDecoder':\n",
58 | " self.model = DDColor(\n",
59 | " encoder_name=self.encoder_name,\n",
60 | " decoder_name='MultiScaleColorDecoder',\n",
61 | " input_size=[self.input_size, self.input_size],\n",
62 | " num_output_channels=2,\n",
63 | " last_norm='Spectral',\n",
64 | " do_normalize=False,\n",
65 | " num_queries=100,\n",
66 | " num_scales=3,\n",
67 | " dec_layers=9,\n",
68 | " ).to(self.device)\n",
69 | " else:\n",
70 | " self.model = DDColor(\n",
71 | " encoder_name=self.encoder_name,\n",
72 | " decoder_name='SingleColorDecoder',\n",
73 | " input_size=[self.input_size, self.input_size],\n",
74 | " num_output_channels=2,\n",
75 | " last_norm='Spectral',\n",
76 | " do_normalize=False,\n",
77 | " num_queries=256,\n",
78 | " ).to(self.device)\n",
79 | "\n",
80 | " self.model.load_state_dict(\n",
81 | " torch.load(model_path, map_location=torch.device('cpu'))['params'],\n",
82 | " strict=False)\n",
83 | " self.model.eval()\n",
84 | "\n",
85 | " @torch.no_grad()\n",
86 | " def process(self, img):\n",
87 | " self.height, self.width = img.shape[:2]\n",
88 | " # print(self.width, self.height)\n",
89 | " # if self.width * self.height < 100000:\n",
90 | " # self.input_size = 256\n",
91 | "\n",
92 | " img = (img / 255.0).astype(np.float32)\n",
93 | " orig_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1] # (h, w, 1)\n",
94 | "\n",
95 | " # resize rgb image -> lab -> get grey -> rgb\n",
96 | " img = cv2.resize(img, (self.input_size, self.input_size))\n",
97 | " img_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]\n",
98 | " img_gray_lab = np.concatenate((img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1)\n",
99 | " img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB)\n",
100 | "\n",
101 | " tensor_gray_rgb = torch.from_numpy(img_gray_rgb.transpose((2, 0, 1))).float().unsqueeze(0).to(self.device)\n",
102 | " output_ab = self.model(tensor_gray_rgb).cpu() # (1, 2, self.height, self.width)\n",
103 | "\n",
104 | " # resize ab -> concat original l -> rgb\n",
105 | " output_ab_resize = F.interpolate(output_ab, size=(self.height, self.width))[0].float().numpy().transpose(1, 2, 0)\n",
106 | " output_lab = np.concatenate((orig_l, output_ab_resize), axis=-1)\n",
107 | " output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR)\n",
108 | "\n",
109 | " output_img = (output_bgr * 255.0).round().astype(np.uint8) \n",
110 | "\n",
111 | " return output_img\n",
112 | "\n",
113 | "colorizer = ImageColorizationPipeline(model_path='/content/DDColor/models/pytorch_model.pt', input_size=512)\n",
114 | "\n",
115 | "# helper function taken from: https://huggingface.co/blog/stable_diffusion\n",
116 | "from PIL import Image\n",
117 | "def image_grid(imgs, rows, cols):\n",
118 | " assert len(imgs) == rows*cols\n",
119 | "\n",
120 | " w, h = imgs[0].size\n",
121 | " grid = Image.new('RGB', size=(cols*w, rows*h))\n",
122 | " grid_w, grid_h = grid.size\n",
123 | "\n",
124 | " for i, img in enumerate(imgs):\n",
125 | " grid.paste(img, box=(i%cols*w, i//cols*h))\n",
126 | " return grid"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {},
133 | "outputs": [],
134 | "source": [
135 | "image_in = cv2.imread('/content/DDColor/in.jpg')\n",
136 | "image_out = colorizer.process(image_in)\n",
137 | "cv2.imwrite('/content/DDColor/out.jpg', image_out)\n",
138 | "image_in_pil = Image.fromarray(cv2.cvtColor(image_in, cv2.COLOR_BGR2RGB))\n",
139 | "image_out_pil = Image.fromarray(cv2.cvtColor(image_out, cv2.COLOR_BGR2RGB))\n",
140 | "images = [image_in_pil, image_out_pil]\n",
141 | "grid = image_grid(images, rows=1, cols=2)\n",
142 | "grid"
143 | ]
144 | }
145 | ],
146 | "metadata": {
147 | "accelerator": "GPU",
148 | "colab": {
149 | "gpuType": "T4",
150 | "provenance": []
151 | },
152 | "kernelspec": {
153 | "display_name": "Python 3",
154 | "name": "python3"
155 | },
156 | "language_info": {
157 | "name": "python"
158 | }
159 | },
160 | "nbformat": 4,
161 | "nbformat_minor": 0
162 | }
163 |
--------------------------------------------------------------------------------
/DDColor_gradio_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github"
7 | },
8 | "source": [
9 | "[](https://colab.research.google.com/github/camenduru/DDColor-colab/blob/main/DDColor_gradio_colab.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "%cd /content\n",
19 | "!git clone -b dev https://github.com/camenduru/DDColor\n",
20 | "\n",
21 | "!apt -y install -qq aria2\n",
22 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/cv_ddcolor_image-colorization/resolve/main/pytorch_model.pt -d /content/DDColor/models -o pytorch_model.pt\n",
23 | "\n",
24 | "!wget https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/audrey_hepburn.jpg -O /content/DDColor/in.jpg\n",
25 | "!pip install -q timm gradio gradio_imageslider\n",
26 | "\n",
27 | "%cd /content/DDColor\n",
28 | "\n",
29 | "!sed -i 's/from \\.version import __gitsha__, __version__/# from \\.version import __gitsha__, __version__/' /content/DDColor/basicsr/__init__.py\n",
30 | "\n",
31 | "import argparse\n",
32 | "import cv2\n",
33 | "import numpy as np\n",
34 | "import os\n",
35 | "from tqdm import tqdm\n",
36 | "import torch\n",
37 | "from basicsr.archs.ddcolor_arch import DDColor\n",
38 | "import torch.nn.functional as F\n",
39 | "\n",
40 | "class ImageColorizationPipeline(object):\n",
41 | "\n",
42 | " def __init__(self, model_path, input_size=256, model_size='large'):\n",
43 | " \n",
44 | " self.input_size = input_size\n",
45 | " if torch.cuda.is_available():\n",
46 | " self.device = torch.device('cuda')\n",
47 | " else:\n",
48 | " self.device = torch.device('cpu')\n",
49 | "\n",
50 | " if model_size == 'tiny':\n",
51 | " self.encoder_name = 'convnext-t'\n",
52 | " else:\n",
53 | " self.encoder_name = 'convnext-l'\n",
54 | "\n",
55 | " self.decoder_type = \"MultiScaleColorDecoder\"\n",
56 | "\n",
57 | " if self.decoder_type == 'MultiScaleColorDecoder':\n",
58 | " self.model = DDColor(\n",
59 | " encoder_name=self.encoder_name,\n",
60 | " decoder_name='MultiScaleColorDecoder',\n",
61 | " input_size=[self.input_size, self.input_size],\n",
62 | " num_output_channels=2,\n",
63 | " last_norm='Spectral',\n",
64 | " do_normalize=False,\n",
65 | " num_queries=100,\n",
66 | " num_scales=3,\n",
67 | " dec_layers=9,\n",
68 | " ).to(self.device)\n",
69 | " else:\n",
70 | " self.model = DDColor(\n",
71 | " encoder_name=self.encoder_name,\n",
72 | " decoder_name='SingleColorDecoder',\n",
73 | " input_size=[self.input_size, self.input_size],\n",
74 | " num_output_channels=2,\n",
75 | " last_norm='Spectral',\n",
76 | " do_normalize=False,\n",
77 | " num_queries=256,\n",
78 | " ).to(self.device)\n",
79 | "\n",
80 | " self.model.load_state_dict(\n",
81 | " torch.load(model_path, map_location=torch.device('cpu'))['params'],\n",
82 | " strict=False)\n",
83 | " self.model.eval()\n",
84 | "\n",
85 | " @torch.no_grad()\n",
86 | " def process(self, img):\n",
87 | " self.height, self.width = img.shape[:2]\n",
88 | " # print(self.width, self.height)\n",
89 | " # if self.width * self.height < 100000:\n",
90 | " # self.input_size = 256\n",
91 | "\n",
92 | " img = (img / 255.0).astype(np.float32)\n",
93 | " orig_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1] # (h, w, 1)\n",
94 | "\n",
95 | " # resize rgb image -> lab -> get grey -> rgb\n",
96 | " img = cv2.resize(img, (self.input_size, self.input_size))\n",
97 | " img_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]\n",
98 | " img_gray_lab = np.concatenate((img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1)\n",
99 | " img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB)\n",
100 | "\n",
101 | " tensor_gray_rgb = torch.from_numpy(img_gray_rgb.transpose((2, 0, 1))).float().unsqueeze(0).to(self.device)\n",
102 | " output_ab = self.model(tensor_gray_rgb).cpu() # (1, 2, self.height, self.width)\n",
103 | "\n",
104 | " # resize ab -> concat original l -> rgb\n",
105 | " output_ab_resize = F.interpolate(output_ab, size=(self.height, self.width))[0].float().numpy().transpose(1, 2, 0)\n",
106 | " output_lab = np.concatenate((orig_l, output_ab_resize), axis=-1)\n",
107 | " output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR)\n",
108 | "\n",
109 | " output_img = (output_bgr * 255.0).round().astype(np.uint8) \n",
110 | "\n",
111 | " return output_img\n",
112 | "\n",
113 | "colorizer = ImageColorizationPipeline(model_path='/content/DDColor/models/pytorch_model.pt', input_size=512)\n",
114 | "\n",
115 | "from PIL import Image\n",
116 | "import gradio as gr\n",
117 | "import subprocess\n",
118 | "import shutil, os\n",
119 | "from gradio_imageslider import ImageSlider\n",
120 | "\n",
121 | "def generate(image):\n",
122 | " image_in = cv2.imread(image)\n",
123 | " image_out = colorizer.process(image_in)\n",
124 | " cv2.imwrite('/content/DDColor/out.jpg', image_out)\n",
125 | " image_in_pil = Image.fromarray(cv2.cvtColor(image_in, cv2.COLOR_BGR2RGB))\n",
126 | " image_out_pil = Image.fromarray(cv2.cvtColor(image_out, cv2.COLOR_BGR2RGB))\n",
127 | " return (image_in_pil, image_out_pil)\n",
128 | "\n",
129 | "with gr.Blocks() as demo:\n",
130 | " with gr.Row():\n",
131 | " with gr.Column():\n",
132 | " image = gr.Image(type='filepath')\n",
133 | " button = gr.Button()\n",
134 | " output_image = ImageSlider(show_label=False, type=\"filepath\", interactive=False)\n",
135 | " button.click(fn=generate, inputs=[image], outputs=[output_image])\n",
136 | "\n",
137 | "demo.queue().launch(inline=False, share=True, debug=True)"
138 | ]
139 | }
140 | ],
141 | "metadata": {
142 | "accelerator": "GPU",
143 | "colab": {
144 | "gpuType": "T4",
145 | "provenance": []
146 | },
147 | "kernelspec": {
148 | "display_name": "Python 3",
149 | "name": "python3"
150 | },
151 | "language_info": {
152 | "name": "python"
153 | }
154 | },
155 | "nbformat": 4,
156 | "nbformat_minor": 0
157 | }
158 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 🐣 Please follow me for new updates https://twitter.com/camenduru
2 | 🔥 Please join our discord server https://discord.gg/k5BwmmvJJU
3 | 🥳 Please join my patreon community https://patreon.com/camenduru
4 |
5 | ### 🦒 Colab
6 |
7 | | Colab | Info
8 | | --- | --- |
9 | [](https://colab.research.google.com/github/camenduru/DDColor-colab/blob/main/DDColor_colab.ipynb) | DDColor_colab
10 | [](https://colab.research.google.com/github/camenduru/DDColor-colab/blob/main/DDColor_gradio_colab.ipynb) | DDColor_gradio_colab
11 |
12 | ### 🧬 Code
13 | https://github.com/piddnad/DDColor
14 |
15 | ### 📄 Paper
16 | https://arxiv.org/abs/2212.11613
17 |
18 | ### 🌐 Page
19 | https://www.modelscope.cn/models/damo/cv_ddcolor_image-colorization/summary
20 |
21 | ### 🖼 Output
22 |
23 | https://github.com/camenduru/DDColor-colab/assets/54370274/6f75b3ae-a1ed-48a8-882b-0c00ae332c3c
24 |
25 |
26 | ### 🏢 Sponsor
27 | https://modelslab.com
28 |
--------------------------------------------------------------------------------