├── .github └── workflows │ └── test.yml ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── Demo.ipynb ├── LICENSE ├── README.md ├── colorization ├── __init__.py ├── baseline │ ├── __init__.py │ ├── colorizer.py │ └── utils.py ├── include │ └── colorization.h ├── iterative_colorizer │ ├── __init__.py │ ├── iterative_colorizer.py │ ├── utils.py │ └── window_neighbour.py └── main.cpp ├── colorize.py ├── data ├── original │ ├── example.png │ ├── example2.png │ ├── example3.png │ ├── example4.png │ ├── example5.png │ ├── example6.png │ ├── example7.png │ └── example8.png ├── results │ ├── result.png │ ├── result2.png │ ├── result3.png │ ├── result4.png │ ├── result5.png │ ├── result6.png │ ├── result7.png │ └── result8.png └── visual-clues │ ├── example.png │ ├── example2_marked.png │ ├── example3_marked.png │ ├── example4.png │ ├── example5.png │ ├── example6.png │ ├── example7.png │ └── example8.png ├── install.sh ├── requirements.txt └── tests.py /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - feature/colorizer 7 | - feature/cpp-implementation 8 | paths: 9 | - .github/** 10 | - colorization/** 11 | - tests.py 12 | - requirements.txt 13 | jobs: 14 | testing-on-linux: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v2 18 | - run: | 19 | pip install -U pip 20 | pip install -r requirements.txt 21 | pytest tests.py -s 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | .idea/ 3 | build/ 4 | .vscode/ 5 | **pycache** 6 | .ipynb_checkpoints/ 7 | cmake-build-debug/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "colorization/include/cpptqdm"] 2 | path = colorization/include/cpptqdm 3 | url = https://github.com/aminnj/cpptqdm 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.19) 2 | project(colorization) 3 | set(CMAKE_CXX_STANDARD 14) 4 | 5 | find_package(Eigen3 REQUIRED) 6 | include_directories(${EIGEN3_INCLUDE_DIR}) 7 | 8 | find_package(OpenCV REQUIRED) 9 | include_directories(${OpenCV_INCLUDE_DIRS}) 10 | 11 | include_directories(${PROJECT_NAME} ${CMAKE_SOURCE_DIR}/colorization/include) 12 | 13 | add_executable(${PROJECT_NAME} colorization/main.cpp colorization/include/colorization.h) 14 | target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS}) 15 | -------------------------------------------------------------------------------- /Demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e6deadda", 6 | "metadata": { 7 | "colab_type": "text", 8 | "id": "view-in-github" 9 | }, 10 | "source": [ 11 | "\"Open" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "Mrgh7NHkUNAo", 17 | "metadata": { 18 | "id": "Mrgh7NHkUNAo" 19 | }, 20 | "source": [ 21 | "## Clone the repository" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "gMeivfsWUMtQ", 28 | "metadata": { 29 | "colab": { 30 | "base_uri": "https://localhost:8080/" 31 | }, 32 | "id": "gMeivfsWUMtQ", 33 | "outputId": "09c4990d-67c7-4063-bfa0-c58445e69866" 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "!git clone https://github.com/soumik12345/colorization-using-optimization\n", 38 | "%cd colorization-using-optimization" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "id": "17f836a3", 44 | "metadata": { 45 | "id": "17f836a3" 46 | }, 47 | "source": [ 48 | "## Baseline Colorizer" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "47540fb4", 55 | "metadata": { 56 | "id": "47540fb4" 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "from colorization import Colorizer" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "fb01dd85", 67 | "metadata": { 68 | "id": "fb01dd85" 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "colorizer = Colorizer(\n", 73 | " gray_image_file='./data/original/example.png',\n", 74 | " visual_clues_file='./data/visual-clues/example.png'\n", 75 | ")" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "21fcde2e", 82 | "metadata": { 83 | "colab": { 84 | "base_uri": "https://localhost:8080/", 85 | "height": 298 86 | }, 87 | "id": "21fcde2e", 88 | "outputId": "892a6982-1ffe-4cbb-c687-d4cb38271681" 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "colorizer.plot_inputs()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "7bd9428a", 99 | "metadata": { 100 | "colab": { 101 | "base_uri": "https://localhost:8080/" 102 | }, 103 | "id": "7bd9428a", 104 | "outputId": "89963492-5ad3-46bf-d83e-a634eca1b665" 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "%%time\n", 109 | "result = colorizer.colorize()" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "c8648e42", 116 | "metadata": { 117 | "colab": { 118 | "base_uri": "https://localhost:8080/", 119 | "height": 386 120 | }, 121 | "id": "c8648e42", 122 | "outputId": "f6587c81-d824-4af8-b82a-57fde8c32e83" 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "colorizer.plot_results(result)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "17e517b1", 132 | "metadata": { 133 | "id": "17e517b1" 134 | }, 135 | "source": [ 136 | "## Iterative Colorizer" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "48e35f19", 143 | "metadata": { 144 | "id": "48e35f19" 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "from colorization import IterativeColorizer" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "67dac6b5", 155 | "metadata": { 156 | "id": "67dac6b5" 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "colorizer = IterativeColorizer(\n", 161 | " original_image='./data/original/example.png',\n", 162 | " visual_clues='./data/visual-clues/example.png'\n", 163 | ")" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "id": "48734ad9", 170 | "metadata": { 171 | "colab": { 172 | "base_uri": "https://localhost:8080/", 173 | "height": 298 174 | }, 175 | "id": "48734ad9", 176 | "outputId": "1d14f808-f2f2-4e9c-c87b-95f3172bfea6" 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "colorizer.plot_inputs()" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "id": "02056cde", 187 | "metadata": { 188 | "colab": { 189 | "base_uri": "https://localhost:8080/" 190 | }, 191 | "id": "02056cde", 192 | "outputId": "491a83dc-e87d-46fb-f4f3-648bb7756c9e" 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "%%time\n", 197 | "colorizer.colorize(epochs=600, log_interval=100)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "8a74de03", 204 | "metadata": { 205 | "colab": { 206 | "base_uri": "https://localhost:8080/", 207 | "height": 1000 208 | }, 209 | "id": "8a74de03", 210 | "outputId": "9104deb8-4089-4e0e-ad10-04bd7a636bd6" 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "colorizer.plot_results(log_interval=100)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "id": "jdB19ikpX1Or", 221 | "metadata": { 222 | "id": "jdB19ikpX1Or" 223 | }, 224 | "outputs": [], 225 | "source": [] 226 | } 227 | ], 228 | "metadata": { 229 | "colab": { 230 | "include_colab_link": true, 231 | "name": "Demo.ipynb", 232 | "provenance": [] 233 | }, 234 | "interpreter": { 235 | "hash": "556b822321b97cfa85b8ef545cd8388f5d22aa515c082e37848490687a34f054" 236 | }, 237 | "kernelspec": { 238 | "display_name": "Python 3 (ipykernel)", 239 | "language": "python", 240 | "name": "python3" 241 | }, 242 | "language_info": { 243 | "codemirror_mode": { 244 | "name": "ipython", 245 | "version": 3 246 | }, 247 | "file_extension": ".py", 248 | "mimetype": "text/x-python", 249 | "name": "python", 250 | "nbconvert_exporter": "python", 251 | "pygments_lexer": "ipython3", 252 | "version": "3.8.10" 253 | } 254 | }, 255 | "nbformat": 4, 256 | "nbformat_minor": 5 257 | } 258 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Soumik Rakshit 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Colorization using Optimization 2 | 3 | build-failing 4 | 5 | Python and C++ implementations of a user-guided image/video colorization technique as proposed by the paper 6 | [Colorization Using Optimization](https://dl.acm.org/doi/10.1145/1015706.1015780). The algorithm is based on a simple premise; neighboring pixels in space-time that have similar intensities should have similar colors. This premise is formalized using a quadratic cost function that obtains an optimization problem that can be solved efficiently using standard techniques. **While using this alogorithm, an artist only needs to annotate the image with a few color scribbles or visual clues, and the indicated colors are automatically propagated in both space and time to produce a fully colorized image or sequence.** The annotation can be done using any drawing tool such as [JSPaint](https://jspaint.app/) or [Gimp](https://www.gimp.org/). 7 | 8 | ## Instructions 9 | 10 | ### Instructions for running python version 11 | 12 | 1. Create a virtualenv using: 13 | - `virtualenv venv --python=python3` 14 | - `source venv/bin/activate` 15 | - `pip install -r requirements.txt` 16 | 17 | 2. Colorize images using the CLI: 18 | ``` 19 | python colorize.py 20 | 21 | Options: 22 | --original_image TEXT Original Image Path 23 | --visual_clue TEXT Visual Clue Image Path 24 | --result_path TEXT Colorized Image Path (without file extensions) 25 | -i, --use_itercative Use Iterative Mode 26 | --epochs INTEGER Number of epochs for Iterative Mode 27 | --log_intervals INTEGER Log Interval 28 | --help Show this message and exit. 29 | ``` 30 | 31 | 3. Alternatively, you can run on Google Colab using Open In Colab 32 | 33 | ### Instructions to build C++ version 34 | 35 | 1. Install dependencies using `sh install.sh` 36 | 37 | 2. Create a build directory `mkdir build && cd build` 38 | 39 | 3. Generate makefiles and compile using `cmake .. && make` 40 | 41 | 4. Run the executable using `./colorization [input-image] [visual-clues] [result] [gamma] [threshold]` 42 | 43 | 5. Alternatively, you can download the executable from [here](https://github.com/soumik12345/colorization-using-optimization/releases/download/0.1/colorization) and run it (installation of dependencies is still needed). 44 | 45 | ## Results 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 |
Original ImageVisual CluesColorized Image
94 | -------------------------------------------------------------------------------- /colorization/__init__.py: -------------------------------------------------------------------------------- 1 | from .baseline import Colorizer 2 | from .iterative_colorizer import IterativeColorizer 3 | -------------------------------------------------------------------------------- /colorization/baseline/__init__.py: -------------------------------------------------------------------------------- 1 | from .colorizer import Colorizer 2 | -------------------------------------------------------------------------------- /colorization/baseline/colorizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | from tqdm import tqdm 5 | from scipy import sparse 6 | from typing import Tuple 7 | from matplotlib import pyplot as plt 8 | from scipy.sparse.linalg import spsolve 9 | 10 | from .utils import position_to_id, find_neighbour 11 | 12 | 13 | class Colorizer: 14 | 15 | def __init__(self, gray_image_file: str, visual_clues_file: str) -> None: 16 | self.original_gray_image = self.gray_image = cv2.cvtColor(cv2.imread(gray_image_file), cv2.COLOR_BGR2RGB) 17 | self.original_visual_clues = self.visual_clues = cv2.cvtColor(cv2.imread(visual_clues_file), cv2.COLOR_BGR2RGB) 18 | 19 | def _preprocess(self): 20 | self.gray_image = cv2.cvtColor( 21 | self.gray_image, cv2.COLOR_RGB2YUV) / 255.0 22 | self.visual_clues = cv2.cvtColor( 23 | self.visual_clues, cv2.COLOR_RGB2YUV) / 255.0 24 | 25 | def plot_inputs(self, figure_size: Tuple[int, int] = (12, 12)) -> None: 26 | figure = plt.figure(figsize=figure_size) 27 | figure.add_subplot(1, 2, 1).set_title('Black & White') 28 | plt.imshow(self.original_gray_image) 29 | plt.axis('off') 30 | figure.add_subplot(1, 2, 2).set_title('Color Hints') 31 | plt.imshow(self.original_visual_clues) 32 | plt.axis('off') 33 | plt.show() 34 | 35 | def plot_results(self, result: np.ndarray) -> None: 36 | fig = plt.figure(figsize=(25, 17)) 37 | fig.add_subplot(1, 3, 1).set_title('Black & White') 38 | plt.imshow(self.original_gray_image) 39 | plt.axis('off') 40 | fig.add_subplot(1, 3, 2).set_title('Color Hints') 41 | plt.imshow(self.original_visual_clues) 42 | plt.axis('off') 43 | fig.add_subplot(1, 3, 3).set_title('Colorized') 44 | plt.imshow(result) 45 | plt.axis('off') 46 | plt.show() 47 | 48 | def colorize(self) -> np.ndarray: 49 | self._preprocess() 50 | n, m = self.visual_clues.shape[0], self.visual_clues.shape[1] 51 | size = n * m 52 | W = sparse.lil_matrix((size, size), dtype = float) 53 | b1 = np.zeros(shape = (size)) 54 | b2 = np.zeros(shape = (size)) 55 | for i in tqdm(range(n)): 56 | for j in range(m): 57 | if self.visual_clues[i, j, 0] > 1 - 1e-3: 58 | id = position_to_id(i, j, m) 59 | W[id, id] = 1 60 | b1[id] = self.gray_image[i, j, 1] 61 | b2[id] = self.gray_image[i, j, 2] 62 | continue 63 | if abs( 64 | self.gray_image[i, j, 0] - self.visual_clues[i, j, 0] 65 | ) > 1e-2 or abs( 66 | self.gray_image[i, j, 1] - self.gray_image[i, j, 1] 67 | ) > 1e-2 or abs( 68 | self.gray_image[i, j, 2] - self.visual_clues[i, j, 2] 69 | ) > 1e-2: 70 | id = position_to_id(i, j, m) 71 | W[id, id] = 1 72 | b1[id] = self.visual_clues[i, j, 1] 73 | b2[id] = self.visual_clues[i, j, 2] 74 | continue 75 | Y = self.gray_image[i, j, 0] 76 | id = position_to_id(i, j, m) 77 | neighbour = find_neighbour(i, j, n, m) 78 | Ys, ids, weights = [], [], [] 79 | for pos in neighbour: 80 | Ys.append(self.gray_image[pos[0], pos[1], 0]) 81 | ids.append(position_to_id(pos[0], pos[1], m)) 82 | sigma = np.std(Ys) 83 | sum = 0. 84 | for k in range(len(neighbour)): 85 | if sigma > 1e-3: 86 | w = np.exp(-1 * (Ys[k] - Y) * (Ys[k] - Y) / 2 / sigma / sigma) 87 | sum += w 88 | weights.append(w) 89 | else: 90 | sum += 1. 91 | weights.append(1.) 92 | for k in range(len(neighbour)): 93 | weights[k] /= sum 94 | W[id, ids[k]] += -1 * weights[k] 95 | W[id, id] += 1. 96 | result = np.zeros(shape = (n, m, 3)) 97 | result[:, :, 0] = self.gray_image[:, :, 0] 98 | W = W.tocsc() 99 | u = spsolve(W, b1) 100 | v = spsolve(W, b2) 101 | for i in range(n): 102 | for j in range(m): 103 | id = position_to_id(i, j, m) 104 | result[i, j, 1], result[i, j, 2] = u[id], v[id] 105 | result = (np.clip(result, 0., 1.) * 255).astype(np.uint8) 106 | result = cv2.cvtColor(result, cv2.COLOR_YUV2RGB) 107 | return result 108 | -------------------------------------------------------------------------------- /colorization/baseline/utils.py: -------------------------------------------------------------------------------- 1 | def position_to_id(x, y, m): 2 | return x * m + y 3 | 4 | 5 | def find_neighbour(x, y, n, m, d = 2): 6 | neighbour = [] 7 | for i in range(max(0, x - d), min(n, x + d + 1)): 8 | for j in range(max(0, y - d), min(m, y + d + 1)): 9 | if (i != x) or (j != y): 10 | neighbour.append([i, j]) 11 | return neighbour 12 | -------------------------------------------------------------------------------- /colorization/include/colorization.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by geekyrakshit on 8/16/21. 3 | // 4 | 5 | #ifndef COLORIZATION_COLORIZATION_H 6 | #define COLORIZATION_COLORIZATION_H 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | 20 | #include 21 | 22 | 23 | namespace colorization { 24 | 25 | 26 | cv::Mat eigen2opencv(Eigen::VectorXd& v, int nRows, int nCols) { 27 | cv::Mat X(nRows, nCols, CV_64FC1, v.data()); 28 | return X; 29 | } 30 | 31 | 32 | cv::Mat getVisualClueMask(const cv::Mat& image, const cv::Mat& visual_clues, double eps = 1, int nErosions=1) { 33 | cv::Mat diff; 34 | cv::absdiff(image, visual_clues, diff); 35 | std::vector channels; 36 | cv::split(diff, channels); 37 | cv::Mat mask = channels[0] + channels[1] + channels[2]; 38 | cv::threshold(mask, mask, eps, 255, cv::THRESH_BINARY); 39 | cv::erode(mask, mask, cv::Mat(), cv::Point(-1, -1), nErosions); 40 | return mask; 41 | } 42 | 43 | 44 | template inline T squaredDifference(const std::vector& X, int r, int s) { 45 | return (X[r] - X[s]) * (X[r] - X[s]); 46 | } 47 | 48 | 49 | template void to1D(const cv::Mat& m, std::vector& v) { 50 | v.clear(); 51 | auto nRows = m.rows; 52 | auto nCols = m.cols; 53 | v.reserve(nRows * nCols); 54 | int total = 0; 55 | for (auto i = 0; i < m.rows; ++i) { 56 | v.insert(v.end(), m.ptr(i), m.ptr(i) + nCols); 57 | total += nCols; 58 | } 59 | } 60 | 61 | 62 | template T variance(const std::vector& values, T eps=0.01) { 63 | T sum = 0; 64 | T squaredSum = 0; 65 | for (auto v : values) { 66 | sum += v; 67 | squaredSum += v * v; 68 | } 69 | assert (sum >= 0); 70 | assert (squaredSum >= 0); 71 | T n = values.size(); 72 | return squaredSum / n - (sum * sum) / (n * n) + eps; 73 | } 74 | 75 | 76 | template 77 | void getNeighbours(int i, int j, int nRows, int nCols, std::vector& neighbors) { 78 | neighbors.clear(); 79 | for (int dx = -1; dx < 2; dx += 1) { 80 | for (int dy = -1; dy < 2; dy += 1) { 81 | int m = i + dy; 82 | int n = j + dx; 83 | if ((dx == 0 && dy == 0) || m < 0 || n < 0 || m >= nRows || n >= nCols) 84 | continue; 85 | T s = m * nCols + n; 86 | neighbors.push_back(s); 87 | } 88 | } 89 | } 90 | 91 | 92 | template inline void getWeights( 93 | const std::vector& values, Ti r, 94 | const std::vector& neighbors, std::vector& neighborsWeights, Tw gamma) { 95 | neighborsWeights.clear(); 96 | std::vector neighborsValues; 97 | neighborsValues.reserve(neighbors.size() + 1); 98 | for (auto s : neighbors) { 99 | neighborsWeights.push_back(squaredDifference(values, r, s)); 100 | neighborsValues.push_back(values[s]); 101 | } 102 | neighborsValues.push_back(values[r]); 103 | Tw var = variance(neighborsValues); 104 | Tw normalizer = 0.0; 105 | for (auto& w : neighborsWeights) { 106 | w = std::exp(- gamma * w / (2 * var)); 107 | normalizer += w; 108 | } 109 | for (auto& w : neighborsWeights) { 110 | w /= normalizer; 111 | assert(w >= 0); 112 | } 113 | } 114 | 115 | 116 | void setupProblem( 117 | const cv::Mat& Y, const cv::Mat& visualClues, const cv::Mat& mask, 118 | Eigen::SparseMatrix& A, Eigen::VectorXd& bu, 119 | Eigen::VectorXd& bv, double gamma) { 120 | 121 | typedef Eigen::Triplet TD; 122 | auto nRows = Y.rows; 123 | auto nCols = Y.cols; 124 | auto nPixels = nRows * nCols; 125 | A.resize(nPixels, nPixels); 126 | std::vector coefficients; 127 | coefficients.reserve(nPixels * 3); 128 | bu.resize(nPixels); 129 | bv.resize(nPixels); 130 | bu.setZero(); 131 | bv.setZero(); 132 | cv::Mat yuvVisualClues; 133 | cv::cvtColor(visualClues, yuvVisualClues, cv::COLOR_BGR2YUV); 134 | yuvVisualClues.convertTo(yuvVisualClues, CV_64FC3); 135 | std::vector channels; 136 | cv::split(yuvVisualClues, channels); 137 | cv::Mat& U = channels[1]; 138 | cv::Mat& V = channels[2]; 139 | std::vector y, u, v; 140 | std::vector hasColor; 141 | to1D(Y, y); 142 | to1D(U, u); 143 | to1D(V, v); 144 | to1D(mask, hasColor); 145 | const int numNeighbors = 8; 146 | std::vector weights; 147 | weights.reserve(numNeighbors); 148 | std::vector neighbors; 149 | neighbors.reserve(numNeighbors); 150 | std::cout << "Finding Neighbours..." << std::endl; 151 | tqdm progressBar; 152 | for (auto i = 0; i < nRows; ++i) { 153 | progressBar.progress(i, nRows); 154 | for (auto j = 0; j < nCols; ++j) { 155 | unsigned long r = i * nCols + j; 156 | getNeighbours(i, j, nRows, nCols, neighbors); 157 | getWeights(y, r, neighbors, weights, gamma); 158 | coefficients.emplace_back(r, r, 1); 159 | for (auto k = 0u; k < neighbors.size(); ++k) { 160 | auto s = neighbors[k]; 161 | auto w = weights[k]; 162 | if (hasColor[s]) { 163 | // Move value to RHS of Ax = b 164 | bu(r) += w * u[s]; 165 | bv(r) += w * v[s]; 166 | } else { 167 | coefficients.emplace_back(r, s, -w); 168 | } 169 | } 170 | } 171 | } 172 | progressBar.finish(); 173 | A.setFromTriplets(coefficients.begin(), coefficients.end()); 174 | } 175 | 176 | 177 | cv::Mat colorize( 178 | const cv::Mat& image, const cv::Mat& visualClues, 179 | const cv::Mat& mask, double gamma=2.0) { 180 | cv::Mat yuvImage; 181 | cv::cvtColor(image, yuvImage, cv::COLOR_BGR2YUV); 182 | yuvImage.convertTo(yuvImage, CV_64FC3); 183 | std::vector channels; 184 | cv::split(yuvImage, channels); 185 | cv::Mat Y = channels[0]; 186 | // Set up matrices for U and V channels 187 | Eigen::SparseMatrix A; 188 | Eigen::VectorXd bu; 189 | Eigen::VectorXd bv; 190 | setupProblem(Y, visualClues, mask, A, bu, bv, gamma); 191 | // Solve for U, V channels 192 | std::cout << "Solving for U channel..." << std::endl; 193 | Eigen::BiCGSTAB, 194 | Eigen::DiagonalPreconditioner > solver; 195 | solver.compute(A); 196 | Eigen::VectorXd U = solver.solve(bu); 197 | if (solver.info() != Eigen::Success) 198 | throw std::runtime_error("Failed to solve for U channel."); 199 | std::cout << "Solving for V channel..." << std::endl; 200 | Eigen::VectorXd V = solver.solve(bv); 201 | if (solver.info() != Eigen::Success) 202 | throw std::runtime_error("Failed to solve for V channel."); 203 | std::cout << "Finished coloring" << std::endl; 204 | const int nRows = Y.rows; 205 | const int nCols = Y.cols; 206 | cv::Mat mU = eigen2opencv(U, nRows, nCols); 207 | cv::Mat mV = eigen2opencv(V, nRows, nCols); 208 | channels[1] = mU; channels[2] = mV; 209 | cv::Mat colorImage; 210 | cv::merge(channels, colorImage); 211 | colorImage.convertTo(colorImage, CV_8UC3); 212 | cv::cvtColor(colorImage, colorImage, cv::COLOR_YUV2BGR); 213 | return colorImage; 214 | } 215 | } 216 | 217 | 218 | #endif //COLORIZATION_COLORIZATION_H 219 | -------------------------------------------------------------------------------- /colorization/iterative_colorizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .iterative_colorizer import IterativeColorizer 2 | -------------------------------------------------------------------------------- /colorization/iterative_colorizer/iterative_colorizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import scipy 3 | import colorsys 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from typing import Tuple 8 | from skimage.io import imread 9 | from matplotlib import pyplot as plt 10 | from scipy.sparse.linalg import spsolve 11 | 12 | from .window_neighbour import WindowNeighbor 13 | from .utils import affinity_a, to_seq 14 | 15 | 16 | class IterativeColorizer: 17 | 18 | def __init__(self, original_image: str, visual_clues: str) -> None: 19 | self.image_oiginal_rgb = cv2.cvtColor(cv2.imread(original_image), cv2.COLOR_BGR2RGB) 20 | self.image_original = self.image_oiginal_rgb.astype(float) / 255 21 | self.image_clues_rgb = cv2.cvtColor(cv2.imread(visual_clues), cv2.COLOR_BGR2RGB) 22 | self.image_clues = self.image_clues_rgb.astype(float) / 255 23 | self.result_history = [] 24 | 25 | def plot_inputs(self, figure_size: Tuple[int, int] = (12, 12)) -> None: 26 | figure = plt.figure(figsize=figure_size) 27 | figure.add_subplot(1, 2, 1).set_title('Black & White') 28 | plt.imshow(self.image_original) 29 | plt.axis('off') 30 | figure.add_subplot(1, 2, 2).set_title('Color Hints') 31 | plt.imshow(self.image_clues) 32 | plt.axis('off') 33 | plt.show() 34 | 35 | def plot_results(self, log_interval: int = 100) -> None: 36 | index = log_interval 37 | for result in self.result_history[:-1]: 38 | plt.imshow(result) 39 | plt.title('Result of Iteration: {}'.format(index)) 40 | plt.axis('off') 41 | plt.show() 42 | index += log_interval 43 | 44 | def yuv_channels_to_rgb(self, channel_y, channel_u, channel_v) -> np.ndarray: 45 | """Combine 3 channels of YUV to a RGB photo: n x n x 3 array""" 46 | result_rgb = [colorsys.yiq_to_rgb( 47 | channel_y[i], channel_u[i], channel_v[i] 48 | ) for i in range(len(self.result_y))] 49 | result_rgb = np.array(result_rgb) 50 | image_rgb = np.zeros(self.image_yuv.shape) 51 | image_rgb[:, :, 0] = result_rgb[:, 0].reshape(self.image_rows, self.image_cols, order='F') 52 | image_rgb[:, :, 1] = result_rgb[:, 1].reshape(self.image_rows, self.image_cols, order='F') 53 | image_rgb[:, :, 2] = result_rgb[:, 2].reshape(self.image_rows, self.image_cols, order='F') 54 | return image_rgb 55 | 56 | def jacobi(self, weight_matrix, b_u, b_v, epoch: int, interval: int) -> None: 57 | D_u = weight_matrix.diagonal() 58 | D_v = weight_matrix.diagonal() 59 | R_u = weight_matrix - scipy.sparse.diags(D_u) 60 | R_v = weight_matrix - scipy.sparse.diags(D_v) 61 | x_u = np.zeros(weight_matrix.shape[0]) 62 | x_v = np.zeros(weight_matrix.shape[0]) 63 | print('Optimizing iteratively...') 64 | for epoch in tqdm(range(1, epoch + 1)): 65 | x_u = (b_u - R_u.dot(x_u)) / D_u 66 | x_v = (b_v - R_v.dot(x_v)) / D_v 67 | if epoch % interval == 0: 68 | self.result_history.append( 69 | self.yuv_channels_to_rgb(self.result_y, x_u, x_v)) 70 | 71 | 72 | def colorize(self, epochs: int = 500, log_interval: int = 100) -> None: 73 | (self.image_rows, self.image_cols, _) = self.image_original.shape 74 | image_size = self.image_rows * self.image_cols 75 | channel_Y, _, _ = colorsys.rgb_to_yiq( 76 | self.image_original[:, :, 0], 77 | self.image_original[:, :, 1], 78 | self.image_original[:, :, 2] 79 | ) 80 | _, channel_U, channel_V = colorsys.rgb_to_yiq( 81 | self.image_clues[:, :, 0], 82 | self.image_clues[:, :, 1], 83 | self.image_clues[:, :, 2] 84 | ) 85 | map_colored = (abs(channel_U) + abs(channel_V)) > 0.0001 86 | self.image_yuv = np.dstack((channel_Y, channel_U, channel_V)) 87 | weight_data = [] 88 | wd_width = 1 89 | print('Finding neighbouring pixels...') 90 | for c in tqdm(range(self.image_cols)): 91 | for r in range(self.image_rows): 92 | window_neighbour = WindowNeighbor(wd_width, (r, c), self.image_yuv) 93 | if not map_colored[r,c]: 94 | weights = affinity_a(window_neighbour) 95 | for e in weights: 96 | weight_data.append([window_neighbour.center, (e[0], e[1]), e[2]]) 97 | weight_data.append([ 98 | window_neighbour.center, (window_neighbour.center[0], window_neighbour.center[1]), 1.]) 99 | sparse_index_data = [ 100 | [ 101 | to_seq(e[0][0], e[0][1], self.image_rows), 102 | to_seq(e[1][0], e[1][1], self.image_rows), e[2] 103 | ] for e in weight_data 104 | ] 105 | sparse_index_row_col = np.array(sparse_index_data, dtype=np.integer)[:, 0:2] 106 | sparse_data = np.array(sparse_index_data, dtype=np.float64)[:, 2] 107 | weight_matrix = scipy.sparse.csr_matrix( 108 | (sparse_data, (sparse_index_row_col[:,0], sparse_index_row_col[:,1])), 109 | shape=(image_size, image_size) 110 | ) 111 | b_u = np.zeros(image_size) 112 | b_v = np.zeros(image_size) 113 | idx_colored = np.nonzero(map_colored.reshape(image_size, order='F')) 114 | pic_u_flat = self.image_yuv[:,:,1].reshape(image_size, order='F') 115 | b_u[idx_colored] = pic_u_flat[idx_colored] 116 | pic_v_flat = self.image_yuv[:,:,2].reshape(image_size, order='F') 117 | b_v[idx_colored] = pic_v_flat[idx_colored] 118 | self.result_y = self.image_yuv[:, :, 0].reshape(image_size, order='F') 119 | self.jacobi( 120 | weight_matrix, b_u, b_v, 121 | epoch=epochs, interval=log_interval 122 | ) 123 | -------------------------------------------------------------------------------- /colorization/iterative_colorizer/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def to_seq(r, c, rows): 5 | return c * rows + r 6 | 7 | 8 | def fr_seq(seq, rows): 9 | r = seq % rows 10 | c = int((seq - r) / rows) 11 | return (r, c) 12 | 13 | 14 | def affinity_a(w): 15 | """affinity functions, calculate weights of 16 | pixels in a window by their intensity""" 17 | nbs = np.array(w.neighbors) 18 | sY = nbs[:,2] 19 | cY = w.center[2] 20 | diff = sY - cY 21 | sig = np.var(np.append(sY, cY)) 22 | if sig < 1e-6: 23 | sig = 1e-6 24 | wrs = np.exp(- np.power(diff,2) / (sig * 2.0)) 25 | wrs = - wrs / np.sum(wrs) 26 | nbs[:,2] = wrs 27 | return nbs 28 | -------------------------------------------------------------------------------- /colorization/iterative_colorizer/window_neighbour.py: -------------------------------------------------------------------------------- 1 | class WindowNeighbor: 2 | """The window class for finding the 3 | neighbor pixels around the center""" 4 | 5 | def __init__(self, width, center, image): 6 | # center is a list of [row, col, Y_intensity] 7 | self.center = [center[0], center[1], image[center][0]] 8 | self.width = width 9 | self.neighbors = None 10 | self.find_neighbors(image) 11 | self.mean = None 12 | self.var = None 13 | 14 | def find_neighbors(self, image): 15 | self.neighbors = [] 16 | ix_r_min = max(0, self.center[0] - self.width) 17 | ix_r_max = min(image.shape[0], self.center[0] + self.width + 1) 18 | ix_c_min = max(0, self.center[1] - self.width) 19 | ix_c_max = min(image.shape[1], self.center[1] + self.width + 1) 20 | for r in range(ix_r_min, ix_r_max): 21 | for c in range(ix_c_min, ix_c_max): 22 | if r == self.center[0] and c == self.center[1]: 23 | continue 24 | self.neighbors.append([r, c, image[r, c, 0]]) 25 | 26 | def __str__(self): 27 | return 'windows c=(%d, %d, %f) size: %d' % ( 28 | self.center[0], self.center[1], 29 | self.center[2], len(self.neighbors) 30 | ) 31 | -------------------------------------------------------------------------------- /colorization/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "colorization.h" 10 | 11 | 12 | int main(int argc, char* argv[]) { 13 | 14 | if (argc < 4) { 15 | std::cerr << argv[0] << " [input-image] [visual-clues] [result] [gamma] [threshold]" << std::endl; 16 | return 0; 17 | } 18 | 19 | double gamma = 2.0; 20 | if (argc >= 5) 21 | gamma = std::stod(argv[4]); 22 | 23 | int threshold = 10; 24 | if (argc >= 6) 25 | threshold = std::stoi(argv[5]); 26 | 27 | std::string inputImagePath{argv[1]}; 28 | std::string visualCluesPath{argv[2]}; 29 | std::string resultPath{argv[3]}; 30 | 31 | cv::Mat image = cv::imread(inputImagePath); 32 | cv::Mat visual_clues = cv::imread(visualCluesPath); 33 | 34 | if (image.empty()) { 35 | std::cerr << "Failed to read file from " << inputImagePath << std::endl; 36 | return 0; 37 | } 38 | 39 | if (visual_clues.empty()) { 40 | std::cerr << "Failed to read file from " << visualCluesPath << std::endl; 41 | return 0; 42 | } 43 | 44 | assert(image.size() == visual_clues.size()); 45 | cv::Mat mask = colorization::getVisualClueMask(image, visual_clues, threshold); 46 | cv::Mat colorImage = colorization::colorize(image, visual_clues, mask, gamma); 47 | cv::imwrite(resultPath, colorImage); 48 | return 0; 49 | } 50 | -------------------------------------------------------------------------------- /colorize.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import click 3 | 4 | from colorization import Colorizer, IterativeColorizer 5 | 6 | 7 | @click.command() 8 | @click.option('--original_image', help='Original Image Path') 9 | @click.option('--visual_clue', help='Visual Clue Image Path') 10 | @click.option('--result_path', default='./result', help='Colorized Image Path (without file extensions)') 11 | @click.option('--use_itercative', '-i', is_flag=True, help='Use Iterative Mode') 12 | @click.option('--epochs', default=500, help='Number of epochs for Iterative Mode') 13 | @click.option('--log_intervals', default=100, help='Log Interval') 14 | def colorize(original_image, visual_clue, result_path, use_itercative, epochs, log_intervals): 15 | if use_itercative: 16 | colorizer = Colorizer( 17 | gray_image_file=original_image, 18 | visual_clues_file=visual_clue 19 | ) 20 | colorizer.plot_inputs() 21 | result = colorizer.colorize() 22 | colorizer.plot_results(result) 23 | cv2.imwrite( 24 | result_path + '.png', 25 | cv2.cvtColor(result, cv2.COLOR_RGB2BGR) 26 | ) 27 | else: 28 | colorizer = IterativeColorizer( 29 | original_image=original_image, 30 | visual_clues=visual_clue 31 | ) 32 | colorizer.plot_inputs() 33 | colorizer.colorize( 34 | epochs=epochs, log_interval=log_intervals 35 | ) 36 | colorizer.plot_results(log_intervals=log_intervals) 37 | for i, result in enumerate(colorizer.result_history): 38 | cv2.imwrite( 39 | result_path + '{}.png'.format(i + 1), 40 | cv2.cvtColor(result, cv2.COLOR_RGB2BGR) 41 | ) 42 | 43 | 44 | if __name__ == '__main__': 45 | colorize() 46 | -------------------------------------------------------------------------------- /data/original/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/original/example.png -------------------------------------------------------------------------------- /data/original/example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/original/example2.png -------------------------------------------------------------------------------- /data/original/example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/original/example3.png -------------------------------------------------------------------------------- /data/original/example4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/original/example4.png -------------------------------------------------------------------------------- /data/original/example5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/original/example5.png -------------------------------------------------------------------------------- /data/original/example6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/original/example6.png -------------------------------------------------------------------------------- /data/original/example7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/original/example7.png -------------------------------------------------------------------------------- /data/original/example8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/original/example8.png -------------------------------------------------------------------------------- /data/results/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/results/result.png -------------------------------------------------------------------------------- /data/results/result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/results/result2.png -------------------------------------------------------------------------------- /data/results/result3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/results/result3.png -------------------------------------------------------------------------------- /data/results/result4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/results/result4.png -------------------------------------------------------------------------------- /data/results/result5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/results/result5.png -------------------------------------------------------------------------------- /data/results/result6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/results/result6.png -------------------------------------------------------------------------------- /data/results/result7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/results/result7.png -------------------------------------------------------------------------------- /data/results/result8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/results/result8.png -------------------------------------------------------------------------------- /data/visual-clues/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/visual-clues/example.png -------------------------------------------------------------------------------- /data/visual-clues/example2_marked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/visual-clues/example2_marked.png -------------------------------------------------------------------------------- /data/visual-clues/example3_marked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/visual-clues/example3_marked.png -------------------------------------------------------------------------------- /data/visual-clues/example4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/visual-clues/example4.png -------------------------------------------------------------------------------- /data/visual-clues/example5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/visual-clues/example5.png -------------------------------------------------------------------------------- /data/visual-clues/example6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/visual-clues/example6.png -------------------------------------------------------------------------------- /data/visual-clues/example7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/visual-clues/example7.png -------------------------------------------------------------------------------- /data/visual-clues/example8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/colorization-using-optimization/85a38e19810092b3bb630c3485f040a1a39a647d/data/visual-clues/example8.png -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | wget https://gitlab.com/libeigen/eigen/-/archive/3.4-rc1/eigen-3.4-rc1.zip 2 | unzip eigen-3.4-rc1.zip && rm eigen-3.4-rc1.zip 3 | # shellcheck disable=SC2164 4 | mkdir eigen-3.4-rc1/build && cd eigen-3.4-rc1/build 5 | cmake .. 6 | make install && cd ../../ 7 | 8 | wget -O opencv.zip https://github.com/opencv/opencv/archive/master.zip 9 | unzip opencv.zip && rm opencv.zip 10 | # shellcheck disable=SC2164 11 | mkdir opencv-master/build && cd opencv-master/build 12 | cmake .. 13 | make install && cd ../../ 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.5.3.56 2 | Pillow==8.3.1 3 | scikit-image==0.18.2 4 | tqdm==4.62.1 5 | scipy==1.7.1 6 | matplotlib==3.4.3 7 | pytest==6.2.4 8 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from colorization import Colorizer, IterativeColorizer 4 | 5 | 6 | class TestColorizer(TestCase): 7 | 8 | def test_output_shape(self): 9 | colorizer = Colorizer( 10 | gray_image_file='./data/original/example.png', 11 | visual_clues_file='./data/visual-clues/example.png' 12 | ) 13 | result = colorizer.colorize() 14 | assert colorizer.original_gray_image.shape[0] == result.shape[0] 15 | assert colorizer.original_gray_image.shape[1] == result.shape[1] 16 | 17 | 18 | class TestIterativeColorizer(TestCase): 19 | 20 | def test_output_shape(self): 21 | colorizer = IterativeColorizer( 22 | original_image='./data/original/example.png', 23 | visual_clues='./data/visual-clues/example.png' 24 | ) 25 | colorizer.colorize(epochs=600, log_interval=100) 26 | for result in colorizer.result_history: 27 | assert colorizer.image_original.shape[0] == result.shape[0] 28 | assert colorizer.image_original.shape[1] == result.shape[1] 29 | --------------------------------------------------------------------------------