├── .gitignore ├── color.png ├── gray.png ├── naive.png ├── images ├── 01.png ├── 02.png ├── 03.png ├── 04.png ├── 05.png ├── 06.png └── 07.png ├── README.md ├── log ├── dist.lua ├── model.lua ├── test.lua ├── opts.lua ├── main.lua └── color2gray.lua /.gitignore: -------------------------------------------------------------------------------- 1 | log -------------------------------------------------------------------------------- /color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/HEAD/color.png -------------------------------------------------------------------------------- /gray.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/HEAD/gray.png -------------------------------------------------------------------------------- /naive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/HEAD/naive.png -------------------------------------------------------------------------------- /images/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/HEAD/images/01.png -------------------------------------------------------------------------------- /images/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/HEAD/images/02.png -------------------------------------------------------------------------------- /images/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/HEAD/images/03.png -------------------------------------------------------------------------------- /images/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/HEAD/images/04.png -------------------------------------------------------------------------------- /images/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/HEAD/images/05.png -------------------------------------------------------------------------------- /images/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/HEAD/images/06.png -------------------------------------------------------------------------------- /images/07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/HEAD/images/07.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CNN-Color2Gray 2 | 3 | 4 | ## Introduction 5 | 6 | This is an implementation of [Color2Gray](http://www.cs.northwestern.edu/~ago820/color2gray/) with convolutional neural networks. 7 | 8 | ## Example 9 | 10 | original image 11 | ![color](https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/master/color.png) 12 | 13 | naive color transformation 14 | ![naive](https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/master/naive.png) 15 | 16 | Color2Gray 17 | ![gray](https://raw.githubusercontent.com/yangky11/CNN-Color2Gray/master/gray.png) 18 | 19 | ## Usage 20 | 21 | qlua main.lua -h 22 | qlua main.lua -inp ./color.png 23 | 24 | You can also use `th` but then no image will be displayed. 25 | -------------------------------------------------------------------------------- /log: -------------------------------------------------------------------------------- 1 | [program started on Thu Dec 24 00:24:28 2015] 2 | [command line arguments] 3 | oup gray.png 4 | LR 0.1 5 | theta 45 6 | cuda false 7 | momentum 0.5 8 | inp images/07.png 9 | optimAlgo cg 10 | alpha 10 11 | miniter 10 12 | webcam false 13 | batchsize 1 14 | test false 15 | maxiter 2000 16 | [----------------------] 17 | image size: 223x311 18 | ------------------ 19 | 32.02677154541 20 | 5.3820643424988 21 | 4.1118364334106 22 | 3.1892559528351 23 | 2.8391959667206 24 | 2.7098784446716 25 | 2.7098784446716 26 | 2.7098784446716 27 | 2.7098784446716 28 | 2.7098784446716 29 | 30 | loss failed to decrease, stopped earlier 31 | final loss: 2.709878 32 | time elapsed: 1.614357sec 33 | the result saved to gray.png 34 | -------------------------------------------------------------------------------- /dist.lua: -------------------------------------------------------------------------------- 1 | require 'math' 2 | 3 | 4 | function dist(img) 5 | local deltaL = model:forward(img:narrow(1, 1, 1)):clone() 6 | local deltaA = model:forward(img:narrow(1, 2, 1)):clone() 7 | local deltaB = model:forward(img:narrow(1, 3, 1)):clone() 8 | if opts.test then 9 | paths.dofile('test.lua') 10 | test(img[1], deltaL) 11 | test(img[2], deltaA) 12 | test(img[3], deltaB) 13 | end 14 | local normDeltaC = torch.sqrt(torch.cmul(deltaA, deltaA) + torch.cmul(deltaB, deltaB)) 15 | local function crunch(x) return torch.tanh(x / opts.alpha) * opts.alpha end 16 | local delta = deltaL:clone() 17 | local case23 = torch.lt(torch.abs(deltaL), crunch(normDeltaC)) 18 | local sign = torch.sign(deltaA * torch.cos(opts.theta * math.pi / 180) + deltaB * torch.sin(opts.theta * math.pi / 180)) 19 | delta[case23] = crunch(torch.cmul(sign, normDeltaC)[case23]) 20 | return delta 21 | end -------------------------------------------------------------------------------- /model.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | 4 | function createModel() 5 | local conv 6 | local criterion = nn.MSECriterion(true) 7 | if opts.cuda then 8 | require 'cunn' 9 | require 'cudnn' 10 | conv = cudnn.SpatialConvolution(1, 8, 3, 3, 1, 1, 1, 1) 11 | conv:cuda() 12 | criterion:cuda() 13 | else 14 | conv = nn.SpatialConvolution(1, 8, 3, 3, 1, 1, 1, 1) 15 | end 16 | 17 | conv.weight:zero() 18 | conv.weight:sub(1, -1, 1, -1, 2, 2, 2, 2):add(1) 19 | conv.weight[{1, 1, 1, 1}] = -1 20 | conv.weight[{2, 1, 1, 2}] = -1 21 | conv.weight[{3, 1, 1, 3}] = -1 22 | conv.weight[{4, 1, 2, 1}] = -1 23 | conv.weight[{5, 1, 2, 3}] = -1 24 | conv.weight[{6, 1, 3, 1}] = -1 25 | conv.weight[{7, 1, 3, 2}] = -1 26 | conv.weight[{8, 1, 3, 3}] = -1 27 | conv.bias:zero() 28 | if opts.cuda then 29 | conv:resetWeightDescriptors() 30 | end 31 | 32 | return conv, criterion 33 | end 34 | -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | function test(img, delta) 2 | assert(img:dim() == 2 and delta:dim() == 3 and delta:size(1) == 8 3 | and img:size(1) == delta:size(2) and img:size(2) == delta:size(3)) 4 | local height = img:size(1) 5 | local width = img:size(2) 6 | local eps = 1e-5 7 | for i = 1, 8 do 8 | if i == 1 then 9 | u, v = -1, -1 10 | elseif i == 2 then 11 | u, v = -1, 0 12 | elseif i == 3 then 13 | u, v = -1, 1 14 | elseif i == 4 then 15 | u, v = 0, -1 16 | elseif i == 5 then 17 | u, v = 0, 1 18 | elseif i == 6 then 19 | u, v = 1, -1 20 | elseif i == 7 then 21 | u, v = 1, 0 22 | else 23 | u, v = 1, 1 24 | end 25 | for j = 1, height do 26 | for k = 1, width do 27 | if 1 <= j + u and j + u <= height and 1 <= k + v and k + v <= height then 28 | assert(torch.abs(delta[{i, j, k}] - img[{j, k}] + img[{j + u, k + v}]) < eps 29 | , string.format('testing failed: %f ~= %f', delta[{i, j, k}], img[{j, k}] - img[{j + u, k + v}])) 30 | end 31 | end 32 | end 33 | end 34 | end -------------------------------------------------------------------------------- /opts.lua: -------------------------------------------------------------------------------- 1 | require 'optim' 2 | 3 | 4 | local cmd = torch.CmdLine() 5 | cmd:text() 6 | cmd:text() 7 | cmd:text('Color2Gray: Salience-Preserving Color Removal') 8 | cmd:text() 9 | cmd:text('Options') 10 | cmd:option('-inp', 'color.png', 'path to the input image') 11 | cmd:option('-oup', 'gray.png', 'path to the output image') 12 | cmd:option('-batchsize', 1, 'the number of images in a mini-batch') 13 | cmd:option('-maxiter', 2000, 'the maximum number of iterations') 14 | cmd:option('-miniter', 10, 'the minimum number of iterations') 15 | cmd:option('-LR', 0.1, 'learning rate(if applicable)') 16 | cmd:option('-momentum', 0.5, 'momentum(if applicable)') 17 | cmd:option('-theta', 45, 'whether chromatic differences are mapped to increases or decreases in luminance value') 18 | cmd:option('-alpha', 10, 'how much chromatic variation is allowed to change the source luminance value') 19 | cmd:option('-test', false, 'whether to use naive implmentation as a double-check') 20 | cmd:option('-optimAlgo', 'cg', 'optimization algorithm: cg, lbfgs, etc.') 21 | cmd:option('-cuda', false, 'enable CUDA support') 22 | cmd:option('-webcam', false, 'take input from the webcam') 23 | cmd:text() 24 | 25 | opts = cmd:parse(arg) 26 | cmd:log('./log', opts) 27 | 28 | opts.optimAlgo = optim[opts.optimAlgo] 29 | assert(opts.miniter <= opts.maxiter, 'miniter > maxiter') -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'paths' 3 | paths.dofile('opts.lua') 4 | if opts.cuda then 5 | require 'cutorch' 6 | end 7 | require 'image' 8 | torch.setdefaulttensortype('torch.FloatTensor') 9 | 10 | -- create the model and training criterion 11 | paths.dofile('model.lua') 12 | model, criterion = createModel() 13 | 14 | require 'color2gray' 15 | if opts.webcam then 16 | local cv = require 'cv' 17 | require 'cv.highgui' 18 | require 'cv.videoio' 19 | local cap = cv.VideoCapture({device=0}) 20 | assert(cap:isOpened(), 'failed to open the camera') 21 | cv.namedWindow({winname='CNN-Color2Gray', flags=cv.WINDOW_AUTOSIZE}) 22 | local _, frame = cap:read({}) 23 | while true do 24 | frame = frame:float() 25 | print(frame:size()) 26 | os.exit() 27 | local gray = color2gray(frame, model, criterion, false); 28 | --image.display(frame) 29 | --image.save(opts.oup, gray) 30 | print(gray:size()) 31 | os.exit() 32 | --cv.imshow({winname='CNN-Color2Gray', image=gray}) 33 | --cap:read({frame}) 34 | end 35 | else 36 | local rgbImg = image.load(opts.inp) 37 | print(string.format('image size: %dx%d\n------------------', rgbImg:size(2), rgbImg:size(3))) 38 | local gray = color2gray(rgbImg, model, criterion) 39 | print(string.format('the result saved to %s', opts.oup)) 40 | image.save(opts.oup, gray) 41 | local status, _ = pcall(function() require 'qlua' end) 42 | if status == true then 43 | image.display(rgbImg) 44 | image.display(gray) 45 | end 46 | end 47 | -------------------------------------------------------------------------------- /color2gray.lua: -------------------------------------------------------------------------------- 1 | require 'optim' 2 | 3 | 4 | function color2gray(rgbImg, model, criterion, silent) 5 | 6 | silent = silent or false 7 | local timer = torch.Timer() 8 | 9 | local labImg = image.rgb2lab(rgbImg) 10 | if opts.cuda then 11 | labImg = labImg:cuda() 12 | end 13 | 14 | -- compute the distances 15 | require 'dist' 16 | local delta = dist(labImg) 17 | 18 | -- define the loss function to optimize 19 | local function loss(x) 20 | local output = model:forward(x) 21 | local f = criterion:forward(output, delta) 22 | model:backward(x, criterion:backward(output, delta)) 23 | local grad = model.gradInput:clone() 24 | grad:add(-grad:mean()) 25 | return f, grad 26 | end 27 | 28 | -- train the model 29 | optimState = { 30 | learningRate = opts.LR, 31 | momentum = opts.momentum, 32 | } 33 | local gray = torch.mean(rgbImg, 1) 34 | if opts.cuda then 35 | gray = gray:cuda() 36 | end 37 | local errors = {} 38 | local cnt = 0 39 | for i = 1, opts.maxiter do 40 | local _, err = opts.optimAlgo(loss, gray, optimState) 41 | if #errors >= 1 and err[1] >= errors[#errors] then 42 | cnt = cnt + 1 43 | if cnt == 5 and i > opts.miniter then 44 | if not silent then print('\nloss failed to decrease, stopped earlier') end 45 | break 46 | end 47 | end 48 | errors[#errors + 1] = err[1] 49 | if not silent then xlua.progress(i, opts.maxiter) end 50 | print(err[1]) 51 | end 52 | if not silent then 53 | print(string.format('final loss: %f', errors[#errors])) 54 | print(string.format('time elapsed: %fsec', timer:time().real)) 55 | end 56 | 57 | -- normalize the result 58 | gray:add(-gray:min()) 59 | gray:div(gray:max()) 60 | 61 | return gray 62 | end 63 | --------------------------------------------------------------------------------