├── .gitignore ├── Procfile ├── README.md ├── data └── tfrecords │ ├── .gitignore │ └── .keep ├── eval ├── data │ └── tfrecords │ │ └── .gitignore ├── logdir │ └── .gitignore ├── misc │ ├── collage.py │ └── errors.py └── train_and_eval.py ├── logdir └── .gitignore ├── memo.txt ├── misc ├── clustering │ ├── .gitignore │ ├── calc.py │ ├── download.py │ ├── images │ │ └── .gitignore │ ├── outputs_from_images.py │ └── outputs_from_tfrecords.py ├── experiment │ └── distortion │ │ ├── .gitignore │ │ ├── face.png │ │ └── main.py ├── graph │ └── freeze.py └── projector │ ├── README.md │ ├── directory.py │ ├── images │ └── .gitignore │ ├── logdir │ └── .gitignore │ └── tfrecords.py ├── model.py ├── requirements.txt ├── runtime.txt ├── train.py └── web.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *,cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | 60 | # Others 61 | venv/ 62 | data/**/*.tfrecords 63 | data/**/labels.json 64 | eval/data/**/*.tfrecords 65 | eval/misc/out.jpg 66 | -------------------------------------------------------------------------------- /Procfile: -------------------------------------------------------------------------------- 1 | web: python web.py --port $PORT 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tf-face-recognizer 2 | 3 | Image classifier model for face recognition. 4 | 5 | 6 | ## Requirements ## 7 | 8 | - Python 9 | - `>= 3.5` 10 | 11 | - TensorFlow 12 | - `>= 0.12.0` 13 | -------------------------------------------------------------------------------- /data/tfrecords/.gitignore: -------------------------------------------------------------------------------- 1 | *.tfrecords 2 | labels.json 3 | -------------------------------------------------------------------------------- /data/tfrecords/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugyan/tf-face-recognizer/8f94d0228e21d584bbeaf1251bba0e62dbab9296/data/tfrecords/.keep -------------------------------------------------------------------------------- /eval/data/tfrecords/.gitignore: -------------------------------------------------------------------------------- 1 | data-*.tfrecords 2 | -------------------------------------------------------------------------------- /eval/logdir/.gitignore: -------------------------------------------------------------------------------- 1 | checkpoint 2 | model.ckpt-* 3 | events.out.tfevents.* 4 | -------------------------------------------------------------------------------- /eval/misc/collage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | FLAGS = tf.app.flags.FLAGS 5 | 6 | tf.app.flags.DEFINE_string('data_file', 7 | os.path.join(os.path.dirname(__file__), '..', 'data', 'tfrecords', 'data-00.tfrecords'), 8 | """Path to the TFRecord.""") 9 | 10 | 11 | def main(argv=None): 12 | example = tf.placeholder(tf.string) 13 | features = tf.parse_single_example(example, features={ 14 | 'image_raw': tf.FixedLenFeature([], tf.string) 15 | }) 16 | decode = tf.image.decode_jpeg(features['image_raw'], channels=3) 17 | 18 | size = 16 19 | with tf.Session() as sess: 20 | images = [] 21 | for record in tf.python_io.tf_record_iterator(FLAGS.data_file): 22 | images.append(sess.run(decode, feed_dict={example: record})) 23 | if len(images) >= size ** 2: 24 | break 25 | collage = tf.concat([tf.concat(images[i*size:(i+1)*size], 1) for i in range(size)], 0) 26 | image = sess.run(tf.image.encode_jpeg(collage)) 27 | with open(os.path.join(os.path.dirname(__file__), 'out.jpg'), 'wb') as f: 28 | f.write(image) 29 | 30 | 31 | if __name__ == '__main__': 32 | tf.app.run() 33 | -------------------------------------------------------------------------------- /eval/misc/errors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | 5 | FLAGS = tf.app.flags.FLAGS 6 | 7 | tf.app.flags.DEFINE_string('data_file', 8 | os.path.join(os.path.dirname(__file__), '..', 'data', 'tfrecords', 'data-00.tfrecords'), 9 | """Path to the TFRecord.""") 10 | tf.app.flags.DEFINE_string('checkpoint_path', 11 | os.path.join(os.path.dirname(__file__), '..', 'logdir', 'model.ckpt-80000'), 12 | """Path to read model checkpoint.""") 13 | tf.app.flags.DEFINE_integer('num_classes', 120, 14 | """Number of classes.""") 15 | 16 | 17 | def main(argv=None): 18 | example = tf.placeholder(tf.string) 19 | features = tf.parse_single_example(example, features={ 20 | 'label': tf.FixedLenFeature([], tf.int64), 21 | 'image_raw': tf.FixedLenFeature([], tf.string) 22 | }) 23 | decode = tf.image.decode_jpeg(features['image_raw'], channels=3) 24 | image1 = tf.image.resize_image_with_crop_or_pad(decode, 88, 88) 25 | image1 = tf.image.resize_images(image1, [96, 96]) 26 | image2 = tf.image.resize_image_with_crop_or_pad(decode, 96, 96) 27 | image2 = tf.image.resize_images(image2, [96, 96]) 28 | image3 = tf.image.resize_image_with_crop_or_pad(decode, 104, 104) 29 | image3 = tf.image.resize_images(image3, [96, 96]) 30 | image4 = tf.image.resize_image_with_crop_or_pad(decode, 100, 100) 31 | image4 = tf.image.resize_images(image4, [96, 96]) 32 | image5 = tf.image.resize_image_with_crop_or_pad(decode, 92, 92) 33 | image5 = tf.image.resize_images(image5, [96, 96]) 34 | inputs = tf.stack([ 35 | tf.image.per_image_standardization(image1), 36 | tf.image.per_image_standardization(image2), 37 | tf.image.per_image_standardization(image3), 38 | tf.image.per_image_standardization(image4), 39 | tf.image.per_image_standardization(image5) 40 | ]) 41 | # inputs = tf.expand_dims(tf.image.per_image_standardization(image), axis=0) 42 | labels = tf.expand_dims(features['label'], axis=0) 43 | 44 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 45 | import model 46 | logits = tf.nn.softmax(model.inference(inputs, FLAGS.num_classes)) 47 | values, indices = tf.nn.top_k(logits) 48 | _, top_value = tf.nn.top_k(tf.transpose(values)) 49 | answer = tf.gather(indices, tf.squeeze(top_value)) 50 | # variable_averages = tf.train.ExponentialMovingAverage(model.MOVING_AVERAGE_DECAY) 51 | 52 | with tf.Session() as sess: 53 | tf.train.Saver(tf.trainable_variables()).restore(sess, FLAGS.checkpoint_path) 54 | # tf.train.Saver(variable_averages.variables_to_restore()).restore(sess, FLAGS.checkpoint_path) 55 | 56 | ok, ng = 0, 0 57 | errors = {} 58 | error_images = [] 59 | for i, record in enumerate(tf.python_io.tf_record_iterator(FLAGS.data_file)): 60 | labels_value, answer_value, image_value = sess.run([labels, answer, image2], feed_dict={example: record}) 61 | if labels_value[0] == answer_value[0]: 62 | ok += 1 63 | else: 64 | print('{:04d}: {:3d} - {:3d}'.format(i, answer_value[0], labels_value[0])) 65 | if labels_value[0] not in errors: 66 | errors[labels_value[0]] = 0 67 | errors[labels_value[0]] += 1 68 | ng += 1 69 | if len(error_images) < 100: 70 | error_images.append(image_value) 71 | size = 10 72 | collage = tf.concat([tf.concat(error_images[i*size:(i+1)*size], 1) for i in range(size)], 0) 73 | # collage = tf.image.convert_image_dtype(tf.div(collage, 255.0), tf.uint8) 74 | with open(os.path.join(os.path.dirname(__file__), 'errors.jpg'), 'wb') as f: 75 | f.write(sess.run(tf.image.encode_jpeg(collage))) 76 | print('{}/{} ({:.3f} %)'.format(ok, ok + ng, 100.0 * ok / (ok + ng))) 77 | print('errors:') 78 | for k, v in sorted(errors.items(), key=lambda x: x[1], reverse=True): 79 | if v >= 5: 80 | print('{:3d}: {:3d}'.format(k, v)) 81 | 82 | 83 | if __name__ == '__main__': 84 | tf.app.run() 85 | -------------------------------------------------------------------------------- /eval/train_and_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import sys 4 | import time 5 | from datetime import datetime 6 | 7 | import tensorflow as tf 8 | 9 | FLAGS = tf.app.flags.FLAGS 10 | 11 | tf.app.flags.DEFINE_string('datadir', os.path.join(os.path.dirname(__file__), 'data', 'tfrecords'), 12 | """Path to the TFRecord data directory.""") 13 | tf.app.flags.DEFINE_string('eval_file', 14 | os.path.join(os.path.dirname(__file__), 'data', 'tfrecords', 'data-00.tfrecords'), 15 | """Path to the TFRecord for evaluation.""") 16 | tf.app.flags.DEFINE_string('logdir', 17 | os.path.join(os.path.dirname(__file__), 'logdir'), 18 | """Directory where to write event logs and checkpoint.""") 19 | tf.app.flags.DEFINE_integer('num_examples_per_epoch_for_train', 19200, 20 | 'Number of examples for train') 21 | tf.app.flags.DEFINE_integer('num_examples_per_epoch_for_eval', 4800, 22 | 'Number of examples for evaluation') 23 | tf.app.flags.DEFINE_integer('max_steps', 20001, 24 | """Number of batches to run.""") 25 | tf.app.flags.DEFINE_integer('num_classes', 120, 26 | """Number of classes.""") 27 | 28 | 29 | def distorted_inputs(filenames, distortion=0, batch_size=128): 30 | fqueue = tf.train.string_input_producer(filenames) 31 | reader = tf.TFRecordReader() 32 | _, value = reader.read(fqueue) 33 | features = tf.parse_single_example(value, features={ 34 | 'label': tf.FixedLenFeature([], tf.int64), 35 | 'image_raw': tf.FixedLenFeature([], tf.string), 36 | }) 37 | label = features['label'] 38 | image = tf.image.decode_jpeg(features['image_raw'], channels=3) 39 | 40 | if distortion == 0: 41 | image = tf.random_crop(image, [96, 96, 3]) 42 | if distortion == 1: 43 | bounding_boxes = tf.div(tf.constant([[[8, 8, 104, 104]]], dtype=tf.float32), 112.0) 44 | begin, size, _ = tf.image.sample_distorted_bounding_box( 45 | tf.shape(image), bounding_boxes, 46 | min_object_covered=(80.0*80.0)/(96.0*96.0), 47 | aspect_ratio_range=[9.0/10.0, 10.0/9.0]) 48 | image = tf.slice(image, begin, size) 49 | image = tf.image.resize_images(image, [96, 96]) 50 | # common distortion 51 | image = tf.image.random_flip_left_right(image) 52 | image = tf.image.random_brightness(image, max_delta=0.4) 53 | image = tf.image.random_contrast(image, lower=0.6, upper=1.4) 54 | image = tf.image.random_hue(image, max_delta=0.04) 55 | image = tf.image.random_saturation(image, lower=0.6, upper=1.4) 56 | image = tf.image.per_image_standardization(image) 57 | 58 | min_fraction_of_examples_in_queue = 0.4 59 | min_queue_examples = int(FLAGS.num_examples_per_epoch_for_train * min_fraction_of_examples_in_queue) 60 | images, labels = tf.train.shuffle_batch( 61 | [image, label], batch_size, min_queue_examples + 3 * batch_size, min_queue_examples) 62 | tf.summary.image('disotrted_inputs', images, max_outputs=16) 63 | return images, labels 64 | 65 | 66 | def inputs(filename, batch_size=100): 67 | fqueue = tf.train.string_input_producer([filename]) 68 | reader = tf.TFRecordReader() 69 | _, value = reader.read(fqueue) 70 | features = tf.parse_single_example(value, features={ 71 | 'label': tf.FixedLenFeature([], tf.int64), 72 | 'image_raw': tf.FixedLenFeature([], tf.string), 73 | }) 74 | label = features['label'] 75 | image = tf.image.decode_jpeg(features['image_raw'], channels=3) 76 | image = tf.image.resize_image_with_crop_or_pad(image, 96, 96) 77 | image = tf.image.per_image_standardization(image) 78 | 79 | images, labels = tf.train.batch([image, label], batch_size) 80 | tf.summary.image('inputs', images, max_outputs=16) 81 | return images, labels 82 | 83 | 84 | def main(argv=None): 85 | filenames = [] 86 | for f in [x for x in os.listdir(FLAGS.datadir) if x.endswith('.tfrecords')]: 87 | filepath = os.path.join(FLAGS.datadir, f) 88 | if filepath != FLAGS.eval_file: 89 | filenames.append(filepath) 90 | t_images, t_labels = distorted_inputs(filenames, distortion=1) 91 | e_images, e_labels = inputs(FLAGS.eval_file) 92 | 93 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 94 | import model 95 | t_logits = model.inference(t_images, FLAGS.num_classes, reuse=False) 96 | e_logits = model.inference(e_images, FLAGS.num_classes, reuse=True) 97 | # train ops 98 | losses = model.loss(t_logits, t_labels) 99 | train_op = model.train(losses) 100 | is_nan = tf.is_nan(losses) 101 | # eval ops, variables 102 | e_batch_size = int(e_logits.get_shape()[0]) 103 | num_iter = int(math.ceil(1.0 * FLAGS.num_examples_per_epoch_for_eval / e_batch_size)) 104 | true_count_op = tf.reduce_sum(tf.train.batch([tf.count_nonzero(tf.nn.in_top_k(e_logits, e_labels, 1))], num_iter)) 105 | total_count = num_iter * e_batch_size 106 | 107 | # summary 108 | summary_op = tf.summary.merge_all() 109 | 110 | with tf.Session() as sess: 111 | summary_writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph) 112 | # initialize (and restore) variables 113 | sess.run(tf.global_variables_initializer()) 114 | ckpt = tf.train.get_checkpoint_state(FLAGS.logdir) 115 | if ckpt and ckpt.model_checkpoint_path: 116 | print('restore variables from {}.'.format(ckpt.model_checkpoint_path)) 117 | tf.train.Saver(tf.trainable_variables()).restore(sess, ckpt.model_checkpoint_path) 118 | # checkpoint saver 119 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=21) 120 | 121 | coord = tf.train.Coordinator() 122 | threads = tf.train.start_queue_runners(coord=coord) 123 | 124 | for step in range(FLAGS.max_steps): 125 | start_time = time.time() 126 | _, loss_value, is_nan_value = sess.run([train_op, losses, is_nan]) 127 | duration = time.time() - start_time 128 | 129 | assert not is_nan_value, 'Model diverged with loss = NaN' 130 | 131 | print('{}: step {:05d}, loss = {:.5f} ({:.3f} sec/batch)'.format( 132 | datetime.now(), step, loss_value, duration)) 133 | 134 | if step % 500 == 0: 135 | true_count = sess.run(true_count_op) 136 | precision = 100.0 * true_count / total_count 137 | print('{}: precision = {:.3f} %'.format(datetime.now(), precision)) 138 | # write summary 139 | summary = tf.Summary() 140 | summary.ParseFromString(sess.run(summary_op)) 141 | summary.value.add(tag='precision', simple_value=precision) 142 | summary_writer.add_summary(summary, global_step=step) 143 | if step % 1000 == 0: 144 | checkpoint_path = os.path.join(FLAGS.logdir, 'model.ckpt') 145 | saver.save(sess, checkpoint_path, global_step=step) 146 | 147 | coord.request_stop() 148 | coord.join(threads) 149 | 150 | 151 | if __name__ == '__main__': 152 | tf.app.run() 153 | -------------------------------------------------------------------------------- /logdir/.gitignore: -------------------------------------------------------------------------------- 1 | * -------------------------------------------------------------------------------- /memo.txt: -------------------------------------------------------------------------------- 1 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py -------------------------------------------------------------------------------- /misc/clustering/.gitignore: -------------------------------------------------------------------------------- 1 | fc5.csv 2 | fc6.csv 3 | -------------------------------------------------------------------------------- /misc/clustering/calc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | from scipy.spatial.distance import pdist 4 | from scipy.cluster.hierarchy import linkage 5 | 6 | if len(sys.argv) < 2: 7 | print('usage: {} '.format(sys.argv[0])) 8 | sys.exit() 9 | target = sys.argv[1] 10 | df = pd.read_csv(target, header=None, dtype={0: str}).set_index(0) 11 | row_clusters = linkage(pdist(df, metric='euclidean'), method='complete') 12 | 13 | for row in row_clusters[:30]: 14 | if row[3] > 2: 15 | break 16 | print(df.index[int(row[0])], df.index[int(row[1])], row[2]) 17 | -------------------------------------------------------------------------------- /misc/clustering/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from urllib.request import urlopen, urlretrieve, Request 4 | 5 | url = os.environ['API_ENDPOINT'] + '/faces/random.json' 6 | auth_headers = { 7 | 'X-User-Email': os.environ['API_AUTH_EMAIL'], 8 | 'X-User-Token': os.environ['API_AUTH_TOKEN'], 9 | } 10 | req = Request(url, None, auth_headers) 11 | for i in range(100): 12 | data = json.loads(urlopen(req).read().decode()) 13 | filename, _ = urlretrieve(data['image_url'], os.path.join(os.path.dirname(__file__), 'images', '%07d.jpg' % data['id'])) 14 | print(filename) 15 | -------------------------------------------------------------------------------- /misc/clustering/images/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /misc/clustering/outputs_from_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 4 | 5 | import json 6 | import tensorflow as tf 7 | from model.recognizer import Recognizer 8 | 9 | FLAGS = tf.app.flags.FLAGS 10 | tf.app.flags.DEFINE_string('checkpoint_path', '/tmp/model.ckpt', 11 | """Path to read model checkpoints.""") 12 | tf.app.flags.DEFINE_string('imgdir', os.path.join(os.path.dirname(__file__), 'images'), 13 | """Path to the images directory.""") 14 | tf.app.flags.DEFINE_integer('input_size', 96, 15 | """Size of input image""") 16 | 17 | 18 | def main(argv=None): 19 | r = Recognizer(batch_size=1) 20 | data = tf.placeholder(tf.string) 21 | image = tf.image.decode_jpeg(data, channels=3) 22 | image = tf.image.resize_image_with_crop_or_pad(image, FLAGS.input_size, FLAGS.input_size) 23 | image = tf.image.per_image_standardization(image) 24 | inputs = tf.expand_dims(image, axis=0) 25 | r.inference(inputs, 0) 26 | fc5 = tf.get_default_graph().get_tensor_by_name('fc5/fc5:0') 27 | fc6 = tf.get_default_graph().get_tensor_by_name('fc6/fc6:0') 28 | 29 | with tf.Session() as sess: 30 | variable_averages = tf.train.ExponentialMovingAverage(r.MOVING_AVERAGE_DECAY) 31 | variables_to_restore = variable_averages.variables_to_restore() 32 | for name, v in variables_to_restore.items(): 33 | try: 34 | tf.train.Saver([v]).restore(sess, FLAGS.checkpoint_path) 35 | except Exception: 36 | print('initialize %s' % name) 37 | sess.run(tf.variables_initializer([v])) 38 | 39 | outputs = {} 40 | dirname = os.path.join(os.path.dirname(__file__), 'images') 41 | for filename in os.listdir(dirname): 42 | if not filename.endswith('.jpg'): 43 | continue 44 | with open(os.path.join(dirname, filename), 'rb') as f: 45 | results = sess.run({'fc5': fc5, 'fc6': fc6}, feed_dict={data: f.read()}) 46 | outputs[filename] = { 47 | 'fc5': results['fc5'].flatten().tolist(), 48 | 'fc6': results['fc6'].flatten().tolist(), 49 | } 50 | for out in ['fc5', 'fc6']: 51 | filename = os.path.join(os.path.dirname(__file__), '%s.csv' % out) 52 | with open(filename, 'w') as f: 53 | for name, values in outputs.items(): 54 | f.write(','.join([name] + [str(x) for x in values[out]]) + '\n') 55 | 56 | 57 | if __name__ == '__main__': 58 | tf.app.run() 59 | -------------------------------------------------------------------------------- /misc/clustering/outputs_from_tfrecords.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | 5 | FLAGS = tf.app.flags.FLAGS 6 | tf.app.flags.DEFINE_string('model_path', '/tmp/model.pb', 7 | """Directory where to read model data.""") 8 | tf.app.flags.DEFINE_string('file', 'data.tfrecords', 9 | """Path to the tfrecord file.""") 10 | 11 | 12 | def main(argv=None): 13 | if not os.path.isfile(FLAGS.model_path): 14 | print('No model data file found: {}'.format(FLAGS.model_path)) 15 | sys.exit() 16 | # load graph 17 | graph_def = tf.GraphDef() 18 | with tf.gfile.FastGFile(FLAGS.model_path, 'rb') as f: 19 | graph_def.ParseFromString(f.read()) 20 | tf.import_graph_def(graph_def, name='') 21 | # fc5, fc6 22 | graph = tf.get_default_graph() 23 | fc5 = graph.get_tensor_by_name('fc5/fc5:0') 24 | fc6 = graph.get_tensor_by_name('fc6/fc6:0') 25 | 26 | example = tf.placeholder(tf.string) 27 | features = tf.parse_single_example(example, features={ 28 | 'id': tf.FixedLenFeature([], tf.int64), 29 | 'image_raw': tf.FixedLenFeature([], tf.string) 30 | }) 31 | with tf.Session() as sess: 32 | # read records and run for getting outputs 33 | outputs = {} 34 | for i, record in enumerate(tf.python_io.tf_record_iterator(FLAGS.file)): 35 | print('processing {:04d}...'.format(i)) 36 | data = sess.run(features, feed_dict={example: record}) 37 | results = sess.run({'fc5': fc5, 'fc6': fc6}, feed_dict={'contents:0': data['image_raw']}) 38 | outputs[data['id']] = { 39 | 'fc5': results['fc5'].flatten().tolist(), 40 | 'fc6': results['fc6'].flatten().tolist(), 41 | } 42 | # write outputs to CSV 43 | for out in ['fc5', 'fc6']: 44 | filename = os.path.join(os.path.dirname(__file__), '{}.csv'.format(out)) 45 | with open(filename, 'w') as f: 46 | for name, values in outputs.items(): 47 | f.write(','.join(['{:07d}'.format(name)] + [str(x) for x in values[out]]) + '\n') 48 | 49 | 50 | if __name__ == '__main__': 51 | tf.app.run() 52 | -------------------------------------------------------------------------------- /misc/experiment/distortion/.gitignore: -------------------------------------------------------------------------------- 1 | out.png 2 | -------------------------------------------------------------------------------- /misc/experiment/distortion/face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugyan/tf-face-recognizer/8f94d0228e21d584bbeaf1251bba0e62dbab9296/misc/experiment/distortion/face.png -------------------------------------------------------------------------------- /misc/experiment/distortion/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | 5 | FLAGS = tf.app.flags.FLAGS 6 | 7 | tf.app.flags.DEFINE_integer('distortion', 0, 8 | """Distortion mode.""") 9 | 10 | 11 | def main(argv=None): 12 | with open(os.path.join(os.path.dirname(__file__), 'face.png'), 'rb') as f: 13 | png = f.read() 14 | image = tf.image.decode_png(png, channels=3) 15 | 16 | if FLAGS.distortion == 0: 17 | image = tf.to_float(tf.random_crop(image, [96, 96, 3])) 18 | else: 19 | bounding_boxes = tf.div(tf.constant([[[8, 8, 104, 104]]], dtype=tf.float32), 112.0) 20 | begin, size, _ = tf.image.sample_distorted_bounding_box( 21 | tf.shape(image), bounding_boxes, 22 | min_object_covered=(80.0*80.0)/(96.0*96.0), 23 | aspect_ratio_range=[9.0/10.0, 10.0/9.0]) 24 | image = tf.slice(image, begin, size) 25 | image = tf.image.resize_images(image, [96, 96]) 26 | image = tf.image.random_brightness(image, max_delta=0.4) 27 | image = tf.image.random_contrast(image, lower=0.6, upper=1.4) 28 | image = tf.image.random_hue(image, max_delta=0.04) 29 | image = tf.image.random_saturation(image, lower=0.6, upper=1.4) 30 | 31 | images = tf.unstack(tf.train.batch([image], 64)) 32 | montage = tf.concat([tf.concat(images[x*8:(x+1)*8], 1) for x in range(8)], 0) 33 | montage = tf.image.encode_jpeg(tf.image.convert_image_dtype(tf.div(montage, 255.0), tf.uint8, saturate=True)) 34 | 35 | with tf.Session() as sess: 36 | tf.train.start_queue_runners(sess=sess) 37 | 38 | with open(os.path.join(os.path.dirname(__file__), 'out.jpg'), 'wb') as f: 39 | f.write(sess.run(montage)) 40 | 41 | 42 | if __name__ == '__main__': 43 | tf.app.run() 44 | -------------------------------------------------------------------------------- /misc/graph/freeze.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.framework import graph_util 2 | import tensorflow as tf 3 | 4 | import json 5 | import os 6 | import sys 7 | 8 | FLAGS = tf.app.flags.FLAGS 9 | tf.app.flags.DEFINE_string('datadir', 'data/tfrecords', 10 | """Path to the TFRecord data directory.""") 11 | tf.app.flags.DEFINE_string('checkpoint_path', '/tmp/model.ckpt', 12 | """Path to read model checkpoint.""") 13 | tf.app.flags.DEFINE_string("output_graph", "", 14 | """Output 'GraphDef' file name.""") 15 | tf.app.flags.DEFINE_integer('input_size', 96, 16 | """Size of input image""") 17 | 18 | 19 | def main(argv=None): 20 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 21 | import model 22 | 23 | # read labels data 24 | labels_path = os.path.join(os.path.join(FLAGS.datadir, 'labels.json')) 25 | with open(labels_path, 'r') as f: 26 | labels = json.loads(f.read()) 27 | # create model graph 28 | contents = tf.placeholder(tf.string, name='contents') 29 | decoded = tf.image.decode_jpeg(contents, channels=3) 30 | resized = tf.image.resize_images(decoded, [FLAGS.input_size, FLAGS.input_size]) 31 | images = tf.expand_dims(tf.image.per_image_standardization(resized), 0) 32 | inferences = model.inference(images, len(labels) + 1) 33 | labels_str = tf.constant(json.dumps(labels), name='labels') 34 | # restore variables 35 | variable_averages = tf.train.ExponentialMovingAverage(model.MOVING_AVERAGE_DECAY) 36 | saver = tf.train.Saver(variable_averages.variables_to_restore()) 37 | with tf.Session() as sess: 38 | saver.restore(sess, FLAGS.checkpoint_path) 39 | node_names = [x.name.split(':')[0] for x in [inferences, labels_str]] 40 | output = graph_util.convert_variables_to_constants( 41 | sess, tf.get_default_graph().as_graph_def(), node_names) 42 | with open(FLAGS.output_graph, 'wb') as f: 43 | f.write(output.SerializeToString()) 44 | 45 | 46 | if __name__ == '__main__': 47 | tf.app.run() 48 | -------------------------------------------------------------------------------- /misc/projector/README.md: -------------------------------------------------------------------------------- 1 | ### Embedding Visualization 2 | 3 | https://www.tensorflow.org/versions/r0.12/how_tos/embedding_viz/index.html#tensorboard-embedding-visualization 4 | 5 | ``` 6 | $ python misc/embedding/main.py --input_file data/tfrecords/01.tfrecords 7 | $ tensorboard --logdir misc/embedding/logdir 8 | ``` 9 | -------------------------------------------------------------------------------- /misc/projector/directory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 4 | 5 | import math 6 | from model.recognizer import Recognizer 7 | from tensorflow.contrib.tensorboard.plugins import projector 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | FLAGS = tf.app.flags.FLAGS 12 | tf.app.flags.DEFINE_string('checkpoint_path', '/tmp/model.ckpt', 13 | """Path to model checkpoints.""") 14 | tf.app.flags.DEFINE_string('imgdir', os.path.join(os.path.dirname(__file__), 'images'), 15 | """Path to the images directory.""") 16 | tf.app.flags.DEFINE_string('logdir', os.path.join(os.path.dirname(__file__), 'logdir'), 17 | """Directory where to write checkpoints.""") 18 | 19 | 20 | def main(argv=None): 21 | if not os.path.exists(FLAGS.imgdir): 22 | raise Exception('%s does not exist' % FLAGS.imgdir) 23 | 24 | r = Recognizer(batch_size=1) 25 | data = tf.placeholder(tf.string) 26 | orig_image = tf.image.decode_jpeg(data, channels=3) 27 | image = tf.image.resize_image_with_crop_or_pad(orig_image, 96, 96) 28 | image = tf.image.per_image_standardization(image) 29 | r.inference(tf.expand_dims(image, axis=0), 1) 30 | fc5 = tf.get_default_graph().get_tensor_by_name('fc5/fc5:0') 31 | fc6 = tf.get_default_graph().get_tensor_by_name('fc6/fc6:0') 32 | with tf.Session() as sess: 33 | variable_averages = tf.train.ExponentialMovingAverage(r.MOVING_AVERAGE_DECAY) 34 | variables_to_restore = variable_averages.variables_to_restore() 35 | for name, v in variables_to_restore.items(): 36 | try: 37 | tf.train.Saver([v]).restore(sess, FLAGS.checkpoint_path) 38 | except Exception: 39 | print('initialize %s' % name) 40 | sess.run(tf.variables_initializer([v])) 41 | 42 | outputs = { 43 | 'fc5': [], 44 | 'fc6': [], 45 | 'images': [] 46 | } 47 | for file in os.listdir(FLAGS.imgdir): 48 | if not file.endswith('.jpg'): 49 | continue 50 | print('processing {}...'.format(file)) 51 | with open(os.path.join(FLAGS.imgdir, file), 'rb') as f: 52 | results = sess.run({ 53 | 'fc5': fc5, 54 | 'fc6': fc6, 55 | 'image': orig_image 56 | }, feed_dict={data: f.read()}) 57 | outputs['fc5'].append(results['fc5'].flatten().tolist()) 58 | outputs['fc6'].append(results['fc6'].flatten().tolist()) 59 | outputs['images'].append(results['image']) 60 | 61 | # write to sprite image file 62 | image_path = os.path.join(FLAGS.logdir, 'sprite.jpg') 63 | images = outputs['images'] 64 | rows = [] 65 | size = int(math.sqrt(len(images))) + 1 66 | while len(images) < size * size: 67 | images.append(np.zeros((112, 112, 3), dtype=np.uint8)) 68 | for i in range(size): 69 | rows.append(tf.concat(1, images[i*size:(i+1)*size])) 70 | jpeg = tf.image.encode_jpeg(tf.concat(0, rows)) 71 | with open(image_path, 'wb') as f: 72 | f.write(sess.run(jpeg)) 73 | # add embeding data 74 | targets = [ 75 | tf.Variable(np.stack(outputs['fc5']), name='fc5'), 76 | tf.Variable(np.stack(outputs['fc6']), name='fc6'), 77 | ] 78 | config = projector.ProjectorConfig() 79 | for v in targets: 80 | embedding = config.embeddings.add() 81 | embedding.tensor_name = v.name 82 | # embedding.metadata_path = metadata_path 83 | embedding.sprite.image_path = image_path 84 | embedding.sprite.single_image_dim.extend([112, 112]) 85 | sess.run(tf.variables_initializer(targets)) 86 | summary_writer = tf.summary.FileWriter(FLAGS.logdir) 87 | projector.visualize_embeddings(summary_writer, config) 88 | graph_saver = tf.train.Saver(targets) 89 | graph_saver.save(sess, os.path.join(FLAGS.logdir, 'model.ckpt')) 90 | 91 | 92 | if __name__ == '__main__': 93 | tf.app.run() 94 | -------------------------------------------------------------------------------- /misc/projector/images/.gitignore: -------------------------------------------------------------------------------- 1 | * -------------------------------------------------------------------------------- /misc/projector/logdir/.gitignore: -------------------------------------------------------------------------------- 1 | * -------------------------------------------------------------------------------- /misc/projector/tfrecords.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 4 | 5 | from model.recognizer import Recognizer 6 | from tensorflow.contrib.tensorboard.plugins import projector 7 | import tensorflow as tf 8 | 9 | FLAGS = tf.app.flags.FLAGS 10 | tf.app.flags.DEFINE_string('checkpoint_path', '/tmp/model.ckpt', 11 | """Path to model checkpoints.""") 12 | tf.app.flags.DEFINE_string('input_file', 'data.tfrecords', 13 | """Path to the TFRecord data.""") 14 | tf.app.flags.DEFINE_string('logdir', os.path.join(os.path.dirname(__file__), 'logdir'), 15 | """Directory where to write checkpoints.""") 16 | 17 | 18 | def inputs(files, batch_size=0): 19 | fqueue = tf.train.string_input_producer(files) 20 | reader = tf.TFRecordReader() 21 | key, value = reader.read(fqueue) 22 | features = tf.parse_single_example(value, features={ 23 | 'label': tf.FixedLenFeature([], tf.int64), 24 | 'image_raw': tf.FixedLenFeature([], tf.string), 25 | }) 26 | label = features['label'] 27 | image = tf.image.decode_jpeg(features['image_raw'], channels=3) 28 | image = tf.image.resize_image_with_crop_or_pad(image, 96, 96) 29 | return tf.train.batch( 30 | [tf.image.per_image_standardization(image), image, label], batch_size 31 | ) 32 | 33 | 34 | def main(argv=None): 35 | filepath = FLAGS.input_file 36 | if not os.path.exists(filepath): 37 | raise Exception('%s does not exist' % filepath) 38 | 39 | r = Recognizer(batch_size=900) 40 | input_images, orig_images, labels = inputs([filepath], batch_size=r.batch_size) 41 | r.inference(input_images, 1) 42 | fc5 = tf.get_default_graph().get_tensor_by_name('fc5/fc5:0') 43 | fc6 = tf.get_default_graph().get_tensor_by_name('fc6/fc6:0') 44 | with tf.Session() as sess: 45 | variable_averages = tf.train.ExponentialMovingAverage(r.MOVING_AVERAGE_DECAY) 46 | variables_to_restore = variable_averages.variables_to_restore() 47 | for name, v in variables_to_restore.items(): 48 | try: 49 | tf.train.Saver([v]).restore(sess, FLAGS.checkpoint_path) 50 | except Exception: 51 | sess.run(tf.variables_initializer([v])) 52 | 53 | tf.train.start_queue_runners(sess=sess) 54 | outputs = sess.run({'fc5': fc5, 'fc6': fc6, 'images': orig_images, 'labels': labels}) 55 | 56 | # write to metadata file 57 | metadata_path = os.path.join(FLAGS.logdir, 'metadata.tsv') 58 | with open(metadata_path, 'w') as f: 59 | for index in outputs['labels']: 60 | f.write('%d\n' % index) 61 | # write to sprite image file 62 | image_path = os.path.join(FLAGS.logdir, 'sprite.jpg') 63 | unpacked = tf.unpack(outputs['images'], 900) 64 | rows = [] 65 | for i in range(30): 66 | rows.append(tf.concat(1, unpacked[i*30:(i+1)*30])) 67 | jpeg = tf.image.encode_jpeg(tf.concat(0, rows)) 68 | with open(image_path, 'wb') as f: 69 | f.write(sess.run(jpeg)) 70 | # add embedding data 71 | targets = [tf.Variable(e, name=name) for name, e in outputs.items() if name.startswith('fc')] 72 | config = projector.ProjectorConfig() 73 | for v in targets: 74 | embedding = config.embeddings.add() 75 | embedding.tensor_name = v.name 76 | embedding.metadata_path = metadata_path 77 | embedding.sprite.image_path = image_path 78 | embedding.sprite.single_image_dim.extend([96, 96]) 79 | sess.run(tf.variables_initializer(targets)) 80 | summary_writer = tf.train.SummaryWriter(FLAGS.logdir) 81 | projector.visualize_embeddings(summary_writer, config) 82 | graph_saver = tf.train.Saver(targets) 83 | graph_saver.save(sess, os.path.join(FLAGS.logdir, 'model.ckpt')) 84 | 85 | 86 | if __name__ == '__main__': 87 | tf.app.run() 88 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | MOVING_AVERAGE_DECAY = 0.9999 4 | 5 | 6 | def inference(images, num_classes, reuse=False): 7 | def _activation_summary(x): 8 | if not reuse: 9 | tensor_name = x.op.name 10 | tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) 11 | 12 | inputs = tf.identity(images, name='inputs') 13 | 14 | with tf.variable_scope('conv1', reuse=reuse) as scope: 15 | conv1 = tf.layers.conv2d(inputs, 60, [3, 3], padding='SAME', activation=tf.nn.relu) 16 | _activation_summary(conv1) 17 | conv1 = tf.identity(conv1, name=scope.name) 18 | pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1') 19 | 20 | with tf.variable_scope('conv2', reuse=reuse) as scope: 21 | conv2 = tf.layers.conv2d(pool1, 90, [3, 3], padding='SAME', activation=tf.nn.relu) 22 | _activation_summary(conv2) 23 | conv2 = tf.identity(conv2, name=scope.name) 24 | pool2 = tf.nn.max_pool(conv2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool2') 25 | 26 | with tf.variable_scope('conv3', reuse=reuse) as scope: 27 | conv3 = tf.layers.conv2d(pool2, 120, [3, 3], padding='SAME', activation=tf.nn.relu) 28 | _activation_summary(conv3) 29 | conv3 = tf.identity(conv3, name=scope.name) 30 | pool3 = tf.nn.max_pool(conv3, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool3') 31 | 32 | with tf.variable_scope('conv4', reuse=reuse) as scope: 33 | conv4 = tf.layers.conv2d(pool3, 150, [3, 3], padding='SAME', activation=tf.nn.relu) 34 | _activation_summary(conv4) 35 | conv4 = tf.identity(conv4, name=scope.name) 36 | pool4 = tf.nn.max_pool(conv4, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool4') 37 | 38 | with tf.variable_scope('fc5', reuse=reuse) as scope: 39 | reshape = tf.reshape(pool4, [images.get_shape()[0].value, -1]) 40 | fc5 = tf.layers.dense(reshape, 200, activation=tf.nn.relu) 41 | _activation_summary(fc5) 42 | fc5 = tf.identity(fc5, name=scope.name) 43 | 44 | with tf.variable_scope('fc6', reuse=reuse) as scope: 45 | fc6 = tf.layers.dense(fc5, 200, activation=tf.nn.relu) 46 | _activation_summary(fc6) 47 | fc6 = tf.identity(fc6, name=scope.name) 48 | 49 | with tf.variable_scope('fc7', reuse=reuse) as scope: 50 | fc7 = tf.layers.dense(fc6, num_classes, activation=None) 51 | fc7 = tf.identity(fc7, name=scope.name) 52 | 53 | return fc7 54 | 55 | 56 | def loss(logits, labels): 57 | # cross entropy 58 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) 59 | mean = tf.reduce_mean(cross_entropy, name='cross_entropy') 60 | tf.add_to_collection('losses', mean) 61 | # add weight decay 62 | wd = { 63 | 'fc5': 0.001, 64 | 'fc6': 0.001, 65 | } 66 | for scope, scale in wd.items(): 67 | with tf.variable_scope(scope, reuse=True): 68 | v = tf.get_variable('dense/kernel') 69 | weight_decay = tf.multiply(tf.nn.l2_loss(v), scale, name='weight_loss') 70 | tf.add_to_collection('losses', weight_decay) 71 | return tf.add_n(tf.get_collection('losses'), name='total_loss') 72 | 73 | 74 | def train(total_loss): 75 | loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') 76 | losses = tf.get_collection('losses') 77 | loss_averages_op = loss_averages.apply(losses + [total_loss]) 78 | 79 | for l in losses + [total_loss]: 80 | tf.summary.scalar(l.op.name + ' (raw)', l) 81 | 82 | # Apply gradients, and add histograms 83 | with tf.control_dependencies([loss_averages_op]): 84 | opt = tf.train.AdamOptimizer() 85 | grads = opt.compute_gradients(total_loss) 86 | apply_gradient_op = opt.apply_gradients(grads) 87 | for var in tf.trainable_variables(): 88 | tf.summary.histogram(var.op.name, var) 89 | for grad, var in grads: 90 | if grad is not None: 91 | tf.summary.histogram(var.op.name + '/gradients', grad) 92 | 93 | # Track the moving averages of all trainable variables 94 | variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY) 95 | variables_averages_op = variable_averages.apply(tf.trainable_variables()) 96 | 97 | with tf.control_dependencies([apply_gradient_op, variables_averages_op]): 98 | train_op = tf.no_op(name='train') 99 | return train_op 100 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.5.0 2 | appdirs==1.4.3 3 | astor==0.7.1 4 | backports.weakref==1.0rc1 5 | bleach==1.5.0 6 | click==6.7 7 | enum34==1.1.6 8 | Flask==1.0.2 9 | futures==3.1.1 10 | gast==0.2.0 11 | grpcio==1.15.0 12 | h5py==2.8.0 13 | html5lib==0.9999999 14 | itsdangerous==0.24 15 | Jinja2>=2.10.1 16 | Keras-Applications==1.0.6 17 | Keras-Preprocessing==1.0.5 18 | Markdown==3.0.1 19 | MarkupSafe==0.23 20 | mock==3.0.5 21 | numpy==1.15.2 22 | packaging==16.8 23 | protobuf==3.6.1 24 | pyparsing==2.2.0 25 | six==1.11.0 26 | tensorboard==1.13.1 27 | tensorflow==1.13.1 28 | tensorflow-estimator==1.13.0 29 | termcolor==1.1.0 30 | Werkzeug==0.14.1 31 | -------------------------------------------------------------------------------- /runtime.txt: -------------------------------------------------------------------------------- 1 | python-3.6.6 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import tensorflow as tf 3 | import numpy as np 4 | import model 5 | 6 | import json 7 | import os 8 | import time 9 | 10 | FLAGS = tf.app.flags.FLAGS 11 | 12 | tf.app.flags.DEFINE_string('datadir', 'data/tfrecords', 13 | """Path to the TFRecord data directory.""") 14 | tf.app.flags.DEFINE_string('logdir', 'logdir', 15 | """Directory where to write event logs and checkpoint.""") 16 | tf.app.flags.DEFINE_string('checkpoint_path', '/tmp/model.ckpt', 17 | """Path to read model checkpoint.""") 18 | tf.app.flags.DEFINE_integer('input_size', 96, 19 | """Size of input image""") 20 | tf.app.flags.DEFINE_integer('max_steps', 5001, 21 | """Number of batches to run.""") 22 | 23 | 24 | def inputs(batch_size, files, num_examples_per_epoch_for_train=5000): 25 | queues = {} 26 | for i in range(len(files)): 27 | key = i % 5 28 | if key not in queues: 29 | queues[key] = [] 30 | queues[key].append(files[i]) 31 | 32 | def read_files(files): 33 | fqueue = tf.train.string_input_producer(files) 34 | reader = tf.TFRecordReader() 35 | key, value = reader.read(fqueue) 36 | features = tf.parse_single_example(value, features={ 37 | 'label': tf.FixedLenFeature([], tf.int64), 38 | 'image_raw': tf.FixedLenFeature([], tf.string), 39 | }) 40 | image = tf.image.decode_jpeg(features['image_raw'], channels=3) 41 | image = tf.cast(image, tf.float32) 42 | 43 | # distort 44 | bounding_boxes = tf.div(tf.constant([[[8, 8, 104, 104]]], dtype=tf.float32), 112.0) 45 | begin, size, _ = tf.image.sample_distorted_bounding_box( 46 | tf.shape(image), bounding_boxes, 47 | min_object_covered=(80.0*80.0)/(96.0*96.0), 48 | aspect_ratio_range=[9.0/10.0, 10.0/9.0]) 49 | image = tf.slice(image, begin, size) 50 | image = tf.image.resize_images(image, [FLAGS.input_size, FLAGS.input_size]) 51 | image = tf.image.random_flip_left_right(image) 52 | image = tf.image.random_brightness(image, max_delta=0.4) 53 | image = tf.image.random_contrast(image, lower=0.6, upper=1.4) 54 | image = tf.image.random_hue(image, max_delta=0.04) 55 | image = tf.image.random_saturation(image, lower=0.6, upper=1.4) 56 | image.set_shape([None, None, 3]) 57 | 58 | return [tf.image.per_image_standardization(image), features['label']] 59 | 60 | min_queue_examples = num_examples_per_epoch_for_train 61 | images, labels = tf.train.shuffle_batch_join( 62 | [read_files(files) for files in queues.values()], 63 | batch_size=batch_size, 64 | capacity=min_queue_examples + 3 * batch_size, 65 | min_after_dequeue=min_queue_examples 66 | ) 67 | images = tf.image.resize_images(images, [FLAGS.input_size, FLAGS.input_size]) 68 | tf.summary.image('images', images) 69 | return images, labels 70 | 71 | 72 | def labels_json(): 73 | filepath = os.path.join(os.path.join(FLAGS.datadir, 'labels.json')) 74 | with open(filepath, 'r') as f: 75 | return f.read() 76 | 77 | 78 | def restore_or_initialize(sess): 79 | initialize_variables = [] 80 | for v in tf.global_variables(): 81 | if v in tf.trainable_variables() or "ExponentialMovingAverage" in v.name: 82 | if tf.train.checkpoint_exists(FLAGS.checkpoint_path): 83 | print('restore variable "%s"' % v.name) 84 | restorer = tf.train.Saver([v]) 85 | try: 86 | restorer.restore(sess, FLAGS.checkpoint_path) 87 | continue 88 | except Exception: 89 | print('could not restore, initialize!') 90 | initialize_variables.append(v) 91 | sess.run(tf.variables_initializer(initialize_variables)) 92 | 93 | 94 | def main(argv=None): 95 | labels_data = labels_json() 96 | tf.Variable(labels_data, trainable=False, name='labels') 97 | 98 | batch_size = 128 99 | files = [os.path.join(FLAGS.datadir, f) for f in os.listdir(os.path.join(FLAGS.datadir)) if f.endswith('.tfrecords')] 100 | with tf.variable_scope('inputs'): 101 | images, labels = inputs(batch_size, files) 102 | logits = model.inference(images, len(json.loads(labels_data)) + 1) 103 | losses = model.loss(logits, labels) 104 | train_op = model.train(losses) 105 | summary_op = tf.summary.merge_all() 106 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=21, keep_checkpoint_every_n_hours=1) 107 | with tf.Session() as sess: 108 | summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph=sess.graph) 109 | restore_or_initialize(sess) 110 | 111 | tf.train.start_queue_runners(sess=sess) 112 | 113 | for step in range(FLAGS.max_steps): 114 | start_time = time.time() 115 | _, loss_value = sess.run([train_op, losses]) 116 | duration = time.time() - start_time 117 | 118 | assert not np.isnan(loss_value), 'Model diverged with loss = NaN' 119 | 120 | format_str = '%s: step %d, loss = %.5f (%.3f sec/batch)' 121 | print(format_str % (datetime.now(), step, loss_value, duration)) 122 | 123 | if step % 100 == 0: 124 | summary_str = sess.run(summary_op) 125 | summary_writer.add_summary(summary_str, step) 126 | if step % 250 == 0 or (step + 1) == FLAGS.max_steps: 127 | checkpoint_path = os.path.join(FLAGS.logdir, 'model.ckpt') 128 | saver.save(sess, checkpoint_path, global_step=step, write_meta_graph=False) 129 | 130 | 131 | if __name__ == '__main__': 132 | tf.app.run() 133 | -------------------------------------------------------------------------------- /web.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, jsonify, request 2 | import tensorflow as tf 3 | 4 | import base64 5 | import urllib.request 6 | import os 7 | import json 8 | 9 | FLAGS = tf.app.flags.FLAGS 10 | tf.app.flags.DEFINE_string('model_path', '/tmp/model.pb', 11 | """Path to model data.""") 12 | tf.app.flags.DEFINE_integer('port', 5000, 13 | """Application port.""") 14 | tf.app.flags.DEFINE_integer('top_k', 5, 15 | """Finds the k largest entries""") 16 | tf.app.flags.DEFINE_integer('input_size', 96, 17 | """Size of input image""") 18 | 19 | sess = tf.Session() 20 | 21 | # load model data, get top_k 22 | if not os.path.isfile(FLAGS.model_path): 23 | print('No model data file found') 24 | urllib.request.urlretrieve(os.environ['MODEL_DOWNLOAD_URL'], FLAGS.model_path) 25 | graph_def = tf.GraphDef() 26 | with tf.gfile.FastGFile(FLAGS.model_path, 'rb') as f: 27 | graph_def.ParseFromString(f.read()) 28 | tf.import_graph_def(graph_def, name='') 29 | fc7 = sess.graph.get_tensor_by_name('fc7/fc7:0') 30 | top_values, top_indices = tf.nn.top_k(tf.nn.softmax(fc7), k=FLAGS.top_k) 31 | # retrieve labels 32 | labels = json.loads(sess.run(sess.graph.get_tensor_by_name('labels:0')).decode()) 33 | print('{} labels loaded.'.format(len(labels))) 34 | 35 | # Flask setup 36 | app = Flask(__name__) 37 | app.debug = True 38 | 39 | 40 | @app.route('/labels') 41 | def label(): 42 | return jsonify(labels=labels) 43 | 44 | 45 | @app.route('/', methods=['POST']) 46 | def api(): 47 | results = [] 48 | ops = [top_values, top_indices] 49 | for image in request.form.getlist('images'): 50 | values, indices = sess.run(ops, feed_dict={'contents:0': base64.b64decode(image.split(',')[1])}) 51 | top_k = [] 52 | for i in range(FLAGS.top_k): 53 | top_k.append({ 54 | 'label': labels.get(str(indices.flatten().tolist()[i]), {}), 55 | 'value': values.flatten().tolist()[i], 56 | }) 57 | results.append({'top': top_k}) 58 | return jsonify(results=results) 59 | 60 | 61 | if __name__ == '__main__': 62 | app.run(host='0.0.0.0', port=FLAGS.port) 63 | --------------------------------------------------------------------------------