├── .gitignore ├── utils ├── crop_images.py ├── extract_frames.sh └── convert_images.py ├── eval.py ├── README.md ├── net.py ├── data_ops.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pcy 2 | temp/ 3 | *.*~ 4 | -------------------------------------------------------------------------------- /utils/crop_images.py: -------------------------------------------------------------------------------- 1 | # crop images for crystal and diamond 2 | 3 | import cv2 4 | import sys 5 | import glob 6 | from tqdm import tqdm 7 | 8 | image_dir = sys.argv[1] 9 | images = glob.glob(image_dir+'*.png') 10 | 11 | print 'Cropping images' 12 | for image in tqdm(images): 13 | img = cv2.imread(image) 14 | img = img[0:360, 120:520] 15 | #cv2.imshow('image',img) 16 | #cv2.waitKey(0) 17 | #cv2.destroyAllWindows() 18 | #exit() 19 | cv2.imwrite(image, img) 20 | -------------------------------------------------------------------------------- /utils/extract_frames.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script for extracting frames from multiple video files 4 | # If you have a folder filled with videos like so... 5 | # 6 | # folder/ 7 | # - video_1.mp4 8 | # - video_2.mp4 9 | # - ... 10 | 11 | # you can simply run the script by doing `./extract_frames.sh folder/ 12 | # and it will create an individual folder for each video and place the 13 | # images in their respective folders. 14 | 15 | for file in "$1"/*.*; do 16 | destination="${file%.*}" 17 | echo "Extracting from $file..." 18 | mkdir -p "$destination" 19 | ffmpeg -i "$file" -r 1/1 "$destination/image_%03d.png" 20 | done 21 | 22 | echo "" 23 | echo "Deleting images that are less than 40kb..." 24 | find "$1" -name "*.png" -size -40k -delete 25 | -------------------------------------------------------------------------------- /utils/convert_images.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Cameron Fabbri 3 | 4 | Script to resize and convert images to grayscale for 5 | reading in tensorflow. The new image size can be set, 6 | but this purpose is (160x144) for gameboy color res 7 | 8 | This script will NOT overwrite the original images, 9 | but create two new images from every image read in. 10 | 11 | image_1.png 12 | image_1_resized.png 13 | image_1_resized_gray.png 14 | 15 | ''' 16 | 17 | import sys 18 | import fnmatch 19 | import os 20 | from tqdm import tqdm 21 | 22 | if __name__ == '__main__': 23 | 24 | data_dir = sys.argv[1] 25 | pattern = "*.png" 26 | image_list = list() 27 | for d, s, fList in os.walk(data_dir): 28 | for filename in fList: 29 | if fnmatch.fnmatch(filename, pattern): 30 | image_list.append(os.path.join(d,filename)) 31 | 32 | print 'Working on images...' 33 | for image in tqdm(image_list): 34 | image_dir = os.path.dirname(image) 35 | resized_image = image_dir+'/'+image.split('/')[-1].split('.')[0]+'_resized.png' 36 | resized_gray_image = image_dir+'/'+image.split('/')[-1].split('.')[0]+'_resized_gray.png' 37 | 38 | # the 'true' image to be used in tensorflow (label) 39 | os.system('convert "' + image + '" -resize 160x144\! "' + resized_image +'"') 40 | 41 | # the gray image that we are training on 42 | os.system('convert "' + image + '" -resize 160x144\! -colorspace Gray "' + resized_gray_image +'"') 43 | 44 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | from tqdm import tqdm 3 | import tensorflow as tf 4 | import colorarch 5 | from scipy import misc 6 | import numpy as np 7 | import argparse 8 | import ntpath 9 | import sys 10 | import os 11 | import time 12 | import glob 13 | 14 | sys.path.insert(0, 'ops/') 15 | sys.path.insert(0, 'config/') 16 | 17 | import data_ops 18 | 19 | if __name__ == '__main__': 20 | 21 | CHECKPOINT_DIR = 'checkpoints/' 22 | IMAGES_DIR = CHECKPOINT_DIR+'images/' 23 | BATCH_SIZE=1 24 | 25 | test_images = glob.glob(sys.argv[1]+'*.*') 26 | num_images = len(test_images) 27 | 28 | Data = data_ops.loadData(sys.argv[1], BATCH_SIZE, train=False) 29 | # The gray 'lightness' channel in range [-1, 1] 30 | gray_image = Data.inputs 31 | 32 | # The color channels in [-1, 1] range 33 | color_image = Data.targets 34 | 35 | # architecture from 36 | # http://hi.cs.waseda.ac.jp/~iizuka/projects/colorization/data/colorization_sig2016.pdf 37 | col_img = colorarch.architecture(gray_image, train=False) 38 | 39 | col_img = tf.image.convert_image_dtype(col_img, dtype=tf.uint8, saturate=True) 40 | 41 | saver = tf.train.Saver(max_to_keep=1) 42 | 43 | init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 44 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) 45 | sess.run(init) 46 | 47 | ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR) 48 | # restore previous model if there is one 49 | if ckpt and ckpt.model_checkpoint_path: 50 | print "Restoring previous model..." 51 | try: 52 | saver.restore(sess, ckpt.model_checkpoint_path) 53 | print "Model restored" 54 | except: 55 | print "Could not restore model" 56 | pass 57 | 58 | ########################################### training portion 59 | start = time.time() 60 | coord = tf.train.Coordinator() 61 | threads = tf.train.start_queue_runners(sess, coord=coord) 62 | 63 | prediction = np.squeeze(np.asarray(sess.run(col_img))) 64 | i = 1 65 | misc.imsave(IMAGES_DIR+str(i)+'.png', prediction) 66 | i+=1 67 | print 'Done. Images are in',IMAGES_DIR 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Colorizing Images 2 | 3 | **UPDATE - Completely cleaning up code for Tensorflow 1.0 and retraining models.** 4 | 5 | A deep learning approach to colorizing images, specifically for Pokemon. 6 | 7 | The current model was trained on screenshots taken from Pokemon Silver, Crystal, 8 | and Diamond, then tested on Pokemon Blue Version. Sample results below. 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | ## Basic Training Usage 17 | The files in the `images/train` folder are as follows: 18 | 19 | ## Evaluating on Images 20 | I've included a trained model in the `models/` directory that you can run your own images on. 21 | You can either run the model on one image or a folder of images. For one image, run `eval_one.py` 22 | and pass it the model and the image as parameters. To run it on multiple images, run `eval.py` 23 | and pass it the model and the folder to the images. `eval.py` will save your images in the 24 | `output` folder, where as `eval_one.py` will save them in the current directory. Examples: 25 | 26 | ## Training your own data 27 | 28 | There are scripts included to help create your own dataset, which is desirable because 29 | the amount of data needed to obtain good results is a good amount. The results above 30 | were trained on about 50,000 images. 31 | 32 | The easiest method to obtain images is to extract them from Youtube walkthrough videos of 33 | different games. Given that you have a folder with videos 34 | 35 | `videos/` 36 | 37 | `video_1.mp4` 38 | 39 | `video_2.mp4` 40 | 41 | `...` 42 | 43 | 44 | use `extract_frames.sh` to extract images from each video. Just pass it the folder containing images. 45 | 46 | Depending on if the video had a border around the game, you may need to use `crop_images.py` to crop 47 | out the border. There are comments in the script you can uncomment to view the image before it crops 48 | all of them to be sure the cropping is correct. 49 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import sys 3 | 4 | ''' 5 | Leaky RELU 6 | https://arxiv.org/pdf/1502.01852.pdf 7 | ''' 8 | def lrelu(x, leak=0.2, name='lrelu'): 9 | return tf.maximum(leak*x, x) 10 | 11 | def architecture(gray_image, train=True): 12 | conv1 = lrelu(tf.layers.conv2d(gray_image, 32, 1, strides=1, name='conv1',padding='VALID')) 13 | print 'conv1:',conv1 14 | conv2 = lrelu(tf.layers.conv2d(conv1, 32, 1, strides=1, name='conv2',padding='VALID')) 15 | print 'conv2:',conv2 16 | conv3 = lrelu(tf.layers.conv2d(conv2, 64, 1, strides=1, name='conv3',padding='VALID')) 17 | print 'conv3:',conv3 18 | conv4 = lrelu(tf.layers.conv2d(conv3, 64, 1, strides=1, name='conv4',padding='VALID')) 19 | print 'conv4:',conv4 20 | conv5 = lrelu(tf.layers.conv2d(conv4, 128, 1, strides=1, name='conv5',padding='VALID')) 21 | print 'conv5:',conv5 22 | conv6 = lrelu(tf.layers.conv2d(conv5, 128, 1, strides=1, name='conv6',padding='VALID')) 23 | print 'conv6:',conv6 24 | conv7 = lrelu(tf.layers.conv2d(conv6, 256, 1, strides=1, name='conv7',padding='VALID')) 25 | print 'conv7:',conv7 26 | conv8 = lrelu(tf.layers.conv2d(conv7, 256, 1, strides=1, name='conv8',padding='VALID')) 27 | print 'conv8:',conv8 28 | conv9 = lrelu(tf.layers.conv2d(conv8, 128, 1, strides=1, name='conv9',padding='VALID')) 29 | print 'conv9:',conv9 30 | conv10 = lrelu(tf.layers.conv2d(conv9, 128, 1, strides=1, name='conv10',padding='VALID')) 31 | print 'conv10:',conv10 32 | conv11 = lrelu(tf.layers.conv2d(conv10, 64, 1, strides=1, name='conv11',padding='VALID')) 33 | print 'conv11:',conv11 34 | conv12 = lrelu(tf.layers.conv2d(conv11, 64, 1, strides=1, name='conv12',padding='VALID')) 35 | print 'conv12:',conv12 36 | conv13 = lrelu(tf.layers.conv2d(conv12, 32, 1, strides=1, name='conv13',padding='VALID')) 37 | print 'conv13:',conv13 38 | conv14 = lrelu(tf.layers.conv2d(conv13, 32, 1, strides=1, name='conv14',padding='VALID')) 39 | print 'conv14:',conv14 40 | conv15 = lrelu(tf.layers.conv2d(conv14, 16, 1, strides=1, name='conv15',padding='VALID')) 41 | print 'conv15:',conv15 42 | conv16 = lrelu(tf.layers.conv2d(conv15, 16, 1, strides=1, name='conv16',padding='VALID')) 43 | print 'conv16:',conv16 44 | conv17 = lrelu(tf.layers.conv2d(conv16, 8, 1, strides=1, name='conv17',padding='VALID')) 45 | print 'conv17:',conv17 46 | if train: conv17 = tf.nn.dropout(conv17, 0.8) 47 | conv18 = lrelu(tf.layers.conv2d(conv17, 3, 1, strides=1, name='conv18',padding='VALID')) 48 | if train: conv18 = tf.nn.dropout(conv18, 0.8) 49 | 50 | tf.add_to_collection('vars', conv1) 51 | tf.add_to_collection('vars', conv2) 52 | tf.add_to_collection('vars', conv3) 53 | tf.add_to_collection('vars', conv4) 54 | tf.add_to_collection('vars', conv5) 55 | tf.add_to_collection('vars', conv6) 56 | tf.add_to_collection('vars', conv7) 57 | tf.add_to_collection('vars', conv8) 58 | tf.add_to_collection('vars', conv9) 59 | tf.add_to_collection('vars', conv10) 60 | tf.add_to_collection('vars', conv11) 61 | tf.add_to_collection('vars', conv12) 62 | tf.add_to_collection('vars', conv13) 63 | tf.add_to_collection('vars', conv14) 64 | tf.add_to_collection('vars', conv15) 65 | tf.add_to_collection('vars', conv16) 66 | tf.add_to_collection('vars', conv17) 67 | tf.add_to_collection('vars', conv18) 68 | 69 | return conv18 70 | -------------------------------------------------------------------------------- /data_ops.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | Operations used for data management 4 | 5 | MASSIVE help from https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py 6 | 7 | ''' 8 | 9 | from __future__ import division 10 | from __future__ import absolute_import 11 | 12 | from scipy import misc 13 | from skimage import color 14 | import collections 15 | import tensorflow as tf 16 | import numpy as np 17 | import math 18 | import time 19 | import random 20 | import glob 21 | import os 22 | import fnmatch 23 | import cPickle as pickle 24 | 25 | Data = collections.namedtuple('trainData', 'paths, inputs, targets, count, steps_per_epoch') 26 | 27 | 28 | def getPaths(data_dir, ext='jpg'): 29 | pattern = '*.'+ext 30 | image_paths = [] 31 | for d, s, fList in os.walk(data_dir): 32 | for filename in fList: 33 | if fnmatch.fnmatch(filename, pattern): 34 | image_paths.append(os.path.join(d,filename)) 35 | return image_paths 36 | 37 | 38 | def loadData(data_dir, batch_size, train=True): 39 | 40 | if data_dir is None or not os.path.exists(data_dir): raise Exception('data_dir does not exist') 41 | 42 | if train: 43 | pkl_train_file = 'pokemon.pkl' 44 | 45 | if os.path.isfile(pkl_train_file): 46 | print 'Found pickle file' 47 | train_paths = pickle.load(open(pkl_train_file, 'rb')) 48 | else: 49 | train_paths = getPaths(data_dir) 50 | random.shuffle(train_paths) 51 | 52 | pf = open(pkl_train_file, 'wb') 53 | data = pickle.dumps(train_paths) 54 | pf.write(data) 55 | pf.close() 56 | input_paths = train_paths 57 | 58 | else: 59 | input_paths = [data_dir] 60 | 61 | decode = tf.image.decode_image 62 | 63 | if len(input_paths) == 0: raise Exception('data_dir contains no image files') 64 | else: print 'Found',len(input_paths),'images!' 65 | 66 | with tf.name_scope('load_images'): 67 | path_queue = tf.train.string_input_producer(input_paths, shuffle=train) 68 | reader = tf.WholeFileReader() 69 | paths, contents = reader.read(path_queue) 70 | raw_input_ = decode(contents) 71 | raw_input_ = tf.image.convert_image_dtype(raw_input_, dtype=tf.float32) 72 | 73 | raw_input_.set_shape([None, None, 3]) 74 | 75 | inputs = tf.image.rgb_to_grayscale(raw_input_) 76 | targets = raw_input_ 77 | 78 | scale_size = 180 79 | height = 160 80 | width = 144 81 | 82 | seed = random.randint(0, 2**31 - 1) 83 | def transform(image): 84 | r = image 85 | r = tf.image.random_flip_left_right(r, seed=seed) 86 | r = tf.image.resize_images(r, [height, width], method=tf.image.ResizeMethod.AREA) 87 | #offset = tf.cast(tf.floor(tf.random_uniform([2], 0, scale_size - width + 1, seed=seed)), dtype=tf.int32) 88 | #r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], height, width) 89 | return r 90 | 91 | if train: 92 | input_images = transform(inputs) 93 | target_images = transform(targets) 94 | else: 95 | input_images = tf.image.resize_images(inputs, [160, 160], method=tf.image.ResizeMethod.AREA) 96 | target_images = tf.image.resize_images(targets, [160, 160], method=tf.image.ResizeMethod.AREA) 97 | 98 | paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=batch_size) 99 | steps_per_epoch = int(math.ceil(len(input_paths) / batch_size)) 100 | 101 | return Data( 102 | paths=paths_batch, 103 | inputs=inputs_batch, 104 | targets=targets_batch, 105 | count=len(input_paths), 106 | steps_per_epoch=steps_per_epoch, 107 | ) 108 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import tensorflow as tf 3 | from scipy import misc 4 | import numpy as np 5 | import argparse 6 | import ntpath 7 | import sys 8 | import os 9 | import time 10 | 11 | import data_ops 12 | import net 13 | 14 | if __name__ == '__main__': 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--EPOCHS', required=False,default=10,type=int,help='Number of epochs to train for') 18 | parser.add_argument('--DATA_DIR', required=True,help='Directory where data is') 19 | parser.add_argument('--BATCH_SIZE', required=False,type=int,default=32,help='Batch size to use') 20 | a = parser.parse_args() 21 | 22 | EPOCHS = a.EPOCHS 23 | DATA_DIR = a.DATA_DIR 24 | BATCH_SIZE = a.BATCH_SIZE 25 | 26 | CHECKPOINT_DIR = 'checkpoints/' 27 | IMAGES_DIR = CHECKPOINT_DIR+'images/' 28 | 29 | try: os.mkdir(CHECKPOINT_DIR) 30 | except: pass 31 | try: os.mkdir(IMAGES_DIR) 32 | except: pass 33 | 34 | # write all this info to a pickle file in the experiments directory 35 | exp_info = dict() 36 | exp_info['EPOCHS'] = EPOCHS 37 | exp_info['DATA_DIR'] = DATA_DIR 38 | exp_info['BATCH_SIZE'] = BATCH_SIZE 39 | exp_pkl = open(CHECKPOINT_DIR+'info.pkl', 'wb') 40 | data = pickle.dumps(exp_info) 41 | exp_pkl.write(data) 42 | exp_pkl.close() 43 | 44 | print 45 | print 'EPOCHS: ',EPOCHS 46 | print 'DATA_DIR: ',DATA_DIR 47 | print 'BATCH_SIZE: ',BATCH_SIZE 48 | print 49 | 50 | # global step that is saved with a model to keep track of how many steps/epochs 51 | global_step = tf.Variable(0, name='global_step', trainable=False) 52 | 53 | # load data 54 | Data = data_ops.loadData(DATA_DIR, BATCH_SIZE) 55 | 56 | num_train = Data.count 57 | gray_image = Data.inputs 58 | color_image = Data.targets 59 | 60 | # architecture from 61 | col_img = net.architecture(gray_image) 62 | 63 | #loss = tf.reduce_mean((ab_image-col_img)**2) 64 | loss = tf.reduce_mean(tf.nn.l2_loss(color_image-col_img)) 65 | train_op = tf.train.AdamOptimizer(learning_rate=1e-6).minimize(loss, global_step=global_step) 66 | saver = tf.train.Saver(max_to_keep=1) 67 | 68 | # tensorboard summaries 69 | tf.summary.scalar('loss', loss) 70 | 71 | init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 72 | sess = tf.Session() 73 | sess.run(init) 74 | 75 | # write out logs for tensorboard to the checkpointSdir 76 | summary_writer = tf.summary.FileWriter(CHECKPOINT_DIR+'/logs/', graph=tf.get_default_graph()) 77 | 78 | ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR) 79 | # restore previous model if there is one 80 | if ckpt and ckpt.model_checkpoint_path: 81 | print "Restoring previous model..." 82 | try: 83 | saver.restore(sess, ckpt.model_checkpoint_path) 84 | print "Model restored" 85 | except: 86 | print "Could not restore model" 87 | pass 88 | 89 | ########################################### training portion 90 | step = sess.run(global_step) 91 | coord = tf.train.Coordinator() 92 | threads = tf.train.start_queue_runners(sess, coord=coord) 93 | merged_summary_op = tf.summary.merge_all() 94 | start = time.time() 95 | 96 | epoch_num = step/(num_train/BATCH_SIZE) 97 | while epoch_num < EPOCHS: 98 | epoch_num = step/(num_train/BATCH_SIZE) 99 | s = time.time() 100 | sess.run(train_op) 101 | loss_, summary = sess.run([loss, merged_summary_op]) 102 | summary_writer.add_summary(summary, step) 103 | summary_writer.add_summary(summary, step) 104 | print 'epoch:',epoch_num,'step:',step,'loss:',loss_,'time:',time.time()-s 105 | step += 1 106 | 107 | if step%500 == 0: 108 | print 'Saving model...' 109 | saver.save(sess, CHECKPOINT_DIR+'checkpoint-'+str(step)) 110 | saver.export_meta_graph(CHECKPOINT_DIR+'checkpoint-'+str(step)+'.meta') 111 | print 'Model saved\n' 112 | 113 | print 'Finished training', time.time()-start 114 | saver.save(sess, CHECKPOINT_DIR+'checkpoint-'+str(step)) 115 | saver.export_meta_graph(CHECKPOINT_DIR+'checkpoint-'+str(step)+'.meta') 116 | exit() 117 | --------------------------------------------------------------------------------