├── .gitignore
├── utils
├── crop_images.py
├── extract_frames.sh
└── convert_images.py
├── eval.py
├── README.md
├── net.py
├── data_ops.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pcy
2 | temp/
3 | *.*~
4 |
--------------------------------------------------------------------------------
/utils/crop_images.py:
--------------------------------------------------------------------------------
1 | # crop images for crystal and diamond
2 |
3 | import cv2
4 | import sys
5 | import glob
6 | from tqdm import tqdm
7 |
8 | image_dir = sys.argv[1]
9 | images = glob.glob(image_dir+'*.png')
10 |
11 | print 'Cropping images'
12 | for image in tqdm(images):
13 | img = cv2.imread(image)
14 | img = img[0:360, 120:520]
15 | #cv2.imshow('image',img)
16 | #cv2.waitKey(0)
17 | #cv2.destroyAllWindows()
18 | #exit()
19 | cv2.imwrite(image, img)
20 |
--------------------------------------------------------------------------------
/utils/extract_frames.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Script for extracting frames from multiple video files
4 | # If you have a folder filled with videos like so...
5 | #
6 | # folder/
7 | # - video_1.mp4
8 | # - video_2.mp4
9 | # - ...
10 |
11 | # you can simply run the script by doing `./extract_frames.sh folder/
12 | # and it will create an individual folder for each video and place the
13 | # images in their respective folders.
14 |
15 | for file in "$1"/*.*; do
16 | destination="${file%.*}"
17 | echo "Extracting from $file..."
18 | mkdir -p "$destination"
19 | ffmpeg -i "$file" -r 1/1 "$destination/image_%03d.png"
20 | done
21 |
22 | echo ""
23 | echo "Deleting images that are less than 40kb..."
24 | find "$1" -name "*.png" -size -40k -delete
25 |
--------------------------------------------------------------------------------
/utils/convert_images.py:
--------------------------------------------------------------------------------
1 | '''
2 | Cameron Fabbri
3 |
4 | Script to resize and convert images to grayscale for
5 | reading in tensorflow. The new image size can be set,
6 | but this purpose is (160x144) for gameboy color res
7 |
8 | This script will NOT overwrite the original images,
9 | but create two new images from every image read in.
10 |
11 | image_1.png
12 | image_1_resized.png
13 | image_1_resized_gray.png
14 |
15 | '''
16 |
17 | import sys
18 | import fnmatch
19 | import os
20 | from tqdm import tqdm
21 |
22 | if __name__ == '__main__':
23 |
24 | data_dir = sys.argv[1]
25 | pattern = "*.png"
26 | image_list = list()
27 | for d, s, fList in os.walk(data_dir):
28 | for filename in fList:
29 | if fnmatch.fnmatch(filename, pattern):
30 | image_list.append(os.path.join(d,filename))
31 |
32 | print 'Working on images...'
33 | for image in tqdm(image_list):
34 | image_dir = os.path.dirname(image)
35 | resized_image = image_dir+'/'+image.split('/')[-1].split('.')[0]+'_resized.png'
36 | resized_gray_image = image_dir+'/'+image.split('/')[-1].split('.')[0]+'_resized_gray.png'
37 |
38 | # the 'true' image to be used in tensorflow (label)
39 | os.system('convert "' + image + '" -resize 160x144\! "' + resized_image +'"')
40 |
41 | # the gray image that we are training on
42 | os.system('convert "' + image + '" -resize 160x144\! -colorspace Gray "' + resized_gray_image +'"')
43 |
44 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import cPickle as pickle
2 | from tqdm import tqdm
3 | import tensorflow as tf
4 | import colorarch
5 | from scipy import misc
6 | import numpy as np
7 | import argparse
8 | import ntpath
9 | import sys
10 | import os
11 | import time
12 | import glob
13 |
14 | sys.path.insert(0, 'ops/')
15 | sys.path.insert(0, 'config/')
16 |
17 | import data_ops
18 |
19 | if __name__ == '__main__':
20 |
21 | CHECKPOINT_DIR = 'checkpoints/'
22 | IMAGES_DIR = CHECKPOINT_DIR+'images/'
23 | BATCH_SIZE=1
24 |
25 | test_images = glob.glob(sys.argv[1]+'*.*')
26 | num_images = len(test_images)
27 |
28 | Data = data_ops.loadData(sys.argv[1], BATCH_SIZE, train=False)
29 | # The gray 'lightness' channel in range [-1, 1]
30 | gray_image = Data.inputs
31 |
32 | # The color channels in [-1, 1] range
33 | color_image = Data.targets
34 |
35 | # architecture from
36 | # http://hi.cs.waseda.ac.jp/~iizuka/projects/colorization/data/colorization_sig2016.pdf
37 | col_img = colorarch.architecture(gray_image, train=False)
38 |
39 | col_img = tf.image.convert_image_dtype(col_img, dtype=tf.uint8, saturate=True)
40 |
41 | saver = tf.train.Saver(max_to_keep=1)
42 |
43 | init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
44 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
45 | sess.run(init)
46 |
47 | ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
48 | # restore previous model if there is one
49 | if ckpt and ckpt.model_checkpoint_path:
50 | print "Restoring previous model..."
51 | try:
52 | saver.restore(sess, ckpt.model_checkpoint_path)
53 | print "Model restored"
54 | except:
55 | print "Could not restore model"
56 | pass
57 |
58 | ########################################### training portion
59 | start = time.time()
60 | coord = tf.train.Coordinator()
61 | threads = tf.train.start_queue_runners(sess, coord=coord)
62 |
63 | prediction = np.squeeze(np.asarray(sess.run(col_img)))
64 | i = 1
65 | misc.imsave(IMAGES_DIR+str(i)+'.png', prediction)
66 | i+=1
67 | print 'Done. Images are in',IMAGES_DIR
68 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Colorizing Images
2 |
3 | **UPDATE - Completely cleaning up code for Tensorflow 1.0 and retraining models.**
4 |
5 | A deep learning approach to colorizing images, specifically for Pokemon.
6 |
7 | The current model was trained on screenshots taken from Pokemon Silver, Crystal,
8 | and Diamond, then tested on Pokemon Blue Version. Sample results below.
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## Basic Training Usage
17 | The files in the `images/train` folder are as follows:
18 |
19 | ## Evaluating on Images
20 | I've included a trained model in the `models/` directory that you can run your own images on.
21 | You can either run the model on one image or a folder of images. For one image, run `eval_one.py`
22 | and pass it the model and the image as parameters. To run it on multiple images, run `eval.py`
23 | and pass it the model and the folder to the images. `eval.py` will save your images in the
24 | `output` folder, where as `eval_one.py` will save them in the current directory. Examples:
25 |
26 | ## Training your own data
27 |
28 | There are scripts included to help create your own dataset, which is desirable because
29 | the amount of data needed to obtain good results is a good amount. The results above
30 | were trained on about 50,000 images.
31 |
32 | The easiest method to obtain images is to extract them from Youtube walkthrough videos of
33 | different games. Given that you have a folder with videos
34 |
35 | `videos/`
36 |
37 | `video_1.mp4`
38 |
39 | `video_2.mp4`
40 |
41 | `...`
42 |
43 |
44 | use `extract_frames.sh` to extract images from each video. Just pass it the folder containing images.
45 |
46 | Depending on if the video had a border around the game, you may need to use `crop_images.py` to crop
47 | out the border. There are comments in the script you can uncomment to view the image before it crops
48 | all of them to be sure the cropping is correct.
49 |
--------------------------------------------------------------------------------
/net.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import sys
3 |
4 | '''
5 | Leaky RELU
6 | https://arxiv.org/pdf/1502.01852.pdf
7 | '''
8 | def lrelu(x, leak=0.2, name='lrelu'):
9 | return tf.maximum(leak*x, x)
10 |
11 | def architecture(gray_image, train=True):
12 | conv1 = lrelu(tf.layers.conv2d(gray_image, 32, 1, strides=1, name='conv1',padding='VALID'))
13 | print 'conv1:',conv1
14 | conv2 = lrelu(tf.layers.conv2d(conv1, 32, 1, strides=1, name='conv2',padding='VALID'))
15 | print 'conv2:',conv2
16 | conv3 = lrelu(tf.layers.conv2d(conv2, 64, 1, strides=1, name='conv3',padding='VALID'))
17 | print 'conv3:',conv3
18 | conv4 = lrelu(tf.layers.conv2d(conv3, 64, 1, strides=1, name='conv4',padding='VALID'))
19 | print 'conv4:',conv4
20 | conv5 = lrelu(tf.layers.conv2d(conv4, 128, 1, strides=1, name='conv5',padding='VALID'))
21 | print 'conv5:',conv5
22 | conv6 = lrelu(tf.layers.conv2d(conv5, 128, 1, strides=1, name='conv6',padding='VALID'))
23 | print 'conv6:',conv6
24 | conv7 = lrelu(tf.layers.conv2d(conv6, 256, 1, strides=1, name='conv7',padding='VALID'))
25 | print 'conv7:',conv7
26 | conv8 = lrelu(tf.layers.conv2d(conv7, 256, 1, strides=1, name='conv8',padding='VALID'))
27 | print 'conv8:',conv8
28 | conv9 = lrelu(tf.layers.conv2d(conv8, 128, 1, strides=1, name='conv9',padding='VALID'))
29 | print 'conv9:',conv9
30 | conv10 = lrelu(tf.layers.conv2d(conv9, 128, 1, strides=1, name='conv10',padding='VALID'))
31 | print 'conv10:',conv10
32 | conv11 = lrelu(tf.layers.conv2d(conv10, 64, 1, strides=1, name='conv11',padding='VALID'))
33 | print 'conv11:',conv11
34 | conv12 = lrelu(tf.layers.conv2d(conv11, 64, 1, strides=1, name='conv12',padding='VALID'))
35 | print 'conv12:',conv12
36 | conv13 = lrelu(tf.layers.conv2d(conv12, 32, 1, strides=1, name='conv13',padding='VALID'))
37 | print 'conv13:',conv13
38 | conv14 = lrelu(tf.layers.conv2d(conv13, 32, 1, strides=1, name='conv14',padding='VALID'))
39 | print 'conv14:',conv14
40 | conv15 = lrelu(tf.layers.conv2d(conv14, 16, 1, strides=1, name='conv15',padding='VALID'))
41 | print 'conv15:',conv15
42 | conv16 = lrelu(tf.layers.conv2d(conv15, 16, 1, strides=1, name='conv16',padding='VALID'))
43 | print 'conv16:',conv16
44 | conv17 = lrelu(tf.layers.conv2d(conv16, 8, 1, strides=1, name='conv17',padding='VALID'))
45 | print 'conv17:',conv17
46 | if train: conv17 = tf.nn.dropout(conv17, 0.8)
47 | conv18 = lrelu(tf.layers.conv2d(conv17, 3, 1, strides=1, name='conv18',padding='VALID'))
48 | if train: conv18 = tf.nn.dropout(conv18, 0.8)
49 |
50 | tf.add_to_collection('vars', conv1)
51 | tf.add_to_collection('vars', conv2)
52 | tf.add_to_collection('vars', conv3)
53 | tf.add_to_collection('vars', conv4)
54 | tf.add_to_collection('vars', conv5)
55 | tf.add_to_collection('vars', conv6)
56 | tf.add_to_collection('vars', conv7)
57 | tf.add_to_collection('vars', conv8)
58 | tf.add_to_collection('vars', conv9)
59 | tf.add_to_collection('vars', conv10)
60 | tf.add_to_collection('vars', conv11)
61 | tf.add_to_collection('vars', conv12)
62 | tf.add_to_collection('vars', conv13)
63 | tf.add_to_collection('vars', conv14)
64 | tf.add_to_collection('vars', conv15)
65 | tf.add_to_collection('vars', conv16)
66 | tf.add_to_collection('vars', conv17)
67 | tf.add_to_collection('vars', conv18)
68 |
69 | return conv18
70 |
--------------------------------------------------------------------------------
/data_ops.py:
--------------------------------------------------------------------------------
1 | '''
2 |
3 | Operations used for data management
4 |
5 | MASSIVE help from https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py
6 |
7 | '''
8 |
9 | from __future__ import division
10 | from __future__ import absolute_import
11 |
12 | from scipy import misc
13 | from skimage import color
14 | import collections
15 | import tensorflow as tf
16 | import numpy as np
17 | import math
18 | import time
19 | import random
20 | import glob
21 | import os
22 | import fnmatch
23 | import cPickle as pickle
24 |
25 | Data = collections.namedtuple('trainData', 'paths, inputs, targets, count, steps_per_epoch')
26 |
27 |
28 | def getPaths(data_dir, ext='jpg'):
29 | pattern = '*.'+ext
30 | image_paths = []
31 | for d, s, fList in os.walk(data_dir):
32 | for filename in fList:
33 | if fnmatch.fnmatch(filename, pattern):
34 | image_paths.append(os.path.join(d,filename))
35 | return image_paths
36 |
37 |
38 | def loadData(data_dir, batch_size, train=True):
39 |
40 | if data_dir is None or not os.path.exists(data_dir): raise Exception('data_dir does not exist')
41 |
42 | if train:
43 | pkl_train_file = 'pokemon.pkl'
44 |
45 | if os.path.isfile(pkl_train_file):
46 | print 'Found pickle file'
47 | train_paths = pickle.load(open(pkl_train_file, 'rb'))
48 | else:
49 | train_paths = getPaths(data_dir)
50 | random.shuffle(train_paths)
51 |
52 | pf = open(pkl_train_file, 'wb')
53 | data = pickle.dumps(train_paths)
54 | pf.write(data)
55 | pf.close()
56 | input_paths = train_paths
57 |
58 | else:
59 | input_paths = [data_dir]
60 |
61 | decode = tf.image.decode_image
62 |
63 | if len(input_paths) == 0: raise Exception('data_dir contains no image files')
64 | else: print 'Found',len(input_paths),'images!'
65 |
66 | with tf.name_scope('load_images'):
67 | path_queue = tf.train.string_input_producer(input_paths, shuffle=train)
68 | reader = tf.WholeFileReader()
69 | paths, contents = reader.read(path_queue)
70 | raw_input_ = decode(contents)
71 | raw_input_ = tf.image.convert_image_dtype(raw_input_, dtype=tf.float32)
72 |
73 | raw_input_.set_shape([None, None, 3])
74 |
75 | inputs = tf.image.rgb_to_grayscale(raw_input_)
76 | targets = raw_input_
77 |
78 | scale_size = 180
79 | height = 160
80 | width = 144
81 |
82 | seed = random.randint(0, 2**31 - 1)
83 | def transform(image):
84 | r = image
85 | r = tf.image.random_flip_left_right(r, seed=seed)
86 | r = tf.image.resize_images(r, [height, width], method=tf.image.ResizeMethod.AREA)
87 | #offset = tf.cast(tf.floor(tf.random_uniform([2], 0, scale_size - width + 1, seed=seed)), dtype=tf.int32)
88 | #r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], height, width)
89 | return r
90 |
91 | if train:
92 | input_images = transform(inputs)
93 | target_images = transform(targets)
94 | else:
95 | input_images = tf.image.resize_images(inputs, [160, 160], method=tf.image.ResizeMethod.AREA)
96 | target_images = tf.image.resize_images(targets, [160, 160], method=tf.image.ResizeMethod.AREA)
97 |
98 | paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=batch_size)
99 | steps_per_epoch = int(math.ceil(len(input_paths) / batch_size))
100 |
101 | return Data(
102 | paths=paths_batch,
103 | inputs=inputs_batch,
104 | targets=targets_batch,
105 | count=len(input_paths),
106 | steps_per_epoch=steps_per_epoch,
107 | )
108 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import cPickle as pickle
2 | import tensorflow as tf
3 | from scipy import misc
4 | import numpy as np
5 | import argparse
6 | import ntpath
7 | import sys
8 | import os
9 | import time
10 |
11 | import data_ops
12 | import net
13 |
14 | if __name__ == '__main__':
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--EPOCHS', required=False,default=10,type=int,help='Number of epochs to train for')
18 | parser.add_argument('--DATA_DIR', required=True,help='Directory where data is')
19 | parser.add_argument('--BATCH_SIZE', required=False,type=int,default=32,help='Batch size to use')
20 | a = parser.parse_args()
21 |
22 | EPOCHS = a.EPOCHS
23 | DATA_DIR = a.DATA_DIR
24 | BATCH_SIZE = a.BATCH_SIZE
25 |
26 | CHECKPOINT_DIR = 'checkpoints/'
27 | IMAGES_DIR = CHECKPOINT_DIR+'images/'
28 |
29 | try: os.mkdir(CHECKPOINT_DIR)
30 | except: pass
31 | try: os.mkdir(IMAGES_DIR)
32 | except: pass
33 |
34 | # write all this info to a pickle file in the experiments directory
35 | exp_info = dict()
36 | exp_info['EPOCHS'] = EPOCHS
37 | exp_info['DATA_DIR'] = DATA_DIR
38 | exp_info['BATCH_SIZE'] = BATCH_SIZE
39 | exp_pkl = open(CHECKPOINT_DIR+'info.pkl', 'wb')
40 | data = pickle.dumps(exp_info)
41 | exp_pkl.write(data)
42 | exp_pkl.close()
43 |
44 | print
45 | print 'EPOCHS: ',EPOCHS
46 | print 'DATA_DIR: ',DATA_DIR
47 | print 'BATCH_SIZE: ',BATCH_SIZE
48 | print
49 |
50 | # global step that is saved with a model to keep track of how many steps/epochs
51 | global_step = tf.Variable(0, name='global_step', trainable=False)
52 |
53 | # load data
54 | Data = data_ops.loadData(DATA_DIR, BATCH_SIZE)
55 |
56 | num_train = Data.count
57 | gray_image = Data.inputs
58 | color_image = Data.targets
59 |
60 | # architecture from
61 | col_img = net.architecture(gray_image)
62 |
63 | #loss = tf.reduce_mean((ab_image-col_img)**2)
64 | loss = tf.reduce_mean(tf.nn.l2_loss(color_image-col_img))
65 | train_op = tf.train.AdamOptimizer(learning_rate=1e-6).minimize(loss, global_step=global_step)
66 | saver = tf.train.Saver(max_to_keep=1)
67 |
68 | # tensorboard summaries
69 | tf.summary.scalar('loss', loss)
70 |
71 | init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
72 | sess = tf.Session()
73 | sess.run(init)
74 |
75 | # write out logs for tensorboard to the checkpointSdir
76 | summary_writer = tf.summary.FileWriter(CHECKPOINT_DIR+'/logs/', graph=tf.get_default_graph())
77 |
78 | ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
79 | # restore previous model if there is one
80 | if ckpt and ckpt.model_checkpoint_path:
81 | print "Restoring previous model..."
82 | try:
83 | saver.restore(sess, ckpt.model_checkpoint_path)
84 | print "Model restored"
85 | except:
86 | print "Could not restore model"
87 | pass
88 |
89 | ########################################### training portion
90 | step = sess.run(global_step)
91 | coord = tf.train.Coordinator()
92 | threads = tf.train.start_queue_runners(sess, coord=coord)
93 | merged_summary_op = tf.summary.merge_all()
94 | start = time.time()
95 |
96 | epoch_num = step/(num_train/BATCH_SIZE)
97 | while epoch_num < EPOCHS:
98 | epoch_num = step/(num_train/BATCH_SIZE)
99 | s = time.time()
100 | sess.run(train_op)
101 | loss_, summary = sess.run([loss, merged_summary_op])
102 | summary_writer.add_summary(summary, step)
103 | summary_writer.add_summary(summary, step)
104 | print 'epoch:',epoch_num,'step:',step,'loss:',loss_,'time:',time.time()-s
105 | step += 1
106 |
107 | if step%500 == 0:
108 | print 'Saving model...'
109 | saver.save(sess, CHECKPOINT_DIR+'checkpoint-'+str(step))
110 | saver.export_meta_graph(CHECKPOINT_DIR+'checkpoint-'+str(step)+'.meta')
111 | print 'Model saved\n'
112 |
113 | print 'Finished training', time.time()-start
114 | saver.save(sess, CHECKPOINT_DIR+'checkpoint-'+str(step))
115 | saver.export_meta_graph(CHECKPOINT_DIR+'checkpoint-'+str(step)+'.meta')
116 | exit()
117 |
--------------------------------------------------------------------------------