├── .gitignore ├── LICENSE ├── README.md ├── classifier.py ├── inputs.py ├── inputs_test.py ├── lang_dataset.sh ├── predictor.py ├── predictor_client.py ├── process_input.py ├── requirements.txt ├── text_utils.py ├── train_classifier.sh └── train_langdetect.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Alan Patterson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastText in Tensorflow 2 | 3 | This project is based on the ideas in Facebook's [FastText](https://github.com/facebookresearch/fastText) but implemented in 4 | Tensorflow. However, it is not an exact replica of fastText. 5 | 6 | Classification is done by embedding each word, taking the mean 7 | embedding over the full text and classifying that using a linear 8 | classifier. The embedding is trained with the classifier. You can 9 | also specify to use 2+ character ngrams. These ngrams get hashed then 10 | embedded in a similar manner to the orginal words. Note, ngrams make 11 | training much slower but only make marginal improvements in 12 | performance, at least in English. 13 | 14 | I may implement skipgram and cbow training later. Or preloading 15 | embedding tables. 16 | 17 | << Still WIP >> 18 | 19 | You can use [Horovod](https://github.com/uber/horovod) to distribute 20 | training across multiple GPUs, on one or multiple servers. See usage 21 | section below. 22 | 23 | ## FastText Language Identification 24 | 25 | I have added utilities to train a classifier to detect languages, as 26 | described in [Fast and Accurate Language Identification using 27 | FastText](https://fasttext.cc/blog/2017/10/02/blog-post.html) 28 | 29 | See usage below. It basically works in the same way as default usage. 30 | 31 | ## Implemented: 32 | - classification of text using word embeddings 33 | - char ngrams, hashed to n bins 34 | - training and prediction program 35 | - serve models on tensorflow serving 36 | - preprocess facebook format, or text input into tensorflow records 37 | 38 | ## Not Implemented: 39 | - separate word vector training (though can export embeddings) 40 | - heirarchical softmax. 41 | - quantize models (supported by tensorflow, but I haven't tried it yet) 42 | 43 | # Usage 44 | 45 | The following are examples of how to use the applications. Get full help with 46 | `--help` option on any of the programs. 47 | 48 | To transform input data into tensorflow Example format: 49 | 50 | process_input.py --facebook_input=queries.txt --output_dir=. --ngrams=2,3,4 51 | 52 | Or, using a text file with one example per line with an extra file for labels: 53 | 54 | process_input.py --text_input=queries.txt --labels=labels.txt --output_dir=. 55 | 56 | To train a text classifier: 57 | 58 | classifier.py \ 59 | --train_records=queries.tfrecords \ 60 | --eval_records=queries.tfrecords \ 61 | --label_file=labels.txt \ 62 | --vocab_file=vocab.txt \ 63 | --model_dir=model \ 64 | --export_dir=model 65 | 66 | To predict classifications for text, use a saved_model from 67 | classifier. `classifier.py --export_dir` stores a saved model in a 68 | numbered directory below `export_dir`. Pass this directory to the 69 | following to use that model for predictions: 70 | 71 | predictor.py 72 | --saved_model=model/12345678 73 | --text="some text to classify" 74 | --signature_def=proba 75 | 76 | To export the embedding layer you can export from predictor. Note, 77 | this will only be the text embedding, not the ngram embeddings. 78 | 79 | predictor.py 80 | --saved_model=model/12345678 81 | --text="some text to classify" 82 | --signature_def=embedding 83 | 84 | Use the provided script to train easily: 85 | 86 | train_classifier.sh path-to-data-directory 87 | 88 | # Language Identification 89 | 90 | To implement something similar to the method described in [Fast and 91 | Accurate Language Identification using 92 | FastText](https://fasttext.cc/blog/2017/10/02/blog-post.html) you need to download the data: 93 | 94 | lang_dataset.sh [datadir] 95 | 96 | You can then process the training and validation data using 97 | `process_input.py` and `classifier.py` as described above. 98 | 99 | There is a utility script to do this for you: 100 | 101 | train_langdetect.sh datadir 102 | 103 | It reaches about 96% accuracy using word embeddings and this increases to nearly 99% when 104 | adding `--ngrams=2,3,4` 105 | 106 | # Distributed Training 107 | 108 | You can run training across multiple GPUs either on one or multiple 109 | servers. To do so you need to install MPI and 110 | [Horovod](https://github.com/uber/horovod) then add the `--horovod` 111 | option. It runs very close to the GPU multiple in terms of 112 | performance. I.e. if you have 2 GPUs on your server, it should run 113 | close to 2x the speed. 114 | 115 | NUM_GPUS=2 116 | mpirun -np $NUM_GPUS python classifier.py \ 117 | --horovod \ 118 | --train_records=queries.tfrecords \ 119 | --eval_records=queries.tfrecords \ 120 | --label_file=labels.txt \ 121 | --vocab_file=vocab.txt \ 122 | --model_dir=model \ 123 | --export_dir=model 124 | 125 | The training script has this option added: `train_classifier.sh`. 126 | 127 | # Tensorflow Serving 128 | 129 | As well as using `predictor.py` to run a saved model to provide 130 | predictions, it is easy to serve a saved model using Tensorflow 131 | Serving with a client server setup. There is a supplied simple rpc client (`predictor_client.py`) 132 | that provides predictions by using tensorflow server. 133 | 134 | First make sure you install the tensorflow serving binaries. Instructions are [here](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/setup.md#installing-the-modelserver). 135 | 136 | You then serve the latest saved model by supplying the base export 137 | directory where you exported saved models to. This directory will 138 | contain the numbered model directories: 139 | 140 | tensorflow_model_server --port=9000 --model_base_path=model 141 | 142 | Now you can make requests to the server using gRPC calls. An example 143 | simple client is provided in `predictor_client.py`: 144 | 145 | predictor_client.py --text="Some text to classify" 146 | 147 | # Facebook Examples 148 | 149 | << NOT IMPLEMENTED YET >> 150 | 151 | You can compare with Facebook's fastText by running similar examples 152 | to what's provided in their repository. 153 | 154 | ./classification_example.sh 155 | ./classification_results.sh 156 | -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | """Train simple fastText-style classifier. 2 | 3 | Inputs: 4 | words - text to classify 5 | ngrams - n char ngrams for each word in words 6 | labels - output classes to classify 7 | 8 | Model: 9 | word embedding 10 | ngram embedding 11 | LogisticRegression classifier of embeddings to labels 12 | """ 13 | from __future__ import absolute_import 14 | from __future__ import division 15 | from __future__ import print_function 16 | 17 | import inputs 18 | import sys 19 | import tensorflow as tf 20 | from tensorflow.contrib.layers import feature_column 21 | from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig 22 | 23 | 24 | tf.flags.DEFINE_string("train_records", None, 25 | "Training file pattern for TFRecords, can use wildcards") 26 | tf.flags.DEFINE_string("eval_records", None, 27 | "Evaluation file pattern for TFRecords, can use wildcards") 28 | tf.flags.DEFINE_string("predict_records", None, 29 | "File pattern for TFRecords to predict, can use wildcards") 30 | tf.flags.DEFINE_string("label_file", None, "File containing output labels") 31 | tf.flags.DEFINE_integer("num_labels", None, "Number of output labels") 32 | tf.flags.DEFINE_string("vocab_file", None, "Vocabulary file, one word per line") 33 | tf.flags.DEFINE_integer("vocab_size", None, "Number of words in vocabulary") 34 | tf.flags.DEFINE_integer("num_oov_vocab_buckets", 20, 35 | "Number of hash buckets to use for OOV words") 36 | tf.flags.DEFINE_string("model_dir", ".", 37 | "Output directory for checkpoints and summaries") 38 | tf.flags.DEFINE_string("export_dir", None, "Directory to store savedmodel") 39 | 40 | tf.flags.DEFINE_integer("embedding_dimension", 10, "Dimension of word embedding") 41 | tf.flags.DEFINE_boolean("use_ngrams", False, "Use character ngrams in embedding") 42 | tf.flags.DEFINE_integer("num_ngram_buckets", 1000000, 43 | "Number of hash buckets for ngrams") 44 | tf.flags.DEFINE_integer("ngram_embedding_dimension", 10, "Dimension of word embedding") 45 | 46 | tf.flags.DEFINE_float("learning_rate", 0.001, "Learning rate for training") 47 | tf.flags.DEFINE_float("clip_gradient", 5.0, "Clip gradient norm to this ratio") 48 | tf.flags.DEFINE_integer("batch_size", 128, "Training minibatch size") 49 | tf.flags.DEFINE_integer("train_steps", 1000, 50 | "Number of train steps, None for continuous") 51 | tf.flags.DEFINE_integer("eval_steps", 100, "Number of eval steps") 52 | tf.flags.DEFINE_integer("num_epochs", None, "Number of training data epochs") 53 | tf.flags.DEFINE_integer("checkpoint_steps", 1000, 54 | "Steps between saving checkpoints") 55 | tf.flags.DEFINE_integer("num_threads", 1, "Number of reader threads") 56 | tf.flags.DEFINE_boolean("log_device_placement", False, "log where ops are located") 57 | tf.flags.DEFINE_boolean("horovod", False, 58 | "Run across multiple GPUs using Horovod MPI. https://github.com/uber/horovod") 59 | tf.flags.DEFINE_boolean("debug", False, "Debug") 60 | FLAGS = tf.flags.FLAGS 61 | 62 | if FLAGS.horovod: 63 | try: 64 | import horovod.tensorflow as hvd 65 | except ImportError, e: 66 | print(e) 67 | print("Make sure Horovod is installed: https://github.com/uber/horovod") 68 | sys.exit(1) 69 | hvd.init() 70 | 71 | 72 | def InputFn(mode, input_file): 73 | return inputs.InputFn( 74 | mode, FLAGS.use_ngrams, input_file, FLAGS.vocab_file, FLAGS.vocab_size, 75 | FLAGS.embedding_dimension, FLAGS.num_oov_vocab_buckets, 76 | FLAGS.label_file, FLAGS.num_labels, 77 | FLAGS.ngram_embedding_dimension, FLAGS.num_ngram_buckets, 78 | FLAGS.batch_size, FLAGS.num_epochs, FLAGS.num_threads) 79 | 80 | 81 | def Exports(probs, embedding): 82 | exports = { 83 | "proba": tf.estimator.export.ClassificationOutput(scores=probs), 84 | "embedding": tf.estimator.export.RegressionOutput(value=embedding), 85 | tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: \ 86 | tf.estimator.export.ClassificationOutput(scores=probs), 87 | } 88 | return exports 89 | 90 | 91 | def FastTextEstimator(model_dir, config=None): 92 | params = { 93 | "learning_rate": FLAGS.learning_rate, 94 | } 95 | def model_fn(features, labels, mode, params): 96 | features["text"] = tf.sparse_tensor_to_dense(features["text"], 97 | default_value=" ") 98 | if FLAGS.use_ngrams: 99 | features["ngrams"] = tf.sparse_tensor_to_dense(features["ngrams"], 100 | default_value=" ") 101 | text_lookup_table = tf.contrib.lookup.index_table_from_file( 102 | FLAGS.vocab_file, FLAGS.num_oov_vocab_buckets, FLAGS.vocab_size) 103 | text_ids = text_lookup_table.lookup(features["text"]) 104 | text_embedding_w = tf.Variable(tf.random_uniform( 105 | [FLAGS.vocab_size + FLAGS.num_oov_vocab_buckets, FLAGS.embedding_dimension], 106 | -0.1, 0.1)) 107 | text_embedding = tf.reduce_mean(tf.nn.embedding_lookup( 108 | text_embedding_w, text_ids), axis=-2) 109 | input_layer = text_embedding 110 | if FLAGS.use_ngrams: 111 | ngram_hash = tf.string_to_hash_bucket(features["ngrams"], 112 | FLAGS.num_ngram_buckets) 113 | ngram_embedding_w = tf.Variable(tf.random_uniform( 114 | [FLAGS.num_ngram_buckets, FLAGS.ngram_embedding_dimension], -0.1, 0.1)) 115 | ngram_embedding = tf.reduce_mean(tf.nn.embedding_lookup( 116 | ngram_embedding_w, ngram_hash), axis=-2) 117 | ngram_embedding = tf.expand_dims(ngram_embedding, -2) 118 | input_layer = tf.concat([text_embedding, ngram_embedding], -1) 119 | num_classes = FLAGS.num_labels 120 | logits = tf.contrib.layers.fully_connected( 121 | inputs=input_layer, num_outputs=num_classes, 122 | activation_fn=None) 123 | predictions = tf.argmax(logits, axis=-1) 124 | probs = tf.nn.softmax(logits) 125 | loss, train_op = None, None 126 | metrics = {} 127 | if mode != tf.estimator.ModeKeys.PREDICT: 128 | label_lookup_table = tf.contrib.lookup.index_table_from_file( 129 | FLAGS.label_file, vocab_size=FLAGS.num_labels) 130 | labels = label_lookup_table.lookup(labels) 131 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 132 | labels=labels, logits=logits)) 133 | opt = tf.train.AdamOptimizer(params["learning_rate"]) 134 | if FLAGS.horovod: 135 | opt = hvd.DistributedOptimizer(opt) 136 | train_op = opt.minimize(loss, global_step=tf.train.get_global_step()) 137 | metrics = { 138 | "accuracy": tf.metrics.accuracy(labels, predictions) 139 | } 140 | exports = {} 141 | if FLAGS.export_dir: 142 | exports = Exports(probs, text_embedding) 143 | return tf.estimator.EstimatorSpec( 144 | mode, predictions=predictions, loss=loss, train_op=train_op, 145 | eval_metric_ops=metrics, export_outputs=exports) 146 | session_config = tf.ConfigProto( 147 | log_device_placement=FLAGS.log_device_placement) 148 | if FLAGS.horovod: 149 | session_config.gpu_options.visible_device_list = str(hvd.local_rank()) 150 | config = tf.contrib.learn.RunConfig( 151 | save_checkpoints_secs=None, 152 | save_checkpoints_steps=FLAGS.checkpoint_steps, 153 | session_config=session_config) 154 | return tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, 155 | params=params, config=config) 156 | 157 | 158 | def FastTrain(): 159 | print("FastTrain", FLAGS.train_steps) 160 | estimator = FastTextEstimator(FLAGS.model_dir) 161 | print("TEST" + FLAGS.train_records) 162 | train_input = InputFn(tf.estimator.ModeKeys.TRAIN, FLAGS.train_records) 163 | print("STARTING TRAIN") 164 | hooks = None 165 | if FLAGS.horovod: 166 | hooks = [hvd.BroadcastGlobalVariablesHook(0)] 167 | estimator.train(input_fn=train_input, steps=FLAGS.train_steps, hooks=hooks) 168 | print("TRAIN COMPLETE") 169 | if not FLAGS.horovod or hvd.rank() == 0: 170 | print("EVALUATE") 171 | eval_input = InputFn(tf.estimator.ModeKeys.EVAL, FLAGS.eval_records) 172 | #eval_metrics = { "accuracy": tf.metrics.accuracy(labels, predictions) } 173 | result = estimator.evaluate(input_fn=eval_input, steps=FLAGS.eval_steps, hooks=None) 174 | print(result) 175 | print("DONE") 176 | if FLAGS.export_dir: 177 | print("EXPORTING") 178 | estimator.export_savedmodel(FLAGS.export_dir, 179 | inputs.ServingInputFn(FLAGS.use_ngrams)) 180 | 181 | 182 | def main(_): 183 | if not FLAGS.vocab_size: 184 | FLAGS.vocab_size = len(open(FLAGS.vocab_file).readlines()) 185 | if not FLAGS.num_labels: 186 | FLAGS.num_labels = len(open(FLAGS.label_file).readlines()) 187 | if FLAGS.horovod: 188 | nproc = hvd.size() 189 | total = FLAGS.train_steps 190 | FLAGS.train_steps = total / nproc 191 | print("Running %d steps on each of %d processes for %d total" % ( 192 | FLAGS.train_steps, nproc, total)) 193 | FastTrain() 194 | 195 | 196 | if __name__ == '__main__': 197 | if FLAGS.debug: 198 | tf.logging.set_verbosity(tf.logging.DEBUG) 199 | tf.app.run() 200 | -------------------------------------------------------------------------------- /inputs.py: -------------------------------------------------------------------------------- 1 | """Input feature columns and input_fn for models. 2 | 3 | Handles both training, evaluation and inference. 4 | """ 5 | import tensorflow as tf 6 | 7 | 8 | def BuildTextExample(text, ngrams=None, label=None): 9 | record = tf.train.Example() 10 | text = [tf.compat.as_bytes(x) for x in text] 11 | record.features.feature["text"].bytes_list.value.extend(text) 12 | if label is not None: 13 | label = tf.compat.as_bytes(label) 14 | record.features.feature["label"].bytes_list.value.append(label) 15 | if ngrams is not None: 16 | ngrams = [tf.compat.as_bytes(x) for x in ngrams] 17 | record.features.feature["ngrams"].bytes_list.value.extend(ngrams) 18 | return record 19 | 20 | 21 | def ParseSpec(use_ngrams, include_target): 22 | parse_spec = {"text": tf.VarLenFeature(dtype=tf.string)} 23 | if use_ngrams: 24 | parse_spec["ngrams"] = tf.VarLenFeature(dtype=tf.string) 25 | if include_target: 26 | parse_spec["label"] = tf.FixedLenFeature(shape=(), dtype=tf.string, 27 | default_value=None) 28 | return parse_spec 29 | 30 | 31 | def InputFn(mode, 32 | use_ngrams, 33 | input_file, 34 | vocab_file, 35 | vocab_size, 36 | embedding_dimension, 37 | num_oov_vocab_buckets, 38 | label_file, 39 | label_size, 40 | ngram_embedding_dimension, 41 | num_ngram_hash_buckets, 42 | batch_size, 43 | num_epochs=None, 44 | num_threads=1): 45 | if num_epochs <= 0: 46 | num_epochs=None 47 | def input_fn(): 48 | include_target = mode != tf.estimator.ModeKeys.PREDICT 49 | parse_spec = ParseSpec(use_ngrams, include_target) 50 | print("ParseSpec", parse_spec) 51 | print("Input file:", input_file) 52 | features = tf.contrib.learn.read_batch_features( 53 | input_file, batch_size, parse_spec, tf.TFRecordReader, 54 | num_epochs=num_epochs, reader_num_threads=num_threads) 55 | label = None 56 | if include_target: 57 | label = features.pop("label") 58 | return features, label 59 | return input_fn 60 | 61 | 62 | def ServingInputFn(use_ngrams): 63 | parse_spec = ParseSpec(use_ngrams, include_target=False) 64 | return tf.estimator.export.build_parsing_serving_input_receiver_fn( 65 | parse_spec) 66 | -------------------------------------------------------------------------------- /inputs_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tensorflow as tf 3 | from google.protobuf import text_format 4 | from inputs import FeatureColumns, InputFn 5 | 6 | VOCAB_FILE='/home/alan/Workspace/other/fastText/data/ag_news.train.vocab' 7 | VOCAB_SIZE=95810 8 | INPUT_FILE='/home/alan/Workspace/other/fastText/data/ag_news.train.tfrecords-1-of-1' 9 | 10 | def test_parse_spec(): 11 | fc = FeatureColumns( 12 | True, 13 | False, 14 | VOCAB_FILE, 15 | VOCAB_SIZE, 16 | 10, 17 | 10, 18 | 1000, 19 | 10) 20 | parse_spec = tf.feature_column.make_parse_example_spec(fc) 21 | print parse_spec 22 | reader = tf.python_io.tf_record_iterator(INPUT_FILE) 23 | sess = tf.Session() 24 | for record in reader: 25 | example = tf.parse_single_example( 26 | record, 27 | parse_spec) 28 | print sess.run(example) 29 | break 30 | 31 | 32 | def test_reading_inputs(): 33 | parse_spec = { 34 | "text": tf.VarLenFeature(tf.string), 35 | "label": tf.FixedLenFeature(shape=(1,), dtype=tf.int64, 36 | default_value=None) 37 | } 38 | sess = tf.Session() 39 | reader = tf.python_io.tf_record_iterator(INPUT_FILE) 40 | ESZ = 4 41 | HSZ = 100 42 | NC = 4 43 | n = 0 44 | text_lookup_table = tf.contrib.lookup.index_table_from_file( 45 | VOCAB_FILE, 10, VOCAB_SIZE) 46 | text_embedding_w = tf.Variable(tf.random_uniform( 47 | [VOCAB_SIZE, ESZ], -1.0, 1.0)) 48 | sess.run([tf.tables_initializer()]) 49 | for record in reader: 50 | example = tf.parse_single_example( 51 | record, 52 | parse_spec) 53 | text = example["text"] 54 | labels = tf.subtract(example["label"], 1) 55 | text_ids = text_lookup_table.lookup(text) 56 | dense = tf.sparse_tensor_to_dense(text_ids) 57 | print dense.shape 58 | text_embedding = tf.reduce_mean(tf.nn.embedding_lookup( 59 | text_embedding_w, dense), axis=-2) 60 | print text_embedding.shape 61 | text_embedding = tf.expand_dims(text_embedding, -2) 62 | print text_embedding.shape 63 | text_embedding_2 = tf.contrib.layers.bow_encoder( 64 | dense, VOCAB_SIZE, ESZ) 65 | print text_embedding_2.shape 66 | num_classes = 2 67 | logits = tf.contrib.layers.fully_connected( 68 | inputs=text_embedding, num_outputs=4, 69 | activation_fn=None) 70 | sess.run([tf.global_variables_initializer()]) 71 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 72 | labels=labels, logits=logits) 73 | x = sess.run([text_embedding, text_embedding_2, logits, labels, loss]) 74 | print(len(x), list(str(x[i]) for i in range(len(x)))) 75 | if n > 2: 76 | break 77 | n += 1 78 | 79 | 80 | if __name__ == '__main__': 81 | print "Test Parse Spec:" 82 | test_parse_spec() 83 | print "Test Input Fn" 84 | test_reading_inputs() 85 | -------------------------------------------------------------------------------- /lang_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $# -ne 1 ]] 4 | then 5 | DATADIR="/tmp/lang_detect" 6 | fi 7 | 8 | if [[ ! -d $DATADIR ]] 9 | then 10 | mkdir -p $DATADIR 11 | if [[ $? -ne 0 ]] 12 | then 13 | echo "Failed to create $DATADIR" 14 | echo "Usage: lang_dataset.sh [datadir]" 15 | exit 1 16 | fi 17 | fi 18 | 19 | set -v 20 | 21 | pushd $DATADIR 22 | wget http://downloads.tatoeba.org/exports/sentences.tar.bz2 23 | bunzip2 sentences.tar.bz2 24 | tar xvf sentences.tar 25 | awk -F"\t" '{print"__label__"$2" "$3}' < sentences.csv | shuf > processed_sentences.txt 26 | head -n 10000 processed_sentences.txt > valid.txt 27 | tail -n +10001 processed_sentences.txt > train.txt 28 | popd 29 | 30 | ls -lh $DATADIR 31 | echo DONE 32 | -------------------------------------------------------------------------------- /predictor.py: -------------------------------------------------------------------------------- 1 | """Predict classification on provided text. 2 | 3 | Uses a SavedModel produced by classifier.py 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import inputs 12 | import text_utils 13 | from tensorflow.contrib.saved_model.python.saved_model import reader 14 | from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils 15 | from tensorflow.python.saved_model import loader 16 | 17 | tf.flags.DEFINE_string("text", None, "Text to predict label of") 18 | tf.flags.DEFINE_string("ngrams", None, "List of ngram lengths, E.g. --ngrams=2,3,4") 19 | tf.flags.DEFINE_string("signature_def", "proba", 20 | "Stored signature key of method to call (proba|embedding)") 21 | tf.flags.DEFINE_string("saved_model", None, "Directory of SavedModel") 22 | tf.flags.DEFINE_string("tag", "serve", "SavedModel tag, serve|gpu") 23 | tf.flags.DEFINE_boolean("debug", False, "Debug") 24 | FLAGS = tf.flags.FLAGS 25 | 26 | 27 | def RunModel(saved_model_dir, signature_def_key, tag, text, ngrams_list=None): 28 | saved_model = reader.read_saved_model(saved_model_dir) 29 | meta_graph = None 30 | for meta_graph_def in saved_model.meta_graphs: 31 | if tag in meta_graph_def.meta_info_def.tags: 32 | meta_graph = meta_graph_def 33 | break 34 | if meta_graph_def is None: 35 | raise ValueError("Cannot find saved_model with tag" + tag) 36 | signature_def = signature_def_utils.get_signature_def_by_key( 37 | meta_graph, signature_def_key) 38 | text = text_utils.TokenizeText(text) 39 | ngrams = None 40 | if ngrams_list is not None: 41 | ngrams_list = text_utils.ParseNgramsOpts(ngrams_list) 42 | ngrams = text_utils.GenerateNgrams(text, ngrams_list) 43 | example = inputs.BuildTextExample(text, ngrams=ngrams) 44 | example = example.SerializeToString() 45 | inputs_feed_dict = { 46 | signature_def.inputs["inputs"].name: [example], 47 | } 48 | if signature_def_key == "proba": 49 | output_key = "scores" 50 | elif signature_def_key == "embedding": 51 | output_key = "outputs" 52 | else: 53 | raise ValueError("Unrecognised signature_def %s" % (signature_def_key)) 54 | output_tensor = signature_def.outputs[output_key].name 55 | with tf.Session() as sess: 56 | loader.load(sess, [tag], saved_model_dir) 57 | outputs = sess.run(output_tensor, 58 | feed_dict=inputs_feed_dict) 59 | return outputs 60 | 61 | 62 | def main(_): 63 | if not FLAGS.text: 64 | raise ValueError("No --text provided") 65 | outputs = RunModel(FLAGS.saved_model, FLAGS.signature_def, FLAGS.tag, 66 | FLAGS.text, FLAGS.ngrams) 67 | if FLAGS.signature_def == "proba": 68 | print("Proba:", outputs) 69 | print("Class(1-N):", np.argmax(outputs) + 1) 70 | elif FLAGS.signature_def == "embedding": 71 | print(outputs[0]) 72 | 73 | 74 | if __name__ == '__main__': 75 | if FLAGS.debug: 76 | tf.logging.set_verbosity(tf.logging.DEBUG) 77 | tf.app.run() 78 | 79 | -------------------------------------------------------------------------------- /predictor_client.py: -------------------------------------------------------------------------------- 1 | """Predict classification on provided text. 2 | 3 | Send request to a tensorflow_model_server. 4 | 5 | tensorflow_model_server --port=9000 --model_base_path=$export_dir_base 6 | 7 | Usage: 8 | 9 | predictor_client.py --text='some text' --ngrams=1,2,4 10 | 11 | """ 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import tensorflow as tf 17 | import inputs 18 | import text_utils 19 | 20 | from grpc.beta import implementations 21 | from tensorflow_serving.apis import classification_pb2 22 | from tensorflow_serving.apis import prediction_service_pb2 23 | 24 | 25 | tf.flags.DEFINE_string('server', 'localhost:9000', 26 | 'TensorflowService host:port') 27 | tf.flags.DEFINE_string("text", None, "Text to predict label of") 28 | tf.flags.DEFINE_string("ngrams", None, "List of ngram lengths, E.g. --ngrams=2,3,4") 29 | tf.flags.DEFINE_string("signature_def", "proba", 30 | "Stored signature key of method to call (proba|embedding)") 31 | FLAGS = tf.flags.FLAGS 32 | 33 | 34 | def Request(text, ngrams): 35 | text = text_utils.TokenizeText(text) 36 | ngrams = None 37 | if ngrams is not None: 38 | ngrams_list = text_utils.ParseNgramsOpts(ngrams) 39 | ngrams = text_utils.GenerateNgrams(text, ngrams_list) 40 | example = inputs.BuildTextExample(text, ngrams=ngrams) 41 | request = classification_pb2.ClassificationRequest() 42 | request.model_spec.name = 'default' 43 | request.model_spec.signature_name = 'proba' 44 | request.input.example_list.examples.extend([example]) 45 | return request 46 | 47 | 48 | def main(_): 49 | if not FLAGS.text: 50 | raise ValueError("No --text provided") 51 | host, port = FLAGS.server.split(':') 52 | channel = implementations.insecure_channel(host, int(port)) 53 | stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) 54 | request = Request(FLAGS.text, FLAGS.ngrams) 55 | result = stub.Classify(request, 10.0) # 10 secs timeout 56 | print(result) 57 | 58 | 59 | if __name__ == '__main__': 60 | tf.app.run() 61 | 62 | -------------------------------------------------------------------------------- /process_input.py: -------------------------------------------------------------------------------- 1 | """Process input data into tensorflow examples, to ease training. 2 | 3 | Input data is in one of two formats: 4 | - facebook's format used in their fastText library. 5 | - two text files, one with input text per line, the other a label per line. 6 | """ 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os.path 12 | import re 13 | import sys 14 | import tensorflow as tf 15 | import inputs 16 | import text_utils 17 | from collections import Counter 18 | from six.moves import zip 19 | 20 | 21 | tf.flags.DEFINE_string("facebook_input", None, 22 | "Input file in facebook train|test format") 23 | tf.flags.DEFINE_string("text_input", None, 24 | """Input text file containing one text phrase per line. 25 | Must have --labels defined 26 | Used instead of --facebook_input""") 27 | tf.flags.DEFINE_string("labels", None, 28 | """Input text file containing one label for 29 | classification per line. 30 | Must have --text_input defined. 31 | Used instead of --facebook_input""") 32 | tf.flags.DEFINE_string("ngrams", None, 33 | "list of ngram sizes to create, e.g. --ngrams=2,3,4,5") 34 | tf.flags.DEFINE_string("output_dir", ".", 35 | "Directory to store resulting vector models and checkpoints in") 36 | tf.flags.DEFINE_integer("num_shards", 1, 37 | "Number of outputfiles to create") 38 | FLAGS = tf.flags.FLAGS 39 | 40 | 41 | def ParseFacebookInput(inputfile, ngrams): 42 | """Parse input in the format used by facebook FastText. 43 | labels are formatted as __label__1 44 | where the label values start at 0. 45 | """ 46 | examples = [] 47 | for line in open(inputfile): 48 | words = line.split() 49 | # label is first field with __label__ removed 50 | match = re.match(r'__label__(.+)', words[0]) 51 | label = match.group(1) if match else None 52 | # Strip out label and first , 53 | first = 2 if words[1] == "," else 1 54 | words = words[first:] 55 | examples.append({ 56 | "text": words, 57 | "label": label 58 | }) 59 | if ngrams: 60 | examples[-1]["ngrams"] = text_utils.GenerateNgrams(words, ngrams) 61 | return examples 62 | 63 | 64 | def ParseTextInput(textfile, labelsfie, ngrams): 65 | """Parse input from two text files: text and labels. 66 | labels are specified 0-offset one per line. 67 | """ 68 | examples = [] 69 | with open(textfile) as f1, open(labelsfile) as f2: 70 | for text, label in zip(f1, f2): 71 | words = text_utils.TokenizeText(text) 72 | examples.append({ 73 | "text": words, 74 | "label": label, 75 | }) 76 | if ngrams: 77 | examples[-1]["ngrams"] = text_utils.GenerateNgrams(words, ngrams) 78 | return examples 79 | 80 | 81 | def WriteExamples(examples, outputfile, num_shards): 82 | """Write examles in TFRecord format. 83 | Args: 84 | examples: list of feature dicts. 85 | {'text': [words], 'label': [labels]} 86 | outputfile: full pathname of output file 87 | """ 88 | shard = 0 89 | num_per_shard = len(examples) / num_shards + 1 90 | for n, example in enumerate(examples): 91 | if n % num_per_shard == 0: 92 | shard += 1 93 | writer = tf.python_io.TFRecordWriter(outputfile + '-%d-of-%d' % \ 94 | (shard, num_shards)) 95 | record = inputs.BuildTextExample( 96 | example["text"], example.get("ngrams", None), example["label"]) 97 | writer.write(record.SerializeToString()) 98 | 99 | 100 | def WriteVocab(examples, vocabfile, labelfile): 101 | words = Counter() 102 | labels = set() 103 | for example in examples: 104 | words.update(example["text"]) 105 | labels.add(example["label"]) 106 | with open(vocabfile, "w") as f: 107 | # Write out vocab in most common first order 108 | # We need this as NCE loss in TF uses Zipf distribution 109 | for word in words.most_common(): 110 | f.write(word[0] + '\n') 111 | with open(labelfile, "w") as f: 112 | labels = sorted(list(labels)) 113 | for label in labels: 114 | f.write(str(label) + '\n') 115 | 116 | 117 | def main(_): 118 | # Check flags 119 | if not (FLAGS.facebook_input or (FLAGS.text_input and FLAGS.labels)): 120 | print >>sys.stderr, \ 121 | "Error: You must define either facebook_input or both text_input and labels" 122 | sys.exit(1) 123 | ngrams = None 124 | if FLAGS.ngrams: 125 | ngrams = text_utils.ParseNgramsOpts(FLAGS.ngrams) 126 | if FLAGS.facebook_input: 127 | inputfile = FLAGS.facebook_input 128 | examples = ParseFacebookInput(FLAGS.facebook_input, ngrams) 129 | else: 130 | inputfile = FLAGS.text_input 131 | examples = ParseTextInput(FLAGS.text_input, FLAGS.labels, ngrams) 132 | outputfile = os.path.join(FLAGS.output_dir, inputfile + ".tfrecords") 133 | WriteExamples(examples, outputfile, FLAGS.num_shards) 134 | vocabfile = os.path.join(FLAGS.output_dir, inputfile + ".vocab") 135 | labelfile = os.path.join(FLAGS.output_dir, inputfile + ".labels") 136 | WriteVocab(examples, vocabfile, labelfile) 137 | 138 | 139 | if __name__ == '__main__': 140 | tf.app.run() 141 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow 2 | six 3 | nltk 4 | -------------------------------------------------------------------------------- /text_utils.py: -------------------------------------------------------------------------------- 1 | from nltk.tokenize import word_tokenize 2 | 3 | 4 | def TokenizeText(text): 5 | return word_tokenize(text.lower()) 6 | 7 | 8 | def ParseNgramsOpts(opts): 9 | ngrams = [int(g) for g in opts.split(',')] 10 | ngrams = [g for g in ngrams if (g > 1 and g < 7)] 11 | return ngrams 12 | 13 | 14 | def GenerateNgrams(words, ngrams): 15 | nglist = [] 16 | for ng in ngrams: 17 | for word in words: 18 | nglist.extend([word[n:n+ng] for n in range(len(word)-ng+1)]) 19 | return nglist 20 | -------------------------------------------------------------------------------- /train_classifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ ! -d "$1" ]] 4 | then 5 | echo "Usage: train_classifer.sh data_dir [dataset_name=ag_news]" 6 | exit 1 7 | fi 8 | 9 | DATADIR=$1 10 | DATASET=${2:-ag_news} 11 | OUTPUT=$DATADIR/models/${DATASET} 12 | EXPORT_DIR=$DATADIR/models/${DATASET} 13 | INPUT_TRAIN_FILE=$DATADIR/${DATASET}.train 14 | INPUT_TEST_FILE=$DATADIR/${DATASET}.test 15 | TRAIN_FILE="$DATADIR/${DATASET}.train.tfrecords-*" 16 | TEST_FILE="$DATADIR/${DATASET}.test.tfrecords-*" 17 | 18 | echo "Looking for $TRAIN_FILE" 19 | if ls ${TRAIN_FILE} 1> /dev/null 2>&1 20 | then 21 | echo "Found" 22 | else 23 | echo "Not Found $TRAIN_FILE" 24 | echo "Processing training dataset file" 25 | python process_input.py --facebook_input=${INPUT_TRAIN_FILE} --ngrams=2,3,4 26 | if ls ${TRAIN_FILE} 1> /dev/null 2>&1 27 | then 28 | echo "$TRAIN_FILE created" 29 | else 30 | echo "Failed to create $TRAIN_FILE" 31 | exit 1 32 | fi 33 | fi 34 | 35 | echo "Looking for $TEST_FILE" 36 | if ls ${TEST_FILE} 1> /dev/null 2>&1 37 | then 38 | echo "Found" 39 | else 40 | echo "Not Found $TEST_FILE" 41 | echo "Processing test dataset file" 42 | python process_input.py --facebook_input=${INPUT_TEST_FILE} --ngrams=2,3,4 43 | if ls ${TEST_FILE} 1> /dev/null 2>&1 44 | then 45 | echo "$TEST_FILE created" 46 | else 47 | echo "Failed to create $TEST_FILE" 48 | exit 1 49 | fi 50 | fi 51 | 52 | LABELS=$DATADIR/${DATASET}.train.labels 53 | VOCAB=$DATADIR/${DATASET}.train.vocab 54 | VOCAB_SIZE=`cat $VOCAB | wc -l | sed -e "s/[ \t]//g"` 55 | 56 | echo $VOCAB 57 | echo $VOCAB_SIZE 58 | 59 | # Uncomment if you don't have horovod installed. 60 | # python classifier.py \ 61 | # --train_records=$TRAIN_FILE \ 62 | # --eval_records=$TEST_FILE \ 63 | # --label_file=$LABELS \ 64 | # --vocab_file=$VOCAB \ 65 | # --vocab_size=$VOCAB_SIZE \ 66 | # --num_oov_vocab_buckets=100 \ 67 | # --model_dir=$OUTPUT \ 68 | # --export_dir=$EXPORT_DIR \ 69 | # --embedding_dimension=10 \ 70 | # --num_ngram_buckets=100000 \ 71 | # --ngram_embedding_dimension=10 \ 72 | # --learning_rate=0.01 \ 73 | # --batch_size=32 \ 74 | # --train_steps=5000 \ 75 | # --eval_steps=100 \ 76 | # --num_epochs=1 \ 77 | # --num_threads=1 \ 78 | # --nouse_ngrams \ 79 | # --nolog_device_placement \ 80 | # --debug 81 | 82 | mpirun -np 2 python classifier.py \ 83 | --train_records=$TRAIN_FILE \ 84 | --eval_records=$TEST_FILE \ 85 | --label_file=$LABELS \ 86 | --vocab_file=$VOCAB \ 87 | --vocab_size=$VOCAB_SIZE \ 88 | --num_oov_vocab_buckets=100 \ 89 | --model_dir=$OUTPUT \ 90 | --export_dir=$EXPORT_DIR \ 91 | --embedding_dimension=10 \ 92 | --num_ngram_buckets=100000 \ 93 | --ngram_embedding_dimension=10 \ 94 | --learning_rate=0.01 \ 95 | --batch_size=32 \ 96 | --train_steps=5000 \ 97 | --eval_steps=100 \ 98 | --num_epochs=1 \ 99 | --num_threads=1 \ 100 | --nouse_ngrams \ 101 | --nolog_device_placement \ 102 | --horovod \ 103 | --debug 104 | -------------------------------------------------------------------------------- /train_langdetect.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ ! -d $1 ]] 4 | then 5 | echo "Usage: train_langdetect.sh data_dir" 6 | exit 1 7 | fi 8 | 9 | if [[ ! -f $1/train.txt ]] || [[ ! -f $1/valid.txt ]] 10 | then 11 | echo "data_dir must contain train.txt and valid.txt from lang_dataset.sh" 12 | exit 2 13 | fi 14 | 15 | set +v 16 | 17 | DATADIR=$1 18 | OUTPUT=$DATADIR/models/${DATASET} 19 | EXPORT_DIR=$DATADIR/models/${DATASET} 20 | INPUT_TRAIN_FILE=$DATADIR/train.txt 21 | INPUT_TEST_FILE=$DATADIR/valid.txt 22 | TRAIN_FILE=$DATADIR/train.txt.tfrecords-1-of-1 23 | TEST_FILE=$DATADIR/valid.txt.tfrecords-1-of-1 24 | 25 | echo "Looking for $TRAIN_FILE" 26 | if ls ${TRAIN_FILE} 1> /dev/null 2>&1 27 | then 28 | echo "Found" 29 | else 30 | echo "Not Found $TRAIN_FILE" 31 | echo "Processing training dataset file" 32 | python process_input.py --facebook_input=${INPUT_TRAIN_FILE} --ngrams=2,3,4 33 | if ls ${TRAIN_FILE} 1> /dev/null 2>&1 34 | then 35 | echo "$TRAIN_FILE created" 36 | else 37 | echo "Failed to create $TRAIN_FILE" 38 | exit 1 39 | fi 40 | fi 41 | 42 | echo "Looking for $TEST_FILE" 43 | if ls ${TEST_FILE} 1> /dev/null 2>&1 44 | then 45 | echo "Found" 46 | else 47 | echo "Not Found $TEST_FILE" 48 | echo "Processing test dataset file" 49 | python process_input.py --facebook_input=${INPUT_TEST_FILE} --ngrams=2,3,4 50 | if ls ${TEST_FILE} 1> /dev/null 2>&1 51 | then 52 | echo "$TEST_FILE created" 53 | else 54 | echo "Failed to create $TEST_FILE" 55 | exit 1 56 | fi 57 | fi 58 | 59 | LABELS=$DATADIR/train.txt.labels 60 | VOCAB=$DATADIR/train.txt.vocab 61 | VOCAB_SIZE=`cat $VOCAB | wc -l | sed -e "s/[ \t]//g"` 62 | 63 | echo $VOCAB 64 | echo $VOCAB_SIZE 65 | echo $LABELS 66 | 67 | # python classifier.py \ 68 | # --train_records=$TRAIN_FILE \ 69 | # --eval_records=$TEST_FILE \ 70 | # --label_file=$LABELS \ 71 | # --vocab_file=$VOCAB \ 72 | # --vocab_size=$VOCAB_SIZE \ 73 | # --model_dir=$OUTPUT \ 74 | # --export_dir=$EXPORT_DIR \ 75 | # --embedding_dimension=16 \ 76 | # --num_ngram_buckets=100000 \ 77 | # --ngram_embedding_dimension=16 \ 78 | # --learning_rate=0.01 \ 79 | # --batch_size=128 \ 80 | # --train_steps=20000 \ 81 | # --eval_steps=1000 \ 82 | # --num_epochs=1 \ 83 | # --num_threads=1 \ 84 | # --use_ngrams \ 85 | # --nolog_device_placement \ 86 | # --fast \ 87 | # --debug 88 | 89 | mpirun -np 2 python classifier.py \ 90 | --train_records=$TRAIN_FILE \ 91 | --eval_records=$TEST_FILE \ 92 | --label_file=$LABELS \ 93 | --vocab_file=$VOCAB \ 94 | --vocab_size=$VOCAB_SIZE \ 95 | --model_dir=$OUTPUT \ 96 | --export_dir=$EXPORT_DIR \ 97 | --embedding_dimension=16 \ 98 | --num_ngram_buckets=100000 \ 99 | --ngram_embedding_dimension=16 \ 100 | --learning_rate=0.01 \ 101 | --batch_size=128 \ 102 | --train_steps=20000 \ 103 | --eval_steps=1000 \ 104 | --num_epochs=1 \ 105 | --num_threads=1 \ 106 | --use_ngrams \ 107 | --nolog_device_placement \ 108 | --fast \ 109 | --horovod \ 110 | --debug 111 | --------------------------------------------------------------------------------