├── .gitignore ├── GAN ├── AC-GAN │ ├── README.md │ ├── discriminator.py │ ├── generate.py │ ├── generator.py │ ├── mnist.py │ ├── optimizer.py │ ├── queue_context.py │ ├── train.py │ └── utils.py └── Info-GAN │ ├── README.md │ ├── discriminator.py │ ├── generate.py │ ├── generator.py │ ├── mnist.py │ ├── optimizer.py │ ├── queue_context.py │ ├── train.py │ └── utils.py ├── README.md ├── algorithm ├── README.md ├── __init__.py ├── __init__.pyc ├── input_data.py ├── input_data.pyc ├── train_mnist_multi_perceptron.py ├── train_mnist_single_perceptron.py └── train_mnist_softmax.py ├── chinese_hand_write_rec └── src │ ├── README.md │ ├── chinese_rec.py │ └── data_to_png.py ├── covert_to_tfrecord ├── covert_datasets_tfrecord.py ├── covert_sex_to_tfrecord.py └── covert_somedata_to_tfrecord.py ├── deepcolor ├── .gitignore ├── README.md ├── download_images.py ├── download_images_multithread.py ├── download_images_v2.py ├── imgs-valid │ ├── 1.jpg │ ├── 2.jpg │ ├── 5000_1.jpg │ ├── 5000_2.jpg │ ├── 5001_1.jpg │ ├── 5001_2.jpg │ ├── 5001_3.jpg │ ├── 5001_4.jpg │ ├── 5001_5.jpg │ ├── 5001_6.jpg │ ├── 5001_7.jpg │ ├── 5001_8.jpg │ ├── 5001_9.jpg │ ├── 5002_1.jpg │ ├── 5002_2.jpg │ ├── 5002_3.jpg │ ├── 5002_4.jpg │ ├── 5002_5.jpg │ ├── 5002_6.jpg │ ├── 5002_7.jpg │ ├── 5002_8.jpg │ ├── 5003_1.jpg │ ├── 5003_2.jpg │ ├── 5003_3.jpg │ ├── 5003_4.jpg │ ├── 5003_5.jpg │ ├── 5003_6.jpg │ ├── 5003_7.jpg │ ├── 5003_8.jpg │ ├── 5003_9.jpg │ ├── 5004_1.jpg │ ├── 5004_10.jpg │ ├── 5004_2.jpg │ ├── 5004_3.jpg │ ├── 5004_4.jpg │ ├── 5004_5.jpg │ ├── 5004_6.jpg │ ├── 5004_7.jpg │ ├── 5004_8.jpg │ ├── 5004_9.jpg │ ├── 5005_1.jpg │ ├── 5005_2.jpg │ ├── 5005_3.jpg │ ├── 5005_4.jpg │ ├── 5005_5.jpg │ ├── 5005_6.jpg │ ├── 5005_7.jpg │ ├── 5005_8.jpg │ ├── 5006_1.jpg │ ├── 5006_2.jpg │ ├── 5006_3.jpg │ ├── 5006_4.jpg │ ├── 5006_5.jpg │ ├── 5006_6.jpg │ ├── 5007_1.jpg │ ├── 5007_10.jpg │ ├── 5007_2.jpg │ ├── 5007_3.jpg │ ├── 5007_4.jpg │ ├── 5007_5.jpg │ ├── 5007_6.jpg │ ├── 5007_7.jpg │ ├── 5007_8.jpg │ └── 5007_9.jpg ├── lines.py ├── main.py ├── nohup.out ├── ops.py ├── report.txt ├── server.py ├── testserver.py ├── tf_upgrade.py ├── uploaded │ ├── colors.jpg │ ├── colors3.jpg │ ├── convert.png │ ├── gen.jpg │ ├── gen3.jpg │ ├── lines.jpg │ ├── picasso.png │ └── sanae.png ├── utils.py └── web │ ├── colorpicker │ ├── css │ │ ├── bootstrap-colorpicker.css │ │ ├── bootstrap-colorpicker.css.map │ │ ├── bootstrap-colorpicker.min.css │ │ └── bootstrap-colorpicker.min.css.map │ ├── img │ │ └── bootstrap-colorpicker │ │ │ ├── alpha-horizontal.png │ │ │ ├── alpha.png │ │ │ ├── hue-horizontal.png │ │ │ ├── hue.png │ │ │ └── saturation.png │ └── js │ │ ├── bootstrap-colorpicker.js │ │ └── bootstrap-colorpicker.min.js │ ├── css │ ├── bootstrap.min.css │ └── grid.css │ ├── draw.html │ ├── image_examples │ ├── armscross.jpg │ ├── armscross.png │ ├── picasso.png │ └── sanae.png │ ├── images │ ├── 1_color.png │ ├── 1_line.png │ ├── 2_color.png │ ├── 2_line.png │ ├── 3_color.png │ ├── 3_line.png │ ├── 4_color.png │ ├── 4_line.png │ ├── 5_color.png │ ├── 5_line.png │ ├── 6_color.png │ ├── 6_line.png │ ├── 7_color.png │ ├── 7_line.png │ └── a_line.png │ ├── index.html │ ├── jquery-1.11.2.min.js │ └── sketch.js ├── finetuning ├── README.md ├── convert_pys │ ├── covert_datasets_tfrecord.py │ └── covert_somedata_to_tfrecord.py ├── datasets │ ├── __init__.py │ ├── dataset_factory.py │ ├── dataset_utils.py │ └── fisher.py ├── deployment │ ├── __init__.py │ └── model_deploy.py ├── eval_image_classifier.py ├── flask │ ├── flask_inference.py │ ├── flask_inference_1000.py │ ├── imagenet_metadata.txt │ ├── sysnet.txt │ └── uploads │ │ ├── 3a43c94e-241c-4988-a2f5-a3d34595ff40.png │ │ ├── 3d060309-0e80-479e-994b-924deec9b34d.png │ │ ├── 60f94ed5-1851-43aa-af5d-7a3ad9c2fbb0.jpg │ │ ├── 779ee9bb-421a-441b-87bf-b26373b7d43e.jpg │ │ ├── cdbc38e5-5001-46f7-868c-76e9037299d8.jpg │ │ ├── dfa92819-9372-4519-be57-81995751f2e6.png │ │ ├── f513b5c8-32b8-4fdd-bad9-d59a19dd094f.42 │ │ └── logo.jpeg ├── inference │ ├── fish_inference.py │ └── inference_1000.py ├── nets │ ├── __init__.py │ ├── alexnet.py │ ├── alexnet_test.py │ ├── cifarnet.py │ ├── inception.py │ ├── inception_resnet_v2.py │ ├── inception_resnet_v2_test.py │ ├── inception_utils.py │ ├── inception_v1.py │ ├── inception_v1_test.py │ ├── inception_v2.py │ ├── inception_v2_test.py │ ├── inception_v3.py │ ├── inception_v3_test.py │ ├── inception_v4.py │ ├── inception_v4_test.py │ ├── lenet.py │ ├── nets_factory.py │ ├── nets_factory_test.py │ ├── overfeat.py │ ├── overfeat_test.py │ ├── resnet_utils.py │ ├── resnet_v1.py │ ├── resnet_v1_test.py │ ├── resnet_v2.py │ ├── resnet_v2_test.py │ ├── vgg.py │ └── vgg_test.py ├── preprocessing │ ├── __init__.py │ ├── inception_preprocessing.py │ └── preprocessing_factory.py ├── pretrain_model │ └── .gitkeep ├── requirements.txt ├── run_scripts │ ├── run.sh │ ├── run_all.sh │ ├── run_all_eval.sh │ └── run_eval.sh ├── tfrecords │ └── labels.txt ├── train │ ├── ALB │ │ └── img_00003.jpg │ ├── BET │ │ └── img_04557.jpg │ ├── DOL │ │ └── img_00951.jpg │ ├── LAG │ │ └── img_01644.jpg │ ├── NoF │ │ └── img_00028.jpg │ ├── OTHER │ │ └── img_00063.jpg │ ├── SHARK │ │ └── img_02176.jpg │ └── YFT │ │ └── img_00184.jpg ├── train_image_classifier.py └── uploads │ └── 1.png ├── images ├── acgan-fig-01.png ├── acgan-result-01.png ├── acgan-result.png ├── demo_result.png ├── flask_with_pretrain_model.png ├── flask_with_pretrain_model_00.png ├── infogan-fig-01.png ├── infogan-result-01.png ├── infogan-result.png ├── mac_blogs_deepcolor-01.png ├── mac_blogs_deepcolor-03.png ├── mltookit_log_00.png ├── mnist_client_result.png └── mnist_server.png ├── machinelearning_toolkit ├── README.md ├── scripts │ ├── linear_classifier.py │ ├── simple-tf-rf.py │ ├── simple-tf-svm.py │ ├── tf-rf.py │ ├── tf-svm.py │ └── tf_wide_deep.py └── wide_deep_scripts │ ├── linear_classifier.py │ ├── simple-tf-rf.py │ ├── simple-tf-svm.py │ ├── simple_tf_wide_deep.py │ ├── tf-rf.py │ ├── tf-svm.py │ └── tf_wide_deep.py ├── nlp ├── NMT │ ├── README.md │ └── scripts │ │ ├── __init__.py │ │ ├── config.py │ │ ├── dataset_helpers │ │ ├── Constants.py │ │ ├── Dict.py │ │ ├── IO.py │ │ ├── __init__.py │ │ ├── copus.py │ │ ├── preprocess.py │ │ ├── preprocess.sh │ │ └── test.log │ │ └── models │ │ └── models.py ├── Tag2Vec │ ├── README.md │ └── scripts │ │ ├── gen_w2v.py │ │ ├── tsne.png │ │ ├── visual_embeddings.py │ │ ├── visual_word2vec.py │ │ ├── word2vec.py │ │ ├── word2vec_basic.py │ │ └── word2vec_ops.so └── text_classifier │ ├── README.md │ ├── data │ └── origin_data │ │ └── sample.csv │ └── scripts │ ├── __init__.py │ ├── bow_text_classifier.py │ ├── cnn_lstm_text_classifier.py │ ├── cnn_text_classifier.py │ ├── cnn_text_classifier_v2.py │ ├── cnn_text_classifier_v3.py │ ├── config.py │ ├── dataset_helpers │ ├── __init__.py │ ├── cut_doc.py │ ├── doc_dataset.py │ └── gen_w2v.py │ ├── lstm_text_classifier.py │ └── tfidf_text_classifier.py ├── serving ├── READMD.md ├── checkpoint │ ├── checkpoint │ ├── checkpoint.ckpt │ └── checkpoint.ckpt.meta ├── generate_grpc_file.sh ├── mnist_client.py ├── mnist_server.py ├── predict.proto ├── predict_pb2.py ├── predict_pb2.pyc └── train_mnist_softmax4serving.py ├── tf_upgrade.py └── zhihu_code ├── README.md └── src ├── gen_verification.py ├── samples ├── 亭猫那吉腊咖材年辜惯_3921.jpeg ├── 希巫瀑汝蓝渡靡旱淋迈_052931.jpeg ├── 慑嘻佬晚决幢聚色签品_90617.jpeg ├── 抨玩厘撕唇惧茸死氦徒_36410.jpeg ├── 毒气饺春类厂全编揉辅_39841.jpeg ├── 片蹬片藤袒俏酶讫垢管_803547.jpeg ├── 畸具岔狂硼瘁务廷撕刮_92634.jpeg ├── 算嫩曼掂羌屁解嫩历担_429637.jpeg ├── 茎搔逞沫彬培韶眠僧痛_675.jpeg └── 陇琴卉底仅铰嗽寒售澜_021346.jpeg └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | algorithm/mnist/ 2 | mnist_jpgs/ 3 | mnist_jpgs.zip 4 | .vscode/ 5 | *.pyc 6 | train 7 | .DS_Store 8 | inception_v3.ckpt 9 | GAN/*/asset/* 10 | data/images/*.png 11 | billiard_detection/data/images/*.png -------------------------------------------------------------------------------- /GAN/AC-GAN/README.md: -------------------------------------------------------------------------------- 1 | ## ACGAN 2 | A tensorflow implementation of ACGAN. 3 | 4 | Reference [sugartensoracgan](https://github.com/buriburisuri/acgan) 5 | ## ACGAN Structure 6 | ![](../../images/acganfig01.png) 7 | 8 | acgan build a gan which discriminator output include not only the probability of the real/fake but also the class label distribution. 9 | [https://arxiv.org/pdf/1610.09585v3.pdf](https://arxiv.org/pdf/1610.09585v3.pdf) -------------------------------------------------------------------------------- /GAN/AC-GAN/discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | from tensorflow.python.ops import variable_scope 4 | 5 | def leaky_relu(x): 6 | return tf.where(tf.greater(x, 0), x, 0.01 * x) 7 | 8 | def discriminator(tensor, num_category=10, batch_size=32, num_cont=2): 9 | """ 10 | """ 11 | 12 | reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0 13 | print reuse 14 | print tensor.get_shape() 15 | with variable_scope.variable_scope('discriminator', reuse=reuse): 16 | tensor = slim.conv2d(tensor, num_outputs = 64, kernel_size=[4,4], stride=2, activation_fn=leaky_relu) 17 | tensor = slim.conv2d(tensor, num_outputs=128, kernel_size=[4,4], stride=2, activation_fn=leaky_relu) 18 | tensor = slim.flatten(tensor) 19 | shared_tensor = slim.fully_connected(tensor, num_outputs=1024, activation_fn = leaky_relu) 20 | recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn = leaky_relu) 21 | disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=None) 22 | disc = tf.squeeze(disc, -1) 23 | recog_cat = slim.fully_connected(recog_shared, num_outputs=num_category, activation_fn=None) 24 | recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid) 25 | return disc, recog_cat, recog_cont 26 | -------------------------------------------------------------------------------- /GAN/AC-GAN/generate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from generator import generator 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import os 8 | 9 | _logger = tf.logging._logger 10 | _logger.setLevel(0) 11 | 12 | batch_size = 100 # batch size 13 | cat_dim = 10 # total categorical factor 14 | con_dim = 2 # total continuous factor 15 | rand_dim = 38 # total random latent dimension 16 | 17 | 18 | target_num = tf.placeholder(dtype=tf.int32, shape=batch_size) 19 | target_cval_1 = tf.placeholder(dtype=tf.float32, shape=batch_size) 20 | target_cval_2 = tf.placeholder(dtype=tf.float32, shape=batch_size) 21 | 22 | z = tf.one_hot(tf.ones(batch_size, dtype=tf.int32) * target_num, depth=cat_dim) 23 | z = tf.concat(axis=z.get_shape().ndims-1, values=[z, tf.expand_dims(target_cval_1, -1), tf.expand_dims(target_cval_2, -1)]) 24 | 25 | z = tf.concat(axis=z.get_shape().ndims-1, values=[z, tf.random_normal((batch_size, rand_dim))]) 26 | 27 | gen = tf.squeeze(generator(z), -1) 28 | 29 | def run_generator(num, x1, x2, fig_name='sample.png'): 30 | with tf.Session() as sess: 31 | sess.run(tf.group(tf.global_variables_initializer(), 32 | tf.local_variables_initializer())) 33 | saver = tf.train.Saver() 34 | saver.restore(sess, tf.train.latest_checkpoint('checkpoint_dir')) 35 | imgs = sess.run(gen, {target_num: num, target_cval_1: x1, target_cval_2:x2}) 36 | 37 | _, ax = plt.subplots(10,10, sharex=True, sharey=True) 38 | for i in range(10): 39 | for j in range(10): 40 | ax[i][j].imshow(imgs[i*10+j], 'gray') 41 | ax[i][j].set_axis_off() 42 | plt.savefig(os.path.join('result/',fig_name), dpi=600) 43 | print 'Sample image save to "result/{0}"'.format(fig_name) 44 | plt.close() 45 | 46 | a = np.random.randint(0, cat_dim, batch_size) 47 | print a 48 | run_generator(a, 49 | np.random.uniform(0, 1, batch_size), np.random.uniform(0, 1, batch_size), 50 | fig_name='fake.png') 51 | 52 | # classified image 53 | run_generator(np.arange(10).repeat(10), np.linspace(0, 1, 10).repeat(10), np.expand_dims(np.linspace(0, 1, 10), axis=1).repeat(10, axis=1).T.flatten(),) 54 | -------------------------------------------------------------------------------- /GAN/AC-GAN/generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | from tensorflow.python.ops import variable_scope 4 | 5 | def generator(tensor): 6 | reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0 7 | print tensor.get_shape() 8 | with variable_scope.variable_scope('generator', reuse = reuse): 9 | tensor = slim.fully_connected(tensor, 1024) 10 | print tensor 11 | tensor = slim.batch_norm(tensor, activation_fn=tf.nn.relu) 12 | tensor = slim.fully_connected(tensor, 7*7*128) 13 | tensor = slim.batch_norm(tensor, activation_fn=tf.nn.relu) 14 | tensor = tf.reshape(tensor, [-1, 7, 7, 128]) 15 | # print '22',tensor.get_shape() 16 | tensor = slim.conv2d_transpose(tensor, 64, kernel_size=[4,4], stride=2, activation_fn = None) 17 | print 'gen',tensor.get_shape() 18 | tensor = slim.batch_norm(tensor, activation_fn = tf.nn.relu) 19 | tensor = slim.conv2d_transpose(tensor, 1, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.sigmoid) 20 | return tensor -------------------------------------------------------------------------------- /GAN/AC-GAN/mnist.py: -------------------------------------------------------------------------------- 1 | from tensorflow.examples.tutorials.mnist import input_data 2 | from utils import data_to_tensor, Opt 3 | 4 | 5 | 6 | 7 | 8 | class Mnist(object): 9 | r"""Downloads Mnist datasets and puts them in queues. 10 | """ 11 | _data_dir = './asset/data/mnist' 12 | 13 | def __init__(self, batch_size=128, num_epochs=30, reshape=False, one_hot=False): 14 | 15 | # load sg_data set 16 | data_set = input_data.read_data_sets(Mnist._data_dir, reshape=reshape, one_hot=one_hot) 17 | 18 | self.batch_size = batch_size 19 | 20 | # save each sg_data set 21 | _train = data_set.train 22 | _valid = data_set.validation 23 | _test = data_set.test 24 | 25 | # member initialize 26 | self.train, self.valid, self.test = Opt(), Opt(), Opt() 27 | 28 | # convert to tensor queue 29 | self.train.image, self.train.label = \ 30 | data_to_tensor([_train.images, _train.labels.astype('int32')], batch_size, name='train') 31 | self.valid.image, self.valid.label = \ 32 | data_to_tensor([_valid.images, _valid.labels.astype('int32')], batch_size, name='valid') 33 | self.test.image, self.test.label = \ 34 | data_to_tensor([_test.images, _test.labels.astype('int32')], batch_size, name='test') 35 | 36 | # calc total batch count 37 | self.train.num_batch = _train.labels.shape[0] // batch_size 38 | self.valid.num_batch = _valid.labels.shape[0] // batch_size 39 | self.test.num_batch = _test.labels.shape[0] // batch_size 40 | -------------------------------------------------------------------------------- /GAN/AC-GAN/optimizer.py: -------------------------------------------------------------------------------- 1 | from utils import Opt 2 | import tensorflow as tf 3 | 4 | def optim(loss, **kwargs): 5 | r"""Applies gradients to variables. 6 | 7 | Args: 8 | loss: A 0-D `Tensor` containing the value to minimize. 9 | kwargs: 10 | optim: A name for optimizer. 'MaxProp' (default), 'AdaMax', 'Adam', or 'sgd'. 11 | lr: A Python Scalar (optional). Learning rate. Default is .001. 12 | beta1: A Python Scalar (optional). Default is .9. 13 | beta2: A Python Scalar (optional). Default is .99. 14 | category: A string or string list. Specifies the variables that should be trained (optional). 15 | Only if the name of a trainable variable starts with `category`, it's value is updated. 16 | Default is '', which means all trainable variables are updated. 17 | """ 18 | opt = Opt(kwargs) 19 | # opt += Opt(optim='MaxProp', lr=0.001, beta1=0.9, beta2=0.99, category='') 20 | 21 | # default training options 22 | opt += Opt(optim='MaxProp', lr=0.001, beta1=0.9, beta2=0.99, category='') 23 | 24 | # select optimizer 25 | # if opt.optim == 'MaxProp': 26 | # optim = tf.sg_optimize.MaxPropOptimizer(learning_rate=opt.lr, beta2=opt.beta2) 27 | # elif opt.optim == 'AdaMax': 28 | # optim = tf.sg_optimize.AdaMaxOptimizer(learning_rate=opt.lr, beta1=opt.beta1, beta2=opt.beta2) 29 | # elif opt.optim == 'Adam': 30 | if opt.optim == 'Adm': 31 | optim = tf.train.AdamOptimizer(learning_rate=opt.lr, beta1=opt.beta1, beta2=opt.beta2) 32 | else: 33 | optim = tf.train.GradientDescentOptimizer(learning_rate=opt.lr) 34 | 35 | # get trainable variables 36 | if isinstance(opt.category, (tuple, list)): 37 | var_list = [] 38 | for cat in opt.category: 39 | var_list.extend([t for t in tf.trainable_variables() if t.name.startswith(cat)]) 40 | else: 41 | var_list = [t for t in tf.trainable_variables() if t.name.startswith(opt.category)] 42 | 43 | # calc gradient 44 | gradient = optim.compute_gradients(loss, var_list=var_list) 45 | 46 | # add summary 47 | for v, g in zip(var_list, gradient): 48 | # exclude batch normal statics 49 | if 'mean' not in v.name and 'variance' not in v.name \ 50 | and 'beta' not in v.name and 'gamma' not in v.name: 51 | prefix = '' 52 | # summary name 53 | name = prefix + ''.join(v.name.split(':')[:-1]) 54 | # summary statistics 55 | # noinspection PyBroadException 56 | try: 57 | tf.summary.scalar(name + '/grad', tf.global_norm([g])) 58 | tf.summary.histogram(name + '/grad-h', g) 59 | except: 60 | pass 61 | global_step = tf.Variable(0, name='global_step', trainable=False) 62 | # gradient update op 63 | return optim.apply_gradients(gradient, global_step=global_step), global_step -------------------------------------------------------------------------------- /GAN/AC-GAN/queue_context.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def queue_context(sess=None): 4 | r"""Context helper for queue routines. 5 | 6 | Args: 7 | sess: A session to open queues. If not specified, a new session is created. 8 | 9 | Returns: 10 | None 11 | """ 12 | 13 | # default session 14 | sess = tf.get_default_session() if sess is None else sess 15 | 16 | # thread coordinator 17 | coord = tf.train.Coordinator() 18 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 19 | return coord, threads -------------------------------------------------------------------------------- /GAN/AC-GAN/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import logging 5 | from mnist import Mnist 6 | import tensorflow.contrib.slim as slim 7 | from generator import generator 8 | from discriminator import discriminator 9 | from optimizer import optim 10 | from queue_context import queue_context 11 | import os 12 | import sys 13 | 14 | 15 | tf.logging.set_verbosity(0) 16 | 17 | 18 | # 19 | # hyper parameters 20 | # 21 | 22 | batch_size = 32 # batch size 23 | cat_dim = 10 # total categorical factor 24 | con_dim = 2 # total continuous factor 25 | rand_dim = 38 26 | num_epochs = 30 27 | debug_max_steps = 1000 28 | save_epoch = 5 29 | max_epochs = 50 30 | 31 | # 32 | # inputs 33 | # 34 | 35 | # MNIST input tensor ( with QueueRunner ) 36 | data = Mnist(batch_size=batch_size, num_epochs=num_epochs) 37 | num_batch_per_epoch = data.train.num_batch 38 | 39 | 40 | # input images and labels 41 | x = data.train.image 42 | y = data.train.label 43 | 44 | # labels for discriminator 45 | y_real = tf.ones(batch_size) 46 | y_fake = tf.zeros(batch_size) 47 | 48 | 49 | # discriminator labels ( half 1s, half 0s ) 50 | y_disc = tf.concat(axis=0, values=[y, y * 0]) 51 | 52 | # 53 | # create generator 54 | # 55 | 56 | # get random class number 57 | if(int(tf.__version__.split(".")[1])<13 and int(tf.__version__.split(".")[0])<2): ### tf version < 1.13 58 | z_cat = tf.multinomial(tf.ones((batch_size, cat_dim), dtype=tf.float32) / cat_dim, 1) 59 | else: ### tf versioin >= 1.13 60 | z_cat = tf.random.categorical(tf.ones((batch_size, cat_dim), dtype=tf.float32) / cat_dim, 1) 61 | 62 | z_cat = tf.squeeze(z_cat, -1) 63 | z_cat = tf.cast(z_cat, tf.int32) 64 | 65 | # continuous latent variable 66 | z_con = tf.random_normal((batch_size, con_dim)) 67 | z_rand = tf.random_normal((batch_size, rand_dim)) 68 | 69 | z = tf.concat(axis=1, values=[tf.one_hot(z_cat, depth = cat_dim), z_con, z_rand]) 70 | 71 | 72 | # generator network 73 | gen = generator(z) 74 | 75 | # add image summary 76 | # tf.sg_summary_image(gen) 77 | tf.summary.image('real', x) 78 | tf.summary.image('fake', gen) 79 | 80 | # 81 | # discriminator 82 | disc_real, cat_real, _ = discriminator(x) 83 | disc_fake, cat_fake, con_fake = discriminator(gen) 84 | 85 | # discriminator loss 86 | loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real)) 87 | loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake)) 88 | loss_d = (loss_d_r + loss_d_f) / 2 89 | print 'loss_d', loss_d.get_shape() 90 | # generator loss 91 | loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_real)) 92 | 93 | # categorical factor loss 94 | loss_c_r = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=cat_real, labels=y)) 95 | loss_c_d = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=cat_fake, labels=z_cat)) 96 | loss_c = (loss_c_r + loss_c_d) / 2 97 | print 'loss_c', loss_c.get_shape() 98 | # continuous factor loss 99 | loss_con =tf.reduce_mean(tf.square(con_fake-z_con)) 100 | print 'loss_con', loss_con.get_shape() 101 | 102 | 103 | 104 | train_disc, disc_global_step = optim(loss_d + loss_c + loss_con, lr=0.0001, optim = 'Adm', category='discriminator') 105 | train_gen, gen_global_step = optim(loss_g + loss_c + loss_con, lr=0.001, optim = 'Adm', category='generator') 106 | init = tf.global_variables_initializer() 107 | saver = tf.train.Saver() 108 | print train_gen 109 | 110 | cur_epoch = 0 111 | cur_step = 0 112 | 113 | 114 | with tf.Session() as sess: 115 | sess.run(init) 116 | coord, threads = queue_context(sess) 117 | try: 118 | while not coord.should_stop(): 119 | cur_step += 1 120 | dis_part = cur_step*1.0/num_batch_per_epoch 121 | dis_part = int(dis_part*50) 122 | sys.stdout.write("process bar ::|"+"<"* dis_part+'|'+str(cur_step*1.0/num_batch_per_epoch*100)+'%'+'\r') 123 | sys.stdout.flush() 124 | l_disc, _, l_d_step = sess.run([loss_d, train_disc, disc_global_step]) 125 | l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step]) 126 | last_epoch = cur_epoch 127 | cur_epoch = l_d_step / num_batch_per_epoch 128 | if cur_epoch > max_epochs: 129 | break 130 | 131 | if cur_epoch> last_epoch: 132 | cur_step = 0 133 | print 'cur epoch {0} update l_d step {1}, loss_disc {2}, loss_gen {3}'.format(cur_epoch, l_d_step, l_disc, l_gen) 134 | if cur_epoch % save_epoch == 0: 135 | # save 136 | saver.save(sess, os.path.join('./checkpoint_dir', 'ac_gan'), global_step=l_d_step) 137 | except tf.errors.OutOfRangeError: 138 | print 'Train Finished' 139 | finally: 140 | coord.request_stop() 141 | -------------------------------------------------------------------------------- /GAN/AC-GAN/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import six 3 | import tensorflow as tf 4 | import sys 5 | 6 | 7 | 8 | 9 | class Opt(collections.MutableMapping): 10 | r"""Option utility class. 11 | 12 | This class is only internally used for sg_opt. 13 | """ 14 | 15 | def __init__(self, *args, **kwargs): 16 | self.__dict__.update(*args, **kwargs) 17 | 18 | def __setitem__(self, key, value): 19 | self.__dict__[key] = value 20 | 21 | def __getitem__(self, key): 22 | return self.__dict__[key] 23 | 24 | def __delitem__(self, key): 25 | del self.__dict__[key] 26 | 27 | # noinspection PyUnusedLocal,PyUnusedLocal 28 | def __getattr__(self, key): 29 | return None 30 | 31 | def __iter__(self): 32 | return iter(self.__dict__) 33 | 34 | def __len__(self): 35 | return len(self.__dict__) 36 | 37 | def __str__(self): 38 | return str(self.__dict__) 39 | 40 | def __repr__(self): 41 | return self.__dict__.__repr__() 42 | 43 | def __add__(self, other): 44 | r"""Overloads `+` operator. 45 | 46 | It does NOT overwrite the existing item. 47 | 48 | For example, 49 | 50 | ```python 51 | import sugartensor as tf 52 | 53 | opt = tf.sg_opt(size=1) 54 | opt += tf.sg_opt(size=2) 55 | print(opt) # Should be {'size': 1} 56 | ``` 57 | """ 58 | res = Opt(self.__dict__) 59 | for k, v in six.iteritems(other): 60 | if k not in res.__dict__ or res.__dict__[k] is None: 61 | res.__dict__[k] = v 62 | return res 63 | 64 | def __mul__(self, other): 65 | r"""Overloads `*` operator. 66 | 67 | It overwrites the existing item. 68 | 69 | For example, 70 | 71 | ```python 72 | import sugartensor as tf 73 | 74 | opt = tf.sg_opt(size=1) 75 | opt *= tf.sg_opt(size=2) 76 | print(opt) # Should be {'size': 2} 77 | ``` 78 | """ 79 | res = Opt(self.__dict__) 80 | for k, v in six.iteritems(other): 81 | res.__dict__[k] = v 82 | return res 83 | 84 | 85 | def data_to_tensor(data_list, batch_size, name=None): 86 | r"""Returns batch queues from the whole data. 87 | 88 | Args: 89 | data_list: A list of ndarrays. Every array must have the same size in the first dimension. 90 | batch_size: An integer. 91 | name: A name for the operations (optional). 92 | 93 | Returns: 94 | A list of tensors of `batch_size`. 95 | """ 96 | # convert to constant tensor 97 | const_list = [tf.constant(data) for data in data_list] 98 | 99 | # create queue from constant tensor 100 | queue_list = tf.train.slice_input_producer(const_list, capacity=batch_size*128, name=name) 101 | 102 | # create batch queue 103 | return tf.train.shuffle_batch(queue_list, batch_size, capacity=batch_size*128, 104 | min_after_dequeue=batch_size*32, name=name) 105 | -------------------------------------------------------------------------------- /GAN/Info-GAN/README.md: -------------------------------------------------------------------------------- 1 | ## Info-GAN 2 | A tensorflow implementation of Info-GAN. 3 | 4 | Reference [sugartensor](https://github.com/buriburisuri/sugartensor) 5 | ## Info-GAN Structure 6 | ![](../../images/infogan-fig-01.png) 7 | 8 | Info-GAN build a loss include the mutual information between latent code and generator output. 9 | [https://arxiv.org/pdf/1606.03657v1.pdf](https://arxiv.org/pdf/1606.03657v1.pdf) 10 | 11 | -------------------------------------------------------------------------------- /GAN/Info-GAN/discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | from tensorflow.python.ops import variable_scope 4 | 5 | def leaky_relu(x): 6 | return tf.where(tf.greater(x, 0), x, 0.01 * x) 7 | 8 | def discriminator(tensor, num_category=10, batch_size=32, num_cont=2): 9 | """ 10 | """ 11 | 12 | reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0 13 | print reuse 14 | print tensor.get_shape() 15 | with variable_scope.variable_scope('discriminator', reuse=reuse): 16 | tensor = slim.conv2d(tensor, num_outputs = 64, kernel_size=[4,4], stride=2, activation_fn=leaky_relu) 17 | tensor = slim.conv2d(tensor, num_outputs=128, kernel_size=[4,4], stride=2, activation_fn=leaky_relu) 18 | tensor = slim.flatten(tensor) 19 | shared_tensor = slim.fully_connected(tensor, num_outputs=1024, activation_fn = leaky_relu) 20 | recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn = leaky_relu) 21 | disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=None) 22 | disc = tf.squeeze(disc, -1) 23 | recog_cat = slim.fully_connected(recog_shared, num_outputs=num_category, activation_fn=None) 24 | recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid) 25 | return disc, recog_cat, recog_cont 26 | -------------------------------------------------------------------------------- /GAN/Info-GAN/generate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from generator import generator 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import os 8 | 9 | _logger = tf.logging._logger 10 | _logger.setLevel(0) 11 | 12 | batch_size = 100 # batch size 13 | cat_dim = 10 # total categorical factor 14 | con_dim = 2 # total continuous factor 15 | rand_dim = 38 # total random latent dimension 16 | 17 | 18 | target_num = tf.placeholder(dtype=tf.int32, shape=batch_size) 19 | target_cval_1 = tf.placeholder(dtype=tf.float32, shape=batch_size) 20 | target_cval_2 = tf.placeholder(dtype=tf.float32, shape=batch_size) 21 | 22 | z = tf.one_hot(tf.ones(batch_size, dtype=tf.int32) * target_num, depth=cat_dim) 23 | z = tf.concat(axis=z.get_shape().ndims-1, values=[z, tf.expand_dims(target_cval_1, -1), tf.expand_dims(target_cval_2, -1)]) 24 | 25 | z = tf.concat(axis=z.get_shape().ndims-1, values=[z, tf.random_normal((batch_size, rand_dim))]) 26 | 27 | gen = tf.squeeze(generator(z), -1) 28 | 29 | def run_generator(num, x1, x2, fig_name='sample.png'): 30 | with tf.Session() as sess: 31 | try: 32 | sess.run(tf.group(tf.global_variables_initializer(), 33 | tf.local_variables_initializer())) 34 | except AttributeError: 35 | sess.run(tf.group(tf.initialize_all_variables(), 36 | tf.initialize_local_variables())) 37 | saver = tf.train.Saver() 38 | saver.restore(sess, tf.train.latest_checkpoint('checkpoint_dir')) 39 | imgs = sess.run(gen, {target_num: num, target_cval_1: x1, target_cval_2:x2}) 40 | 41 | _, ax = plt.subplots(10,10, sharex=True, sharey=True) 42 | for i in range(10): 43 | for j in range(10): 44 | ax[i][j].imshow(imgs[i*10+j], 'gray') 45 | ax[i][j].set_axis_off() 46 | plt.savefig(os.path.join('result/',fig_name), dpi=600) 47 | print 'Sample image save to "result/{0}"'.format(fig_name) 48 | plt.close() 49 | 50 | a = np.random.randint(0, cat_dim, batch_size) 51 | print a 52 | run_generator(a, 53 | np.random.uniform(0, 1, batch_size), np.random.uniform(0, 1, batch_size), 54 | fig_name='fake.png') 55 | 56 | # classified image 57 | run_generator(np.arange(10).repeat(10), np.linspace(0, 1, 10).repeat(10), np.expand_dims(np.linspace(0, 1, 10), axis=1).repeat(10, axis=1).T.flatten(),) 58 | -------------------------------------------------------------------------------- /GAN/Info-GAN/generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | from tensorflow.python.ops import variable_scope 4 | 5 | def generator(tensor): 6 | reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0 7 | print tensor.get_shape() 8 | with variable_scope.variable_scope('generator', reuse = reuse): 9 | tensor = slim.fully_connected(tensor, 1024) 10 | print tensor 11 | tensor = slim.batch_norm(tensor, activation_fn=tf.nn.relu) 12 | tensor = slim.fully_connected(tensor, 7*7*128) 13 | tensor = slim.batch_norm(tensor, activation_fn=tf.nn.relu) 14 | tensor = tf.reshape(tensor, [-1, 7, 7, 128]) 15 | # print '22',tensor.get_shape() 16 | tensor = slim.conv2d_transpose(tensor, 64, kernel_size=[4,4], stride=2, activation_fn = None) 17 | print 'gen',tensor.get_shape() 18 | tensor = slim.batch_norm(tensor, activation_fn = tf.nn.relu) 19 | tensor = slim.conv2d_transpose(tensor, 1, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.sigmoid) 20 | return tensor -------------------------------------------------------------------------------- /GAN/Info-GAN/mnist.py: -------------------------------------------------------------------------------- 1 | from tensorflow.examples.tutorials.mnist import input_data 2 | from utils import data_to_tensor, Opt 3 | 4 | 5 | 6 | 7 | 8 | class Mnist(object): 9 | r"""Downloads Mnist datasets and puts them in queues. 10 | """ 11 | _data_dir = './asset/data/mnist' 12 | 13 | def __init__(self, batch_size=128, num_epochs=30, reshape=False, one_hot=False): 14 | 15 | # load sg_data set 16 | data_set = input_data.read_data_sets(Mnist._data_dir, reshape=reshape, one_hot=one_hot) 17 | 18 | self.batch_size = batch_size 19 | 20 | # save each sg_data set 21 | _train = data_set.train 22 | _valid = data_set.validation 23 | _test = data_set.test 24 | 25 | # member initialize 26 | self.train, self.valid, self.test = Opt(), Opt(), Opt() 27 | 28 | # convert to tensor queue 29 | self.train.image, self.train.label = \ 30 | data_to_tensor([_train.images, _train.labels.astype('int32')], batch_size, name='train') 31 | self.valid.image, self.valid.label = \ 32 | data_to_tensor([_valid.images, _valid.labels.astype('int32')], batch_size, name='valid') 33 | self.test.image, self.test.label = \ 34 | data_to_tensor([_test.images, _test.labels.astype('int32')], batch_size, name='test') 35 | 36 | # calc total batch count 37 | self.train.num_batch = _train.labels.shape[0] // batch_size 38 | self.valid.num_batch = _valid.labels.shape[0] // batch_size 39 | self.test.num_batch = _test.labels.shape[0] // batch_size 40 | -------------------------------------------------------------------------------- /GAN/Info-GAN/optimizer.py: -------------------------------------------------------------------------------- 1 | from utils import Opt 2 | import tensorflow as tf 3 | 4 | def optim(loss, **kwargs): 5 | r"""Applies gradients to variables. 6 | 7 | Args: 8 | loss: A 0-D `Tensor` containing the value to minimize. 9 | kwargs: 10 | optim: A name for optimizer. 'MaxProp' (default), 'AdaMax', 'Adam', or 'sgd'. 11 | lr: A Python Scalar (optional). Learning rate. Default is .001. 12 | beta1: A Python Scalar (optional). Default is .9. 13 | beta2: A Python Scalar (optional). Default is .99. 14 | category: A string or string list. Specifies the variables that should be trained (optional). 15 | Only if the name of a trainable variable starts with `category`, it's value is updated. 16 | Default is '', which means all trainable variables are updated. 17 | """ 18 | opt = Opt(kwargs) 19 | # opt += Opt(optim='MaxProp', lr=0.001, beta1=0.9, beta2=0.99, category='') 20 | 21 | # default training options 22 | opt += Opt(optim='MaxProp', lr=0.001, beta1=0.9, beta2=0.99, category='') 23 | 24 | # select optimizer 25 | # if opt.optim == 'MaxProp': 26 | # optim = tf.sg_optimize.MaxPropOptimizer(learning_rate=opt.lr, beta2=opt.beta2) 27 | # elif opt.optim == 'AdaMax': 28 | # optim = tf.sg_optimize.AdaMaxOptimizer(learning_rate=opt.lr, beta1=opt.beta1, beta2=opt.beta2) 29 | # elif opt.optim == 'Adam': 30 | if opt.optim == 'Adm': 31 | optim = tf.train.AdamOptimizer(learning_rate=opt.lr, beta1=opt.beta1, beta2=opt.beta2) 32 | else: 33 | optim = tf.train.GradientDescentOptimizer(learning_rate=opt.lr) 34 | 35 | # get trainable variables 36 | if isinstance(opt.category, (tuple, list)): 37 | var_list = [] 38 | for cat in opt.category: 39 | var_list.extend([t for t in tf.trainable_variables() if t.name.startswith(cat)]) 40 | else: 41 | var_list = [t for t in tf.trainable_variables() if t.name.startswith(opt.category)] 42 | 43 | # calc gradient 44 | gradient = optim.compute_gradients(loss, var_list=var_list) 45 | 46 | # add summary 47 | for v, g in zip(var_list, gradient): 48 | # exclude batch normal statics 49 | if 'mean' not in v.name and 'variance' not in v.name \ 50 | and 'beta' not in v.name and 'gamma' not in v.name: 51 | prefix = '' 52 | # summary name 53 | name = prefix + ''.join(v.name.split(':')[:-1]) 54 | # summary statistics 55 | # noinspection PyBroadException 56 | try: 57 | tf.summary.scalar(name + '/grad', tf.global_norm([g])) 58 | tf.summary.histogram(name + '/grad-h', g) 59 | except AttributeError: 60 | tf.scalar_summary(name + '/grad', tf.global_norm([g])) 61 | tf.histogram_summary(name + '/grad-h', g) 62 | except: 63 | pass 64 | global_step = tf.Variable(0, name='global_step', trainable=False) 65 | # gradient update op 66 | return optim.apply_gradients(gradient, global_step=global_step), global_step -------------------------------------------------------------------------------- /GAN/Info-GAN/queue_context.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def queue_context(sess=None): 4 | r"""Context helper for queue routines. 5 | 6 | Args: 7 | sess: A session to open queues. If not specified, a new session is created. 8 | 9 | Returns: 10 | None 11 | """ 12 | 13 | # default session 14 | sess = tf.get_default_session() if sess is None else sess 15 | 16 | # thread coordinator 17 | coord = tf.train.Coordinator() 18 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 19 | return coord, threads -------------------------------------------------------------------------------- /GAN/Info-GAN/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import logging 5 | from mnist import Mnist 6 | import tensorflow.contrib.slim as slim 7 | from generator import generator 8 | from discriminator import discriminator 9 | from optimizer import optim 10 | from queue_context import queue_context 11 | import os 12 | import sys 13 | 14 | 15 | tf.logging.set_verbosity(0) 16 | 17 | 18 | # 19 | # hyper parameters 20 | # 21 | 22 | batch_size = 32 # batch size 23 | cat_dim = 10 # total categorical factor 24 | con_dim = 2 # total continuous factor 25 | rand_dim = 38 26 | num_epochs = 30 27 | debug_max_steps = 1000 28 | save_epoch = 5 29 | max_epochs = 50 30 | 31 | # 32 | # inputs 33 | # 34 | 35 | # MNIST input tensor ( with QueueRunner ) 36 | data = Mnist(batch_size=batch_size, num_epochs=num_epochs) 37 | num_batch_per_epoch = data.train.num_batch 38 | 39 | 40 | # input images and labels 41 | x = data.train.image 42 | y = data.train.label 43 | 44 | # labels for discriminator 45 | y_real = tf.ones(batch_size) 46 | y_fake = tf.zeros(batch_size) 47 | 48 | 49 | # discriminator labels ( half 1s, half 0s ) 50 | y_disc = tf.concat(axis=0, values=[y, y * 0]) 51 | 52 | # 53 | # create generator 54 | # 55 | 56 | # get random class number 57 | if(int(tf.__version__.split(".")[1])<13 and int(tf.__version__.split(".")[0])<2): ### tf version < 1.13 58 | z_cat = tf.multinomial(tf.ones((batch_size, cat_dim), dtype=tf.float32) / cat_dim, 1) 59 | else: ### tf version >= 1.13 60 | z_cat = tf.random.categorical(tf.ones((batch_size, cat_dim), dtype=tf.float32) / cat_dim, 1) 61 | 62 | z_cat = tf.squeeze(z_cat, -1) 63 | z_cat = tf.cast(z_cat, tf.int32) 64 | 65 | # continuous latent variable 66 | z_con = tf.random_normal((batch_size, con_dim)) 67 | z_rand = tf.random_normal((batch_size, rand_dim)) 68 | 69 | z = tf.concat(axis=1, values=[tf.one_hot(z_cat, depth = cat_dim), z_con, z_rand]) 70 | 71 | 72 | # generator network 73 | gen = generator(z) 74 | 75 | # add image summary 76 | # tf.sg_summary_image(gen) 77 | try: 78 | tf.summary.image('real', x) 79 | tf.summary.image('fake', gen) 80 | except AttributeError: 81 | tf.image_summary('real', x) 82 | tf.image_summary('fake', gen) 83 | 84 | # 85 | # discriminator 86 | disc_real, _, _ = discriminator(x) 87 | disc_fake, cat_fake, con_fake = discriminator(gen) 88 | 89 | # discriminator loss 90 | loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real)) 91 | loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake)) 92 | loss_d = (loss_d_r + loss_d_f) / 2 93 | print 'loss_d', loss_d.get_shape() 94 | # generator loss 95 | loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_real)) 96 | 97 | # categorical factor loss 98 | loss_c = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=cat_fake, labels=z_cat)) 99 | 100 | # continuous factor loss 101 | loss_con =tf.reduce_mean(tf.square(con_fake-z_con)) 102 | 103 | train_disc, disc_global_step = optim(loss_d + loss_c + loss_con, lr=0.0001, optim = 'Adm', category='discriminator') 104 | train_gen, gen_global_step = optim(loss_g + loss_c + loss_con, lr=0.001, optim = 'Adm', category='generator') 105 | try: 106 | init = tf.global_variables_initializer() 107 | except AttributeError: 108 | init = tf.initialize_all_variables() 109 | saver = tf.train.Saver() 110 | print train_gen 111 | 112 | cur_epoch = 0 113 | cur_step = 0 114 | 115 | 116 | with tf.Session() as sess: 117 | sess.run(init) 118 | coord, threads = queue_context(sess) 119 | try: 120 | while not coord.should_stop(): 121 | cur_step += 1 122 | dis_part = cur_step*1.0/num_batch_per_epoch 123 | dis_part = int(dis_part*50) 124 | sys.stdout.write("process bar ::|"+"<"* dis_part+'|'+str(cur_step*1.0/num_batch_per_epoch*100)+'%'+'\r') 125 | sys.stdout.flush() 126 | l_disc, _, l_d_step = sess.run([loss_d, train_disc, disc_global_step]) 127 | l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step]) 128 | last_epoch = cur_epoch 129 | cur_epoch = l_d_step / num_batch_per_epoch 130 | if cur_epoch > max_epochs: 131 | break 132 | 133 | if cur_epoch> last_epoch: 134 | cur_step = 0 135 | print 'cur epoch {0} update l_d step {1}, loss_disc {2}, loss_gen {3}'.format(cur_epoch, l_d_step, l_disc, l_gen) 136 | if cur_epoch % save_epoch == 0: 137 | # save 138 | saver.save(sess, os.path.join('./checkpoint_dir', 'ac_gan'), global_step=l_d_step) 139 | except tf.errors.OutOfRangeError: 140 | print 'Train Finished' 141 | finally: 142 | coord.request_stop() 143 | -------------------------------------------------------------------------------- /GAN/Info-GAN/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import six 3 | import tensorflow as tf 4 | import sys 5 | 6 | 7 | 8 | 9 | class Opt(collections.MutableMapping): 10 | r"""Option utility class. 11 | 12 | This class is only internally used for sg_opt. 13 | """ 14 | 15 | def __init__(self, *args, **kwargs): 16 | self.__dict__.update(*args, **kwargs) 17 | 18 | def __setitem__(self, key, value): 19 | self.__dict__[key] = value 20 | 21 | def __getitem__(self, key): 22 | return self.__dict__[key] 23 | 24 | def __delitem__(self, key): 25 | del self.__dict__[key] 26 | 27 | # noinspection PyUnusedLocal,PyUnusedLocal 28 | def __getattr__(self, key): 29 | return None 30 | 31 | def __iter__(self): 32 | return iter(self.__dict__) 33 | 34 | def __len__(self): 35 | return len(self.__dict__) 36 | 37 | def __str__(self): 38 | return str(self.__dict__) 39 | 40 | def __repr__(self): 41 | return self.__dict__.__repr__() 42 | 43 | def __add__(self, other): 44 | r"""Overloads `+` operator. 45 | 46 | It does NOT overwrite the existing item. 47 | 48 | For example, 49 | 50 | ```python 51 | import sugartensor as tf 52 | 53 | opt = tf.sg_opt(size=1) 54 | opt += tf.sg_opt(size=2) 55 | print(opt) # Should be {'size': 1} 56 | ``` 57 | """ 58 | res = Opt(self.__dict__) 59 | for k, v in six.iteritems(other): 60 | if k not in res.__dict__ or res.__dict__[k] is None: 61 | res.__dict__[k] = v 62 | return res 63 | 64 | def __mul__(self, other): 65 | r"""Overloads `*` operator. 66 | 67 | It overwrites the existing item. 68 | 69 | For example, 70 | 71 | ```python 72 | import sugartensor as tf 73 | 74 | opt = tf.sg_opt(size=1) 75 | opt *= tf.sg_opt(size=2) 76 | print(opt) # Should be {'size': 2} 77 | ``` 78 | """ 79 | res = Opt(self.__dict__) 80 | for k, v in six.iteritems(other): 81 | res.__dict__[k] = v 82 | return res 83 | 84 | 85 | def data_to_tensor(data_list, batch_size, name=None): 86 | r"""Returns batch queues from the whole data. 87 | 88 | Args: 89 | data_list: A list of ndarrays. Every array must have the same size in the first dimension. 90 | batch_size: An integer. 91 | name: A name for the operations (optional). 92 | 93 | Returns: 94 | A list of tensors of `batch_size`. 95 | """ 96 | # convert to constant tensor 97 | const_list = [tf.constant(data) for data in data_list] 98 | 99 | # create queue from constant tensor 100 | queue_list = tf.train.slice_input_producer(const_list, capacity=batch_size*128, name=name) 101 | 102 | # create batch queue 103 | return tf.train.shuffle_batch(queue_list, batch_size, capacity=batch_size*128, 104 | min_after_dequeue=batch_size*32, name=name) 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## algorithm 2 | 3 | MLP for mnist done 4 | 5 | ## covert to tfrecord 6 | 7 | test on 102flowers done 8 | command like this: 9 | 10 | `python covert_somedata_to_tfrecord.py --dataset_name=102flowers --dataset_dir='./'` 11 | 12 | ## serving 13 | test on MLP model on mnist done 14 | 15 | download the mnist jpgs on baidu yunpan [mnist jpgs](https://pan.baidu.com/s/1o8EWkVS) 16 | 17 | ![](./images/mnist_server.png) 18 | ![](./images/mnist_client_result.png) 19 | 20 | 21 | 22 | 23 | Reference: [https://github.com/tobegit3hub/deep_recommend_system](https://github.com/tobegit3hub/deep_recommend_system) 24 | 25 | 26 | ## finetuning and deploy model with flask 27 | 28 | ![](./images/flask_with_pretrain_model.png) 29 | ![](./images/flask_with_pretrain_model_00.png) 30 | 31 | In folder finetuning, we use tf.slim to finetuning the pretrain model (I use the same method in my porn detection) and use flask to buid a very simple inference system. 32 | 33 | 34 | 35 | ## Inference Demo 36 | ![](./images/demo_result.png) 37 | I deploy a image classification in [demo page](http://demo.duanshishi.com). It is based on Tensorflow and Flask. Feel free to try. 38 | 39 | ## Chinese rec 40 | 41 | ![](./images/chinese_rec_example.png) 42 | 43 | You can get the detailed introductation in [TensorFlow与中文手写汉字识别](http://hacker.duanshishi.com/?p=1753) 44 | 45 | ## GAN 46 | ### AC-GAN and Info-GAN in TensorFlow 47 | The project in GAN includ AC-GAN and Info-GAN: 48 | 49 | - AC-GAN ![](./images/ac-gan-fig-01.png) 50 | - Info-GAN ![](./images/infogan-fig-01.png) 51 | 52 | ### Result 53 | Infogan result: 54 | 55 | ![](./images/infogan-result.png) 56 | ![](./images/infogan-result-01.png) 57 | AC-GAN result: 58 | ![](./images/acgan-result.png) 59 | ![](./images/acgan-result-01.png) 60 | 61 | You can get detailed information in [GAN的理解与TF的实现](http://hacker.duanshishi.com/?p=1766) 62 | 63 | 64 | ## DeepColor 65 | 66 | The images from [safebooru.org](http://safebooru.org), you can download from the baidu yunpan [https://pan.baidu.com/s/1c1HOIHU](https://pan.baidu.com/s/1c1HOIHU). 67 | 68 | ![](./images/mac_blogs_deepcolor-01.png) 69 | 70 | The results: 71 | ![](./images/mac_blogs_deepcolor-03.png) 72 | 73 | 74 | details in [CGAN之deepcolor实践](http://www.duanshishi.com/?p=1791) 75 | 76 | 77 | ## Text Classifer 78 | 79 | You can get the detail information in [自然语言处理第一番之文本分类器](http://hacker.duanshishi.com/?p=1805). 80 | Classic methods such as Word Count Distribution and TF-IDF and DeepLearning Methods based CNN and C-LSTM are provided in [nlp/text_classifier](https://github.com/burness/tensorflow-101/tree/master/nlp/text_classifier) -------------------------------------------------------------------------------- /algorithm/README.md: -------------------------------------------------------------------------------- 1 | ## train_mnist_softmax.py 2 | steps: 10000, acc: 0.9146 3 | 4 | ## train_mnist_single_perceptron.py 5 | steps: 10000, acc: 0.9427 6 | add a perceptron in train_mnist_single_perceptron.py: 7 | steps: 10000, acc: 0.9 8 | 9 | -------------------------------------------------------------------------------- /algorithm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/algorithm/__init__.py -------------------------------------------------------------------------------- /algorithm/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/algorithm/__init__.pyc -------------------------------------------------------------------------------- /algorithm/input_data.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/algorithm/input_data.pyc -------------------------------------------------------------------------------- /algorithm/train_mnist_multi_perceptron.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets 8 | import tensorflow as tf 9 | import time 10 | 11 | FLAGS = tf.app.flags.FLAGS 12 | 13 | tf.flags.DEFINE_string("data_dir", "./mnist", 14 | "mnist data_dir") 15 | 16 | 17 | def main(_): 18 | mnist = read_data_sets(FLAGS.data_dir, one_hot=True) 19 | x = tf.placeholder(tf.float32, [None, 784]) 20 | W1 = tf.Variable(tf.random_normal([784, 256])) 21 | b1 = tf.Variable(tf.random_normal([256])) 22 | W2 = tf.Variable(tf.random_normal([256, 256])) 23 | b2 = tf.Variable(tf.random_normal([256])) 24 | W3 = tf.Variable(tf.random_normal([256,10])) 25 | b3 = tf.Variable(tf.random_normal([10])) 26 | 27 | lay1 = tf.nn.relu(tf.add(tf.matmul(x, W1),b1)) 28 | lay2 = tf.nn.relu(tf.add(tf.matmul(lay1, W2), b2)) 29 | y = tf.add(tf.matmul(lay2, W3),b3) 30 | 31 | y_ = tf.placeholder(tf.float32, [None, 10]) 32 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)) 33 | train_step = tf.train.GradientDescentOptimizer(0.0095).minimize(cross_entropy) 34 | 35 | sess = tf.InteractiveSession() 36 | tf.global_variables_initializer().run() 37 | for index in range(100000): 38 | # print('process the {}th batch'.format(index)) 39 | start_train = time.time() 40 | batch_xs, batch_ys = mnist.train.next_batch(100) 41 | sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 42 | # print('the {0} batch takes time: {1}'.format(index, time.time()-start_train)) 43 | 44 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 45 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 46 | print(sess.run(accuracy, feed_dict={x: mnist.test.images, 47 | y_: mnist.test.labels})) 48 | 49 | if __name__ == '__main__': 50 | tf.app.run() -------------------------------------------------------------------------------- /algorithm/train_mnist_single_perceptron.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets 8 | import tensorflow as tf 9 | import time 10 | 11 | FLAGS = tf.app.flags.FLAGS 12 | 13 | tf.flags.DEFINE_string("data_dir", "./mnist", 14 | "mnist data_dir") 15 | 16 | 17 | def main(_): 18 | mnist = read_data_sets(FLAGS.data_dir, one_hot=True) 19 | x = tf.placeholder(tf.float32, [None, 784]) 20 | W1 = tf.Variable(tf.random_normal([784, 256])) 21 | b1 = tf.Variable(tf.random_normal([256])) 22 | W2 = tf.Variable(tf.random_normal([256, 10])) 23 | b2 = tf.Variable(tf.random_normal([10])) 24 | lay1 = tf.nn.relu(tf.matmul(x, W1) + b1) 25 | y = tf.add(tf.matmul(lay1, W2),b2) 26 | 27 | y_ = tf.placeholder(tf.float32, [None, 10]) 28 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)) 29 | train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 30 | 31 | sess = tf.InteractiveSession() 32 | tf.global_variables_initializer().run() 33 | for index in range(10000): 34 | # print('process the {}th batch'.format(index)) 35 | start_train = time.time() 36 | batch_xs, batch_ys = mnist.train.next_batch(100) 37 | sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 38 | # print('the {0} batch takes time: {1}'.format(index, time.time()-start_train)) 39 | 40 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 41 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 42 | print(sess.run(accuracy, feed_dict={x: mnist.test.images, 43 | y_: mnist.test.labels})) 44 | 45 | if __name__ == '__main__': 46 | tf.app.run() -------------------------------------------------------------------------------- /algorithm/train_mnist_softmax.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets 8 | import tensorflow as tf 9 | import time 10 | 11 | FLAGS = tf.app.flags.FLAGS 12 | 13 | tf.flags.DEFINE_string("data_dir", "./mnist", 14 | "mnist data_dir") 15 | 16 | 17 | def main(_): 18 | mnist = read_data_sets(FLAGS.data_dir, one_hot=True) 19 | x = tf.placeholder(tf.float32, [None, 784]) 20 | W = tf.Variable(tf.random_normal([784, 10])) 21 | b = tf.Variable(tf.random_normal([10])) 22 | y = tf.matmul(x, W) + b 23 | 24 | y_ = tf.placeholder(tf.float32, [None, 10]) 25 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)) 26 | train_step = tf.train.GradientDescentOptimizer(0.2).minimize(cross_entropy) 27 | 28 | sess = tf.InteractiveSession() 29 | tf.global_variables_initializer().run() 30 | for index in range(10000): 31 | print('process the {}th batch'.format(index)) 32 | start_train = time.time() 33 | batch_xs, batch_ys = mnist.train.next_batch(100) 34 | sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 35 | print('the {0} batch takes time: {1}'.format(index, time.time()-start_train)) 36 | 37 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 38 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 39 | print(sess.run(accuracy, feed_dict={x: mnist.test.images, 40 | y_: mnist.test.labels})) 41 | 42 | if __name__ == '__main__': 43 | tf.app.run() -------------------------------------------------------------------------------- /chinese_hand_write_rec/src/README.md: -------------------------------------------------------------------------------- 1 | ## Dataset 2 | Download your dataset in [baidu yun](https://pan.baidu.com/s/1o84jIrg) 3 | 4 | ## Train 5 | 6 | run the command `python chinese_rec.py --mode=train --max_steps=200000 --eval_steps=1000 --save_steps=10000` 7 | 8 | ## Validation 9 | run the command `python chinese_rec.py --mode=validation` 10 | 11 | 12 | ## Inference 13 | run the command `python chinese_rec.py--mode=inference` 14 | 15 | 16 | 17 | **The detailed you can get in [TensorFlow与中文手写汉字识别](http://hacker.duanshishi.com/?p=1753)** 18 | 19 | TODOs: 20 | 21 | - delete the placeholders in chinese_rec.py 22 | 23 | -------------------------------------------------------------------------------- /chinese_hand_write_rec/src/data_to_png.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import struct 4 | from PIL import Image 5 | 6 | 7 | data_dir = '../data' 8 | # train_data_dir = "../data/HWDB1.1trn_gnt" 9 | train_data_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt') 10 | test_data_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt') 11 | 12 | 13 | def read_from_gnt_dir(gnt_dir=train_data_dir): 14 | def one_file(f): 15 | header_size = 10 16 | while True: 17 | header = np.fromfile(f, dtype='uint8', count=header_size) 18 | if not header.size: break 19 | sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24) 20 | tagcode = header[5] + (header[4]<<8) 21 | width = header[6] + (header[7]<<8) 22 | height = header[8] + (header[9]<<8) 23 | if header_size + width*height != sample_size: 24 | break 25 | image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width)) 26 | yield image, tagcode 27 | for file_name in os.listdir(gnt_dir): 28 | if file_name.endswith('.gnt'): 29 | file_path = os.path.join(gnt_dir, file_name) 30 | with open(file_path, 'rb') as f: 31 | for image, tagcode in one_file(f): 32 | yield image, tagcode 33 | char_set = set() 34 | for _, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir): 35 | tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312') 36 | char_set.add(tagcode_unicode) 37 | char_list = list(char_set) 38 | char_dict = dict(zip(sorted(char_list), range(len(char_list)))) 39 | print len(char_dict) 40 | import pickle 41 | f = open('char_dict', 'wb') 42 | pickle.dump(char_dict, f) 43 | f.close() 44 | train_counter = 0 45 | test_counter = 0 46 | for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir): 47 | tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312') 48 | im = Image.fromarray(image) 49 | dir_name = '../data/train/' + '%0.5d'%char_dict[tagcode_unicode] 50 | if not os.path.exists(dir_name): 51 | os.mkdir(dir_name) 52 | im.convert('RGB').save(dir_name+'/' + str(train_counter) + '.png') 53 | train_counter += 1 54 | for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir): 55 | tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312') 56 | im = Image.fromarray(image) 57 | dir_name = '../data/test/' + '%0.5d'%char_dict[tagcode_unicode] 58 | if not os.path.exists(dir_name): 59 | os.mkdir(dir_name) 60 | im.convert('RGB').save(dir_name+'/' + str(test_counter) + '.png') 61 | test_counter += 1 62 | -------------------------------------------------------------------------------- /covert_to_tfrecord/covert_sex_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import covert_datasets_tfrecord 7 | 8 | FLAGS = tf.app.flags.FLAGS 9 | 10 | tf.app.flags.DEFINE_string( 11 | 'dataset_name', 12 | None, 13 | 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".') 14 | 15 | tf.app.flags.DEFINE_string( 16 | 'dataset_dir', 17 | None, 18 | 'The directory where the output TFRecords and temporary files are saved.') 19 | 20 | 21 | def main(_): 22 | if not FLAGS.dataset_name: 23 | raise ValueError('You must supply the dataset name with --dataset_name') 24 | if not FLAGS.dataset_dir: 25 | raise ValueError('You must supply the dataset name with --dataset_dir') 26 | 27 | covert_datasets_tfrecord.run(FLAGS.dataset_dir, FLAGS.dataset_name) 28 | 29 | 30 | if __name__ == '__main__': 31 | tf.app.run() -------------------------------------------------------------------------------- /covert_to_tfrecord/covert_somedata_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import covert_datasets_tfrecord 7 | 8 | FLAGS = tf.app.flags.FLAGS 9 | 10 | tf.app.flags.DEFINE_string( 11 | 'dataset_name', 12 | None, 13 | 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".') 14 | 15 | tf.app.flags.DEFINE_string( 16 | 'dataset_dir', 17 | None, 18 | 'The directory where the output TFRecords and temporary files are saved.') 19 | 20 | 21 | def main(_): 22 | if not FLAGS.dataset_name: 23 | raise ValueError('You must supply the dataset name with --dataset_name') 24 | if not FLAGS.dataset_dir: 25 | raise ValueError('You must supply the dataset name with --dataset_dir') 26 | 27 | covert_datasets_tfrecord.run(FLAGS.dataset_dir, FLAGS.dataset_name) 28 | 29 | 30 | if __name__ == '__main__': 31 | tf.app.run() -------------------------------------------------------------------------------- /deepcolor/.gitignore: -------------------------------------------------------------------------------- 1 | imgs/ 2 | results_straight/ 3 | results_colortrain/ 4 | oldres/ 5 | checkpoint/ 6 | checkpoint_old/ 7 | main.pyc 8 | utils.pyc 9 | server.pyc 10 | ops.pyc 11 | cv.py 12 | cv2.so 13 | *.pyc 14 | -------------------------------------------------------------------------------- /deepcolor/README.md: -------------------------------------------------------------------------------- 1 | # deepcolor: Automatic coloring and shading of manga-style lineart, using Tensorflow + cGANs 2 | 3 | ![]() 4 | 5 | ![]() 6 | 7 | Setup: 8 | ``` 9 | 0. Have tensorflow + OpenCV installed. 10 | 1. make a folder called "results" 11 | 2. make a folder called "imgs" 12 | 3. Fill the "imgs" folder with your own .jpg images, or run "download_images.py" to download from Safebooru. 13 | 4. Run "python main.py train". I trained for ~20 epochs, taking about 16 hours on one GPU. 14 | 5. To sample, run "python main.py sample" 15 | 6. To start the server, run "python server.py". It will host on port 8000. 16 | ``` 17 | 18 | Read the writeup: 19 | http://kvfrans.com/coloring-and-shading-line-art-automatically-through-conditional-gans/ 20 | 21 | Try the demo: 22 | http://color.kvfrans.com 23 | 24 | Code based off [this pix2pix implementation](https://github.com/yenchenlin/pix2pix-tensorflow). 25 | 26 | 27 | ------------ 28 | Reference from the [deep color](https://github.com/kvfrans/deepcolor) 29 | -------------------------------------------------------------------------------- /deepcolor/download_images.py: -------------------------------------------------------------------------------- 1 | import urllib2 2 | import urllib 3 | import json 4 | import numpy as np 5 | import cv2 6 | import untangle 7 | 8 | maxsize = 512 9 | 10 | # tags = ["asu_tora","puuakachan","mankun","hammer_%28sunset_beach%29",""] 11 | 12 | # for tag in tags: 13 | 14 | count = 0 15 | 16 | for i in xrange(299,10000): 17 | header = {'Referer':'http://safebooru.org/index.php?page=post&s=list','User-Agent' : 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36'} 18 | url = "http://safebooru.org/index.php?page=dapi&s=post&q=index&tags=1girl%20solo&pid="+str(i+5000) 19 | request = urllib2.Request(url, headers=header) 20 | stringreturn = urllib2.urlopen(request).read() 21 | print stringreturn 22 | xmlreturn = untangle.parse(stringreturn) 23 | for post in xmlreturn.posts.post: 24 | imgurl = "http:" + post["sample_url"] 25 | print imgurl 26 | if ("png" in imgurl) or ("jpg" in imgurl): 27 | 28 | resp = urllib.urlopen(imgurl) 29 | image = np.asarray(bytearray(resp.read()), dtype="uint8") 30 | image = cv2.imdecode(image, cv2.IMREAD_COLOR) 31 | height, width = image.shape[:2] 32 | if height > width: 33 | scalefactor = (maxsize*1.0) / width 34 | res = cv2.resize(image,(int(width * scalefactor), int(height*scalefactor)), interpolation = cv2.INTER_CUBIC) 35 | cropped = res[0:maxsize,0:maxsize] 36 | if width > height: 37 | scalefactor = (maxsize*1.0) / height 38 | res = cv2.resize(image,(int(width * scalefactor), int(height*scalefactor)), interpolation = cv2.INTER_CUBIC) 39 | center_x = int(round(width*scalefactor*0.5)) 40 | print center_x 41 | cropped = res[0:maxsize,center_x - maxsize/2:center_x + maxsize/2] 42 | 43 | # img_edge = cv2.adaptiveThreshold(cropped, 255, 44 | # cv2.ADAPTIVE_THRESH_MEAN_C, 45 | # cv2.THRESH_BINARY, 46 | # blockSize=9, 47 | # C=2) 48 | 49 | count += 1 50 | cv2.imwrite("imgs-valid/"+str(count)+".jpg",cropped) 51 | # cv2.imwrite("imgs/"+str(post["id"])+"-edge.jpg",img_edge) 52 | -------------------------------------------------------------------------------- /deepcolor/download_images_multithread.py: -------------------------------------------------------------------------------- 1 | import os 2 | import Queue 3 | from threading import Thread 4 | from time import time 5 | from itertools import chain 6 | import urllib2 7 | import untangle 8 | import numpy as np 9 | import cv2 10 | 11 | def download_imgs(url): 12 | # count = 0 13 | maxsize = 512 14 | file_name = url.split('=')[-1] 15 | header = {'Referer':'http://safebooru.org/index.php?page=post&s=list','User-Agent' : 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36'} 16 | request = urllib2.Request(url, headers=header) 17 | stringreturn = urllib2.urlopen(request).read() 18 | xmlreturn = untangle.parse(stringreturn) 19 | count = 0 20 | print xmlreturn.posts[0]['sample_url'] 21 | try: 22 | for post in xmlreturn.posts.post: 23 | try: 24 | imgurl = "http:" + post["sample_url"] 25 | print imgurl 26 | if ("png" in imgurl) or ("jpg" in imgurl): 27 | resp = urllib2.urlopen(imgurl) 28 | image = np.asarray(bytearray(resp.read()), dtype="uint8") 29 | image = cv2.imdecode(image, cv2.IMREAD_COLOR) 30 | height, width = image.shape[:2] 31 | if height > width: 32 | scalefactor = (maxsize*1.0) / width 33 | res = cv2.resize(image,(int(width * scalefactor), int(height*scalefactor)), interpolation = cv2.INTER_CUBIC) 34 | cropped = res[0:maxsize,0:maxsize] 35 | if width >= height: 36 | scalefactor = (maxsize*1.0) / height 37 | res = cv2.resize(image,(int(width * scalefactor), int(height*scalefactor)), interpolation = cv2.INTER_CUBIC) 38 | center_x = int(round(width*scalefactor*0.5)) 39 | print center_x 40 | cropped = res[0:maxsize,center_x - maxsize/2:center_x + maxsize/2] 41 | count += 1 42 | cv2.imwrite("imgs-valid/"+file_name+'_'+str(count)+'.jpg',cropped) 43 | except: 44 | continue 45 | except: 46 | print "no post in xml" 47 | return 48 | 49 | class DownloadWorker(Thread): 50 | def __init__(self, queue): 51 | Thread.__init__(self) 52 | self.queue = queue 53 | 54 | def run(self): 55 | while True: 56 | # Get the work from the queue and expand the tuple 57 | url = self.queue.get() 58 | if url is None: 59 | break 60 | # download_link(directory, link) 61 | download_imgs(url) 62 | self.queue.task_done() 63 | 64 | if __name__ == '__main__': 65 | start = time() 66 | download_queue = Queue.Queue(maxsize=100) 67 | for x in range(8): 68 | worker = DownloadWorker(download_queue) 69 | worker.daemon = True 70 | worker.start() 71 | 72 | url_links = ["http://safebooru.org/index.php?page=dapi&s=post&q=index&tags=1girl%20solo&pid="+str(i+5000) for i in xrange(299,10000)] 73 | # print url_links[:10] 74 | 75 | for link in url_links: 76 | download_queue.put(link) 77 | download_queue.join() 78 | print "the images num is {0}".format(len(url_links)) 79 | print "took time : {0}".format(time() - start) 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /deepcolor/download_images_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import Queue 3 | from threading import Thread 4 | from time import time 5 | from itertools import chain 6 | import urllib2 7 | import untangle 8 | import numpy as np 9 | import cv2 10 | 11 | def download_imgs(url): 12 | # count = 0 13 | maxsize = 512 14 | file_name = url.split('=')[-1] 15 | header = {'Referer':'http://safebooru.org/index.php?page=post&s=list','User-Agent' : 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36'} 16 | def get_links(self) : 17 | return 18 | 19 | class DownloadWorker(Thread): 20 | def __init__(self, queue): 21 | Thread.__init__(self) 22 | self.queue = queue 23 | 24 | def run(self): 25 | while True: 26 | # Get the work from the queue and expand the tuple 27 | url = self.queue.get() 28 | if url is None: 29 | break 30 | # download_link(directory, link) 31 | download_imgs(url) 32 | self.queue.task_done() 33 | 34 | if __name__ == '__main__': 35 | start = time() 36 | download_queue = Queue.Queue(maxsize=100) 37 | for x in range(8): 38 | worker = DownloadWorker(download_queue) 39 | worker.daemon = True 40 | worker.start() 41 | 42 | url_links = ["http://safebooru.org/index.php?page=post&s=view&id={0}".format(string(i)) for i in xrange(299,10000)] 43 | # print url_links[:10] 44 | 45 | for link in url_links: 46 | download_queue.put(link) 47 | download_queue.join() 48 | print "the images num is {0}".format(len(url_links)) 49 | print "took time : {0}".format(time() - start) 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /deepcolor/imgs-valid/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/1.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/2.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5000_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5000_1.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5000_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5000_2.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5001_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5001_1.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5001_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5001_2.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5001_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5001_3.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5001_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5001_4.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5001_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5001_5.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5001_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5001_6.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5001_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5001_7.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5001_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5001_8.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5001_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5001_9.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5002_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5002_1.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5002_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5002_2.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5002_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5002_3.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5002_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5002_4.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5002_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5002_5.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5002_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5002_6.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5002_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5002_7.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5002_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5002_8.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5003_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5003_1.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5003_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5003_2.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5003_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5003_3.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5003_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5003_4.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5003_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5003_5.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5003_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5003_6.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5003_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5003_7.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5003_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5003_8.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5003_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5003_9.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5004_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5004_1.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5004_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5004_10.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5004_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5004_2.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5004_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5004_3.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5004_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5004_4.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5004_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5004_5.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5004_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5004_6.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5004_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5004_7.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5004_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5004_8.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5004_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5004_9.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5005_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5005_1.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5005_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5005_2.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5005_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5005_3.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5005_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5005_4.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5005_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5005_5.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5005_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5005_6.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5005_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5005_7.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5005_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5005_8.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5006_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5006_1.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5006_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5006_2.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5006_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5006_3.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5006_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5006_4.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5006_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5006_5.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5006_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5006_6.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5007_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5007_1.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5007_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5007_10.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5007_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5007_2.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5007_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5007_3.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5007_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5007_4.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5007_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5007_5.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5007_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5007_6.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5007_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5007_7.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5007_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5007_8.jpg -------------------------------------------------------------------------------- /deepcolor/imgs-valid/5007_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/imgs-valid/5007_9.jpg -------------------------------------------------------------------------------- /deepcolor/lines.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | from glob import glob 5 | from random import randint 6 | 7 | data = glob("imgs-valid/*.jpg") 8 | for imname in data: 9 | 10 | cimg = cv2.imread(imname,1) 11 | cimg = np.fliplr(cimg.reshape(-1,3)).reshape(cimg.shape) 12 | cimg = cv2.resize(cimg, (256,256)) 13 | 14 | img = cv2.imread(imname,0) 15 | 16 | # kernel = np.ones((5,5),np.float32)/25 17 | for i in xrange(30): 18 | randx = randint(0,205) 19 | randy = randint(0,205) 20 | cimg[randx:randx+50, randy:randy+50] = 255 21 | blur = cv2.blur(cimg,(100,100)) 22 | 23 | 24 | # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 25 | img_edge = cv2.adaptiveThreshold(img, 255, 26 | cv2.ADAPTIVE_THRESH_MEAN_C, 27 | cv2.THRESH_BINARY, 28 | blockSize=9, 29 | C=2) 30 | # img_edge = cv2.cvtColor(img_edge, cv2.COLOR_GRAY2RGB) 31 | # img_cartoon = cv2.bitwise_and(img, img_edge) 32 | 33 | plt.subplot(131),plt.imshow(cimg) 34 | plt.title('Original Image'), plt.xticks([]), plt.yticks([]) 35 | 36 | plt.subplot(132),plt.imshow(blur) 37 | plt.title('Edge Image'), plt.xticks([]), plt.yticks([]) 38 | 39 | plt.subplot(133),plt.imshow(img_edge,cmap = 'gray') 40 | plt.title('Edge Image'), plt.xticks([]), plt.yticks([]) 41 | 42 | plt.show() 43 | -------------------------------------------------------------------------------- /deepcolor/nohup.out: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/nohup.out -------------------------------------------------------------------------------- /deepcolor/ops.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/ops.py -------------------------------------------------------------------------------- /deepcolor/testserver.py: -------------------------------------------------------------------------------- 1 | from bottle import route, run 2 | 3 | @route('/hello') 4 | def hello(): 5 | return "Hello World!" 6 | run(host='0.0.0.0', port=8080, debug=True) 7 | -------------------------------------------------------------------------------- /deepcolor/uploaded/colors.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/uploaded/colors.jpg -------------------------------------------------------------------------------- /deepcolor/uploaded/colors3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/uploaded/colors3.jpg -------------------------------------------------------------------------------- /deepcolor/uploaded/convert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/uploaded/convert.png -------------------------------------------------------------------------------- /deepcolor/uploaded/gen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/uploaded/gen.jpg -------------------------------------------------------------------------------- /deepcolor/uploaded/gen3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/uploaded/gen3.jpg -------------------------------------------------------------------------------- /deepcolor/uploaded/lines.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/uploaded/lines.jpg -------------------------------------------------------------------------------- /deepcolor/uploaded/picasso.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/uploaded/picasso.png -------------------------------------------------------------------------------- /deepcolor/uploaded/sanae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/uploaded/sanae.png -------------------------------------------------------------------------------- /deepcolor/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import cv2 4 | 5 | class batch_norm(object): 6 | # h1 = lrelu(tf.contrib.layers.batch_norm(conv2d(h0, self.df_dim*2, name='d_h1_conv'),decay=0.9,updates_collections=None,epsilon=0.00001,scale=True,scope="d_h1_conv")) 7 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 8 | with tf.variable_scope(name): 9 | self.epsilon = epsilon 10 | self.momentum = momentum 11 | self.name = name 12 | 13 | def __call__(self, x, train=True): 14 | return tf.contrib.layers.batch_norm(x, decay=self.momentum, updates_collections=None, epsilon=self.epsilon, scale=True, scope=self.name) 15 | 16 | batchnorm_count = 0 17 | def bn(x): 18 | global batchnorm_count 19 | batch_object = batch_norm(name=("bn" + str(batchnorm_count))) 20 | batchnorm_count += 1 21 | return batch_object(x) 22 | 23 | def conv2d(input_, output_dim, 24 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 25 | name="conv2d"): 26 | with tf.variable_scope(name): 27 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 28 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 29 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 30 | 31 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 32 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 33 | 34 | return conv 35 | 36 | def deconv2d(input_, output_shape, 37 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 38 | name="deconv2d", with_w=False): 39 | with tf.variable_scope(name): 40 | # filter : [height, width, output_channels, in_channels] 41 | w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], initializer=tf.random_normal_initializer(stddev=stddev)) 42 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) 43 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 44 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 45 | if with_w: 46 | return deconv, w, biases 47 | else: 48 | return deconv 49 | 50 | 51 | def lrelu(x, leak=0.2, name="lrelu"): 52 | return tf.maximum(x, leak*x) 53 | 54 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 55 | shape = input_.get_shape().as_list() 56 | with tf.variable_scope(scope or "Linear"): 57 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 58 | tf.random_normal_initializer(stddev=stddev)) 59 | bias = tf.get_variable("bias", [output_size], 60 | initializer=tf.constant_initializer(bias_start)) 61 | if with_w: 62 | return tf.matmul(input_, matrix) + bias, matrix, bias 63 | else: 64 | return tf.matmul(input_, matrix) + bias 65 | 66 | def get_image(image_path): 67 | return transform(imread(image_path)) 68 | 69 | def transform(image, npx=512, is_crop=True): 70 | cropped_image = cv2.resize(image, (256,256)) 71 | 72 | return np.array(cropped_image) 73 | 74 | def imread(path): 75 | readimage = cv2.imread(path, 1) 76 | return readimage 77 | 78 | def merge_color(images, size): 79 | h, w = images.shape[1], images.shape[2] 80 | img = np.zeros((h * size[0], w * size[1], 3)) 81 | 82 | for idx, image in enumerate(images): 83 | i = idx % size[1] 84 | j = idx / size[1] 85 | img[j*h:j*h+h, i*w:i*w+w, :] = image 86 | 87 | return img 88 | 89 | def merge(images, size): 90 | h, w = images.shape[1], images.shape[2] 91 | img = np.zeros((h * size[0], w * size[1], 1)) 92 | 93 | for idx, image in enumerate(images): 94 | i = idx % size[1] 95 | j = idx / size[1] 96 | img[j*h:j*h+h, i*w:i*w+w] = image 97 | 98 | return img[:,:,0] 99 | 100 | def ims(name, img): 101 | print "saving img " + name 102 | cv2.imwrite(name, img*255) 103 | -------------------------------------------------------------------------------- /deepcolor/web/colorpicker/css/bootstrap-colorpicker.min.css: -------------------------------------------------------------------------------- 1 | /*! 2 | * Bootstrap Colorpicker v2.5.1 3 | * https://itsjavi.com/bootstrap-colorpicker/ 4 | * 5 | * Originally written by (c) 2012 Stefan Petre 6 | * Licensed under the Apache License v2.0 7 | * http://www.apache.org/licenses/LICENSE-2.0.txt 8 | * 9 | */.colorpicker-saturation{width:100px;height:100px;background-image:url(../img/bootstrap-colorpicker/saturation.png);cursor:crosshair;float:left}.colorpicker-saturation i{display:block;height:5px;width:5px;border:1px solid #000;-webkit-border-radius:5px;-moz-border-radius:5px;border-radius:5px;position:absolute;top:0;left:0;margin:-4px 0 0 -4px}.colorpicker-saturation i b{display:block;height:5px;width:5px;border:1px solid #fff;-webkit-border-radius:5px;-moz-border-radius:5px;border-radius:5px}.colorpicker-alpha,.colorpicker-hue{width:15px;height:100px;float:left;cursor:row-resize;margin-left:4px;margin-bottom:4px}.colorpicker-alpha i,.colorpicker-hue i{display:block;height:1px;background:#000;border-top:1px solid #fff;position:absolute;top:0;left:0;width:100%;margin-top:-1px}.colorpicker-hue{background-image:url(../img/bootstrap-colorpicker/hue.png)}.colorpicker-alpha{background-image:url(../img/bootstrap-colorpicker/alpha.png);display:none}.colorpicker-alpha,.colorpicker-hue,.colorpicker-saturation{background-size:contain}.colorpicker{padding:4px;min-width:130px;margin-top:1px;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;z-index:2500}.colorpicker:after,.colorpicker:before{display:table;content:"";line-height:0}.colorpicker:after{clear:both}.colorpicker:before{content:'';display:inline-block;border-left:7px solid transparent;border-right:7px solid transparent;border-bottom:7px solid #ccc;border-bottom-color:rgba(0,0,0,.2);position:absolute;top:-7px;left:6px}.colorpicker:after{content:'';display:inline-block;border-left:6px solid transparent;border-right:6px solid transparent;border-bottom:6px solid #fff;position:absolute;top:-6px;left:7px}.colorpicker div{position:relative}.colorpicker.colorpicker-with-alpha{min-width:140px}.colorpicker.colorpicker-with-alpha .colorpicker-alpha{display:block}.colorpicker-color{height:10px;margin-top:5px;clear:both;background-image:url(../img/bootstrap-colorpicker/alpha.png);background-position:0 100%}.colorpicker-color div{height:10px}.colorpicker-selectors{display:none;height:10px;margin-top:5px;clear:both}.colorpicker-selectors i{cursor:pointer;float:left;height:10px;width:10px}.colorpicker-selectors i+i{margin-left:3px}.colorpicker-element .add-on i,.colorpicker-element .input-group-addon i{display:inline-block;cursor:pointer;height:16px;vertical-align:text-top;width:16px}.colorpicker.colorpicker-inline{position:relative;display:inline-block;float:none;z-index:auto}.colorpicker.colorpicker-horizontal{width:110px;min-width:110px;height:auto}.colorpicker.colorpicker-horizontal .colorpicker-saturation{margin-bottom:4px}.colorpicker.colorpicker-horizontal .colorpicker-color{width:100px}.colorpicker.colorpicker-horizontal .colorpicker-alpha,.colorpicker.colorpicker-horizontal .colorpicker-hue{width:100px;height:15px;float:left;cursor:col-resize;margin-left:0;margin-bottom:4px}.colorpicker.colorpicker-horizontal .colorpicker-alpha i,.colorpicker.colorpicker-horizontal .colorpicker-hue i{display:block;height:15px;background:#fff;position:absolute;top:0;left:0;width:1px;border:none;margin-top:0}.colorpicker.colorpicker-horizontal .colorpicker-hue{background-image:url(../img/bootstrap-colorpicker/hue-horizontal.png)}.colorpicker.colorpicker-horizontal .colorpicker-alpha{background-image:url(../img/bootstrap-colorpicker/alpha-horizontal.png)}.colorpicker-right:before{left:auto;right:6px}.colorpicker-right:after{left:auto;right:7px}.colorpicker-no-arrow:before{border-right:0;border-left:0}.colorpicker-no-arrow:after{border-right:0;border-left:0}.colorpicker-alpha.colorpicker-visible,.colorpicker-hue.colorpicker-visible,.colorpicker-saturation.colorpicker-visible,.colorpicker-selectors.colorpicker-visible,.colorpicker.colorpicker-visible{display:block}.colorpicker-alpha.colorpicker-hidden,.colorpicker-hue.colorpicker-hidden,.colorpicker-saturation.colorpicker-hidden,.colorpicker-selectors.colorpicker-hidden,.colorpicker.colorpicker-hidden{display:none}.colorpicker-inline.colorpicker-visible{display:inline-block} 10 | /*# sourceMappingURL=bootstrap-colorpicker.min.css.map */ -------------------------------------------------------------------------------- /deepcolor/web/colorpicker/css/bootstrap-colorpicker.min.css.map: -------------------------------------------------------------------------------- 1 | {"version":3,"sources":["src/less/colorpicker.less"],"names":[],"mappings":";;;;;;;;AAqBA,wBACE,MAAA,MACA,OAAA,MAXA,iBAAsB,iDAatB,OAAA,UACA,MAAA,KACA,0BACE,QAAA,MACA,OAAA,IACA,MAAA,IACA,OAAA,IAAA,MAAA,KAfF,sBAAA,IACA,mBAAA,IACA,cAAA,IAeE,SAAA,SACA,IAAA,EACA,KAAA,EACA,OAAA,KAAA,EAAA,EAAA,KACA,4BACE,QAAA,MACA,OAAA,IACA,MAAA,IACA,OAAA,IAAA,MAAA,KAzBJ,sBAAA,IACA,mBAAA,IACA,cAAA,IA8BF,mBADA,iBAEE,MAAA,KACA,OAAA,MACA,MAAA,KACA,OAAA,WACA,YAAA,IACA,cAAA,IAIiB,qBADF,mBAEf,QAAA,MACA,OAAA,IACA,WAAA,KACA,WAAA,IAAA,MAAA,KACA,SAAA,SACA,IAAA,EACA,KAAA,EACA,MAAA,KACA,WAAA,KAGF,iBA1DE,iBAAsB,0CA8DxB,mBA9DE,iBAAsB,4CAgEtB,QAAA,KAKF,mBADA,iBADA,wBAGE,gBAAA,QAGF,aACE,QAAA,IACA,UAAA,MACA,WAAA,IAxEA,sBAAA,IACA,mBAAA,IACA,cAAA,IAwEA,QAAA,KAIU,mBADA,oBAEV,QAAA,MACA,QAAA,GACA,YAAA,EAGU,mBACV,MAAA,KAGU,oBACV,QAAA,GACA,QAAA,aACA,YAAA,IAAA,MAAA,YACA,aAAA,IAAA,MAAA,YACA,cAAA,IAAA,MAAA,KACA,oBAAA,eACA,SAAA,SACA,IAAA,KACA,KAAA,IAGU,mBACV,QAAA,GACA,QAAA,aACA,YAAA,IAAA,MAAA,YACA,aAAA,IAAA,MAAA,YACA,cAAA,IAAA,MAAA,KACA,SAAA,SACA,IAAA,KACA,KAAA,IAGW,iBACX,SAAA,SAGU,oCACV,UAAA,MAGkC,uDAClC,QAAA,MAGF,mBACE,OAAA,KACA,WAAA,IACA,MAAA,KAlIA,iBAAsB,4CAoItB,oBAAA,EAAA,KAGiB,uBACjB,OAAA,KAGF,uBACE,QAAA,KACA,OAAA,KACA,WAAA,IACA,MAAA,KAGqB,yBACrB,OAAA,QACA,MAAA,KACA,OAAA,KACA,MAAA,KAGuB,2BACvB,YAAA,IAI2B,+BADW,0CAEtC,QAAA,aACA,OAAA,QACA,OAAA,KACA,eAAA,SACA,MAAA,KAGU,gCACV,SAAA,SACA,QAAA,aACA,MAAA,KACA,QAAA,KAGU,oCACV,MAAA,MACA,UAAA,MACA,OAAA,KAGkC,4DAClC,cAAA,IAGkC,uDAClC,MAAA,MAIkC,uDADA,qDAElC,MAAA,MACA,OAAA,KACA,MAAA,KACA,OAAA,WACA,YAAA,EACA,cAAA,IAIqD,yDADF,uDAEnD,QAAA,MACA,OAAA,KACA,WAAA,KACA,SAAA,SACA,IAAA,EACA,KAAA,EACA,MAAA,IACA,OAAA,KACA,WAAA,EAGkC,qDAlNlC,iBAAsB,qDAsNY,uDAtNlC,iBAAsB,uDA0NN,0BAChB,KAAA,KACA,MAAA,IAGgB,yBAChB,KAAA,KACA,MAAA,IAGmB,6BACnB,aAAA,EACA,YAAA,EAGmB,4BACnB,aAAA,EACA,YAAA,EAQC,uCAAA,qCAAA,4CAAA,2CAAA,iCACC,QAAA,MASD,sCAAA,oCAAA,2CAAA,0CAAA,gCACC,QAAA,KAIe,wCACjB,QAAA"} -------------------------------------------------------------------------------- /deepcolor/web/colorpicker/img/bootstrap-colorpicker/alpha-horizontal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/colorpicker/img/bootstrap-colorpicker/alpha-horizontal.png -------------------------------------------------------------------------------- /deepcolor/web/colorpicker/img/bootstrap-colorpicker/alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/colorpicker/img/bootstrap-colorpicker/alpha.png -------------------------------------------------------------------------------- /deepcolor/web/colorpicker/img/bootstrap-colorpicker/hue-horizontal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/colorpicker/img/bootstrap-colorpicker/hue-horizontal.png -------------------------------------------------------------------------------- /deepcolor/web/colorpicker/img/bootstrap-colorpicker/hue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/colorpicker/img/bootstrap-colorpicker/hue.png -------------------------------------------------------------------------------- /deepcolor/web/colorpicker/img/bootstrap-colorpicker/saturation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/colorpicker/img/bootstrap-colorpicker/saturation.png -------------------------------------------------------------------------------- /deepcolor/web/css/grid.css: -------------------------------------------------------------------------------- 1 | h4 { 2 | margin-top: 25px; 3 | } 4 | .row { 5 | margin-bottom: 20px; 6 | } 7 | .row .row { 8 | margin-top: 10px; 9 | margin-bottom: 0; 10 | } 11 | [class*="col-"] { 12 | padding-top: 15px; 13 | padding-bottom: 15px; 14 | background-color: #eee; 15 | background-color: rgba(86,61,124,.15); 16 | border: 1px solid #ddd; 17 | border: 1px solid rgba(86,61,124,.2); 18 | } 19 | 20 | hr { 21 | margin-top: 40px; 22 | margin-bottom: 40px; 23 | } 24 | -------------------------------------------------------------------------------- /deepcolor/web/image_examples/armscross.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/image_examples/armscross.jpg -------------------------------------------------------------------------------- /deepcolor/web/image_examples/armscross.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/image_examples/armscross.png -------------------------------------------------------------------------------- /deepcolor/web/image_examples/picasso.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/image_examples/picasso.png -------------------------------------------------------------------------------- /deepcolor/web/image_examples/sanae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/image_examples/sanae.png -------------------------------------------------------------------------------- /deepcolor/web/images/1_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/1_color.png -------------------------------------------------------------------------------- /deepcolor/web/images/1_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/1_line.png -------------------------------------------------------------------------------- /deepcolor/web/images/2_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/2_color.png -------------------------------------------------------------------------------- /deepcolor/web/images/2_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/2_line.png -------------------------------------------------------------------------------- /deepcolor/web/images/3_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/3_color.png -------------------------------------------------------------------------------- /deepcolor/web/images/3_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/3_line.png -------------------------------------------------------------------------------- /deepcolor/web/images/4_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/4_color.png -------------------------------------------------------------------------------- /deepcolor/web/images/4_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/4_line.png -------------------------------------------------------------------------------- /deepcolor/web/images/5_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/5_color.png -------------------------------------------------------------------------------- /deepcolor/web/images/5_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/5_line.png -------------------------------------------------------------------------------- /deepcolor/web/images/6_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/6_color.png -------------------------------------------------------------------------------- /deepcolor/web/images/6_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/6_line.png -------------------------------------------------------------------------------- /deepcolor/web/images/7_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/7_color.png -------------------------------------------------------------------------------- /deepcolor/web/images/7_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/7_line.png -------------------------------------------------------------------------------- /deepcolor/web/images/a_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/deepcolor/web/images/a_line.png -------------------------------------------------------------------------------- /deepcolor/web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | the color move 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 27 | 28 |

Examples

29 |

Only line image [source]

30 |
31 |
32 | 33 |
34 |
35 | 36 |
37 |
38 | 39 |

Line image with color hints

40 |
41 |
42 | 43 |
44 |
45 | 46 |
47 |
48 | 49 |

Fine art [source (Picasso)]

50 |
51 |
52 | 53 |
54 |
55 | 56 |
57 |
58 | 59 |

Picasso with color hint

60 |
61 |
62 | 63 |
64 |
65 | 66 |
67 |
68 | 69 |

Confusing background

70 |
71 |
72 | 73 |
74 |
75 | 76 |
77 |
78 | 79 |

Heterochromia

80 |
81 |
82 | 83 |
84 |
85 | 86 |
87 |
88 | 89 |

My bad attempt at drawing

90 |
91 |
92 | 93 |
94 |
95 | 96 |
97 |
98 | 99 |
100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /finetuning/README.md: -------------------------------------------------------------------------------- 1 | ## using tf.slim to finetuning a model to new task 2 | 3 | 1. Download the train dataset [fisher data](https://pan.baidu.com/s/1nvyLmx7) and [pretrained model](https://pan.baidu.com/s/1pLRh2DP), uncompress the dataset in train folder. 4 | 5 | 2. run the command `cd covert_pys;python covert_somedata_to_tfrecord.py --dataset_name=train --dataset_dir=. --nFold=4` to split the train dataset to Train and val in 4 folds. 6 | Then, in folder `tfrecords`, we get the fish_train_00000-of-nFold-*-00001.tfrecord and fish_validation_00000-of-nFold-*-00001.tfrecord. 7 | 8 | 3. run `cd run_scripts; sh run.sh` to finetuning some layers to fit the new task (8 classification task).After finetuning the model, run `sh run_eval.sh` to eval the model. 9 | 10 | 4. (if you want to finetuning the all layers)run run_all.sh and run_all_eval.sh train the all layers and eval the model. 11 | 12 | **PS**: When you train or eval the model, make tfrecords include only one folder tfrecord. 13 | 14 | 5. In `fish_inference.py`, we make an inference with the finetuning model. 15 | 16 | ## make RESTful api with your model 17 | 18 | In flask_inference.py, we build an serving model with flask. It is too simple to get a image file path in your computer to make inference, 19 | the model is hold in your memory when the scrip is runing. 20 | 21 | 22 | I deploy a image classification in [demo page](http://demo.duanshishi.com). Feel free to try. 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /finetuning/convert_pys/covert_somedata_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import covert_datasets_tfrecord 7 | 8 | FLAGS = tf.app.flags.FLAGS 9 | 10 | tf.app.flags.DEFINE_string( 11 | 'dataset_name', None, 12 | 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".') 13 | 14 | tf.app.flags.DEFINE_string( 15 | 'dataset_dir', None, 16 | 'The directory where the output TFRecords and temporary files are saved.') 17 | 18 | tf.app.flags.DEFINE_integer('nFold', 1, "The nFold of Cross validation.") 19 | 20 | 21 | def main(_): 22 | if not FLAGS.dataset_name: 23 | raise ValueError( 24 | 'You must supply the dataset name with --dataset_name') 25 | if not FLAGS.dataset_dir: 26 | raise ValueError('You must supply the dataset name with --dataset_dir') 27 | 28 | covert_datasets_tfrecord.run(FLAGS) 29 | 30 | 31 | if __name__ == '__main__': 32 | tf.app.run() -------------------------------------------------------------------------------- /finetuning/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/datasets/__init__.py -------------------------------------------------------------------------------- /finetuning/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from datasets import fisher 6 | 7 | datasets_map = {'fisher': fisher} 8 | 9 | 10 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 11 | """Given a dataset name and a split_name returns a Dataset. 12 | 13 | Args: 14 | name: String, the name of the dataset. 15 | split_name: A train/test split name. 16 | dataset_dir: The directory where the dataset files are stored. 17 | file_pattern: The file pattern to use for matching the dataset source files. 18 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 19 | reader defined by each dataset is used. 20 | 21 | Returns: 22 | A `Dataset` class. 23 | 24 | Raises: 25 | ValueError: If the dataset `name` is unknown. 26 | """ 27 | if name not in datasets_map: 28 | raise ValueError('Name of dataset unknown %s' % name) 29 | return datasets_map[name].get_split(split_name, dataset_dir, file_pattern, 30 | reader) 31 | -------------------------------------------------------------------------------- /finetuning/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | from six.moves import urllib 25 | import tensorflow as tf 26 | 27 | LABELS_FILENAME = 'labels.txt' 28 | 29 | 30 | def int64_feature(values): 31 | """Returns a TF-Feature of int64s. 32 | 33 | Args: 34 | values: A scalar or list of values. 35 | 36 | Returns: 37 | a TF-Feature. 38 | """ 39 | if not isinstance(values, (tuple, list)): 40 | values = [values] 41 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 42 | 43 | 44 | def bytes_feature(values): 45 | """Returns a TF-Feature of bytes. 46 | 47 | Args: 48 | values: A string. 49 | 50 | Returns: 51 | a TF-Feature. 52 | """ 53 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 54 | 55 | 56 | def image_to_tfexample(image_data, image_format, height, width, class_id): 57 | return tf.train.Example(features=tf.train.Features(feature={ 58 | 'image/encoded': bytes_feature(image_data), 59 | 'image/format': bytes_feature(image_format), 60 | 'image/class/label': int64_feature(class_id), 61 | 'image/height': int64_feature(height), 62 | 'image/width': int64_feature(width), 63 | })) 64 | 65 | 66 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 67 | """Downloads the `tarball_url` and uncompresses it locally. 68 | 69 | Args: 70 | tarball_url: The URL of a tarball file. 71 | dataset_dir: The directory where the temporary files are stored. 72 | """ 73 | filename = tarball_url.split('/')[-1] 74 | filepath = os.path.join(dataset_dir, filename) 75 | 76 | def _progress(count, block_size, total_size): 77 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 78 | filename, float(count * block_size) / float(total_size) * 100.0)) 79 | sys.stdout.flush() 80 | 81 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 82 | print() 83 | statinfo = os.stat(filepath) 84 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 85 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 86 | 87 | 88 | def write_label_file(labels_to_class_names, 89 | dataset_dir, 90 | filename=LABELS_FILENAME): 91 | """Writes a file with the list of class names. 92 | 93 | Args: 94 | labels_to_class_names: A map of (integer) labels to class names. 95 | dataset_dir: The directory in which the labels file should be written. 96 | filename: The filename where the class names are written. 97 | """ 98 | labels_filename = os.path.join(dataset_dir, filename) 99 | with tf.gfile.Open(labels_filename, 'w') as f: 100 | for label in labels_to_class_names: 101 | class_name = labels_to_class_names[label] 102 | f.write('%d:%s\n' % (label, class_name)) 103 | 104 | 105 | def has_labels(dataset_dir, filename=LABELS_FILENAME): 106 | """Specifies whether or not the dataset directory contains a label map file. 107 | 108 | Args: 109 | dataset_dir: The directory in which the labels file is found. 110 | filename: The filename where the class names are written. 111 | 112 | Returns: 113 | `True` if the labels file exists and `False` otherwise. 114 | """ 115 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 116 | 117 | 118 | def read_label_file(dataset_dir, filename=LABELS_FILENAME): 119 | """Reads the labels file and returns a mapping from ID to class name. 120 | 121 | Args: 122 | dataset_dir: The directory in which the labels file is found. 123 | filename: The filename where the class names are written. 124 | 125 | Returns: 126 | A map from a label (integer) to class name. 127 | """ 128 | labels_filename = os.path.join(dataset_dir, filename) 129 | with tf.gfile.Open(labels_filename, 'r') as f: 130 | lines = f.read().decode() 131 | lines = lines.split('\n') 132 | lines = filter(None, lines) 133 | 134 | labels_to_class_names = {} 135 | for line in lines: 136 | index = line.index(':') 137 | labels_to_class_names[int(line[:index])] = line[index + 1:] 138 | return labels_to_class_names 139 | -------------------------------------------------------------------------------- /finetuning/datasets/fisher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Cifar10 dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/slim/data/create_cifar10_dataset.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'fisher_%s_*.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 3320, 'validation': 350} 35 | 36 | _NUM_CLASSES = 8 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A color image of varying size.', 40 | 'label': 'A single integer between 0 and 7', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading cifar10. 46 | 47 | Args: 48 | split_name: A train/validation split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/validation split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature( 74 | (), tf.string, default_value=''), 75 | 'image/format': tf.FixedLenFeature( 76 | (), tf.string, default_value='png'), 77 | 'image/class/label': tf.FixedLenFeature( 78 | [], tf.int64, default_value=tf.zeros( 79 | [], dtype=tf.int64)), 80 | } 81 | 82 | items_to_handlers = { 83 | 'image': slim.tfexample_decoder.Image(), 84 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 85 | } 86 | 87 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, 88 | items_to_handlers) 89 | 90 | labels_to_names = None 91 | if dataset_utils.has_labels(dataset_dir): 92 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 93 | 94 | return slim.dataset.Dataset( 95 | data_sources=file_pattern, 96 | reader=reader, 97 | decoder=decoder, 98 | num_samples=SPLITS_TO_SIZES[split_name], 99 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 100 | num_classes=_NUM_CLASSES, 101 | labels_to_names=labels_to_names) 102 | -------------------------------------------------------------------------------- /finetuning/deployment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/deployment/__init__.py -------------------------------------------------------------------------------- /finetuning/flask/flask_inference.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from functools import wraps 3 | from flask import Flask, request, jsonify 4 | slim = tf.contrib.slim 5 | from PIL import Image 6 | import sys 7 | sys.path.append("..") 8 | from nets.inception_v3 import * 9 | import numpy as np 10 | import os 11 | import time 12 | """ 13 | Load a tensorflow model and make it available as a REST service 14 | """ 15 | app = Flask(__name__) 16 | 17 | 18 | class myTfModel(object): 19 | def __init__(self, model_dir, prefix): 20 | self.model_dir = model_dir 21 | self.prefix = prefix 22 | self.output = {} 23 | self.load_model() 24 | 25 | def load_model(self): 26 | sess = tf.Session() 27 | input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3]) 28 | arg_scope = inception_v3_arg_scope() 29 | with slim.arg_scope(arg_scope): 30 | logits, end_points = inception_v3( 31 | input_tensor, is_training=False, num_classes=8) 32 | saver = tf.train.Saver() 33 | params_file = tf.train.latest_checkpoint(self.model_dir) 34 | saver.restore(sess, params_file) 35 | self.output['sess'] = sess 36 | self.output['input_tensor'] = input_tensor 37 | self.output['logits'] = logits 38 | self.output['end_points'] = end_points 39 | # return sess, input_tensor, logits, end_points 40 | 41 | def execute(self, data, **kwargs): 42 | sess = self.output['sess'] 43 | input_tensor = self.output['input_tensor'] 44 | logits = self.output['logits'] 45 | end_points = self.output['end_points'] 46 | # ims = [] 47 | # for i in range(kwargs['batch_size']): 48 | im = Image.open(data).resize((299, 299)) 49 | im = np.array(im) / 255.0 50 | im = im.reshape(-1, 299, 299, 3) 51 | # ims.append(im) 52 | # ims = np.array(ims) 53 | # print ims.shape 54 | start = time.time() 55 | predict_values, logit_values = sess.run( 56 | [end_points['Predictions'], logits], feed_dict={input_tensor: im}) 57 | return predict_values 58 | # print 'the porn score with the {0} is {1} '.format( 59 | 60 | # data, predict_values[1][1]) 61 | # print 'a image take time {0}'.format(time.time() - start) 62 | 63 | 64 | mymodel = myTfModel('./train_log', 'model.ckpt') 65 | 66 | 67 | @app.route('/model', methods=['GET', 'POST']) 68 | def apply_model(): 69 | image = request.args.get('image') 70 | predict_values = mymodel.execute(image, batch_size=1) 71 | predicted_class = np.argmax(predict_values[0]) 72 | return jsonify(output=int(predicted_class)) 73 | 74 | 75 | if __name__ == '__main__': 76 | app.run(debug=True) -------------------------------------------------------------------------------- /finetuning/flask/uploads/3a43c94e-241c-4988-a2f5-a3d34595ff40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/flask/uploads/3a43c94e-241c-4988-a2f5-a3d34595ff40.png -------------------------------------------------------------------------------- /finetuning/flask/uploads/3d060309-0e80-479e-994b-924deec9b34d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/flask/uploads/3d060309-0e80-479e-994b-924deec9b34d.png -------------------------------------------------------------------------------- /finetuning/flask/uploads/60f94ed5-1851-43aa-af5d-7a3ad9c2fbb0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/flask/uploads/60f94ed5-1851-43aa-af5d-7a3ad9c2fbb0.jpg -------------------------------------------------------------------------------- /finetuning/flask/uploads/779ee9bb-421a-441b-87bf-b26373b7d43e.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/flask/uploads/779ee9bb-421a-441b-87bf-b26373b7d43e.jpg -------------------------------------------------------------------------------- /finetuning/flask/uploads/cdbc38e5-5001-46f7-868c-76e9037299d8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/flask/uploads/cdbc38e5-5001-46f7-868c-76e9037299d8.jpg -------------------------------------------------------------------------------- /finetuning/flask/uploads/dfa92819-9372-4519-be57-81995751f2e6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/flask/uploads/dfa92819-9372-4519-be57-81995751f2e6.png -------------------------------------------------------------------------------- /finetuning/flask/uploads/f513b5c8-32b8-4fdd-bad9-d59a19dd094f.42: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/flask/uploads/f513b5c8-32b8-4fdd-bad9-d59a19dd094f.42 -------------------------------------------------------------------------------- /finetuning/flask/uploads/logo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/flask/uploads/logo.jpeg -------------------------------------------------------------------------------- /finetuning/inference/fish_inference.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | slim = tf.contrib.slim 3 | from PIL import Image 4 | import sys 5 | sys.path.append("..") 6 | from nets.inception_v3 import * 7 | import numpy as np 8 | import os 9 | import time 10 | 11 | 12 | def inference(image): 13 | checkpoint_dir = '../train_log' 14 | input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3]) 15 | sess = tf.Session() 16 | arg_scope = inception_v3_arg_scope() 17 | with slim.arg_scope(arg_scope): 18 | logits, end_points = inception_v3( 19 | input_tensor, is_training=False, num_classes=8) 20 | saver = tf.train.Saver() 21 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 22 | if ckpt and ckpt.model_checkpoint_path: 23 | saver.restore(sess, ckpt.model_checkpoint_path) 24 | im = Image.open(image).resize((299, 299)) 25 | im = np.array(im) / 255.0 26 | im = im.reshape(-1, 299, 299, 3) 27 | start = time.time() 28 | predict_values, logit_values = sess.run( 29 | [end_points['Predictions'], logits], feed_dict={input_tensor: im}) 30 | print 'a image take time {0}'.format(time.time() - start) 31 | return image, predict_values 32 | 33 | 34 | if __name__ == "__main__": 35 | sample_images = 'train/ALB/img_00003.jpg' 36 | image, predict = inference(sample_images) 37 | print 'the porn score with the {0} is {1} '.format(image, predict) 38 | -------------------------------------------------------------------------------- /finetuning/inference/inference_1000.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | slim = tf.contrib.slim 3 | from PIL import Image 4 | import sys 5 | sys.path.append("..") 6 | from nets.inception_v3 import * 7 | import numpy as np 8 | import os 9 | import time 10 | 11 | 12 | def get_label(sysnet_file, metadata_file): 13 | index_sysnet = [] 14 | with open(sysnet_file, 'r') as fread: 15 | for line in fread.readlines(): 16 | line = line.strip('\n') 17 | index_sysnet.append(line) 18 | sys_label = {} 19 | with open(metadata_file, 'r') as fread: 20 | for line in fread.readlines(): 21 | index = line.strip('\n').split('\t')[0] 22 | val = line.strip('\n').split('\t')[1] 23 | sys_label[index] = val 24 | 25 | index_label = [sys_label[i] for i in index_sysnet] 26 | index_label.append("i don't know") 27 | return index_label 28 | 29 | 30 | def inference(image): 31 | checkpoint_dir = '../pretrain_model' 32 | checkpoint_file = '../pretrain_model/inception_v3.ckpt' 33 | input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3]) 34 | sess = tf.Session() 35 | arg_scope = inception_v3_arg_scope() 36 | with slim.arg_scope(arg_scope): 37 | logits, end_points = inception_v3( 38 | input_tensor, is_training=False, num_classes=1001) 39 | saver = tf.train.Saver() 40 | # ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 41 | saver.restore(sess, checkpoint_file) 42 | im = Image.open(image).resize((299, 299)) 43 | im = np.array(im) / 255.0 44 | im = im.reshape(-1, 299, 299, 3) 45 | start = time.time() 46 | predict_values, logit_values = sess.run( 47 | [end_points['Predictions'], logits], feed_dict={input_tensor: im}) 48 | print 'a image take time {0}'.format(time.time() - start) 49 | return image, predict_values 50 | 51 | 52 | if __name__ == "__main__": 53 | sample_images = './cat.jpeg' 54 | image, predict = inference(sample_images) 55 | 56 | print np.argmax(predict[0]) 57 | index_label = get_label('./sysnet.txt', 'imagenet_metadata.txt') 58 | print 'the image {0}, predict label is {1}'.format( 59 | sample_images, index_label[np.argmax(predict[0] - 1)]) 60 | -------------------------------------------------------------------------------- /finetuning/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /finetuning/nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | logits: the pre-softmax activations, a tensor of size 54 | [batch_size, `num_classes`] 55 | end_points: a dictionary from components of the network to the corresponding 56 | activation. 57 | """ 58 | end_points = {} 59 | 60 | with tf.variable_scope(scope, 'CifarNet', [images, num_classes]): 61 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 62 | end_points['conv1'] = net 63 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | end_points['pool1'] = net 65 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 66 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 67 | end_points['conv2'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 69 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 70 | end_points['pool2'] = net 71 | net = slim.flatten(net) 72 | end_points['Flatten'] = net 73 | net = slim.fully_connected(net, 384, scope='fc3') 74 | end_points['fc3'] = net 75 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 76 | scope='dropout3') 77 | net = slim.fully_connected(net, 192, scope='fc4') 78 | end_points['fc4'] = net 79 | logits = slim.fully_connected(net, num_classes, 80 | biases_initializer=tf.zeros_initializer(), 81 | weights_initializer=trunc_normal(1/192.0), 82 | weights_regularizer=None, 83 | activation_fn=None, 84 | scope='logits') 85 | 86 | end_points['Logits'] = logits 87 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 88 | 89 | return logits, end_points 90 | cifarnet.default_image_size = 32 91 | 92 | 93 | def cifarnet_arg_scope(weight_decay=0.004): 94 | """Defines the default cifarnet argument scope. 95 | 96 | Args: 97 | weight_decay: The weight decay to use for regularizing the model. 98 | 99 | Returns: 100 | An `arg_scope` to use for the inception v3 model. 101 | """ 102 | with slim.arg_scope( 103 | [slim.conv2d], 104 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 105 | activation_fn=tf.nn.relu): 106 | with slim.arg_scope( 107 | [slim.fully_connected], 108 | biases_initializer=tf.constant_initializer(0.1), 109 | weights_initializer=trunc_normal(0.04), 110 | weights_regularizer=slim.l2_regularizer(weight_decay), 111 | activation_fn=tf.nn.relu) as sc: 112 | return sc 113 | -------------------------------------------------------------------------------- /finetuning/nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_v1 import inception_v1 25 | from nets.inception_v1 import inception_v1_arg_scope 26 | from nets.inception_v1 import inception_v1_base 27 | from nets.inception_v2 import inception_v2 28 | from nets.inception_v2 import inception_v2_arg_scope 29 | from nets.inception_v2 import inception_v2_base 30 | from nets.inception_v3 import inception_v3 31 | from nets.inception_v3 import inception_v3_arg_scope 32 | from nets.inception_v3 import inception_v3_base 33 | from nets.inception_v4 import inception_v4 34 | from nets.inception_v4 import inception_v4_arg_scope 35 | from nets.inception_v4 import inception_v4_base 36 | # pylint: enable=unused-import 37 | -------------------------------------------------------------------------------- /finetuning/nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001): 36 | """Defines the default arg scope for inception models. 37 | 38 | Args: 39 | weight_decay: The weight decay to use for regularizing the model. 40 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 41 | batch_norm_decay: Decay for batch norm moving average. 42 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 43 | in batch norm. 44 | 45 | Returns: 46 | An `arg_scope` to use for the inception models. 47 | """ 48 | batch_norm_params = { 49 | # Decay for the moving averages. 50 | 'decay': batch_norm_decay, 51 | # epsilon to prevent 0s in variance. 52 | 'epsilon': batch_norm_epsilon, 53 | # collection containing update_ops. 54 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 55 | } 56 | if use_batch_norm: 57 | normalizer_fn = slim.batch_norm 58 | normalizer_params = batch_norm_params 59 | else: 60 | normalizer_fn = None 61 | normalizer_params = {} 62 | # Set weight_decay for weights in Conv and FC layers. 63 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 64 | weights_regularizer=slim.l2_regularizer(weight_decay)): 65 | with slim.arg_scope( 66 | [slim.conv2d], 67 | weights_initializer=slim.variance_scaling_initializer(), 68 | activation_fn=tf.nn.relu, 69 | normalizer_fn=normalizer_fn, 70 | normalizer_params=normalizer_params) as sc: 71 | return sc 72 | -------------------------------------------------------------------------------- /finetuning/nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. 44 | is_training: specifies whether or not we're currently training the model. 45 | This variable will determine the behaviour of the dropout layer. 46 | dropout_keep_prob: the percentage of activation values that are retained. 47 | prediction_fn: a function to get predictions out of logits. 48 | scope: Optional variable_scope. 49 | 50 | Returns: 51 | logits: the pre-softmax activations, a tensor of size 52 | [batch_size, `num_classes`] 53 | end_points: a dictionary from components of the network to the corresponding 54 | activation. 55 | """ 56 | end_points = {} 57 | 58 | with tf.variable_scope(scope, 'LeNet', [images, num_classes]): 59 | net = slim.conv2d(images, 32, [5, 5], scope='conv1') 60 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 61 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 62 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 63 | net = slim.flatten(net) 64 | end_points['Flatten'] = net 65 | 66 | net = slim.fully_connected(net, 1024, scope='fc3') 67 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 68 | scope='dropout3') 69 | logits = slim.fully_connected(net, num_classes, activation_fn=None, 70 | scope='fc4') 71 | 72 | end_points['Logits'] = logits 73 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 74 | 75 | return logits, end_points 76 | lenet.default_image_size = 28 77 | 78 | 79 | def lenet_arg_scope(weight_decay=0.0): 80 | """Defines the default lenet argument scope. 81 | 82 | Args: 83 | weight_decay: The weight decay to use for regularizing the model. 84 | 85 | Returns: 86 | An `arg_scope` to use for the inception v3 model. 87 | """ 88 | with slim.arg_scope( 89 | [slim.conv2d, slim.fully_connected], 90 | weights_regularizer=slim.l2_regularizer(weight_decay), 91 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 92 | activation_fn=tf.nn.relu) as sc: 93 | return sc 94 | -------------------------------------------------------------------------------- /finetuning/nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import alexnet 25 | from nets import cifarnet 26 | from nets import inception 27 | from nets import lenet 28 | from nets import overfeat 29 | from nets import resnet_v1 30 | from nets import resnet_v2 31 | from nets import vgg 32 | 33 | slim = tf.contrib.slim 34 | 35 | networks_map = {'alexnet_v2': alexnet.alexnet_v2, 36 | 'cifarnet': cifarnet.cifarnet, 37 | 'overfeat': overfeat.overfeat, 38 | 'vgg_a': vgg.vgg_a, 39 | 'vgg_16': vgg.vgg_16, 40 | 'vgg_19': vgg.vgg_19, 41 | 'inception_v1': inception.inception_v1, 42 | 'inception_v2': inception.inception_v2, 43 | 'inception_v3': inception.inception_v3, 44 | 'inception_v4': inception.inception_v4, 45 | 'inception_resnet_v2': inception.inception_resnet_v2, 46 | 'lenet': lenet.lenet, 47 | 'resnet_v1_50': resnet_v1.resnet_v1_50, 48 | 'resnet_v1_101': resnet_v1.resnet_v1_101, 49 | 'resnet_v1_152': resnet_v1.resnet_v1_152, 50 | 'resnet_v1_200': resnet_v1.resnet_v1_200, 51 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 52 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 53 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 54 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 55 | } 56 | 57 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, 58 | 'cifarnet': cifarnet.cifarnet_arg_scope, 59 | 'overfeat': overfeat.overfeat_arg_scope, 60 | 'vgg_a': vgg.vgg_arg_scope, 61 | 'vgg_16': vgg.vgg_arg_scope, 62 | 'vgg_19': vgg.vgg_arg_scope, 63 | 'inception_v1': inception.inception_v3_arg_scope, 64 | 'inception_v2': inception.inception_v3_arg_scope, 65 | 'inception_v3': inception.inception_v3_arg_scope, 66 | 'inception_v4': inception.inception_v4_arg_scope, 67 | 'inception_resnet_v2': 68 | inception.inception_resnet_v2_arg_scope, 69 | 'lenet': lenet.lenet_arg_scope, 70 | 'resnet_v1_50': resnet_v1.resnet_arg_scope, 71 | 'resnet_v1_101': resnet_v1.resnet_arg_scope, 72 | 'resnet_v1_152': resnet_v1.resnet_arg_scope, 73 | 'resnet_v1_200': resnet_v1.resnet_arg_scope, 74 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 75 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 76 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 77 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 78 | } 79 | 80 | 81 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 82 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 83 | 84 | Args: 85 | name: The name of the network. 86 | num_classes: The number of classes to use for classification. 87 | weight_decay: The l2 coefficient for the model weights. 88 | is_training: `True` if the model is being used for training and `False` 89 | otherwise. 90 | 91 | Returns: 92 | network_fn: A function that applies the model to a batch of images. It has 93 | the following signature: 94 | logits, end_points = network_fn(images) 95 | Raises: 96 | ValueError: If network `name` is not recognized. 97 | """ 98 | if name not in networks_map: 99 | raise ValueError('Name of network unknown %s' % name) 100 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 101 | func = networks_map[name] 102 | @functools.wraps(func) 103 | def network_fn(images): 104 | with slim.arg_scope(arg_scope): 105 | return func(images, num_classes, is_training=is_training) 106 | if hasattr(func, 'default_image_size'): 107 | network_fn.default_image_size = func.default_image_size 108 | 109 | return network_fn 110 | -------------------------------------------------------------------------------- /finetuning/nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFn(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in nets_factory.networks_map: 34 | with self.test_session(): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | if __name__ == '__main__': 46 | tf.test.main() 47 | -------------------------------------------------------------------------------- /finetuning/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /finetuning/preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import inception_preprocessing 24 | 25 | slim = tf.contrib.slim 26 | 27 | 28 | def get_preprocessing(name, is_training=False): 29 | """Returns preprocessing_fn(image, height, width, **kwargs). 30 | 31 | Args: 32 | name: The name of the preprocessing function. 33 | is_training: `True` if the model is being used for training and `False` 34 | otherwise. 35 | 36 | Returns: 37 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 38 | It has the following signature: 39 | image = preprocessing_fn(image, output_height, output_width, ...). 40 | 41 | Raises: 42 | ValueError: If Preprocessing `name` is not recognized. 43 | """ 44 | preprocessing_fn_map = { 45 | 'inception': inception_preprocessing, 46 | 'inception_v1': inception_preprocessing, 47 | 'inception_v2': inception_preprocessing, 48 | 'inception_v3': inception_preprocessing, 49 | 'inception_v4': inception_preprocessing, 50 | 'inception_resnet_v2': inception_preprocessing 51 | } 52 | 53 | if name not in preprocessing_fn_map: 54 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 55 | 56 | def preprocessing_fn(image, output_height, output_width, **kwargs): 57 | return preprocessing_fn_map[name].preprocess_image( 58 | image, 59 | output_height, 60 | output_width, 61 | is_training=is_training, 62 | **kwargs) 63 | 64 | return preprocessing_fn 65 | -------------------------------------------------------------------------------- /finetuning/pretrain_model/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/pretrain_model/.gitkeep -------------------------------------------------------------------------------- /finetuning/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==0.12.0rc1 2 | Flask==0.11.1 3 | Pillow==3.4.2 4 | -------------------------------------------------------------------------------- /finetuning/run_scripts/run.sh: -------------------------------------------------------------------------------- 1 | TRAIN_DIR=./train_log 2 | DATASET_DIR=./tfrecords 3 | PRETRAINED_CHECKPOINT_DIR=./pretrain_model 4 | 5 | python train_image_classifier.py \ 6 | --train_dir=${TRAIN_DIR} \ 7 | --dataset_name=fisher \ 8 | --dataset_split_name=train \ 9 | --dataset_dir=${DATASET_DIR} \ 10 | --model_name=inception_v3 \ 11 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt \ 12 | --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 13 | --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 14 | --max_number_of_steps=1000 \ 15 | --batch_size=32 \ 16 | --learning_rate=0.01 \ 17 | --learning_rate_decay_type=fixed \ 18 | --save_interval_secs=60 \ 19 | --save_summaries_secs=60 \ 20 | --log_every_n_steps=100 \ 21 | --optimizer=rmsprop \ 22 | --weight_decay=0.00004 -------------------------------------------------------------------------------- /finetuning/run_scripts/run_all.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | TRAIN_DIR=./train_log 3 | DATASET_DIR=./tfrecords 4 | PRETRAINED_CHECKPOINT_DIR=./pretrain_model 5 | 6 | 7 | python train_image_classifier.py \ 8 | --train_dir=${TRAIN_DIR}/all \ 9 | --dataset_name=fisher \ 10 | --dataset_split_name=train \ 11 | --dataset_dir=${DATASET_DIR} \ 12 | --model_name=inception_v3 \ 13 | --checkpoint_path=${TRAIN_DIR} \ 14 | --max_number_of_steps=500 \ 15 | --batch_size=32 \ 16 | --learning_rate=0.0001 \ 17 | --learning_rate_decay_type=fixed \ 18 | --save_interval_secs=60 \ 19 | --save_summaries_secs=60 \ 20 | --log_every_n_steps=10 \ 21 | --optimizer=rmsprop \ 22 | --weight_decay=0.00004 23 | -------------------------------------------------------------------------------- /finetuning/run_scripts/run_all_eval.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | TRAIN_DIR=./train_log 3 | DATASET_DIR=./tfrecords 4 | PRETRAINED_CHECKPOINT_DIR=./pretrain_model 5 | 6 | python eval_image_classifier.py \ 7 | --checkpoint_path=${TRAIN_DIR}/all \ 8 | --eval_dir=${TRAIN_DIR}/all \ 9 | --dataset_name=fisher \ 10 | --dataset_split_name=validation \ 11 | --dataset_dir=${DATASET_DIR} \ 12 | --model_name=inception_v3 13 | -------------------------------------------------------------------------------- /finetuning/run_scripts/run_eval.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | TRAIN_DIR=./train_log 3 | DATASET_DIR=./tfrecords 4 | PRETRAINED_CHECKPOINT_DIR=./pretrain_model 5 | 6 | python eval_image_classifier.py \ 7 | --checkpoint_path=${TRAIN_DIR} \ 8 | --eval_dir=${TRAIN_DIR} \ 9 | --dataset_name=fisher \ 10 | --dataset_split_name=validation \ 11 | --dataset_dir=${DATASET_DIR} \ 12 | --model_name=inception_v3 13 | -------------------------------------------------------------------------------- /finetuning/tfrecords/labels.txt: -------------------------------------------------------------------------------- 1 | 0:ALB 2 | 1:BET 3 | 2:DOL 4 | 3:LAG 5 | 4:NoF 6 | 5:OTHER 7 | 6:SHARK 8 | 7:YFT 9 | -------------------------------------------------------------------------------- /finetuning/train/ALB/img_00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/train/ALB/img_00003.jpg -------------------------------------------------------------------------------- /finetuning/train/BET/img_04557.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/train/BET/img_04557.jpg -------------------------------------------------------------------------------- /finetuning/train/DOL/img_00951.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/train/DOL/img_00951.jpg -------------------------------------------------------------------------------- /finetuning/train/LAG/img_01644.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/train/LAG/img_01644.jpg -------------------------------------------------------------------------------- /finetuning/train/NoF/img_00028.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/train/NoF/img_00028.jpg -------------------------------------------------------------------------------- /finetuning/train/OTHER/img_00063.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/train/OTHER/img_00063.jpg -------------------------------------------------------------------------------- /finetuning/train/SHARK/img_02176.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/train/SHARK/img_02176.jpg -------------------------------------------------------------------------------- /finetuning/train/YFT/img_00184.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/train/YFT/img_00184.jpg -------------------------------------------------------------------------------- /finetuning/uploads/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/finetuning/uploads/1.png -------------------------------------------------------------------------------- /images/acgan-fig-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/acgan-fig-01.png -------------------------------------------------------------------------------- /images/acgan-result-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/acgan-result-01.png -------------------------------------------------------------------------------- /images/acgan-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/acgan-result.png -------------------------------------------------------------------------------- /images/demo_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/demo_result.png -------------------------------------------------------------------------------- /images/flask_with_pretrain_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/flask_with_pretrain_model.png -------------------------------------------------------------------------------- /images/flask_with_pretrain_model_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/flask_with_pretrain_model_00.png -------------------------------------------------------------------------------- /images/infogan-fig-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/infogan-fig-01.png -------------------------------------------------------------------------------- /images/infogan-result-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/infogan-result-01.png -------------------------------------------------------------------------------- /images/infogan-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/infogan-result.png -------------------------------------------------------------------------------- /images/mac_blogs_deepcolor-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/mac_blogs_deepcolor-01.png -------------------------------------------------------------------------------- /images/mac_blogs_deepcolor-03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/mac_blogs_deepcolor-03.png -------------------------------------------------------------------------------- /images/mltookit_log_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/mltookit_log_00.png -------------------------------------------------------------------------------- /images/mnist_client_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/mnist_client_result.png -------------------------------------------------------------------------------- /images/mnist_server.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/images/mnist_server.png -------------------------------------------------------------------------------- /machinelearning_toolkit/README.md: -------------------------------------------------------------------------------- 1 | ## The examples of tf.learn 2 | 3 | ### LinearClassifier 4 | ### RandomForest 5 | ### SVM 6 | 7 | 8 | ### Add monitor in wide and deep model 9 | [https://www.tensorflow.org/get_started/monitors](https://www.tensorflow.org/get_started/monitors) 10 | 11 | 1. add `tf.logging.set_verbosity(tf.logging.INFO)`, then you can log the loss: 12 | ![](../images/mltookit_log_00.png) 13 | 2. validation_monitor: a bug with input_fn, eval_steps=None gen the bug, set eval_steps =1 would solve it. 14 | ![](../images/mltookit_log_01.png) 15 | 3. how to set validation_metrics, see the code [tensorflow/contrib/learn/python/learn/metric_spec.py](https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/contrib/learn/python/learn/metric_spec.py) here. Specify the prediction_key and labal_key (None maybe the default ) -------------------------------------------------------------------------------- /machinelearning_toolkit/scripts/simple-tf-rf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import pandas as pd 3 | from tensorflow.contrib.tensor_forest.client import random_forest 4 | 5 | detailed_occupation_recode = tf.contrib.layers.sparse_column_with_hash_bucket( 6 | column_name='detailed_occupation_recode', hash_bucket_size=1000) 7 | education = tf.contrib.layers.sparse_column_with_hash_bucket( 8 | column_name='education', hash_bucket_size=1000) 9 | # Continuous base columns 10 | age = tf.contrib.layers.real_valued_column('age') 11 | wage_per_hour = tf.contrib.layers.real_valued_column('wage_per_hour') 12 | 13 | columns = [ 14 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 'label' 15 | ] 16 | FEATURE_COLUMNS = [ 17 | # age, age_buckets, class_of_worker, detailed_industry_recode, 18 | age, 19 | detailed_occupation_recode, 20 | education, 21 | wage_per_hour 22 | ] 23 | 24 | LABEL_COLUMN = 'label' 25 | 26 | CONTINUOUS_COLUMNS = ['age', 'wage_per_hour'] 27 | 28 | CATEGORICAL_COLUMNS = ['detailed_occupation_recode', 'education'] 29 | 30 | df_train = pd.DataFrame( 31 | [[12, '12', '7th and 8th grade', 40, '- 50000'], 32 | [40, '45', '7th and 8th grade', 40, '50000+'], 33 | [50, '50', '10th grade', 40, '50000+'], 34 | [60, '30', '7th and 8th grade', 40, '- 50000']], 35 | columns=[ 36 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 37 | 'label' 38 | ]) 39 | 40 | df_test = pd.DataFrame( 41 | [[12, '12', '7th and 8th grade', 40, '- 50000'], 42 | [40, '45', '7th and 8th grade', 40, '50000+'], 43 | [50, '50', '10th grade', 40, '50000+'], 44 | [60, '30', '7th and 8th grade', 40, '- 50000']], 45 | columns=[ 46 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 47 | 'label' 48 | ]) 49 | df_train[LABEL_COLUMN] = ( 50 | df_train[LABEL_COLUMN].apply(lambda x: '+' in x)).astype(int) 51 | df_test[LABEL_COLUMN] = ( 52 | df_test[LABEL_COLUMN].apply(lambda x: '+' in x)).astype(int) 53 | dtypess = df_train.dtypes 54 | 55 | print df_train 56 | print df_test 57 | 58 | 59 | def input_fn(df): 60 | continuous_cols = { 61 | k: tf.expand_dims(tf.constant(df[k].values), 1) 62 | for k in CONTINUOUS_COLUMNS 63 | } 64 | # continuous_cols = { 65 | # k: tf.constant(df[k].values) 66 | # for k in CONTINUOUS_COLUMNS 67 | # } 68 | categorical_cols = { 69 | k: tf.SparseTensor( 70 | indices=[[i, 0] for i in range(df[k].size)], 71 | values=df[k].values, 72 | dense_shape=[df[k].size, 1]) 73 | for k in CATEGORICAL_COLUMNS 74 | } 75 | feature_cols = dict(continuous_cols.items() + categorical_cols.items()) 76 | label = tf.constant(df[LABEL_COLUMN].values) 77 | return feature_cols, label 78 | 79 | 80 | def train_input_fn(): 81 | return input_fn(df_train) 82 | 83 | 84 | def eval_input_fn(): 85 | return input_fn(df_test) 86 | 87 | 88 | model_dir = '../rf_model_dir' 89 | 90 | hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams( 91 | num_trees=10, max_nodes=1000, num_classes=2, num_features=4) 92 | classifier = random_forest.TensorForestEstimator(hparams, model_dir=model_dir) 93 | classifier.fit(input_fn=train_input_fn, steps=100) 94 | results = classifier.evaluate(input_fn=eval_input_fn, steps=1) 95 | print results 96 | for key in sorted(results): 97 | print("%s: %s" % (key, results[key])) 98 | -------------------------------------------------------------------------------- /machinelearning_toolkit/scripts/simple-tf-svm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import pandas as pd 3 | from tensorflow.contrib.learn.python.learn.estimators import svm 4 | 5 | detailed_occupation_recode = tf.contrib.layers.sparse_column_with_hash_bucket( 6 | column_name='detailed_occupation_recode', hash_bucket_size=1000) 7 | education = tf.contrib.layers.sparse_column_with_hash_bucket( 8 | column_name='education', hash_bucket_size=1000) 9 | # Continuous base columns 10 | age = tf.contrib.layers.real_valued_column('age') 11 | wage_per_hour = tf.contrib.layers.real_valued_column('wage_per_hour') 12 | 13 | columns = [ 14 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 'label' 15 | ] 16 | FEATURE_COLUMNS = [ 17 | # age, age_buckets, class_of_worker, detailed_industry_recode, 18 | age, 19 | detailed_occupation_recode, 20 | education, 21 | wage_per_hour 22 | ] 23 | 24 | LABEL_COLUMN = 'label' 25 | 26 | CONTINUOUS_COLUMNS = ['age', 'wage_per_hour'] 27 | 28 | CATEGORICAL_COLUMNS = ['detailed_occupation_recode', 'education'] 29 | 30 | df_train = pd.DataFrame( 31 | [[12, '12', '7th and 8th grade', 40, '- 50000'], 32 | [40, '45', '7th and 8th grade', 40, '50000+'], 33 | [50, '50', '10th grade', 40, '50000+'], 34 | [60, '30', '7th and 8th grade', 40, '- 50000']], 35 | columns=[ 36 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 37 | 'label' 38 | ]) 39 | 40 | df_test = pd.DataFrame( 41 | [[12, '12', '7th and 8th grade', 40, '- 50000'], 42 | [40, '45', '7th and 8th grade', 40, '50000+'], 43 | [50, '50', '10th grade', 40, '50000+'], 44 | [60, '30', '7th and 8th grade', 40, '- 50000']], 45 | columns=[ 46 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 47 | 'label' 48 | ]) 49 | df_train[LABEL_COLUMN] = ( 50 | df_train[LABEL_COLUMN].apply(lambda x: '+' in x)).astype(int) 51 | df_test[LABEL_COLUMN] = ( 52 | df_test[LABEL_COLUMN].apply(lambda x: '+' in x)).astype(int) 53 | dtypess = df_train.dtypes 54 | 55 | 56 | def input_fn(df): 57 | # continuous_cols = {k: tf.constant(df[k].values) for k in CONTINUOUS_COLUMNS} 58 | continuous_cols = { 59 | k: tf.expand_dims(tf.constant(df[k].values), 1) 60 | for k in CONTINUOUS_COLUMNS 61 | } 62 | categorical_cols = { 63 | k: tf.SparseTensor( 64 | indices=[[i, 0] for i in range(df[k].size)], 65 | values=df[k].values, 66 | dense_shape=[df[k].size, 1]) 67 | for k in CATEGORICAL_COLUMNS 68 | } 69 | feature_cols = dict(continuous_cols.items() + categorical_cols.items()) 70 | feature_cols['example_id'] = tf.constant( 71 | [str(i + 1) for i in range(df['age'].size)]) 72 | label = tf.constant(df[LABEL_COLUMN].values) 73 | return feature_cols, label 74 | 75 | 76 | def train_input_fn(): 77 | return input_fn(df_train) 78 | 79 | 80 | def eval_input_fn(): 81 | return input_fn(df_test) 82 | 83 | 84 | model_dir = '../svm_model_dir' 85 | 86 | model = svm.SVM(example_id_column='example_id', 87 | feature_columns=FEATURE_COLUMNS, 88 | model_dir=model_dir) 89 | model.fit(input_fn=train_input_fn, steps=10) 90 | results = model.evaluate(input_fn=eval_input_fn, steps=1) 91 | for key in sorted(results): 92 | print("%s: %s" % (key, results[key])) 93 | -------------------------------------------------------------------------------- /machinelearning_toolkit/wide_deep_scripts/simple-tf-rf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import pandas as pd 3 | from tensorflow.contrib.tensor_forest.client import random_forest 4 | 5 | detailed_occupation_recode = tf.contrib.layers.sparse_column_with_hash_bucket( 6 | column_name='detailed_occupation_recode', hash_bucket_size=1000) 7 | education = tf.contrib.layers.sparse_column_with_hash_bucket( 8 | column_name='education', hash_bucket_size=1000) 9 | # Continuous base columns 10 | age = tf.contrib.layers.real_valued_column('age') 11 | wage_per_hour = tf.contrib.layers.real_valued_column('wage_per_hour') 12 | 13 | columns = [ 14 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 'label' 15 | ] 16 | FEATURE_COLUMNS = [ 17 | # age, age_buckets, class_of_worker, detailed_industry_recode, 18 | age, 19 | detailed_occupation_recode, 20 | education, 21 | wage_per_hour 22 | ] 23 | 24 | LABEL_COLUMN = 'label' 25 | 26 | CONTINUOUS_COLUMNS = ['age', 'wage_per_hour'] 27 | 28 | CATEGORICAL_COLUMNS = ['detailed_occupation_recode', 'education'] 29 | 30 | df_train = pd.DataFrame( 31 | [[12, '12', '7th and 8th grade', 40, '- 50000'], 32 | [40, '45', '7th and 8th grade', 40, '50000+'], 33 | [50, '50', '10th grade', 40, '50000+'], 34 | [60, '30', '7th and 8th grade', 40, '- 50000']], 35 | columns=[ 36 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 37 | 'label' 38 | ]) 39 | 40 | df_test = pd.DataFrame( 41 | [[12, '12', '7th and 8th grade', 40, '- 50000'], 42 | [40, '45', '7th and 8th grade', 40, '50000+'], 43 | [50, '50', '10th grade', 40, '50000+'], 44 | [60, '30', '7th and 8th grade', 40, '- 50000']], 45 | columns=[ 46 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 47 | 'label' 48 | ]) 49 | df_train[LABEL_COLUMN] = ( 50 | df_train[LABEL_COLUMN].apply(lambda x: '+' in x)).astype(int) 51 | df_test[LABEL_COLUMN] = ( 52 | df_test[LABEL_COLUMN].apply(lambda x: '+' in x)).astype(int) 53 | dtypess = df_train.dtypes 54 | 55 | print df_train 56 | print df_test 57 | 58 | 59 | def input_fn(df): 60 | continuous_cols = { 61 | k: tf.expand_dims(tf.constant(df[k].values), 1) 62 | for k in CONTINUOUS_COLUMNS 63 | } 64 | # continuous_cols = { 65 | # k: tf.constant(df[k].values) 66 | # for k in CONTINUOUS_COLUMNS 67 | # } 68 | categorical_cols = { 69 | k: tf.SparseTensor( 70 | indices=[[i, 0] for i in range(df[k].size)], 71 | values=df[k].values, 72 | dense_shape=[df[k].size, 1]) 73 | for k in CATEGORICAL_COLUMNS 74 | } 75 | feature_cols = dict(continuous_cols.items() + categorical_cols.items()) 76 | label = tf.constant(df[LABEL_COLUMN].values) 77 | return feature_cols, label 78 | 79 | 80 | def train_input_fn(): 81 | return input_fn(df_train) 82 | 83 | 84 | def eval_input_fn(): 85 | return input_fn(df_test) 86 | 87 | 88 | model_dir = '../rf_model_dir' 89 | 90 | hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams( 91 | num_trees=10, max_nodes=1000, num_classes=2, num_features=4) 92 | classifier = random_forest.TensorForestEstimator(hparams, model_dir=model_dir) 93 | classifier.fit(input_fn=train_input_fn, steps=100) 94 | results = classifier.evaluate(input_fn=eval_input_fn, steps=1) 95 | print results 96 | for key in sorted(results): 97 | print("%s: %s" % (key, results[key])) 98 | -------------------------------------------------------------------------------- /machinelearning_toolkit/wide_deep_scripts/simple-tf-svm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import pandas as pd 3 | from tensorflow.contrib.learn.python.learn.estimators import svm 4 | 5 | detailed_occupation_recode = tf.contrib.layers.sparse_column_with_hash_bucket( 6 | column_name='detailed_occupation_recode', hash_bucket_size=1000) 7 | education = tf.contrib.layers.sparse_column_with_hash_bucket( 8 | column_name='education', hash_bucket_size=1000) 9 | # Continuous base columns 10 | age = tf.contrib.layers.real_valued_column('age') 11 | wage_per_hour = tf.contrib.layers.real_valued_column('wage_per_hour') 12 | 13 | columns = [ 14 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 'label' 15 | ] 16 | FEATURE_COLUMNS = [ 17 | # age, age_buckets, class_of_worker, detailed_industry_recode, 18 | age, 19 | detailed_occupation_recode, 20 | education, 21 | wage_per_hour 22 | ] 23 | 24 | LABEL_COLUMN = 'label' 25 | 26 | CONTINUOUS_COLUMNS = ['age', 'wage_per_hour'] 27 | 28 | CATEGORICAL_COLUMNS = ['detailed_occupation_recode', 'education'] 29 | 30 | df_train = pd.DataFrame( 31 | [[12, '12', '7th and 8th grade', 40, '- 50000'], 32 | [40, '45', '7th and 8th grade', 40, '50000+'], 33 | [50, '50', '10th grade', 40, '50000+'], 34 | [60, '30', '7th and 8th grade', 40, '- 50000']], 35 | columns=[ 36 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 37 | 'label' 38 | ]) 39 | 40 | df_test = pd.DataFrame( 41 | [[12, '12', '7th and 8th grade', 40, '- 50000'], 42 | [40, '45', '7th and 8th grade', 40, '50000+'], 43 | [50, '50', '10th grade', 40, '50000+'], 44 | [60, '30', '7th and 8th grade', 40, '- 50000']], 45 | columns=[ 46 | 'age', 'detailed_occupation_recode', 'education', 'wage_per_hour', 47 | 'label' 48 | ]) 49 | df_train[LABEL_COLUMN] = ( 50 | df_train[LABEL_COLUMN].apply(lambda x: '+' in x)).astype(int) 51 | df_test[LABEL_COLUMN] = ( 52 | df_test[LABEL_COLUMN].apply(lambda x: '+' in x)).astype(int) 53 | dtypess = df_train.dtypes 54 | 55 | 56 | def input_fn(df): 57 | # continuous_cols = {k: tf.constant(df[k].values) for k in CONTINUOUS_COLUMNS} 58 | continuous_cols = { 59 | k: tf.expand_dims(tf.constant(df[k].values), 1) 60 | for k in CONTINUOUS_COLUMNS 61 | } 62 | categorical_cols = { 63 | k: tf.SparseTensor( 64 | indices=[[i, 0] for i in range(df[k].size)], 65 | values=df[k].values, 66 | dense_shape=[df[k].size, 1]) 67 | for k in CATEGORICAL_COLUMNS 68 | } 69 | feature_cols = dict(continuous_cols.items() + categorical_cols.items()) 70 | feature_cols['example_id'] = tf.constant( 71 | [str(i + 1) for i in range(df['age'].size)]) 72 | label = tf.constant(df[LABEL_COLUMN].values) 73 | return feature_cols, label 74 | 75 | 76 | def train_input_fn(): 77 | return input_fn(df_train) 78 | 79 | 80 | def eval_input_fn(): 81 | return input_fn(df_test) 82 | 83 | 84 | model_dir = '../svm_model_dir' 85 | 86 | model = svm.SVM(example_id_column='example_id', 87 | feature_columns=FEATURE_COLUMNS, 88 | model_dir=model_dir) 89 | model.fit(input_fn=train_input_fn, steps=10) 90 | results = model.evaluate(input_fn=eval_input_fn, steps=1) 91 | for key in sorted(results): 92 | print("%s: %s" % (key, results[key])) 93 | -------------------------------------------------------------------------------- /nlp/NMT/README.md: -------------------------------------------------------------------------------- 1 | ## Coding Notes 2 | 3 | ### Covert the copus to vector representation 4 | 此处可以参考[nmt-keras/model-zoo.py](https://github.com/lvapeab/nmt-keras/blob/master/model_zoo.py)中source和targe的处理方法,当有pretrained weight时,可选择使用,无pretrained weight也可以 5 | -------------------------------------------------------------------------------- /nlp/NMT/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/nlp/NMT/scripts/__init__.py -------------------------------------------------------------------------------- /nlp/NMT/scripts/config.py: -------------------------------------------------------------------------------- 1 | import re 2 | import logging 3 | import sys 4 | 5 | NON_ALPHA_PAT = re.compile('[\.,-]') 6 | COPUS_TYPE = "middle" 7 | TRAIN_X = "train.en" 8 | TRAIN_Y = "train.vi" 9 | TEST_X_2012 = "tst2012.en" 10 | TEST_Y_2012 = "tst2012.cs" 11 | INPUR_SEQ_LENGTH = 30 12 | 13 | logger = logging.getLogger("Neural Machine Translator") 14 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s') 15 | file_handler = logging.FileHandler("test.log") 16 | file_handler.setFormatter(formatter) 17 | console_handler = logging.StreamHandler(sys.stdout) 18 | console_handler.formatter = formatter 19 | logger.addHandler(file_handler) 20 | logger.addHandler(console_handler) 21 | logger.setLevel(logging.INFO) -------------------------------------------------------------------------------- /nlp/NMT/scripts/dataset_helpers/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = '' 7 | UNK_WORD = '' 8 | BOS_WORD = '' 9 | EOS_WORD = '' 10 | -------------------------------------------------------------------------------- /nlp/NMT/scripts/dataset_helpers/Dict.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import Constants 3 | 4 | 5 | class Dict(object): 6 | def __init__(self, data=None, lower=False): 7 | self.idxToLabel = {} 8 | self.labelToIdx = {} 9 | self.frequencies = {} 10 | self.lower = lower 11 | 12 | # Special entries will not be pruned. 13 | self.special = [] 14 | 15 | if data is not None: 16 | if type(data) == str: 17 | self.loadFile(data) 18 | else: 19 | self.addSpecials(data) 20 | 21 | def size(self): 22 | return len(self.idxToLabel) 23 | 24 | def loadFile(self, filename): 25 | "Load entries from a file." 26 | for line in codecs.open(filename, 'r', 'utf-8'): 27 | fields = line.split() 28 | label = fields[0] 29 | idx = int(fields[1]) 30 | self.add(label, idx) 31 | 32 | def writeFile(self, filename): 33 | "Write entries to a file." 34 | with codecs.open(filename, 'w', 'utf-8') as file: 35 | for i in range(self.size()): 36 | label = self.idxToLabel[i] 37 | file.write('%s %d\n' % (label, i)) 38 | 39 | file.close() 40 | 41 | def lookup(self, key, default=None): 42 | key = key.lower() if self.lower else key 43 | try: 44 | return self.labelToIdx[key] 45 | except KeyError: 46 | return default 47 | 48 | def align(self, other): 49 | "Find the id of each label in other dict." 50 | alignment = [Constants.PAD] * self.size() 51 | for idx, label in self.idxToLabel.items(): 52 | if label in other.labelToIdx: 53 | alignment[idx] = other.labelToIdx[label] 54 | return alignment 55 | 56 | def getLabel(self, idx, default=None): 57 | try: 58 | return self.idxToLabel[idx] 59 | except KeyError: 60 | return default 61 | 62 | def addSpecial(self, label, idx=None): 63 | "Mark this `label` and `idx` as special (i.e. will not be pruned)." 64 | idx = self.add(label, idx) 65 | self.special += [idx] 66 | 67 | def addSpecials(self, labels): 68 | "Mark all labels in `labels` as specials (i.e. will not be pruned)." 69 | for label in labels: 70 | self.addSpecial(label) 71 | 72 | def add(self, label, idx=None): 73 | "Add `label` in the dictionary. Use `idx` as its index if given." 74 | label = label.lower() if self.lower else label 75 | if idx is not None: 76 | self.idxToLabel[idx] = label 77 | self.labelToIdx[label] = idx 78 | else: 79 | if label in self.labelToIdx: 80 | idx = self.labelToIdx[label] 81 | else: 82 | idx = len(self.idxToLabel) 83 | self.idxToLabel[idx] = label 84 | self.labelToIdx[label] = idx 85 | 86 | if idx not in self.frequencies: 87 | self.frequencies[idx] = 1 88 | else: 89 | self.frequencies[idx] += 1 90 | 91 | return idx 92 | 93 | def prune(self, size): 94 | "Return a new dictionary with the `size` most frequent entries." 95 | if size >= self.size(): 96 | return self 97 | 98 | # Only keep the `size` most frequent entries. 99 | freq = [self.frequencies[i] for i in range(len(self.frequencies))] 100 | print freq[:100] 101 | idx = sorted(range(len(freq)), key=lambda k: freq[k], reverse=True) 102 | print idx[:100] 103 | 104 | newDict = Dict() 105 | newDict.lower = self.lower 106 | 107 | # Add special entries in all cases. 108 | for i in self.special: 109 | newDict.addSpecial(self.idxToLabel[i]) 110 | 111 | for i in idx[:size]: 112 | newDict.add(self.idxToLabel[i]) 113 | 114 | return newDict 115 | 116 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None): 117 | """ 118 | Convert `labels` to indices. Use `unkWord` if not found. 119 | Optionally insert `bosWord` at the beginning and `eosWord` at the . 120 | """ 121 | vec = [] 122 | 123 | if bosWord is not None: 124 | vec += [self.lookup(bosWord)] 125 | 126 | unk = self.lookup(unkWord) 127 | vec += [self.lookup(label, default=unk) for label in labels] 128 | 129 | if eosWord is not None: 130 | vec += [self.lookup(eosWord)] 131 | 132 | return vec 133 | 134 | def convertToLabels(self, idx, stop): 135 | """ 136 | Convert `idx` to labels. 137 | If index `stop` is reached, convert it and return. 138 | """ 139 | 140 | labels = [] 141 | 142 | for i in idx: 143 | labels += [self.getLabel(i)] 144 | if i == stop: 145 | break 146 | 147 | return labels 148 | -------------------------------------------------------------------------------- /nlp/NMT/scripts/dataset_helpers/IO.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import Constants 3 | import numpy as np 4 | 5 | 6 | def align(src_tokens, tgt_tokens): 7 | """ 8 | Given two sequences of tokens, return 9 | a mask of where there is overlap. 10 | 11 | Returns: 12 | mask: tgt_len x src_len 13 | """ 14 | mask = np.zeros([len(src_tokens), len(tgt_tokens)]) 15 | 16 | for i in range(len(src_tokens)): 17 | for j in range(len(tgt_tokens)): 18 | if src_tokens[i] == tgt_tokens[j]: 19 | mask[i][j] = 1 20 | return mask 21 | 22 | 23 | def readSrcLine(src_line, 24 | src_dict, 25 | src_feature_dicts, 26 | _type="text", 27 | src_img_dir=""): 28 | srcFeats = None 29 | if _type == "text": 30 | srcWords, srcFeatures, _ = extractFeatures(src_line) 31 | srcData = src_dict.convertToIdx(srcWords, Constants.UNK_WORD) 32 | if src_feature_dicts: 33 | srcFeats = [ 34 | src_feature_dicts[j].convertToIdx(srcFeatures[j], 35 | Constants.UNK_WORD) 36 | for j in range(len(src_feature_dicts)) 37 | ] 38 | # elif _type == "img": 39 | # if not transforms: 40 | # loadImageLibs() 41 | # srcData = transforms.ToTensor()(Image.open(src_img_dir + "/" + 42 | # srcWords[0])) 43 | 44 | return srcWords, srcData, srcFeats 45 | 46 | 47 | def readTgtLine(tgt_line, tgt_dict, tgt_feature_dicts, _type="text"): 48 | tgtFeats = None 49 | tgtWords, tgtFeatures, _ = extractFeatures(tgt_line) 50 | tgtData = tgt_dict.convertToIdx(tgtWords, Constants.UNK_WORD, 51 | Constants.BOS_WORD, Constants.EOS_WORD) 52 | if tgt_feature_dicts: 53 | tgtFeats = [ 54 | tgt_feature_dicts[j].convertToIdx(tgtFeatures[j], 55 | Constants.UNK_WORD) 56 | for j in range(len(tgt_feature_dicts)) 57 | ] 58 | 59 | return tgtWords, tgtData, tgtFeats 60 | 61 | 62 | def extractFeatures(tokens): 63 | "Given a list of token separate out words and features (if any)." 64 | words = [] 65 | features = [] 66 | numFeatures = None 67 | 68 | for t in range(len(tokens)): 69 | field = tokens[t].split(u"│") 70 | word = field[0] 71 | if len(word) > 0: 72 | words.append(word) 73 | 74 | if numFeatures is None: 75 | numFeatures = len(field) - 1 76 | else: 77 | assert (len(field) - 1 == numFeatures), \ 78 | "all words must have the same number of features" 79 | 80 | if len(field) > 1: 81 | for i in range(1, len(field)): 82 | if len(features) <= i - 1: 83 | features.append([]) 84 | features[i - 1].append(field[i]) 85 | assert (len(features[i - 1]) == len(words)) 86 | return words, features, numFeatures if numFeatures else 0 87 | -------------------------------------------------------------------------------- /nlp/NMT/scripts/dataset_helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/nlp/NMT/scripts/dataset_helpers/__init__.py -------------------------------------------------------------------------------- /nlp/NMT/scripts/dataset_helpers/copus.py: -------------------------------------------------------------------------------- 1 | from keras.preprocessing.text import Tokenizer 2 | import os 3 | import operator 4 | import sys 5 | sys.path.append("..") 6 | from config import * 7 | 8 | 9 | class copus: 10 | def __init__(self, train_data_path, test_data_path, test_type="2013"): 11 | self.train_data_path = train_data_path 12 | self.test_data_path = test_data_path 13 | self.__build_dataset(test_type) 14 | 15 | def __build_dataset(self, test_type="2013"): 16 | self.train_source_file = os.path.join(self.train_data_path, TRAIN_X) 17 | self.train_target_file = os.path.join(self.train_data_path, TRAIN_Y) 18 | if test_type == "2013": 19 | self.test_source_file = os.path.join(self.test_data_path, 20 | TEST_X_2013) 21 | self.test_target_file = os.path.join(self.test_data_path, 22 | TEST_Y_2013) 23 | if test_type == "2014": 24 | self.test_source_file = os.path.join(self.test_data_path, 25 | TEST_X_2014) 26 | self.test_target_file = os.path.join(self.test_data_path, 27 | TEST_Y_2014) 28 | if test_type == "2015": 29 | self.test_source_file = os.path.join(self.test_data_path, 30 | TEST_X_2015) 31 | self.test_target_file = os.path.join(self.test_data_path, 32 | TEST_Y_2015) 33 | 34 | def read_copus_generator(self, batch_size=64): 35 | """ return a generator with the specified batch_size 36 | """ 37 | logger.info("Beigin read copus {0}".format(file_name)) 38 | data = [] 39 | index = 0 40 | with open(file_name, 'r') as fread: 41 | while True: 42 | try: 43 | line = fread.readline() 44 | data.append(line) 45 | index += 1 46 | if index % 100000 == 0: 47 | logger.info("The program has processed {0} lines ". 48 | format(index)) 49 | except: 50 | logger.info("Read End") 51 | break 52 | tokenizer = Tokenizer(nb_words=30000) 53 | tokenizer.fit_on_texts(data) 54 | logger.info("word num: {0}".format(len(tokenizer.word_counts))) 55 | sorted_word_counts = sorted( 56 | tokenizer.word_counts.items(), 57 | key=operator.itemgetter(1), 58 | reverse=True) 59 | # save the word_counts to the meta 60 | with open(file_name.replace("train.", "meta."), "w") as fwrite: 61 | for word_cnt in sorted_word_counts: 62 | key = word_cnt[0] 63 | val = word_cnt[1] 64 | line = key + ":" + str(val) + "\n" 65 | fwrite.write(line) 66 | vectorize_data = tokenizer.texts_to_matrix(data) 67 | return vectorize_data 68 | 69 | 70 | if __name__ == "__main__": 71 | copus_obj = copus("../../datasets/stanford/train", 72 | "../../datasets/stanford/test") 73 | logger.info(copus_obj.train_source_data[0]) 74 | logger.info("train copus shape {0}".format( 75 | copus_obj.train_source_data.shape)) 76 | -------------------------------------------------------------------------------- /nlp/NMT/scripts/dataset_helpers/preprocess.sh: -------------------------------------------------------------------------------- 1 | python preprocess.py -train_src ../../datasets/stanford/train/small/train.en -train_tgt ../../datasets/stanford/train/small/train.vi -valid_src ../../datasets/stanford/train/small/tst2012.en -valid_tgt ../../datasets/stanford/train/small/tst2012.vi -save_data ../../datasets/stanford/demo/small 2 | -------------------------------------------------------------------------------- /nlp/NMT/scripts/models/models.py: -------------------------------------------------------------------------------- 1 | from keras.optimizers import Adam, RMSprop, Nadam, Adadelta, SGD, Adagrad, Adamax 2 | from ..config import * 3 | from keras.layers import * 4 | 5 | 6 | class TranslationModel(): 7 | def __init__(self, optimizer_type, lr, loss_type): 8 | self.optimizer_type = optimizer_type 9 | self.lr = lr 10 | self.loss_type = loss_type 11 | 12 | def setOptimizer(self): 13 | logger.info("Preparing optimizer: {0}, [LR: {1} - LOSS: {2}.]".format( 14 | self.optimizer_type, self.lr, self.loss_type)) 15 | if self.optimizer_type.lower() == "sgd": 16 | self.optimizer = SGD(lr=self.lr, ) 17 | elif self.optimizer_type.lower() == "rsmprop": 18 | self.optimizer = RMSprop(lr=self.lr) 19 | elif self.optimizer_type.lower() == "adagrad": 20 | self.optimizer = Adagrad(lr=self.lr) 21 | elif self.optimizer_type.lower() == "adam": 22 | self.optimizer = Adam(lr=self.lr) 23 | elif self.optimizer_type.lower() == "adamax": 24 | self.optimizer = Adamax(lr=self.lr) 25 | elif self.optimizer_type.lower() == "nadam": 26 | self.optimizer = Nadam(lr=self.lr) 27 | else: 28 | logger.info("\t WARNING: Not supported Now") 29 | 30 | def setLoss(self): 31 | pass 32 | 33 | def buildModel(self): 34 | src_text = Input( 35 | name="NMT_input", batch_shape=tuple([None, None]), dtype="int32") 36 | -------------------------------------------------------------------------------- /nlp/Tag2Vec/README.md: -------------------------------------------------------------------------------- 1 | ## Tag2Vec 2 | 3 | Details in [Word2Vec的迁移实践:Tag2Vec](http://hacker.duanshishi.com/?p=1813)。 4 | 5 | ![](./scripts/tsne.png) -------------------------------------------------------------------------------- /nlp/Tag2Vec/scripts/gen_w2v.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | from gensim.models import word2vec 3 | # from config import * 4 | import logging 5 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) 6 | sentence = word2vec.LineSentence( 7 | '../data/tag_day_ok.csv' 8 | ) 9 | model = word2vec.Word2Vec(sentences=sentence, size=50, workers=4, min_count=5) 10 | news_w2v = '../data/tag_word2vec.model' 11 | model.save(news_w2v) 12 | -------------------------------------------------------------------------------- /nlp/Tag2Vec/scripts/tsne.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/nlp/Tag2Vec/scripts/tsne.png -------------------------------------------------------------------------------- /nlp/Tag2Vec/scripts/visual_embeddings.py: -------------------------------------------------------------------------------- 1 | from sklearn.manifold import TSNE 2 | import matplotlib.pyplot as plt 3 | import pickle 4 | import numpy as np 5 | plt.rcParams['font.sans-serif'] = ['SimHei'] 6 | plt.rcParams['axes.unicode_minus'] = False 7 | np.random.seed(1) 8 | 9 | # load final_embeddings 10 | final_embeddings = pickle.load(open("../data/final_embeddings.model", "r")) 11 | reverse_dictionary = pickle.load(open("../data/reverse_dictionary.dict", "r")) 12 | dictionary = dict(zip(reverse_dictionary.values(), reverse_dictionary.keys())) 13 | 14 | # read from t_tag_infos.csv and save in a dict 15 | tag_id_name = {'UNK': "UNk"} 16 | with open("../data/t_tag_infos.csv", "r") as fread: 17 | for line in fread.readlines(): 18 | tag_id_name_list = line.split("\t") 19 | tag_id = tag_id_name_list[0] 20 | tag_name = tag_id_name_list[1].strip() 21 | tag_id_name[tag_id] = tag_name 22 | 23 | 24 | def plot_with_labels(low_dim_embs, labels, filename='tsne.png'): 25 | assert low_dim_embs.shape[0] >= len(labels), 'More labels than embeddings' 26 | plt.figure(figsize=(18, 18)) # in inches 27 | for i, label in enumerate(labels): 28 | x, y = low_dim_embs[i, :] 29 | plt.scatter(x, y) 30 | plt.annotate( 31 | label, 32 | xy=(x, y), 33 | xytext=(5, 2), 34 | textcoords='offset points', 35 | ha='right', 36 | va='bottom') 37 | 38 | plt.savefig(filename) 39 | 40 | 41 | tsne = TSNE( 42 | perplexity=30, 43 | n_components=2, 44 | init='pca', 45 | random_state=1, 46 | n_iter=5000, 47 | method='exact') 48 | plot_only = 500 49 | print("final_embeddings size " + str(len(final_embeddings))) 50 | valid_embeddings = final_embeddings[:plot_only, :] 51 | valid_index = [] 52 | for index in xrange(plot_only): 53 | key = reverse_dictionary[index] 54 | if tag_id_name.has_key(key): 55 | valid_index.append(index) 56 | 57 | low_dim_embs = tsne.fit_transform(valid_embeddings[valid_index]) 58 | labels = [ 59 | tag_id_name[reverse_dictionary[i]].decode('utf-8') for i in valid_index 60 | ] 61 | plot_with_labels(low_dim_embs, labels) 62 | 63 | 64 | def get_topk(index, final_embeddings, k=10): 65 | print index 66 | presentation_labels = [] 67 | similarity = np.matmul(final_embeddings, np.transpose(final_embeddings)) 68 | nearest = (-similarity[index, :]).argsort()[1:10 + 1] 69 | print nearest 70 | for k in nearest: 71 | presentation_labels.append(tag_id_name[reverse_dictionary[k]]) 72 | 73 | print "{0} nearest labels : {1}".format( 74 | tag_id_name[reverse_dictionary[index]], ' '.join(presentation_labels)) 75 | 76 | 77 | # 1000629 78 | print dictionary['1000121'] 79 | get_topk(dictionary['1000121'], final_embeddings, k=10) -------------------------------------------------------------------------------- /nlp/Tag2Vec/scripts/visual_word2vec.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | import gensim 3 | import matplotlib.pyplot as plt 4 | from sklearn.manifold import TSNE 5 | plt.rcParams['font.sans-serif'] = ['SimHei'] 6 | plt.rcParams['axes.unicode_minus'] = False 7 | model = gensim.models.Word2Vec.load("../data/tag_word2vec.model") 8 | tag_id_name = {'UNK': "UNk"} 9 | # tag_id=>tag_name 10 | with open("../data/t_tag_infos.csv", "r") as fread: 11 | for line in fread.readlines(): 12 | tag_id_name_list = line.split("\t") 13 | tag_id = tag_id_name_list[0] 14 | tag_name = tag_id_name_list[1].strip() 15 | tag_id_name[tag_id] = tag_name 16 | 17 | tsne = TSNE( 18 | perplexity=30, 19 | n_components=2, 20 | init='pca', 21 | random_state=1, 22 | n_iter=5000, 23 | method='exact') 24 | 25 | 26 | def plot_with_labels(low_dim_embs, labels, filename='tsne.png'): 27 | assert low_dim_embs.shape[0] >= len(labels), 'More labels than embeddings' 28 | plt.figure(figsize=(18, 18)) # in inches 29 | for i, label in enumerate(labels): 30 | x, y = low_dim_embs[i, :] 31 | plt.scatter(x, y) 32 | plt.annotate( 33 | label, 34 | xy=(x, y), 35 | xytext=(5, 2), 36 | textcoords='offset points', 37 | ha='right', 38 | va='bottom') 39 | 40 | plt.savefig(filename) 41 | # plt.imshow() 42 | 43 | 44 | X = model[model.wv.vocab] 45 | X_tsne = tsne.fit_transform(X[:500]) 46 | labels = model.wv.vocab.keys()[:500] 47 | labels = [tag_id_name[i].decode('utf-8') for i in labels] 48 | plot_with_labels(X_tsne, labels) 49 | 50 | tag_name_id = dict(zip(tag_id_name.values(), tag_id_name.keys())) 51 | 52 | 53 | def get_topk(tag_word, model, topk=50): 54 | nearest_list = model.wv.similar_by_word(tag_name_id[tag_word], topn=topk) 55 | nearest_words = [tag_id_name[i[0]] for i in nearest_list] 56 | # nearest_words_score = [tag_id_name[i] for i in nearest_list] 57 | print "near the {0}, the top {1} words are {2}".format( 58 | tag_word, topk, ' '.join(nearest_words)) 59 | 60 | 61 | get_topk("知乎", model) 62 | -------------------------------------------------------------------------------- /nlp/Tag2Vec/scripts/word2vec_ops.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/nlp/Tag2Vec/scripts/word2vec_ops.so -------------------------------------------------------------------------------- /nlp/text_classifier/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/nlp/text_classifier/scripts/__init__.py -------------------------------------------------------------------------------- /nlp/text_classifier/scripts/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | DATA_DIR = "../../data/origin_data" 3 | DATA_DIR = os.path.abspath(DATA_DIR) 4 | print DATA_DIR 5 | filename = [ 6 | 100, 7 | 103, 8 | 104, 9 | 105, 10 | 106, 11 | 107, 12 | 108, 13 | 109, 14 | 110, 15 | 111, 16 | 112, 17 | 115, 18 | 116, 19 | 118, 20 | 119, 21 | 121, 22 | 122, 23 | 123, 24 | 124, 25 | 148, 26 | ] 27 | all_text_filename = os.path.join(DATA_DIR, 'all.csv') 28 | filename_label = dict(zip(filename, range(len(filename)))) 29 | # W2V_FILE = "../../data/pretrain_w2v/zh/test.tsv" 30 | W2V_FILE = "/Users/burness/git_repository/dl_opensource/nlp/oxford-cs-deepnlp-2017/practical-2/data/pretrain_w2v/w2vgood_20170209.model" 31 | # W2V_FILE = os.path.abspath(W2V_FILE) 32 | EMBEDDING_DIM = 200 33 | MAX_SEQUENCE_LENGTH = 1500 34 | CLASS_NUM = 2 35 | WORD_DICT = "/Users/burness/git_repository/dl_opensource/nlp/oxford-cs-deepnlp-2017/practical-2/data/origin_data/t_tag_infos.txt" 36 | all_title_filename = os.path.join(DATA_DIR, 'all_title.csv') 37 | MAX_TITLE_LENGTH = 20 38 | -------------------------------------------------------------------------------- /nlp/text_classifier/scripts/dataset_helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/nlp/text_classifier/scripts/dataset_helpers/__init__.py -------------------------------------------------------------------------------- /nlp/text_classifier/scripts/dataset_helpers/doc_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import sys 3 | sys.path.append("../") 4 | from config import * 5 | import os 6 | import logging 7 | import operator 8 | logger = logging.getLogger(__name__) 9 | handler = logging.StreamHandler() 10 | formatter = logging.Formatter( 11 | '%(asctime)s %(name)-12s %(levelname)-8s %(message)s') 12 | handler.setFormatter(formatter) 13 | logger.addHandler(handler) 14 | logger.setLevel(logging.DEBUG) 15 | 16 | 17 | def prepare_dataset(): 18 | dataset_info = {} 19 | files = os.listdir(os.path.join(DATA_DIR.replace('origin_data', 'title'))) 20 | with open(all_text_filename, 'w') as fwrite: 21 | for file in files: 22 | i = 0 23 | file_path = os.path.join( 24 | DATA_DIR.replace('origin_data', 'title'), file) 25 | logger.info("Process file {0}".format(file_path)) 26 | with open(file_path, 'r') as fread: 27 | for line in fread.readlines(): 28 | i += 1 29 | line_list = line.split("|") 30 | # print line 31 | if len(line_list) >= 3: 32 | doc_text = line.split("|")[2] 33 | w_line = str(filename_label[int( 34 | file)]) + "\t" + doc_text 35 | fwrite.write(w_line) 36 | dataset_info[file] = i 37 | print dataset_info 38 | sorted_dataset_info = sorted( 39 | dataset_info.items(), key=operator.itemgetter(1), reverse=True) 40 | print sorted_dataset_info 41 | 42 | 43 | def prepare_title_dataset(): 44 | files = os.listdir(DATA_DIR.replace('origin_data', 'title')) 45 | with open(all_title_filename, 'w') as fwrite: 46 | for file in files: 47 | i = 0 48 | file_path = os.path.join( 49 | DATA_DIR.replace('origin_data', 'title'), file) 50 | logger.info("Process file {0}".format(file_path)) 51 | with open(file_path, 'r') as fread: 52 | for line in fread.readlines(): 53 | i += 1 54 | line_list = line.split("|") 55 | if len(line_list) >= 3: 56 | doc_title = line.split("|")[1] 57 | w_line = str(filename_label[int( 58 | file)]) + "\t" + doc_title + '\n' 59 | fwrite.write(w_line) 60 | # data 61 | 62 | 63 | if __name__ == "__main__": 64 | prepare_dataset() 65 | # prepare_title_dataset() 66 | -------------------------------------------------------------------------------- /nlp/text_classifier/scripts/dataset_helpers/gen_w2v.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | from gensim.models import word2vec 3 | # from config import * 4 | import logging 5 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) 6 | sentence = word2vec.LineSentence( 7 | '/Users/burness/git_repository/dl_opensource/tensorflow-101/nlp/text_classifier/data/origin_data/all_token.csv' 8 | ) 9 | model = word2vec.Word2Vec(sentences=sentence, size=50, workers=4, min_count=5) 10 | # model.most_similar() 11 | news_w2v = '/Users/burness/git_repository/dl_opensource/tensorflow-101/nlp/text_classifier/data/origin_data/news_w2v.model' 12 | model.save(news_w2v) 13 | # model.save 14 | # model.wv.similar_by_word(u"习近平", topn=10) -------------------------------------------------------------------------------- /serving/READMD.md: -------------------------------------------------------------------------------- 1 | ## bazel 2 | ## gRPC 3 | 4 | pip install grpcio -------------------------------------------------------------------------------- /serving/checkpoint/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "checkpoint.ckpt" 2 | all_model_checkpoint_paths: "checkpoint.ckpt" 3 | -------------------------------------------------------------------------------- /serving/checkpoint/checkpoint.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/serving/checkpoint/checkpoint.ckpt -------------------------------------------------------------------------------- /serving/checkpoint/checkpoint.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/serving/checkpoint/checkpoint.ckpt.meta -------------------------------------------------------------------------------- /serving/generate_grpc_file.sh: -------------------------------------------------------------------------------- 1 | rm#!/bin/bash 2 | 3 | set -x 4 | set -e 5 | 6 | cp predict.proto ./tensorflow/ 7 | cd ./tensorflow/ 8 | python -m grpc.tools.protoc -I./ --python_out=.. --grpc_python_out=.. ./predict.proto 9 | rm ./predict.proto 10 | mv ./predict_pb2.py ../ 11 | -------------------------------------------------------------------------------- /serving/mnist_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function 5 | 6 | import grpc 7 | import json 8 | import numpy as np 9 | from tensorflow.python.framework import tensor_util 10 | 11 | import predict_pb2 12 | from PIL import Image 13 | 14 | 15 | def main(): 16 | # Connect with the gRPC server 17 | server_address = "127.0.0.1:50051" 18 | request_timeout = 5.0 19 | channel = grpc.insecure_channel(server_address) 20 | stub = predict_pb2.PredictionServiceStub(channel) 21 | 22 | # Make request data 23 | request = predict_pb2.PredictRequest() 24 | image = Image.open('../mnist_jpgs/4/pic_test1010.png') 25 | array = np.array(image)/(255*1.0) 26 | samples_features = array.reshape([-1,784]) 27 | 28 | # samples_features = np.array( 29 | # [[10, 10, 10, 8, 6, 1, 8, 9, 1], [10, 10, 10, 8, 6, 1, 8, 9, 1]]) 30 | samples_keys = np.array([1]) 31 | # Convert numpy to TensorProto 32 | request.inputs["features"].CopyFrom(tensor_util.make_tensor_proto( 33 | samples_features)) 34 | request.inputs["key"].CopyFrom(tensor_util.make_tensor_proto(samples_keys)) 35 | 36 | # Invoke gRPC request 37 | response = stub.Predict(request, request_timeout) 38 | 39 | # Convert TensorProto to numpy 40 | result = {} 41 | for k, v in response.outputs.items(): 42 | result[k] = tensor_util.MakeNdarray(v) 43 | print(result) 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /serving/mnist_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from concurrent import futures 5 | import time 6 | import json 7 | import grpc 8 | import numpy as np 9 | import tensorflow as tf 10 | import logging 11 | from tensorflow.python.framework import tensor_util 12 | 13 | import predict_pb2 14 | 15 | logging.basicConfig(level=logging.DEBUG) 16 | 17 | _ONE_DAY_IN_SECONDS = 60 * 60 * 24 18 | 19 | 20 | class PredictionService(predict_pb2.PredictionServiceServicer): 21 | def __init__(self, checkpoint_file, graph_file): 22 | self.checkpoint_file = checkpoint_file 23 | self.graph_file = graph_file 24 | self.sess = None 25 | self.inputs = None 26 | self.outputs = None 27 | 28 | self.init_session_handler() 29 | 30 | def init_session_handler(self): 31 | self.sess = tf.Session() 32 | 33 | # Restore graph and weights from the model file 34 | ckpt = tf.train.get_checkpoint_state(self.checkpoint_file) 35 | if ckpt and ckpt.model_checkpoint_path: 36 | logging.info("Use the model: {}".format( 37 | ckpt.model_checkpoint_path)) 38 | saver = tf.train.import_meta_graph(self.graph_file) 39 | saver.restore(self.sess, ckpt.model_checkpoint_path) 40 | 41 | self.inputs = json.loads(tf.get_collection('input')[0]) 42 | self.outputs = json.loads(tf.get_collection('output')[0]) 43 | else: 44 | logging.error("No model found, exit") 45 | exit() 46 | 47 | def Predict(self, request, context): 48 | """Run predict op for each request. 49 | 50 | Args: 51 | request: The TensorProto which contains the map of "inputs". The request.inputs looks like {'features': dtype: DT_FLOAT tensor_shape { dim { size: 2 } } tensor_content: "\000\000 A\000\000?" }. 52 | context: The grpc.beta._server_adaptations._FaceServicerContext object. 53 | 54 | Returns: 55 | The TensorProto which contains the map of "outputs". The response.outputs looks like {'softmax': dtype: DT_FLOAT tensor_shape { dim { size: 2 } } tensor_content: "\\\326\242=4\245k?\\\326\242=4\245k?" } 56 | """ 57 | request_map = request.inputs 58 | feed_dict = {} 59 | for k, v in self.inputs.items(): 60 | # Convert TensorProto objects to numpy 61 | feed_dict[v] = tensor_util.MakeNdarray(request_map[k]) 62 | 63 | # Example result: {'key': array([ 2., 2.], dtype=float32), 'prediction': array([1, 1]), 'softmax': array([[ 0.07951042, 0.92048955], [ 0.07951042, 0.92048955]], dtype=float32)} 64 | predict_result = self.sess.run(self.outputs, feed_dict=feed_dict) 65 | 66 | response = predict_pb2.PredictResponse() 67 | for k, v in predict_result.items(): 68 | # Convert numpy objects to TensorProto 69 | response.outputs[k].CopyFrom(tensor_util.make_tensor_proto(v)) 70 | return response 71 | 72 | 73 | def serve(prediction_service): 74 | """Start the gRPC service.""" 75 | logging.info("Start gRPC server with PredictionService: {}".format(vars( 76 | prediction_service))) 77 | 78 | # TODO: not able to use ThreadPoolExecutor 79 | #server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) 80 | #inference_pb2.add_InferenceServiceService_to_server(InferenceService(), server) 81 | server = predict_pb2.beta_create_PredictionService_server( 82 | prediction_service) 83 | server.add_insecure_port('[::]:50051') 84 | server.start() 85 | try: 86 | while True: 87 | time.sleep(_ONE_DAY_IN_SECONDS) 88 | except KeyboardInterrupt: 89 | server.stop(0) 90 | 91 | 92 | if __name__ == '__main__': 93 | # Specify the model files 94 | checkpoint_file = "./checkpoint/" 95 | graph_file = "./checkpoint/checkpoint.ckpt.meta" 96 | prediction_service = PredictionService(checkpoint_file, graph_file) 97 | 98 | serve(prediction_service) 99 | -------------------------------------------------------------------------------- /serving/predict.proto: -------------------------------------------------------------------------------- 1 | // Refer to https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto 2 | syntax = "proto3"; 3 | 4 | package tensorflow.serving; 5 | option cc_enable_arenas = true; 6 | 7 | import "tensorflow/core/framework/tensor.proto"; 8 | 9 | message PredictRequest { 10 | map inputs = 1; 11 | } 12 | 13 | message PredictResponse { 14 | map outputs = 1; 15 | } 16 | 17 | service PredictionService { 18 | rpc Predict(PredictRequest) returns (PredictResponse); 19 | } 20 | -------------------------------------------------------------------------------- /serving/predict_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/serving/predict_pb2.pyc -------------------------------------------------------------------------------- /serving/train_mnist_softmax4serving.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | 7 | 8 | import tensorflow as tf 9 | import time 10 | import json 11 | import sys 12 | sys.path.append('../') 13 | from algorithm import input_data 14 | FLAGS = tf.app.flags.FLAGS 15 | 16 | tf.flags.DEFINE_string("data_dir", "../algorithm/mnist", 17 | "mnist data_dir") 18 | tf.flags.DEFINE_string('tensorboard_dir',"./tensorboard", "log dir") 19 | tf.flags.DEFINE_string('checkpoint_dir','./checkpoint', 'checkpoint dir') 20 | 21 | 22 | 23 | 24 | def inference(inputs): 25 | # W = tf.Variable(tf.random_normal([784, 10])) 26 | # b = tf.Variable(tf.random_normal([10])) 27 | # y = tf.matmul(inputs, W) + b 28 | W1 = tf.Variable(tf.random_normal([784, 256])) 29 | b1 = tf.Variable(tf.random_normal([256])) 30 | W2 = tf.Variable(tf.random_normal([256, 10])) 31 | b2 = tf.Variable(tf.random_normal([10])) 32 | lay1 = tf.nn.relu(tf.matmul(inputs, W1) + b1) 33 | y = tf.add(tf.matmul(lay1, W2),b2) 34 | return y 35 | 36 | 37 | 38 | def main(_): 39 | mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) 40 | input = tf.placeholder(tf.float32, [None, 784]) 41 | y = inference(input) 42 | y_ = tf.placeholder(tf.float32, [None, 10]) 43 | global_step = tf.Variable(0, name='global_step', trainable=False) 44 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)) 45 | train_step = tf.train.GradientDescentOptimizer(0.2).minimize(cross_entropy, global_step=global_step) 46 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 47 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 48 | 49 | inference_features = tf.placeholder(tf.float32, [None, 784]) 50 | inference_logits = inference(inference_features) 51 | inference_softmax = tf.nn.softmax(inference_logits) 52 | inference_result = tf.argmax(inference_softmax, 1) 53 | checkpoint_dir = FLAGS.checkpoint_dir 54 | checkpoint_file = checkpoint_dir + "/checkpoint.ckpt" 55 | try: 56 | init_op = tf.global_variables_initializer() 57 | except AttributeError: 58 | init_op = tf.initialize_all_variables() 59 | # add the tensors to the tensorboard logs 60 | try: 61 | tf.summary.scalar('loss', cross_entropy) 62 | tf.summary.scalar('accuracy', accuracy) 63 | except AttributeError: 64 | tf.scalar_summary('loss', cross_entropy) 65 | tf.scalar_summary('accuracy', accuracy) 66 | 67 | saver = tf.train.Saver() 68 | keys_placeholder = tf.placeholder("float") 69 | keys = tf.identity(keys_placeholder) 70 | tf.add_to_collection("input", json.dumps({'key': keys_placeholder.name, 'features': inference_features.name})) 71 | tf.add_to_collection('output', json.dumps({'key': keys.name, 'softmax': inference_softmax.name, 'prediction': inference_result.name})) 72 | with tf.Session() as sess: 73 | try: 74 | summary_op = tf.summary.merge_all() 75 | except AttributeError: 76 | summary_op = tf.merge_all_summaries() 77 | tensorboard_dir = FLAGS.tensorboard_dir 78 | writer = tf.summary.FileWriter(tensorboard_dir, sess.graph) 79 | try: 80 | tf.global_variables_initializer().run() 81 | except AttributeError: 82 | tf.initialize_all_variables().run() 83 | for index in range(10000): 84 | print('process the {}th batch'.format(index)) 85 | start_train = time.time() 86 | batch_xs, batch_ys = mnist.train.next_batch(100) 87 | _,summary_val, step = sess.run([train_step, summary_op, global_step], feed_dict={input: batch_xs, y_: batch_ys}) 88 | writer.add_summary(summary_val, step) 89 | print('the {0} batch takes time: {1}'.format(index, time.time()-start_train)) 90 | print('the test dataset acc: ', sess.run(accuracy, feed_dict={input: mnist.test.images,y_: mnist.test.labels})) 91 | saver.save(sess, checkpoint_file) 92 | 93 | 94 | if __name__ == '__main__': 95 | tf.app.run() -------------------------------------------------------------------------------- /zhihu_code/README.md: -------------------------------------------------------------------------------- 1 | ## how to build a dl model to solve the zhihu verification recognition 2 | 3 | **issue 1** 4 | 5 | How to gen the random flip text image, it seems that `draw.text(pos, txt, font=self.font, fill=fill)` and the `rotate` are not useful. 6 | Google whether there is a solution to gen the random flip text in a image. 7 | Or I can concat the images (random flip text images) to gen a new Image. -------------------------------------------------------------------------------- /zhihu_code/src/gen_verification.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from PIL import Image,ImageDraw,ImageFont 3 | import random 4 | import math, string 5 | import logging 6 | # logger = logging.Logger(name='gen verification') 7 | 8 | class RandomChar(): 9 | @staticmethod 10 | def Unicode(): 11 | val = random.randint(0x4E00, 0x9FBF) 12 | return unichr(val) 13 | 14 | @staticmethod 15 | def GB2312(): 16 | head = random.randint(0xB0, 0xCF) 17 | body = random.randint(0xA, 0xF) 18 | tail = random.randint(0, 0xF) 19 | val = ( head << 8 ) | (body << 4) | tail 20 | str = "%x" % val 21 | return str.decode('hex').decode('gb2312') 22 | 23 | class ImageChar(): 24 | def __init__(self, fontColor = (0, 0, 0), 25 | size = (100, 40), 26 | fontPath = '/Library/Fonts/Arial Unicode.ttf', 27 | bgColor = (255, 255, 255), 28 | fontSize = 20): 29 | self.size = size 30 | self.fontPath = fontPath 31 | self.bgColor = bgColor 32 | self.fontSize = fontSize 33 | self.fontColor = fontColor 34 | self.font = ImageFont.truetype(self.fontPath, self.fontSize) 35 | self.image = Image.new('RGB', size, bgColor) 36 | 37 | def drawText(self, pos, txt, fill): 38 | draw = ImageDraw.Draw(self.image) 39 | draw.text(pos, txt, font=self.font, fill=fill) 40 | del draw 41 | 42 | def drawTextV2(self, pos, txt, fill, angle=180): 43 | image=Image.new('RGB', (25,25), (255,255,255)) 44 | draw = ImageDraw.Draw(image) 45 | draw.text( (0, -3), txt, font=self.font, fill=fill) 46 | w=image.rotate(angle, expand=1) 47 | self.image.paste(w, box=pos) 48 | del draw 49 | 50 | def randRGB(self): 51 | return (0,0,0) 52 | 53 | def randChinese(self, num, num_flip): 54 | gap = 1 55 | start = 0 56 | num_flip_list = random.sample(range(num), num_flip) 57 | # logger.info('num flip list:{0}'.format(num_flip_list)) 58 | print 'num flip list:{0}'.format(num_flip_list) 59 | char_list = [] 60 | for i in range(0, num): 61 | char = RandomChar().GB2312() 62 | char_list.append(char) 63 | x = start + self.fontSize * i + gap + gap * i 64 | if i in num_flip_list: 65 | self.drawTextV2((x, 6), char, self.randRGB()) 66 | else: 67 | self.drawText((x, 0), char, self.randRGB()) 68 | return char_list, num_flip_list 69 | def save(self, path): 70 | self.image.save(path) 71 | 72 | 73 | 74 | err_num = 0 75 | for i in range(10): 76 | try: 77 | ic = ImageChar(fontColor=(100,211, 90), size=(280,28), fontSize = 25) 78 | num_flip = random.randint(3,6) 79 | char_list, num_flip_list = ic.randChinese(10, num_flip) 80 | ic.save(''.join(char_list)+'_'+''.join(str(i) for i in num_flip_list)+".jpeg") 81 | except: 82 | err_num += 1 83 | continue 84 | # print ''.join(char_list) 85 | # print ''.join(str(i) for i in num_flip_list) -------------------------------------------------------------------------------- /zhihu_code/src/samples/亭猫那吉腊咖材年辜惯_3921.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/zhihu_code/src/samples/亭猫那吉腊咖材年辜惯_3921.jpeg -------------------------------------------------------------------------------- /zhihu_code/src/samples/希巫瀑汝蓝渡靡旱淋迈_052931.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/zhihu_code/src/samples/希巫瀑汝蓝渡靡旱淋迈_052931.jpeg -------------------------------------------------------------------------------- /zhihu_code/src/samples/慑嘻佬晚决幢聚色签品_90617.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/zhihu_code/src/samples/慑嘻佬晚决幢聚色签品_90617.jpeg -------------------------------------------------------------------------------- /zhihu_code/src/samples/抨玩厘撕唇惧茸死氦徒_36410.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/zhihu_code/src/samples/抨玩厘撕唇惧茸死氦徒_36410.jpeg -------------------------------------------------------------------------------- /zhihu_code/src/samples/毒气饺春类厂全编揉辅_39841.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/zhihu_code/src/samples/毒气饺春类厂全编揉辅_39841.jpeg -------------------------------------------------------------------------------- /zhihu_code/src/samples/片蹬片藤袒俏酶讫垢管_803547.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/zhihu_code/src/samples/片蹬片藤袒俏酶讫垢管_803547.jpeg -------------------------------------------------------------------------------- /zhihu_code/src/samples/畸具岔狂硼瘁务廷撕刮_92634.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/zhihu_code/src/samples/畸具岔狂硼瘁务廷撕刮_92634.jpeg -------------------------------------------------------------------------------- /zhihu_code/src/samples/算嫩曼掂羌屁解嫩历担_429637.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/zhihu_code/src/samples/算嫩曼掂羌屁解嫩历担_429637.jpeg -------------------------------------------------------------------------------- /zhihu_code/src/samples/茎搔逞沫彬培韶眠僧痛_675.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/zhihu_code/src/samples/茎搔逞沫彬培韶眠僧痛_675.jpeg -------------------------------------------------------------------------------- /zhihu_code/src/samples/陇琴卉底仅铰嗽寒售澜_021346.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/burness/tensorflow-101/c775a54af86542940e6e69b7d90d8d7e8aa9aeb9/zhihu_code/src/samples/陇琴卉底仅铰嗽寒售澜_021346.jpeg --------------------------------------------------------------------------------