├── LICENSE ├── README.md ├── colorizers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── base_color.cpython-37.pyc │ ├── eccv16.cpython-37.pyc │ ├── siggraph17.cpython-37.pyc │ └── util.cpython-37.pyc ├── base_color.py ├── eccv16.py ├── siggraph17.py └── util.py ├── demo_release.py ├── imgs ├── .DS_Store ├── ILSVRC2012_val_00041580.JPEG ├── ILSVRC2012_val_00046524.JPEG ├── ILSVRC2012_val_00046834.JPEG ├── ansel_adams.jpg ├── ansel_adams2.jpg └── ansel_adams3.jpg ├── imgs_out ├── saved_eccv16.png └── saved_siggraph17.png └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Richard Zhang, Phillip Isola, Alexei A. Efros 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Colorful Image Colorization [[Project Page]](http://richzhang.github.io/colorization/)
3 | [Richard Zhang](https://richzhang.github.io/), [Phillip Isola](http://web.mit.edu/phillipi/), [Alexei A. Efros](http://www.eecs.berkeley.edu/~efros/). In [ECCV, 2016](http://arxiv.org/pdf/1603.08511.pdf). 4 | 5 | **+ automatic colorization functionality for Real-Time User-Guided Image Colorization with Learned Deep Priors, SIGGRAPH 2017!** 6 | 7 | **[Sept20 Update]** Since it has been 3-4 years, I converted this repo to support minimal test-time usage in PyTorch. I also added our SIGGRAPH 2017 (it's an interactive method but can also do automatic). See the [Caffe branch](https://github.com/richzhang/colorization/tree/caffe) for the original release. 8 | 9 | ![Teaser Image](http://richzhang.github.io/colorization/resources/images/teaser4.jpg) 10 | 11 | **Clone the repository; install dependencies** 12 | 13 | ``` 14 | git clone https://github.com/richzhang/colorization.git 15 | pip install requirements.txt 16 | ``` 17 | 18 | **Colorize!** This script will colorize an image. The results should match the images in the `imgs_out` folder. 19 | 20 | ``` 21 | python demo_release.py -i imgs/ansel_adams3.jpg 22 | ``` 23 | 24 | **Model loading in Python** The following loads pretrained colorizers. See [demo_release.py](demo_release.py) for some details on how to run the model. There are some pre and post-processing steps: convert to Lab space, resize to 256x256, colorize, and concatenate to the original full resolution, and convert to RGB. 25 | 26 | ```python 27 | import colorizers 28 | colorizer_eccv16 = colorizers.eccv16().eval() 29 | colorizer_siggraph17 = colorizers.siggraph17().eval() 30 | ``` 31 | 32 | ### Original implementation (Caffe branch) 33 | 34 | The original implementation contained train and testing, our network and AlexNet (for representation learning tests), as well as representation learning tests. It is in Caffe and is no longer supported. Please see the [caffe](https://github.com/richzhang/colorization/tree/caffe) branch for it. 35 | 36 | ### Citation ### 37 | 38 | If you find these models useful for your resesarch, please cite with these bibtexs. 39 | 40 | ``` 41 | @inproceedings{zhang2016colorful, 42 | title={Colorful Image Colorization}, 43 | author={Zhang, Richard and Isola, Phillip and Efros, Alexei A}, 44 | booktitle={ECCV}, 45 | year={2016} 46 | } 47 | 48 | @article{zhang2017real, 49 | title={Real-Time User-Guided Image Colorization with Learned Deep Priors}, 50 | author={Zhang, Richard and Zhu, Jun-Yan and Isola, Phillip and Geng, Xinyang and Lin, Angela S and Yu, Tianhe and Efros, Alexei A}, 51 | journal={ACM Transactions on Graphics (TOG)}, 52 | volume={9}, 53 | number={4}, 54 | year={2017}, 55 | publisher={ACM} 56 | } 57 | ``` 58 | 59 | ### Misc ### 60 | Contact Richard Zhang at rich.zhang at eecs.berkeley.edu for any questions or comments. 61 | -------------------------------------------------------------------------------- /colorizers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .base_color import * 3 | from .eccv16 import * 4 | from .siggraph17 import * 5 | from .util import * 6 | 7 | -------------------------------------------------------------------------------- /colorizers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/colorizers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /colorizers/__pycache__/base_color.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/colorizers/__pycache__/base_color.cpython-37.pyc -------------------------------------------------------------------------------- /colorizers/__pycache__/eccv16.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/colorizers/__pycache__/eccv16.cpython-37.pyc -------------------------------------------------------------------------------- /colorizers/__pycache__/siggraph17.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/colorizers/__pycache__/siggraph17.cpython-37.pyc -------------------------------------------------------------------------------- /colorizers/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/colorizers/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /colorizers/base_color.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | 5 | class BaseColor(nn.Module): 6 | def __init__(self): 7 | super(BaseColor, self).__init__() 8 | 9 | self.l_cent = 50. 10 | self.l_norm = 100. 11 | self.ab_norm = 110. 12 | 13 | def normalize_l(self, in_l): 14 | return (in_l-self.l_cent)/self.l_norm 15 | 16 | def unnormalize_l(self, in_l): 17 | return in_l*self.l_norm + self.l_cent 18 | 19 | def normalize_ab(self, in_ab): 20 | return in_ab/self.ab_norm 21 | 22 | def unnormalize_ab(self, in_ab): 23 | return in_ab*self.ab_norm 24 | 25 | -------------------------------------------------------------------------------- /colorizers/eccv16.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from IPython import embed 6 | 7 | from .base_color import * 8 | 9 | class ECCVGenerator(BaseColor): 10 | def __init__(self, norm_layer=nn.BatchNorm2d): 11 | super(ECCVGenerator, self).__init__() 12 | 13 | model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),] 14 | model1+=[nn.ReLU(True),] 15 | model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),] 16 | model1+=[nn.ReLU(True),] 17 | model1+=[norm_layer(64),] 18 | 19 | model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),] 20 | model2+=[nn.ReLU(True),] 21 | model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),] 22 | model2+=[nn.ReLU(True),] 23 | model2+=[norm_layer(128),] 24 | 25 | model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),] 26 | model3+=[nn.ReLU(True),] 27 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 28 | model3+=[nn.ReLU(True),] 29 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),] 30 | model3+=[nn.ReLU(True),] 31 | model3+=[norm_layer(256),] 32 | 33 | model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),] 34 | model4+=[nn.ReLU(True),] 35 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 36 | model4+=[nn.ReLU(True),] 37 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 38 | model4+=[nn.ReLU(True),] 39 | model4+=[norm_layer(512),] 40 | 41 | model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 42 | model5+=[nn.ReLU(True),] 43 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 44 | model5+=[nn.ReLU(True),] 45 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 46 | model5+=[nn.ReLU(True),] 47 | model5+=[norm_layer(512),] 48 | 49 | model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 50 | model6+=[nn.ReLU(True),] 51 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 52 | model6+=[nn.ReLU(True),] 53 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 54 | model6+=[nn.ReLU(True),] 55 | model6+=[norm_layer(512),] 56 | 57 | model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 58 | model7+=[nn.ReLU(True),] 59 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 60 | model7+=[nn.ReLU(True),] 61 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 62 | model7+=[nn.ReLU(True),] 63 | model7+=[norm_layer(512),] 64 | 65 | model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),] 66 | model8+=[nn.ReLU(True),] 67 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 68 | model8+=[nn.ReLU(True),] 69 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 70 | model8+=[nn.ReLU(True),] 71 | 72 | model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),] 73 | 74 | self.model1 = nn.Sequential(*model1) 75 | self.model2 = nn.Sequential(*model2) 76 | self.model3 = nn.Sequential(*model3) 77 | self.model4 = nn.Sequential(*model4) 78 | self.model5 = nn.Sequential(*model5) 79 | self.model6 = nn.Sequential(*model6) 80 | self.model7 = nn.Sequential(*model7) 81 | self.model8 = nn.Sequential(*model8) 82 | 83 | self.softmax = nn.Softmax(dim=1) 84 | self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False) 85 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear') 86 | 87 | def forward(self, input_l): 88 | conv1_2 = self.model1(self.normalize_l(input_l)) 89 | conv2_2 = self.model2(conv1_2) 90 | conv3_3 = self.model3(conv2_2) 91 | conv4_3 = self.model4(conv3_3) 92 | conv5_3 = self.model5(conv4_3) 93 | conv6_3 = self.model6(conv5_3) 94 | conv7_3 = self.model7(conv6_3) 95 | conv8_3 = self.model8(conv7_3) 96 | out_reg = self.model_out(self.softmax(conv8_3)) 97 | 98 | return self.unnormalize_ab(self.upsample4(out_reg)) 99 | 100 | def eccv16(pretrained=True): 101 | model = ECCVGenerator() 102 | if(pretrained): 103 | import torch.utils.model_zoo as model_zoo 104 | model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True)) 105 | return model 106 | -------------------------------------------------------------------------------- /colorizers/siggraph17.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base_color import * 5 | 6 | class SIGGRAPHGenerator(BaseColor): 7 | def __init__(self, norm_layer=nn.BatchNorm2d, classes=529): 8 | super(SIGGRAPHGenerator, self).__init__() 9 | 10 | # Conv1 11 | model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),] 12 | model1+=[nn.ReLU(True),] 13 | model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),] 14 | model1+=[nn.ReLU(True),] 15 | model1+=[norm_layer(64),] 16 | # add a subsampling operation 17 | 18 | # Conv2 19 | model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),] 20 | model2+=[nn.ReLU(True),] 21 | model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),] 22 | model2+=[nn.ReLU(True),] 23 | model2+=[norm_layer(128),] 24 | # add a subsampling layer operation 25 | 26 | # Conv3 27 | model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),] 28 | model3+=[nn.ReLU(True),] 29 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 30 | model3+=[nn.ReLU(True),] 31 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 32 | model3+=[nn.ReLU(True),] 33 | model3+=[norm_layer(256),] 34 | # add a subsampling layer operation 35 | 36 | # Conv4 37 | model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),] 38 | model4+=[nn.ReLU(True),] 39 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 40 | model4+=[nn.ReLU(True),] 41 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 42 | model4+=[nn.ReLU(True),] 43 | model4+=[norm_layer(512),] 44 | 45 | # Conv5 46 | model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 47 | model5+=[nn.ReLU(True),] 48 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 49 | model5+=[nn.ReLU(True),] 50 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 51 | model5+=[nn.ReLU(True),] 52 | model5+=[norm_layer(512),] 53 | 54 | # Conv6 55 | model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 56 | model6+=[nn.ReLU(True),] 57 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 58 | model6+=[nn.ReLU(True),] 59 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 60 | model6+=[nn.ReLU(True),] 61 | model6+=[norm_layer(512),] 62 | 63 | # Conv7 64 | model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 65 | model7+=[nn.ReLU(True),] 66 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 67 | model7+=[nn.ReLU(True),] 68 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 69 | model7+=[nn.ReLU(True),] 70 | model7+=[norm_layer(512),] 71 | 72 | # Conv7 73 | model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)] 74 | model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 75 | 76 | model8=[nn.ReLU(True),] 77 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 78 | model8+=[nn.ReLU(True),] 79 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 80 | model8+=[nn.ReLU(True),] 81 | model8+=[norm_layer(256),] 82 | 83 | # Conv9 84 | model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),] 85 | model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),] 86 | # add the two feature maps above 87 | 88 | model9=[nn.ReLU(True),] 89 | model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),] 90 | model9+=[nn.ReLU(True),] 91 | model9+=[norm_layer(128),] 92 | 93 | # Conv10 94 | model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),] 95 | model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),] 96 | # add the two feature maps above 97 | 98 | model10=[nn.ReLU(True),] 99 | model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),] 100 | model10+=[nn.LeakyReLU(negative_slope=.2),] 101 | 102 | # classification output 103 | model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),] 104 | 105 | # regression output 106 | model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),] 107 | model_out+=[nn.Tanh()] 108 | 109 | self.model1 = nn.Sequential(*model1) 110 | self.model2 = nn.Sequential(*model2) 111 | self.model3 = nn.Sequential(*model3) 112 | self.model4 = nn.Sequential(*model4) 113 | self.model5 = nn.Sequential(*model5) 114 | self.model6 = nn.Sequential(*model6) 115 | self.model7 = nn.Sequential(*model7) 116 | self.model8up = nn.Sequential(*model8up) 117 | self.model8 = nn.Sequential(*model8) 118 | self.model9up = nn.Sequential(*model9up) 119 | self.model9 = nn.Sequential(*model9) 120 | self.model10up = nn.Sequential(*model10up) 121 | self.model10 = nn.Sequential(*model10) 122 | self.model3short8 = nn.Sequential(*model3short8) 123 | self.model2short9 = nn.Sequential(*model2short9) 124 | self.model1short10 = nn.Sequential(*model1short10) 125 | 126 | self.model_class = nn.Sequential(*model_class) 127 | self.model_out = nn.Sequential(*model_out) 128 | 129 | self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),]) 130 | self.softmax = nn.Sequential(*[nn.Softmax(dim=1),]) 131 | 132 | def forward(self, input_A, input_B=None, mask_B=None): 133 | if(input_B is None): 134 | input_B = torch.cat((input_A*0, input_A*0), dim=1) 135 | if(mask_B is None): 136 | mask_B = input_A*0 137 | 138 | conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1)) 139 | conv2_2 = self.model2(conv1_2[:,:,::2,::2]) 140 | conv3_3 = self.model3(conv2_2[:,:,::2,::2]) 141 | conv4_3 = self.model4(conv3_3[:,:,::2,::2]) 142 | conv5_3 = self.model5(conv4_3) 143 | conv6_3 = self.model6(conv5_3) 144 | conv7_3 = self.model7(conv6_3) 145 | 146 | conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) 147 | conv8_3 = self.model8(conv8_up) 148 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) 149 | conv9_3 = self.model9(conv9_up) 150 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) 151 | conv10_2 = self.model10(conv10_up) 152 | out_reg = self.model_out(conv10_2) 153 | 154 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) 155 | conv9_3 = self.model9(conv9_up) 156 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) 157 | conv10_2 = self.model10(conv10_up) 158 | out_reg = self.model_out(conv10_2) 159 | 160 | return self.unnormalize_ab(out_reg) 161 | 162 | def siggraph17(pretrained=True): 163 | model = SIGGRAPHGenerator() 164 | if(pretrained): 165 | import torch.utils.model_zoo as model_zoo 166 | model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True)) 167 | return model 168 | 169 | -------------------------------------------------------------------------------- /colorizers/util.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image 3 | import numpy as np 4 | from skimage import color 5 | import torch 6 | import torch.nn.functional as F 7 | from IPython import embed 8 | 9 | def load_img(img_path): 10 | out_np = np.asarray(Image.open(img_path)) 11 | if(out_np.ndim==2): 12 | out_np = np.tile(out_np[:,:,None],3) 13 | return out_np 14 | 15 | def resize_img(img, HW=(256,256), resample=3): 16 | return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample)) 17 | 18 | def preprocess_img(img_rgb_orig, HW=(256,256), resample=3): 19 | # return original size L and resized L as torch Tensors 20 | img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample) 21 | 22 | img_lab_orig = color.rgb2lab(img_rgb_orig) 23 | img_lab_rs = color.rgb2lab(img_rgb_rs) 24 | 25 | img_l_orig = img_lab_orig[:,:,0] 26 | img_l_rs = img_lab_rs[:,:,0] 27 | 28 | tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:] 29 | tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:] 30 | 31 | return (tens_orig_l, tens_rs_l) 32 | 33 | def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'): 34 | # tens_orig_l 1 x 1 x H_orig x W_orig 35 | # out_ab 1 x 2 x H x W 36 | 37 | HW_orig = tens_orig_l.shape[2:] 38 | HW = out_ab.shape[2:] 39 | 40 | # call resize function if needed 41 | if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]): 42 | out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear') 43 | else: 44 | out_ab_orig = out_ab 45 | 46 | out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1) 47 | return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0))) 48 | -------------------------------------------------------------------------------- /demo_release.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | 5 | from colorizers import * 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('-i','--img_path', type=str, default='imgs/ansel_adams3.jpg') 9 | parser.add_argument('--use_gpu', action='store_true', help='whether to use GPU') 10 | parser.add_argument('-o','--save_prefix', type=str, default='saved', help='will save into this file with {eccv16.png, siggraph17.png} suffixes') 11 | opt = parser.parse_args() 12 | 13 | # load colorizers 14 | colorizer_eccv16 = eccv16(pretrained=True).eval() 15 | colorizer_siggraph17 = siggraph17(pretrained=True).eval() 16 | if(opt.use_gpu): 17 | colorizer_eccv16.cuda() 18 | colorizer_siggraph17.cuda() 19 | 20 | # default size to process images is 256x256 21 | # grab L channel in both original ("orig") and resized ("rs") resolutions 22 | img = load_img(opt.img_path) 23 | (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256)) 24 | if(opt.use_gpu): 25 | tens_l_rs = tens_l_rs.cuda() 26 | 27 | # colorizer outputs 256x256 ab map 28 | # resize and concatenate to original L channel 29 | img_bw = postprocess_tens(tens_l_orig, torch.cat((0*tens_l_orig,0*tens_l_orig),dim=1)) 30 | out_img_eccv16 = postprocess_tens(tens_l_orig, colorizer_eccv16(tens_l_rs).cpu()) 31 | out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu()) 32 | 33 | plt.imsave('%s_eccv16.png'%opt.save_prefix, out_img_eccv16) 34 | plt.imsave('%s_siggraph17.png'%opt.save_prefix, out_img_siggraph17) 35 | 36 | plt.figure(figsize=(12,8)) 37 | plt.subplot(2,2,1) 38 | plt.imshow(img) 39 | plt.title('Original') 40 | plt.axis('off') 41 | 42 | plt.subplot(2,2,2) 43 | plt.imshow(img_bw) 44 | plt.title('Input') 45 | plt.axis('off') 46 | 47 | plt.subplot(2,2,3) 48 | plt.imshow(out_img_eccv16) 49 | plt.title('Output (ECCV 16)') 50 | plt.axis('off') 51 | 52 | plt.subplot(2,2,4) 53 | plt.imshow(out_img_siggraph17) 54 | plt.title('Output (SIGGRAPH 17)') 55 | plt.axis('off') 56 | plt.show() 57 | -------------------------------------------------------------------------------- /imgs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/imgs/.DS_Store -------------------------------------------------------------------------------- /imgs/ILSVRC2012_val_00041580.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/imgs/ILSVRC2012_val_00041580.JPEG -------------------------------------------------------------------------------- /imgs/ILSVRC2012_val_00046524.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/imgs/ILSVRC2012_val_00046524.JPEG -------------------------------------------------------------------------------- /imgs/ILSVRC2012_val_00046834.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/imgs/ILSVRC2012_val_00046834.JPEG -------------------------------------------------------------------------------- /imgs/ansel_adams.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/imgs/ansel_adams.jpg -------------------------------------------------------------------------------- /imgs/ansel_adams2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/imgs/ansel_adams2.jpg -------------------------------------------------------------------------------- /imgs/ansel_adams3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/imgs/ansel_adams3.jpg -------------------------------------------------------------------------------- /imgs_out/saved_eccv16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/imgs_out/saved_eccv16.png -------------------------------------------------------------------------------- /imgs_out/saved_siggraph17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richzhang/colorization/4f6009ed1495b1300231ebeb41cc4015557ddef7/imgs_out/saved_siggraph17.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | skimage 3 | numpy 4 | matplotlib 5 | argparse 6 | PIL --------------------------------------------------------------------------------