├── .github └── ISSUE_TEMPLATE.md ├── .gitignore ├── .gitmodules ├── AUTHORS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── autoencoder ├── AdditiveGaussianNoiseAutoencoderRunner.py ├── AutoencoderRunner.py ├── MaskingNoiseAutoencoderRunner.py ├── Utils.py ├── VariationalAutoencoderRunner.py ├── __init__.py └── autoencoder_models │ ├── Autoencoder.py │ ├── DenoisingAutoencoder.py │ ├── VariationalAutoencoder.py │ └── __init__.py ├── compression ├── README.md ├── decoder.py ├── encoder.py ├── example.png └── msssim.py ├── im2txt ├── .gitignore ├── README.md ├── WORKSPACE ├── g3doc │ ├── COCO_val2014_000000224477.jpg │ ├── example_captions.jpg │ └── show_and_tell_architecture.png └── im2txt │ ├── BUILD │ ├── configuration.py │ ├── data │ ├── build_mscoco_data.py │ └── download_and_preprocess_mscoco.sh │ ├── evaluate.py │ ├── inference_utils │ ├── BUILD │ ├── caption_generator.py │ ├── caption_generator_test.py │ ├── inference_wrapper_base.py │ └── vocabulary.py │ ├── inference_wrapper.py │ ├── ops │ ├── BUILD │ ├── image_embedding.py │ ├── image_embedding_test.py │ ├── image_processing.py │ └── inputs.py │ ├── run_inference.py │ ├── show_and_tell_model.py │ ├── show_and_tell_model_test.py │ └── train.py ├── inception ├── .gitignore ├── README.md ├── WORKSPACE ├── g3doc │ └── inception_v3_architecture.png └── inception │ ├── BUILD │ ├── data │ ├── build_image_data.py │ ├── build_imagenet_data.py │ ├── download_and_preprocess_flowers.sh │ ├── download_and_preprocess_flowers_mac.sh │ ├── download_and_preprocess_imagenet.sh │ ├── download_imagenet.sh │ ├── imagenet_2012_validation_synset_labels.txt │ ├── imagenet_lsvrc_2015_synsets.txt │ ├── imagenet_metadata.txt │ ├── preprocess_imagenet_validation_data.py │ └── process_bounding_boxes.py │ ├── dataset.py │ ├── flowers_data.py │ ├── flowers_eval.py │ ├── flowers_train.py │ ├── image_processing.py │ ├── imagenet_data.py │ ├── imagenet_distributed_train.py │ ├── imagenet_eval.py │ ├── imagenet_train.py │ ├── inception_distributed_train.py │ ├── inception_eval.py │ ├── inception_model.py │ ├── inception_train.py │ └── slim │ ├── BUILD │ ├── README.md │ ├── collections_test.py │ ├── inception_model.py │ ├── inception_test.py │ ├── losses.py │ ├── losses_test.py │ ├── ops.py │ ├── ops_test.py │ ├── scopes.py │ ├── scopes_test.py │ ├── slim.py │ ├── variables.py │ └── variables_test.py ├── lm_1b ├── BUILD ├── README.md ├── data_utils.py └── lm_1b_eval.py ├── namignizer ├── .gitignore ├── README.md ├── data_utils.py ├── model.py └── names.py ├── neural_gpu ├── README.md ├── data_utils.py ├── neural_gpu.py └── neural_gpu_trainer.py ├── privacy ├── README.md ├── aggregation.py ├── deep_cnn.py ├── input.py ├── metrics.py ├── train_student.py ├── train_teachers.py └── utils.py ├── resnet ├── BUILD ├── README.md ├── cifar_input.py ├── g3doc │ ├── cifar_resnet.gif │ └── cifar_resnet_legends.gif ├── resnet_main.py └── resnet_model.py ├── slim ├── BUILD ├── README.md ├── datasets │ ├── __init__.py │ ├── cifar10.py │ ├── dataset_factory.py │ ├── dataset_utils.py │ ├── download_and_convert_cifar10.py │ ├── download_and_convert_flowers.py │ ├── download_and_convert_mnist.py │ ├── flowers.py │ ├── imagenet.py │ └── mnist.py ├── deployment │ ├── __init__.py │ ├── model_deploy.py │ └── model_deploy_test.py ├── download_and_convert_data.py ├── eval_image_classifier.py ├── nets │ ├── __init__.py │ ├── alexnet.py │ ├── alexnet_test.py │ ├── cifarnet.py │ ├── inception.py │ ├── inception_resnet_v2.py │ ├── inception_resnet_v2_test.py │ ├── inception_v1.py │ ├── inception_v1_test.py │ ├── inception_v2.py │ ├── inception_v2_test.py │ ├── inception_v3.py │ ├── inception_v3_test.py │ ├── lenet.py │ ├── nets_factory.py │ ├── nets_factory_test.py │ ├── overfeat.py │ ├── overfeat_test.py │ ├── resnet_utils.py │ ├── resnet_v1.py │ ├── resnet_v1_test.py │ ├── resnet_v2.py │ ├── resnet_v2_test.py │ ├── vgg.py │ └── vgg_test.py ├── preprocessing │ ├── __init__.py │ ├── cifarnet_preprocessing.py │ ├── inception_preprocessing.py │ ├── lenet_preprocessing.py │ ├── preprocessing_factory.py │ └── vgg_preprocessing.py ├── scripts │ ├── finetune_inception_v1_on_flowers.sh │ ├── finetune_inception_v3_on_flowers.sh │ ├── train_cifarnet_on_cifar10.sh │ └── train_lenet_on_mnist.sh ├── slim_walkthough.ipynb └── train_image_classifier.py ├── swivel ├── .gitignore ├── README.md ├── analogy.cc ├── eval.mk ├── fastprep.cc ├── fastprep.mk ├── glove_to_shards.py ├── nearest.py ├── prep.py ├── swivel.py ├── text2bin.py ├── vecs.py └── wordsim.py ├── syntaxnet ├── .gitignore ├── Dockerfile ├── README.md ├── WORKSPACE ├── beam_search_training.png ├── ff_nn_schematic.png ├── looping-parser.gif ├── sawman.png ├── syntaxnet │ ├── BUILD │ ├── affix.cc │ ├── affix.h │ ├── arc_standard_transitions.cc │ ├── arc_standard_transitions_test.cc │ ├── base.h │ ├── beam_reader_ops.cc │ ├── beam_reader_ops_test.py │ ├── binary_segment_state.cc │ ├── binary_segment_state.h │ ├── binary_segment_state_test.cc │ ├── binary_segment_transitions.cc │ ├── binary_segment_transitions_test.cc │ ├── char_properties.cc │ ├── char_properties.h │ ├── char_properties_test.cc │ ├── conll2tree.py │ ├── context.pbtxt │ ├── demo.sh │ ├── dictionary.proto │ ├── document_filters.cc │ ├── document_format.cc │ ├── document_format.h │ ├── embedding_feature_extractor.cc │ ├── embedding_feature_extractor.h │ ├── feature_extractor.cc │ ├── feature_extractor.h │ ├── feature_extractor.proto │ ├── feature_types.h │ ├── fml_parser.cc │ ├── fml_parser.h │ ├── graph_builder.py │ ├── graph_builder_test.py │ ├── kbest_syntax.proto │ ├── lexicon_builder.cc │ ├── lexicon_builder_test.py │ ├── load_parser_ops.py │ ├── models │ │ ├── parsey_mcparseface │ │ │ ├── context.pbtxt │ │ │ ├── fine-to-universal.map │ │ │ ├── label-map │ │ │ ├── parser-params │ │ │ ├── prefix-table │ │ │ ├── suffix-table │ │ │ ├── tag-map │ │ │ ├── tagger-params │ │ │ └── word-map │ │ └── parsey_universal │ │ │ ├── context-tokenize-zh.pbtxt │ │ │ ├── context.pbtxt │ │ │ ├── parse.sh │ │ │ ├── tokenize.sh │ │ │ └── tokenize_zh.sh │ ├── morpher_transitions.cc │ ├── morphology_label_set.cc │ ├── morphology_label_set.h │ ├── morphology_label_set_test.cc │ ├── ops │ │ └── parser_ops.cc │ ├── parser_eval.py │ ├── parser_features.cc │ ├── parser_features.h │ ├── parser_features_test.cc │ ├── parser_state.cc │ ├── parser_state.h │ ├── parser_trainer.py │ ├── parser_trainer_test.sh │ ├── parser_transitions.cc │ ├── parser_transitions.h │ ├── populate_test_inputs.cc │ ├── populate_test_inputs.h │ ├── proto_io.h │ ├── reader_ops.cc │ ├── reader_ops_test.py │ ├── registry.cc │ ├── registry.h │ ├── segmenter_utils.cc │ ├── segmenter_utils.h │ ├── segmenter_utils_test.cc │ ├── sentence.proto │ ├── sentence_batch.cc │ ├── sentence_batch.h │ ├── sentence_features.cc │ ├── sentence_features.h │ ├── sentence_features_test.cc │ ├── shared_store.cc │ ├── shared_store.h │ ├── shared_store_test.cc │ ├── sparse.proto │ ├── structured_graph_builder.py │ ├── syntaxnet.bzl │ ├── tagger_transitions.cc │ ├── tagger_transitions_test.cc │ ├── task_context.cc │ ├── task_context.h │ ├── task_spec.proto │ ├── term_frequency_map.cc │ ├── term_frequency_map.h │ ├── test_main.cc │ ├── testdata │ │ ├── context.pbtxt │ │ ├── document │ │ └── mini-training-set │ ├── text_formats.cc │ ├── text_formats_test.py │ ├── unpack_sparse_features.cc │ ├── utils.cc │ ├── utils.h │ ├── workspace.cc │ └── workspace.h ├── third_party │ └── utf │ │ ├── BUILD │ │ ├── README │ │ ├── rune.c │ │ ├── runestrcat.c │ │ ├── runestrchr.c │ │ ├── runestrcmp.c │ │ ├── runestrcpy.c │ │ ├── runestrdup.c │ │ ├── runestrecpy.c │ │ ├── runestrlen.c │ │ ├── runestrncat.c │ │ ├── runestrncmp.c │ │ ├── runestrncpy.c │ │ ├── runestrrchr.c │ │ ├── runestrstr.c │ │ ├── runetype.c │ │ ├── runetypebody.c │ │ ├── utf.h │ │ ├── utfdef.h │ │ ├── utfecpy.c │ │ ├── utflen.c │ │ ├── utfnlen.c │ │ ├── utfrrune.c │ │ ├── utfrune.c │ │ └── utfutf.c ├── tools │ └── bazel.rc ├── universal.md └── util │ └── utf8 │ ├── BUILD │ ├── gtest_main.cc │ ├── unicodetext.cc │ ├── unicodetext.h │ ├── unicodetext_main.cc │ ├── unicodetext_unittest.cc │ ├── unilib.cc │ ├── unilib.h │ └── unilib_utf8_utils.h ├── textsum ├── BUILD ├── README.md ├── batch_reader.py ├── beam_search.py ├── data.py ├── data │ ├── data │ └── vocab ├── data_convert_example.py ├── seq2seq_attention.py ├── seq2seq_attention_decode.py ├── seq2seq_attention_model.py └── seq2seq_lib.py ├── transformer ├── README.md ├── cluttered_mnist.py ├── data │ └── README.md ├── example.py ├── spatial_transformer.py └── tf_utils.py └── video_prediction ├── README.md ├── download_data.sh ├── lstm_ops.py ├── prediction_input.py ├── prediction_model.py ├── prediction_train.py └── push_datafiles.txt /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Please let us know which model this issue is about (specify the top-level directory) 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tensorflow"] 2 | path = syntaxnet/tensorflow 3 | url = https://github.com/tensorflow/tensorflow.git 4 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the official list of authors for copyright purposes. 2 | # This file is distinct from the CONTRIBUTORS files. 3 | # See the latter for an explanation. 4 | 5 | # Names should be added to this file as: 6 | # Name or Organization 7 | # The email address is not required for organizations. 8 | 9 | Google Inc. 10 | David Dao 11 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | If you have created a model and would like to publish it here, please send us a 4 | pull request. For those just getting started with pull reuests, GitHub has a 5 | [howto](https://help.github.com/articles/using-pull-requests/). 6 | 7 | The code for any model in this repository is licensed under the Apache License 8 | 2.0. 9 | 10 | In order to accept our code, we have to make sure that we can publish your code: 11 | You have to sign a Contributor License Agreement (CLA). 12 | 13 | ### Contributor License Agreements 14 | 15 | Please fill out either the individual or corporate Contributor License Agreement (CLA). 16 | 17 | * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html). 18 | * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html). 19 | 20 | Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests. 21 | 22 | ***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the repository. 23 | 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Models 2 | 3 | This repository contains machine learning models implemented in 4 | [TensorFlow](https://tensorflow.org). The models are maintained by their 5 | respective authors. 6 | 7 | To propose a model for inclusion please submit a pull request. 8 | 9 | 10 | ## Models 11 | - [autoencoder](autoencoder) -- various autoencoders 12 | - [inception](inception) -- deep convolutional networks for computer vision 13 | - [namignizer](namignizer) -- recognize and generate names 14 | - [neural_gpu](neural_gpu) -- highly parallel neural computer 15 | - [privacy](privacy) -- privacy-preserving student models from multiple teachers 16 | - [resnet](resnet) -- deep and wide residual networks 17 | - [slim](slim) -- image classification models in TF-Slim 18 | - [swivel](swivel) -- the Swivel algorithm for generating word embeddings 19 | - [syntaxnet](syntaxnet) -- neural models of natural language syntax 20 | - [textsum](textsum) -- sequence-to-sequence with attention model for text summarization. 21 | - [transformer](transformer) -- spatial transformer network, which allows the spatial manipulation of data within the network 22 | - [im2txt](im2txt) -- image-to-text neural network for image captioning. 23 | -------------------------------------------------------------------------------- /autoencoder/AdditiveGaussianNoiseAutoencoderRunner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import sklearn.preprocessing as prep 4 | import tensorflow as tf 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | from autoencoder.autoencoder_models.DenoisingAutoencoder import AdditiveGaussianNoiseAutoencoder 8 | 9 | mnist = input_data.read_data_sets('MNIST_data', one_hot = True) 10 | 11 | def standard_scale(X_train, X_test): 12 | preprocessor = prep.StandardScaler().fit(X_train) 13 | X_train = preprocessor.transform(X_train) 14 | X_test = preprocessor.transform(X_test) 15 | return X_train, X_test 16 | 17 | def get_random_block_from_data(data, batch_size): 18 | start_index = np.random.randint(0, len(data) - batch_size) 19 | return data[start_index:(start_index + batch_size)] 20 | 21 | X_train, X_test = standard_scale(mnist.train.images, mnist.test.images) 22 | 23 | n_samples = int(mnist.train.num_examples) 24 | training_epochs = 20 25 | batch_size = 128 26 | display_step = 1 27 | 28 | autoencoder = AdditiveGaussianNoiseAutoencoder(n_input = 784, 29 | n_hidden = 200, 30 | transfer_function = tf.nn.softplus, 31 | optimizer = tf.train.AdamOptimizer(learning_rate = 0.001), 32 | scale = 0.01) 33 | 34 | for epoch in range(training_epochs): 35 | avg_cost = 0. 36 | total_batch = int(n_samples / batch_size) 37 | # Loop over all batches 38 | for i in range(total_batch): 39 | batch_xs = get_random_block_from_data(X_train, batch_size) 40 | 41 | # Fit training using batch data 42 | cost = autoencoder.partial_fit(batch_xs) 43 | # Compute average loss 44 | avg_cost += cost / n_samples * batch_size 45 | 46 | # Display logs per epoch step 47 | if epoch % display_step == 0: 48 | print "Epoch:", '%04d' % (epoch + 1), \ 49 | "cost=", "{:.9f}".format(avg_cost) 50 | 51 | print "Total cost: " + str(autoencoder.calc_total_cost(X_test)) 52 | -------------------------------------------------------------------------------- /autoencoder/AutoencoderRunner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import sklearn.preprocessing as prep 4 | import tensorflow as tf 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | from autoencoder.autoencoder_models.Autoencoder import Autoencoder 8 | 9 | mnist = input_data.read_data_sets('MNIST_data', one_hot = True) 10 | 11 | def standard_scale(X_train, X_test): 12 | preprocessor = prep.StandardScaler().fit(X_train) 13 | X_train = preprocessor.transform(X_train) 14 | X_test = preprocessor.transform(X_test) 15 | return X_train, X_test 16 | 17 | def get_random_block_from_data(data, batch_size): 18 | start_index = np.random.randint(0, len(data) - batch_size) 19 | return data[start_index:(start_index + batch_size)] 20 | 21 | X_train, X_test = standard_scale(mnist.train.images, mnist.test.images) 22 | 23 | n_samples = int(mnist.train.num_examples) 24 | training_epochs = 20 25 | batch_size = 128 26 | display_step = 1 27 | 28 | autoencoder = Autoencoder(n_input = 784, 29 | n_hidden = 200, 30 | transfer_function = tf.nn.softplus, 31 | optimizer = tf.train.AdamOptimizer(learning_rate = 0.001)) 32 | 33 | for epoch in range(training_epochs): 34 | avg_cost = 0. 35 | total_batch = int(n_samples / batch_size) 36 | # Loop over all batches 37 | for i in range(total_batch): 38 | batch_xs = get_random_block_from_data(X_train, batch_size) 39 | 40 | # Fit training using batch data 41 | cost = autoencoder.partial_fit(batch_xs) 42 | # Compute average loss 43 | avg_cost += cost / n_samples * batch_size 44 | 45 | # Display logs per epoch step 46 | if epoch % display_step == 0: 47 | print "Epoch:", '%04d' % (epoch + 1), \ 48 | "cost=", "{:.9f}".format(avg_cost) 49 | 50 | print "Total cost: " + str(autoencoder.calc_total_cost(X_test)) 51 | -------------------------------------------------------------------------------- /autoencoder/MaskingNoiseAutoencoderRunner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import sklearn.preprocessing as prep 4 | import tensorflow as tf 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | from autoencoder.autoencoder_models.DenoisingAutoencoder import MaskingNoiseAutoencoder 8 | 9 | mnist = input_data.read_data_sets('MNIST_data', one_hot = True) 10 | 11 | def standard_scale(X_train, X_test): 12 | preprocessor = prep.StandardScaler().fit(X_train) 13 | X_train = preprocessor.transform(X_train) 14 | X_test = preprocessor.transform(X_test) 15 | return X_train, X_test 16 | 17 | def get_random_block_from_data(data, batch_size): 18 | start_index = np.random.randint(0, len(data) - batch_size) 19 | return data[start_index:(start_index + batch_size)] 20 | 21 | X_train, X_test = standard_scale(mnist.train.images, mnist.test.images) 22 | 23 | 24 | n_samples = int(mnist.train.num_examples) 25 | training_epochs = 100 26 | batch_size = 128 27 | display_step = 1 28 | 29 | autoencoder = MaskingNoiseAutoencoder(n_input = 784, 30 | n_hidden = 200, 31 | transfer_function = tf.nn.softplus, 32 | optimizer = tf.train.AdamOptimizer(learning_rate = 0.001), 33 | dropout_probability = 0.95) 34 | 35 | for epoch in range(training_epochs): 36 | avg_cost = 0. 37 | total_batch = int(n_samples / batch_size) 38 | for i in range(total_batch): 39 | batch_xs = get_random_block_from_data(X_train, batch_size) 40 | 41 | cost = autoencoder.partial_fit(batch_xs) 42 | 43 | avg_cost += cost / n_samples * batch_size 44 | 45 | if epoch % display_step == 0: 46 | print "Epoch:", '%04d' % (epoch + 1), \ 47 | "cost=", "{:.9f}".format(avg_cost) 48 | 49 | print "Total cost: " + str(autoencoder.calc_total_cost(X_test)) 50 | -------------------------------------------------------------------------------- /autoencoder/Utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | def xavier_init(fan_in, fan_out, constant = 1): 5 | low = -constant * np.sqrt(6.0 / (fan_in + fan_out)) 6 | high = constant * np.sqrt(6.0 / (fan_in + fan_out)) 7 | return tf.random_uniform((fan_in, fan_out), 8 | minval = low, maxval = high, 9 | dtype = tf.float32) 10 | -------------------------------------------------------------------------------- /autoencoder/VariationalAutoencoderRunner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import sklearn.preprocessing as prep 4 | import tensorflow as tf 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | from autoencoder.autoencoder_models.VariationalAutoencoder import VariationalAutoencoder 8 | 9 | mnist = input_data.read_data_sets('MNIST_data', one_hot = True) 10 | 11 | 12 | 13 | def min_max_scale(X_train, X_test): 14 | preprocessor = prep.MinMaxScaler().fit(X_train) 15 | X_train = preprocessor.transform(X_train) 16 | X_test = preprocessor.transform(X_test) 17 | return X_train, X_test 18 | 19 | 20 | def get_random_block_from_data(data, batch_size): 21 | start_index = np.random.randint(0, len(data) - batch_size) 22 | return data[start_index:(start_index + batch_size)] 23 | 24 | 25 | X_train, X_test = min_max_scale(mnist.train.images, mnist.test.images) 26 | 27 | n_samples = int(mnist.train.num_examples) 28 | training_epochs = 20 29 | batch_size = 128 30 | display_step = 1 31 | 32 | autoencoder = VariationalAutoencoder(n_input = 784, 33 | n_hidden = 200, 34 | optimizer = tf.train.AdamOptimizer(learning_rate = 0.001)) 35 | 36 | for epoch in range(training_epochs): 37 | avg_cost = 0. 38 | total_batch = int(n_samples / batch_size) 39 | # Loop over all batches 40 | for i in range(total_batch): 41 | batch_xs = get_random_block_from_data(X_train, batch_size) 42 | 43 | # Fit training using batch data 44 | cost = autoencoder.partial_fit(batch_xs) 45 | # Compute average loss 46 | avg_cost += cost / n_samples * batch_size 47 | 48 | # Display logs per epoch step 49 | if epoch % display_step == 0: 50 | print "Epoch:", '%04d' % (epoch + 1), \ 51 | "cost=", "{:.9f}".format(avg_cost) 52 | 53 | print "Total cost: " + str(autoencoder.calc_total_cost(X_test)) 54 | -------------------------------------------------------------------------------- /autoencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/autoencoder/__init__.py -------------------------------------------------------------------------------- /autoencoder/autoencoder_models/Autoencoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import autoencoder.Utils 4 | 5 | class Autoencoder(object): 6 | 7 | def __init__(self, n_input, n_hidden, transfer_function=tf.nn.softplus, optimizer = tf.train.AdamOptimizer()): 8 | self.n_input = n_input 9 | self.n_hidden = n_hidden 10 | self.transfer = transfer_function 11 | 12 | network_weights = self._initialize_weights() 13 | self.weights = network_weights 14 | 15 | # model 16 | self.x = tf.placeholder(tf.float32, [None, self.n_input]) 17 | self.hidden = self.transfer(tf.add(tf.matmul(self.x, self.weights['w1']), self.weights['b1'])) 18 | self.reconstruction = tf.add(tf.matmul(self.hidden, self.weights['w2']), self.weights['b2']) 19 | 20 | # cost 21 | self.cost = 0.5 * tf.reduce_sum(tf.pow(tf.sub(self.reconstruction, self.x), 2.0)) 22 | self.optimizer = optimizer.minimize(self.cost) 23 | 24 | init = tf.initialize_all_variables() 25 | self.sess = tf.Session() 26 | self.sess.run(init) 27 | 28 | 29 | def _initialize_weights(self): 30 | all_weights = dict() 31 | all_weights['w1'] = tf.Variable(autoencoder.Utils.xavier_init(self.n_input, self.n_hidden)) 32 | all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32)) 33 | all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype=tf.float32)) 34 | all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype=tf.float32)) 35 | return all_weights 36 | 37 | def partial_fit(self, X): 38 | cost, opt = self.sess.run((self.cost, self.optimizer), feed_dict={self.x: X}) 39 | return cost 40 | 41 | def calc_total_cost(self, X): 42 | return self.sess.run(self.cost, feed_dict = {self.x: X}) 43 | 44 | def transform(self, X): 45 | return self.sess.run(self.hidden, feed_dict={self.x: X}) 46 | 47 | def generate(self, hidden = None): 48 | if hidden is None: 49 | hidden = np.random.normal(size=self.weights["b1"]) 50 | return self.sess.run(self.reconstruction, feed_dict={self.hidden: hidden}) 51 | 52 | def reconstruct(self, X): 53 | return self.sess.run(self.reconstruction, feed_dict={self.x: X}) 54 | 55 | def getWeights(self): 56 | return self.sess.run(self.weights['w1']) 57 | 58 | def getBiases(self): 59 | return self.sess.run(self.weights['b1']) 60 | 61 | -------------------------------------------------------------------------------- /autoencoder/autoencoder_models/VariationalAutoencoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import autoencoder.Utils 4 | 5 | class VariationalAutoencoder(object): 6 | 7 | def __init__(self, n_input, n_hidden, optimizer = tf.train.AdamOptimizer()): 8 | self.n_input = n_input 9 | self.n_hidden = n_hidden 10 | 11 | network_weights = self._initialize_weights() 12 | self.weights = network_weights 13 | 14 | # model 15 | self.x = tf.placeholder(tf.float32, [None, self.n_input]) 16 | self.z_mean = tf.add(tf.matmul(self.x, self.weights['w1']), self.weights['b1']) 17 | self.z_log_sigma_sq = tf.add(tf.matmul(self.x, self.weights['log_sigma_w1']), self.weights['log_sigma_b1']) 18 | 19 | # sample from gaussian distribution 20 | eps = tf.random_normal(tf.pack([tf.shape(self.x)[0], self.n_hidden]), 0, 1, dtype = tf.float32) 21 | self.z = tf.add(self.z_mean, tf.mul(tf.sqrt(tf.exp(self.z_log_sigma_sq)), eps)) 22 | 23 | self.reconstruction = tf.add(tf.matmul(self.z, self.weights['w2']), self.weights['b2']) 24 | 25 | # cost 26 | reconstr_loss = 0.5 * tf.reduce_sum(tf.pow(tf.sub(self.reconstruction, self.x), 2.0)) 27 | latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq 28 | - tf.square(self.z_mean) 29 | - tf.exp(self.z_log_sigma_sq), 1) 30 | self.cost = tf.reduce_mean(reconstr_loss + latent_loss) 31 | self.optimizer = optimizer.minimize(self.cost) 32 | 33 | init = tf.initialize_all_variables() 34 | self.sess = tf.Session() 35 | self.sess.run(init) 36 | 37 | def _initialize_weights(self): 38 | all_weights = dict() 39 | all_weights['w1'] = tf.Variable(autoencoder.Utils.xavier_init(self.n_input, self.n_hidden)) 40 | all_weights['log_sigma_w1'] = tf.Variable(autoencoder.Utils.xavier_init(self.n_input, self.n_hidden)) 41 | all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32)) 42 | all_weights['log_sigma_b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32)) 43 | all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype=tf.float32)) 44 | all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype=tf.float32)) 45 | return all_weights 46 | 47 | def partial_fit(self, X): 48 | cost, opt = self.sess.run((self.cost, self.optimizer), feed_dict={self.x: X}) 49 | return cost 50 | 51 | def calc_total_cost(self, X): 52 | return self.sess.run(self.cost, feed_dict = {self.x: X}) 53 | 54 | def transform(self, X): 55 | return self.sess.run(self.z_mean, feed_dict={self.x: X}) 56 | 57 | def generate(self, hidden = None): 58 | if hidden is None: 59 | hidden = np.random.normal(size=self.weights["b1"]) 60 | return self.sess.run(self.reconstruction, feed_dict={self.z_mean: hidden}) 61 | 62 | def reconstruct(self, X): 63 | return self.sess.run(self.reconstruction, feed_dict={self.x: X}) 64 | 65 | def getWeights(self): 66 | return self.sess.run(self.weights['w1']) 67 | 68 | def getBiases(self): 69 | return self.sess.run(self.weights['b1']) 70 | 71 | -------------------------------------------------------------------------------- /autoencoder/autoencoder_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/autoencoder/autoencoder_models/__init__.py -------------------------------------------------------------------------------- /compression/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/compression/example.png -------------------------------------------------------------------------------- /im2txt/.gitignore: -------------------------------------------------------------------------------- 1 | /bazel-bin 2 | /bazel-ci_build-cache 3 | /bazel-genfiles 4 | /bazel-out 5 | /bazel-im2txt 6 | /bazel-testlogs 7 | /bazel-tf 8 | -------------------------------------------------------------------------------- /im2txt/WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "im2txt") 2 | -------------------------------------------------------------------------------- /im2txt/g3doc/COCO_val2014_000000224477.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/im2txt/g3doc/COCO_val2014_000000224477.jpg -------------------------------------------------------------------------------- /im2txt/g3doc/example_captions.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/im2txt/g3doc/example_captions.jpg -------------------------------------------------------------------------------- /im2txt/g3doc/show_and_tell_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/im2txt/g3doc/show_and_tell_architecture.png -------------------------------------------------------------------------------- /im2txt/im2txt/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = [":internal"]) 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | exports_files(["LICENSE"]) 6 | 7 | package_group( 8 | name = "internal", 9 | packages = [ 10 | "//im2txt/...", 11 | ], 12 | ) 13 | 14 | py_binary( 15 | name = "build_mscoco_data", 16 | srcs = [ 17 | "data/build_mscoco_data.py", 18 | ], 19 | ) 20 | 21 | sh_binary( 22 | name = "download_and_preprocess_mscoco", 23 | srcs = ["data/download_and_preprocess_mscoco.sh"], 24 | data = [ 25 | ":build_mscoco_data", 26 | ], 27 | ) 28 | 29 | py_library( 30 | name = "configuration", 31 | srcs = ["configuration.py"], 32 | srcs_version = "PY2AND3", 33 | ) 34 | 35 | py_library( 36 | name = "show_and_tell_model", 37 | srcs = ["show_and_tell_model.py"], 38 | srcs_version = "PY2AND3", 39 | deps = [ 40 | "//im2txt/ops:image_embedding", 41 | "//im2txt/ops:image_processing", 42 | "//im2txt/ops:inputs", 43 | ], 44 | ) 45 | 46 | py_test( 47 | name = "show_and_tell_model_test", 48 | size = "large", 49 | srcs = ["show_and_tell_model_test.py"], 50 | deps = [ 51 | ":configuration", 52 | ":show_and_tell_model", 53 | ], 54 | ) 55 | 56 | py_library( 57 | name = "inference_wrapper", 58 | srcs = ["inference_wrapper.py"], 59 | srcs_version = "PY2AND3", 60 | deps = [ 61 | ":show_and_tell_model", 62 | "//im2txt/inference_utils:inference_wrapper_base", 63 | ], 64 | ) 65 | 66 | py_binary( 67 | name = "train", 68 | srcs = ["train.py"], 69 | srcs_version = "PY2AND3", 70 | deps = [ 71 | ":configuration", 72 | ":show_and_tell_model", 73 | ], 74 | ) 75 | 76 | py_binary( 77 | name = "evaluate", 78 | srcs = ["evaluate.py"], 79 | srcs_version = "PY2AND3", 80 | deps = [ 81 | ":configuration", 82 | ":show_and_tell_model", 83 | ], 84 | ) 85 | 86 | py_binary( 87 | name = "run_inference", 88 | srcs = ["run_inference.py"], 89 | srcs_version = "PY2AND3", 90 | deps = [ 91 | ":configuration", 92 | ":inference_wrapper", 93 | "//im2txt/inference_utils:caption_generator", 94 | "//im2txt/inference_utils:vocabulary", 95 | ], 96 | ) 97 | -------------------------------------------------------------------------------- /im2txt/im2txt/data/download_and_preprocess_mscoco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Script to download and preprocess the MSCOCO data set. 18 | # 19 | # The outputs of this script are sharded TFRecord files containing serialized 20 | # SequenceExample protocol buffers. See build_mscoco_data.py for details of how 21 | # the SequenceExample protocol buffers are constructed. 22 | # 23 | # usage: 24 | # ./download_and_preprocess_mscoco.sh 25 | set -e 26 | 27 | if [ -z "$1" ]; then 28 | echo "usage download_and_preproces_mscoco.sh [data dir]" 29 | exit 30 | fi 31 | 32 | if [ "$(uname)" == "Darwin" ]; then 33 | UNZIP="tar -xf" 34 | else 35 | UNZIP="unzip -nq" 36 | fi 37 | 38 | # Create the output directories. 39 | OUTPUT_DIR="${1%/}" 40 | SCRATCH_DIR="${OUTPUT_DIR}/raw-data" 41 | mkdir -p "${OUTPUT_DIR}" 42 | mkdir -p "${SCRATCH_DIR}" 43 | CURRENT_DIR=$(pwd) 44 | WORK_DIR="$0.runfiles/im2txt/im2txt" 45 | 46 | # Helper function to download and unpack a .zip file. 47 | function download_and_unzip() { 48 | local BASE_URL=${1} 49 | local FILENAME=${2} 50 | 51 | if [ ! -f ${FILENAME} ]; then 52 | echo "Downloading ${FILENAME} to $(pwd)" 53 | wget -nd -c "${BASE_URL}/${FILENAME}" 54 | else 55 | echo "Skipping download of ${FILENAME}" 56 | fi 57 | echo "Unzipping ${FILENAME}" 58 | ${UNZIP} ${FILENAME} 59 | } 60 | 61 | cd ${SCRATCH_DIR} 62 | 63 | # Download the images. 64 | BASE_IMAGE_URL="http://msvocds.blob.core.windows.net/coco2014" 65 | 66 | TRAIN_IMAGE_FILE="train2014.zip" 67 | download_and_unzip ${BASE_IMAGE_URL} ${TRAIN_IMAGE_FILE} 68 | TRAIN_IMAGE_DIR="${SCRATCH_DIR}/train2014" 69 | 70 | VAL_IMAGE_FILE="val2014.zip" 71 | download_and_unzip ${BASE_IMAGE_URL} ${VAL_IMAGE_FILE} 72 | VAL_IMAGE_DIR="${SCRATCH_DIR}/val2014" 73 | 74 | # Download the captions. 75 | BASE_CAPTIONS_URL="http://msvocds.blob.core.windows.net/annotations-1-0-3" 76 | CAPTIONS_FILE="captions_train-val2014.zip" 77 | download_and_unzip ${BASE_CAPTIONS_URL} ${CAPTIONS_FILE} 78 | TRAIN_CAPTIONS_FILE="${SCRATCH_DIR}/annotations/captions_train2014.json" 79 | VAL_CAPTIONS_FILE="${SCRATCH_DIR}/annotations/captions_val2014.json" 80 | 81 | # Build TFRecords of the image data. 82 | cd "${CURRENT_DIR}" 83 | BUILD_SCRIPT="${WORK_DIR}/build_mscoco_data" 84 | "${BUILD_SCRIPT}" \ 85 | --train_image_dir="${TRAIN_IMAGE_DIR}" \ 86 | --val_image_dir="${VAL_IMAGE_DIR}" \ 87 | --train_captions_file="${TRAIN_CAPTIONS_FILE}" \ 88 | --val_captions_file="${VAL_CAPTIONS_FILE}" \ 89 | --output_dir="${OUTPUT_DIR}" \ 90 | --word_counts_output_file="${OUTPUT_DIR}/word_counts.txt" \ 91 | -------------------------------------------------------------------------------- /im2txt/im2txt/inference_utils/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//im2txt:internal"]) 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | exports_files(["LICENSE"]) 6 | 7 | py_library( 8 | name = "inference_wrapper_base", 9 | srcs = ["inference_wrapper_base.py"], 10 | srcs_version = "PY2AND3", 11 | ) 12 | 13 | py_library( 14 | name = "vocabulary", 15 | srcs = ["vocabulary.py"], 16 | srcs_version = "PY2AND3", 17 | ) 18 | 19 | py_library( 20 | name = "caption_generator", 21 | srcs = ["caption_generator.py"], 22 | srcs_version = "PY2AND3", 23 | ) 24 | 25 | py_test( 26 | name = "caption_generator_test", 27 | srcs = ["caption_generator_test.py"], 28 | deps = [ 29 | ":caption_generator", 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /im2txt/im2txt/inference_utils/vocabulary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Vocabulary class for an image-to-text model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | import tensorflow as tf 23 | 24 | 25 | class Vocabulary(object): 26 | """Vocabulary class for an image-to-text model.""" 27 | 28 | def __init__(self, 29 | vocab_file, 30 | start_word="", 31 | end_word="", 32 | unk_word=""): 33 | """Initializes the vocabulary. 34 | 35 | Args: 36 | vocab_file: File containing the vocabulary, where the words are the first 37 | whitespace-separated token on each line (other tokens are ignored) and 38 | the word ids are the corresponding line numbers. 39 | start_word: Special word denoting sentence start. 40 | end_word: Special word denoting sentence end. 41 | unk_word: Special word denoting unknown words. 42 | """ 43 | if not tf.gfile.Exists(vocab_file): 44 | tf.logging.fatal("Vocab file %s not found.", vocab_file) 45 | tf.logging.info("Initializing vocabulary from file: %s", vocab_file) 46 | 47 | with tf.gfile.GFile(vocab_file, mode="r") as f: 48 | reverse_vocab = list(f.readlines()) 49 | reverse_vocab = [line.split()[0] for line in reverse_vocab] 50 | assert start_word in reverse_vocab 51 | assert end_word in reverse_vocab 52 | if unk_word not in reverse_vocab: 53 | reverse_vocab.append(unk_word) 54 | vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)]) 55 | 56 | tf.logging.info("Created vocabulary with %d words" % len(vocab)) 57 | 58 | self.vocab = vocab # vocab[word] = id 59 | self.reverse_vocab = reverse_vocab # reverse_vocab[id] = word 60 | 61 | # Save special word ids. 62 | self.start_id = vocab[start_word] 63 | self.end_id = vocab[end_word] 64 | self.unk_id = vocab[unk_word] 65 | 66 | def word_to_id(self, word): 67 | """Returns the integer word id of a word string.""" 68 | if word in self.vocab: 69 | return self.vocab[word] 70 | else: 71 | return self.unk_id 72 | 73 | def id_to_word(self, word_id): 74 | """Returns the word string of an integer word id.""" 75 | if word_id >= len(self.reverse_vocab): 76 | return self.reverse_vocab[self.unk_id] 77 | else: 78 | return self.reverse_vocab[word_id] 79 | -------------------------------------------------------------------------------- /im2txt/im2txt/inference_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Model wrapper class for performing inference with a ShowAndTellModel.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | 24 | from im2txt import show_and_tell_model 25 | from im2txt.inference_utils import inference_wrapper_base 26 | 27 | 28 | class InferenceWrapper(inference_wrapper_base.InferenceWrapperBase): 29 | """Model wrapper class for performing inference with a ShowAndTellModel.""" 30 | 31 | def __init__(self): 32 | super(InferenceWrapper, self).__init__() 33 | 34 | def build_model(self, model_config): 35 | model = show_and_tell_model.ShowAndTellModel(model_config, mode="inference") 36 | model.build() 37 | return model 38 | 39 | def feed_image(self, sess, encoded_image): 40 | initial_state = sess.run(fetches="lstm/initial_state:0", 41 | feed_dict={"image_feed:0": encoded_image}) 42 | return initial_state 43 | 44 | def inference_step(self, sess, input_feed, state_feed): 45 | softmax_output, state_output = sess.run( 46 | fetches=["softmax:0", "lstm/state:0"], 47 | feed_dict={ 48 | "input_feed:0": input_feed, 49 | "lstm/state_feed:0": state_feed, 50 | }) 51 | return softmax_output, state_output, None 52 | -------------------------------------------------------------------------------- /im2txt/im2txt/ops/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//im2txt:internal"]) 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | exports_files(["LICENSE"]) 6 | 7 | py_library( 8 | name = "image_processing", 9 | srcs = ["image_processing.py"], 10 | srcs_version = "PY2AND3", 11 | ) 12 | 13 | py_library( 14 | name = "image_embedding", 15 | srcs = ["image_embedding.py"], 16 | srcs_version = "PY2AND3", 17 | ) 18 | 19 | py_test( 20 | name = "image_embedding_test", 21 | size = "small", 22 | srcs = ["image_embedding_test.py"], 23 | deps = [ 24 | ":image_embedding", 25 | ], 26 | ) 27 | 28 | py_library( 29 | name = "inputs", 30 | srcs = ["inputs.py"], 31 | srcs_version = "PY2AND3", 32 | ) 33 | -------------------------------------------------------------------------------- /im2txt/im2txt/run_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Generate captions for images using default beam search parameters.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import os 23 | 24 | 25 | import tensorflow as tf 26 | 27 | from im2txt import configuration 28 | from im2txt import inference_wrapper 29 | from im2txt.inference_utils import caption_generator 30 | from im2txt.inference_utils import vocabulary 31 | 32 | FLAGS = tf.flags.FLAGS 33 | 34 | tf.flags.DEFINE_string("checkpoint_path", "", 35 | "Model checkpoint file or directory containing a " 36 | "model checkpoint file.") 37 | tf.flags.DEFINE_string("vocab_file", "", "Text file containing the vocabulary.") 38 | tf.flags.DEFINE_string("input_files", "", 39 | "File pattern or comma-separated list of file patterns " 40 | "of image files.") 41 | 42 | 43 | def main(_): 44 | # Build the inference graph. 45 | g = tf.Graph() 46 | with g.as_default(): 47 | model = inference_wrapper.InferenceWrapper() 48 | restore_fn = model.build_graph_from_config(configuration.ModelConfig(), 49 | FLAGS.checkpoint_path) 50 | g.finalize() 51 | 52 | # Create the vocabulary. 53 | vocab = vocabulary.Vocabulary(FLAGS.vocab_file) 54 | 55 | filenames = [] 56 | for file_pattern in FLAGS.input_files.split(","): 57 | filenames.extend(tf.gfile.Glob(file_pattern)) 58 | tf.logging.info("Running caption generation on %d files matching %s", 59 | len(filenames), FLAGS.input_files) 60 | 61 | with tf.Session(graph=g) as sess: 62 | # Load the model from checkpoint. 63 | restore_fn(sess) 64 | 65 | # Prepare the caption generator. Here we are implicitly using the default 66 | # beam search parameters. See caption_generator.py for a description of the 67 | # available beam search parameters. 68 | generator = caption_generator.CaptionGenerator(model, vocab) 69 | 70 | for filename in filenames: 71 | with tf.gfile.GFile(filename, "r") as f: 72 | image = f.read() 73 | captions = generator.beam_search(sess, image) 74 | print("Captions for image %s:" % os.path.basename(filename)) 75 | for i, caption in enumerate(captions): 76 | # Ignore begin and end words. 77 | sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]] 78 | sentence = " ".join(sentence) 79 | print(" %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob))) 80 | 81 | 82 | if __name__ == "__main__": 83 | tf.app.run() 84 | -------------------------------------------------------------------------------- /inception/.gitignore: -------------------------------------------------------------------------------- 1 | /bazel-bin 2 | /bazel-ci_build-cache 3 | /bazel-genfiles 4 | /bazel-out 5 | /bazel-inception 6 | /bazel-testlogs 7 | /bazel-tf 8 | -------------------------------------------------------------------------------- /inception/WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "inception") 2 | -------------------------------------------------------------------------------- /inception/g3doc/inception_v3_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/inception/g3doc/inception_v3_architecture.png -------------------------------------------------------------------------------- /inception/inception/data/preprocess_imagenet_validation_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Process the ImageNet Challenge bounding boxes for TensorFlow model training. 17 | 18 | Associate the ImageNet 2012 Challenge validation data set with labels. 19 | 20 | The raw ImageNet validation data set is expected to reside in JPEG files 21 | located in the following directory structure. 22 | 23 | data_dir/ILSVRC2012_val_00000001.JPEG 24 | data_dir/ILSVRC2012_val_00000002.JPEG 25 | ... 26 | data_dir/ILSVRC2012_val_00050000.JPEG 27 | 28 | This script moves the files into a directory structure like such: 29 | data_dir/n01440764/ILSVRC2012_val_00000293.JPEG 30 | data_dir/n01440764/ILSVRC2012_val_00000543.JPEG 31 | ... 32 | where 'n01440764' is the unique synset label associated with 33 | these images. 34 | 35 | This directory reorganization requires a mapping from validation image 36 | number (i.e. suffix of the original file) to the associated label. This 37 | is provided in the ImageNet development kit via a Matlab file. 38 | 39 | In order to make life easier and divorce ourselves from Matlab, we instead 40 | supply a custom text file that provides this mapping for us. 41 | 42 | Sample usage: 43 | ./preprocess_imagenet_validation_data.py ILSVRC2012_img_val \ 44 | imagenet_2012_validation_synset_labels.txt 45 | """ 46 | 47 | from __future__ import absolute_import 48 | from __future__ import division 49 | from __future__ import print_function 50 | 51 | import os 52 | import os.path 53 | import sys 54 | 55 | 56 | if __name__ == '__main__': 57 | if len(sys.argv) < 3: 58 | print('Invalid usage\n' 59 | 'usage: preprocess_imagenet_validation_data.py ' 60 | ' ') 61 | sys.exit(-1) 62 | data_dir = sys.argv[1] 63 | validation_labels_file = sys.argv[2] 64 | 65 | # Read in the 50000 synsets associated with the validation data set. 66 | labels = [l.strip() for l in open(validation_labels_file).readlines()] 67 | unique_labels = set(labels) 68 | 69 | # Make all sub-directories in the validation data dir. 70 | for label in unique_labels: 71 | labeled_data_dir = os.path.join(data_dir, label) 72 | os.makedirs(labeled_data_dir) 73 | 74 | # Move all of the image to the appropriate sub-directory. 75 | for i in xrange(len(labels)): 76 | basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1) 77 | original_filename = os.path.join(data_dir, basename) 78 | if not os.path.exists(original_filename): 79 | print('Failed to find: ' % original_filename) 80 | sys.exit(-1) 81 | new_filename = os.path.join(data_dir, labels[i], basename) 82 | os.rename(original_filename, new_filename) 83 | -------------------------------------------------------------------------------- /inception/inception/flowers_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Small library that points to the flowers data set. 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | 23 | from inception.dataset import Dataset 24 | 25 | 26 | class FlowersData(Dataset): 27 | """Flowers data set.""" 28 | 29 | def __init__(self, subset): 30 | super(FlowersData, self).__init__('Flowers', subset) 31 | 32 | def num_classes(self): 33 | """Returns the number of classes in the data set.""" 34 | return 5 35 | 36 | def num_examples_per_epoch(self): 37 | """Returns the number of examples in the data subset.""" 38 | if self.subset == 'train': 39 | return 3170 40 | if self.subset == 'validation': 41 | return 500 42 | 43 | def download_message(self): 44 | """Instruction to download and extract the tarball from Flowers website.""" 45 | 46 | print('Failed to find any Flowers %s files'% self.subset) 47 | print('') 48 | print('If you have already downloaded and processed the data, then make ' 49 | 'sure to set --data_dir to point to the directory containing the ' 50 | 'location of the sharded TFRecords.\n') 51 | print('Please see README.md for instructions on how to build ' 52 | 'the flowers dataset using download_and_preprocess_flowers.\n') 53 | -------------------------------------------------------------------------------- /inception/inception/flowers_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A binary to evaluate Inception on the flowers data set. 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | import tensorflow as tf 23 | 24 | from inception import inception_eval 25 | from inception.flowers_data import FlowersData 26 | 27 | FLAGS = tf.app.flags.FLAGS 28 | 29 | 30 | def main(unused_argv=None): 31 | dataset = FlowersData(subset=FLAGS.subset) 32 | assert dataset.data_files() 33 | if tf.gfile.Exists(FLAGS.eval_dir): 34 | tf.gfile.DeleteRecursively(FLAGS.eval_dir) 35 | tf.gfile.MakeDirs(FLAGS.eval_dir) 36 | inception_eval.evaluate(dataset) 37 | 38 | 39 | if __name__ == '__main__': 40 | tf.app.run() 41 | -------------------------------------------------------------------------------- /inception/inception/flowers_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A binary to train Inception on the flowers data set. 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from inception import inception_train 26 | from inception.flowers_data import FlowersData 27 | 28 | FLAGS = tf.app.flags.FLAGS 29 | 30 | 31 | def main(_): 32 | dataset = FlowersData(subset=FLAGS.subset) 33 | assert dataset.data_files() 34 | if tf.gfile.Exists(FLAGS.train_dir): 35 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 36 | tf.gfile.MakeDirs(FLAGS.train_dir) 37 | inception_train.train(dataset) 38 | 39 | 40 | if __name__ == '__main__': 41 | tf.app.run() 42 | -------------------------------------------------------------------------------- /inception/inception/imagenet_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Small library that points to the ImageNet data set. 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | 23 | from inception.dataset import Dataset 24 | 25 | 26 | class ImagenetData(Dataset): 27 | """ImageNet data set.""" 28 | 29 | def __init__(self, subset): 30 | super(ImagenetData, self).__init__('ImageNet', subset) 31 | 32 | def num_classes(self): 33 | """Returns the number of classes in the data set.""" 34 | return 1000 35 | 36 | def num_examples_per_epoch(self): 37 | """Returns the number of examples in the data set.""" 38 | # Bounding box data consists of 615299 bounding boxes for 544546 images. 39 | if self.subset == 'train': 40 | return 1281167 41 | if self.subset == 'validation': 42 | return 50000 43 | 44 | def download_message(self): 45 | """Instruction to download and extract the tarball from Flowers website.""" 46 | 47 | print('Failed to find any ImageNet %s files'% self.subset) 48 | print('') 49 | print('If you have already downloaded and processed the data, then make ' 50 | 'sure to set --data_dir to point to the directory containing the ' 51 | 'location of the sharded TFRecords.\n') 52 | print('If you have not downloaded and prepared the ImageNet data in the ' 53 | 'TFRecord format, you will need to do this at least once. This ' 54 | 'process could take several hours depending on the speed of your ' 55 | 'computer and network connection\n') 56 | print('Please see README.md for instructions on how to build ' 57 | 'the ImageNet dataset using download_and_preprocess_imagenet.\n') 58 | print('Note that the raw data size is 300 GB and the processed data size ' 59 | 'is 150 GB. Please ensure you have at least 500GB disk space.') 60 | -------------------------------------------------------------------------------- /inception/inception/imagenet_distributed_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # pylint: disable=line-too-long 16 | """A binary to train Inception in a distributed manner using multiple systems. 17 | 18 | Please see accompanying README.md for details and instructions. 19 | """ 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import tensorflow as tf 25 | 26 | from inception import inception_distributed_train 27 | from inception.imagenet_data import ImagenetData 28 | 29 | FLAGS = tf.app.flags.FLAGS 30 | 31 | 32 | def main(unused_args): 33 | assert FLAGS.job_name in ['ps', 'worker'], 'job_name must be ps or worker' 34 | 35 | # Extract all the hostnames for the ps and worker jobs to construct the 36 | # cluster spec. 37 | ps_hosts = FLAGS.ps_hosts.split(',') 38 | worker_hosts = FLAGS.worker_hosts.split(',') 39 | tf.logging.info('PS hosts are: %s' % ps_hosts) 40 | tf.logging.info('Worker hosts are: %s' % worker_hosts) 41 | 42 | cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts, 43 | 'worker': worker_hosts}) 44 | server = tf.train.Server( 45 | {'ps': ps_hosts, 46 | 'worker': worker_hosts}, 47 | job_name=FLAGS.job_name, 48 | task_index=FLAGS.task_id) 49 | 50 | if FLAGS.job_name == 'ps': 51 | # `ps` jobs wait for incoming connections from the workers. 52 | server.join() 53 | else: 54 | # `worker` jobs will actually do the work. 55 | dataset = ImagenetData(subset=FLAGS.subset) 56 | assert dataset.data_files() 57 | # Only the chief checks for or creates train_dir. 58 | if FLAGS.task_id == 0: 59 | if not tf.gfile.Exists(FLAGS.train_dir): 60 | tf.gfile.MakeDirs(FLAGS.train_dir) 61 | inception_distributed_train.train(server.target, dataset, cluster_spec) 62 | 63 | if __name__ == '__main__': 64 | tf.logging.set_verbosity(tf.logging.INFO) 65 | tf.app.run() 66 | -------------------------------------------------------------------------------- /inception/inception/imagenet_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A binary to evaluate Inception on the flowers data set. 16 | 17 | Note that using the supplied pre-trained inception checkpoint, the eval should 18 | achieve: 19 | precision @ 1 = 0.7874 recall @ 5 = 0.9436 [50000 examples] 20 | 21 | See the README.md for more details. 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | 28 | import tensorflow as tf 29 | 30 | from inception import inception_eval 31 | from inception.imagenet_data import ImagenetData 32 | 33 | FLAGS = tf.app.flags.FLAGS 34 | 35 | 36 | def main(unused_argv=None): 37 | dataset = ImagenetData(subset=FLAGS.subset) 38 | assert dataset.data_files() 39 | if tf.gfile.Exists(FLAGS.eval_dir): 40 | tf.gfile.DeleteRecursively(FLAGS.eval_dir) 41 | tf.gfile.MakeDirs(FLAGS.eval_dir) 42 | inception_eval.evaluate(dataset) 43 | 44 | 45 | if __name__ == '__main__': 46 | tf.app.run() 47 | -------------------------------------------------------------------------------- /inception/inception/imagenet_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A binary to train Inception on the ImageNet data set. 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from inception import inception_train 26 | from inception.imagenet_data import ImagenetData 27 | 28 | FLAGS = tf.app.flags.FLAGS 29 | 30 | 31 | def main(_): 32 | dataset = ImagenetData(subset=FLAGS.subset) 33 | assert dataset.data_files() 34 | if tf.gfile.Exists(FLAGS.train_dir): 35 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 36 | tf.gfile.MakeDirs(FLAGS.train_dir) 37 | inception_train.train(dataset) 38 | 39 | 40 | if __name__ == '__main__': 41 | tf.app.run() 42 | -------------------------------------------------------------------------------- /inception/inception/slim/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Contains the operations and nets for building TensorFlow-Slim models. 3 | 4 | package(default_visibility = ["//inception:internal"]) 5 | 6 | licenses(["notice"]) # Apache 2.0 7 | 8 | exports_files(["LICENSE"]) 9 | 10 | py_library( 11 | name = "scopes", 12 | srcs = ["scopes.py"], 13 | ) 14 | 15 | py_test( 16 | name = "scopes_test", 17 | size = "small", 18 | srcs = ["scopes_test.py"], 19 | deps = [ 20 | ":scopes", 21 | ], 22 | ) 23 | 24 | py_library( 25 | name = "variables", 26 | srcs = ["variables.py"], 27 | deps = [ 28 | ":scopes", 29 | ], 30 | ) 31 | 32 | py_test( 33 | name = "variables_test", 34 | size = "small", 35 | srcs = ["variables_test.py"], 36 | deps = [ 37 | ":variables", 38 | ], 39 | ) 40 | 41 | py_library( 42 | name = "losses", 43 | srcs = ["losses.py"], 44 | ) 45 | 46 | py_test( 47 | name = "losses_test", 48 | size = "small", 49 | srcs = ["losses_test.py"], 50 | deps = [ 51 | ":losses", 52 | ], 53 | ) 54 | 55 | py_library( 56 | name = "ops", 57 | srcs = ["ops.py"], 58 | deps = [ 59 | ":losses", 60 | ":scopes", 61 | ":variables", 62 | ], 63 | ) 64 | 65 | py_test( 66 | name = "ops_test", 67 | size = "small", 68 | srcs = ["ops_test.py"], 69 | deps = [ 70 | ":ops", 71 | ":variables", 72 | ], 73 | ) 74 | 75 | py_library( 76 | name = "inception", 77 | srcs = ["inception_model.py"], 78 | deps = [ 79 | ":ops", 80 | ":scopes", 81 | ], 82 | ) 83 | 84 | py_test( 85 | name = "inception_test", 86 | size = "medium", 87 | srcs = ["inception_test.py"], 88 | deps = [ 89 | ":inception", 90 | ], 91 | ) 92 | 93 | py_library( 94 | name = "slim", 95 | srcs = ["slim.py"], 96 | deps = [ 97 | ":inception", 98 | ":losses", 99 | ":ops", 100 | ":scopes", 101 | ":variables", 102 | ], 103 | ) 104 | 105 | py_test( 106 | name = "collections_test", 107 | size = "small", 108 | srcs = ["collections_test.py"], 109 | deps = [ 110 | ":slim", 111 | ], 112 | ) 113 | -------------------------------------------------------------------------------- /inception/inception/slim/slim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF-Slim grouped API. Please see README.md for details and usage.""" 16 | # pylint: disable=unused-import 17 | 18 | # Collapse tf-slim into a single namespace. 19 | from inception.slim import inception_model as inception 20 | from inception.slim import losses 21 | from inception.slim import ops 22 | from inception.slim import scopes 23 | from inception.slim import variables 24 | from inception.slim.scopes import arg_scope 25 | -------------------------------------------------------------------------------- /lm_1b/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = [":internal"]) 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | exports_files(["LICENSE"]) 6 | 7 | package_group( 8 | name = "internal", 9 | packages = [ 10 | "//lm_1b/...", 11 | ], 12 | ) 13 | 14 | py_library( 15 | name = "data_utils", 16 | srcs = ["data_utils.py"], 17 | ) 18 | 19 | py_binary( 20 | name = "lm_1b_eval", 21 | srcs = [ 22 | "lm_1b_eval.py", 23 | ], 24 | deps = [ 25 | ":data_utils", 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /namignizer/.gitignore: -------------------------------------------------------------------------------- 1 | # Remove the pyc files 2 | *.pyc 3 | 4 | # Ignore the model and the data 5 | model/ 6 | data/ 7 | -------------------------------------------------------------------------------- /namignizer/README.md: -------------------------------------------------------------------------------- 1 | # Namignizer 2 | 3 | Use a variation of the [PTB](https://www.tensorflow.org/versions/r0.8/tutorials/recurrent/index.html#recurrent-neural-networks) model to recognize and generate names using the [Kaggle Baby Name Database](https://www.kaggle.com/kaggle/us-baby-names). 4 | 5 | ### API 6 | Namignizer is implemented in Tensorflow 0.8r and uses the python package `pandas` for some data processing. 7 | 8 | #### How to use 9 | Download the data from Kaggle and place it in your data directory (or use the small training data provided). The example data looks like so: 10 | 11 | ``` 12 | Id,Name,Year,Gender,Count 13 | 1,Mary,1880,F,7065 14 | 2,Anna,1880,F,2604 15 | 3,Emma,1880,F,2003 16 | 4,Elizabeth,1880,F,1939 17 | 5,Minnie,1880,F,1746 18 | 6,Margaret,1880,F,1578 19 | 7,Ida,1880,F,1472 20 | 8,Alice,1880,F,1414 21 | 9,Bertha,1880,F,1320 22 | ``` 23 | 24 | But any data with the two columns: `Name` and `Count` will work. 25 | 26 | With the data, we can then train the model: 27 | 28 | ```python 29 | train("data/SmallNames.txt", "model/namignizer", SmallConfig) 30 | ``` 31 | 32 | And you will get the output: 33 | 34 | ``` 35 | Reading Name data in data/SmallNames.txt 36 | Epoch: 1 Learning rate: 1.000 37 | 0.090 perplexity: 18.539 speed: 282 lps 38 | ... 39 | 0.890 perplexity: 1.478 speed: 285 lps 40 | 0.990 perplexity: 1.477 speed: 284 lps 41 | Epoch: 13 Train Perplexity: 1.477 42 | ``` 43 | 44 | This will as a side effect write model checkpoints to the `model` directory. With this you will be able to determine the perplexity your model will give you for any arbitrary set of names like so: 45 | 46 | ```python 47 | namignize(["mary", "ida", "gazorpazorp", "houyhnhnms", "bob"], 48 | tf.train.latest_checkpoint("model"), SmallConfig) 49 | ``` 50 | You will provide the same config and the same checkpoint directory. This will allow you to use a the model you just trained. You will then get a perplexity output for each name like so: 51 | 52 | ``` 53 | Name mary gives us a perplexity of 1.03105580807 54 | Name ida gives us a perplexity of 1.07770049572 55 | Name gazorpazorp gives us a perplexity of 175.940353394 56 | Name houyhnhnms gives us a perplexity of 9.53870773315 57 | Name bob gives us a perplexity of 6.03938627243 58 | ``` 59 | 60 | Finally, you will also be able generate names using the model like so: 61 | 62 | ```python 63 | namignator(tf.train.latest_checkpoint("model"), SmallConfig) 64 | ``` 65 | 66 | Again, you will need to provide the same config and the same checkpoint directory. This will allow you to use a the model you just trained. You will then get a single generated name. Examples of output that I got when using the provided data are: 67 | 68 | ``` 69 | ['b', 'e', 'r', 't', 'h', 'a', '`'] 70 | ['m', 'a', 'r', 'y', '`'] 71 | ['a', 'n', 'n', 'a', '`'] 72 | ['m', 'a', 'r', 'y', '`'] 73 | ['b', 'e', 'r', 't', 'h', 'a', '`'] 74 | ['a', 'n', 'n', 'a', '`'] 75 | ['e', 'l', 'i', 'z', 'a', 'b', 'e', 't', 'h', '`'] 76 | ``` 77 | 78 | Notice that each name ends with a backtick. This marks the end of the name. 79 | 80 | ### Contact Info 81 | 82 | Feel free to reach out to me at knt(at google) or k.nathaniel.tucker(at gmail) 83 | -------------------------------------------------------------------------------- /neural_gpu/README.md: -------------------------------------------------------------------------------- 1 | # NeuralGPU 2 | Code for the Neural GPU model as described 3 | in [[http://arxiv.org/abs/1511.08228]]. 4 | 5 | Requirements: 6 | * TensorFlow (see tensorflow.org for how to install) 7 | * Matplotlib for Python (sudo apt-get install python-matplotlib) 8 | 9 | The model can be trained on the following algorithmic tasks: 10 | 11 | * `sort` - Sort a symbol list 12 | * `kvsort` - Sort symbol keys in dictionary 13 | * `id` - Return the same symbol list 14 | * `rev` - Reverse a symbol list 15 | * `rev2` - Reverse a symbol dictionary by key 16 | * `incr` - Add one to a symbol value 17 | * `add` - Long decimal addition 18 | * `left` - First symbol in list 19 | * `right` - Last symbol in list 20 | * `left-shift` - Left shift a symbol list 21 | * `right-shift` - Right shift a symbol list 22 | * `bmul` - Long binary multiplication 23 | * `mul` - Long decimal multiplication 24 | * `dup` - Duplicate a symbol list with padding 25 | * `badd` - Long binary addition 26 | * `qadd` - Long quaternary addition 27 | * `search` - Search for symbol key in dictionary 28 | 29 | The value range for symbols are defined by the `niclass` and `noclass` flags. 30 | In particular, the values are in the range `min(--niclass, noclass) - 1`. 31 | So if you set `--niclass=33` and `--noclass=33` (the default) then `--task=rev` 32 | will be reversing lists of 32 symbols, and `--task=id` will be identity on a 33 | list of up to 32 symbols. 34 | 35 | 36 | To train the model on the reverse task run: 37 | 38 | ``` 39 | python neural_gpu_trainer.py --task=rev 40 | ``` 41 | 42 | While training, interim / checkpoint model parameters will be 43 | written to `/tmp/neural_gpu/`. 44 | 45 | Once the amount of error gets down to what you're comfortable 46 | with, hit `Ctrl-C` to stop the training process. The latest 47 | model parameters will be in `/tmp/neural_gpu/neural_gpu.ckpt-` 48 | and used on any subsequent run. 49 | 50 | To test a trained model on how well it decodes run: 51 | 52 | ``` 53 | python neural_gpu_trainer.py --task=rev --mode=1 54 | ``` 55 | 56 | To produce an animation of the result run: 57 | 58 | ``` 59 | python neural_gpu_trainer.py --task=rev --mode=1 --animate=True 60 | ``` 61 | 62 | Maintained by Lukasz Kaiser (lukaszkaiser) 63 | -------------------------------------------------------------------------------- /privacy/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import numpy as np 21 | 22 | 23 | def accuracy(logits, labels): 24 | """ 25 | Return accuracy of the array of logits (or label predictions) wrt the labels 26 | :param logits: this can either be logits, probabilities, or a single label 27 | :param labels: the correct labels to match against 28 | :return: the accuracy as a float 29 | """ 30 | assert len(logits) == len(labels) 31 | 32 | if len(np.shape(logits)) > 1: 33 | # Predicted labels are the argmax over axis 1 34 | predicted_labels = np.argmax(logits, axis=1) 35 | else: 36 | # Input was already labels 37 | assert len(np.shape(logits)) == 1 38 | predicted_labels = logits 39 | 40 | # Check against correct labels to compute correct guesses 41 | correct = np.sum(predicted_labels == labels.reshape(len(labels))) 42 | 43 | # Divide by number of labels to obtain accuracy 44 | accuracy = float(correct) / len(labels) 45 | 46 | # Return float value 47 | return accuracy 48 | 49 | 50 | -------------------------------------------------------------------------------- /privacy/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | def batch_indices(batch_nb, data_length, batch_size): 18 | """ 19 | This helper function computes a batch start and end index 20 | :param batch_nb: the batch number 21 | :param data_length: the total length of the data being parsed by batches 22 | :param batch_size: the number of inputs in each batch 23 | :return: pair of (start, end) indices 24 | """ 25 | # Batch start and end index 26 | start = int(batch_nb * batch_size) 27 | end = int((batch_nb + 1) * batch_size) 28 | 29 | # When there are not enough inputs left, we reuse some to complete the batch 30 | if end > data_length: 31 | shift = end - data_length 32 | start -= shift 33 | end -= shift 34 | 35 | return start, end 36 | -------------------------------------------------------------------------------- /resnet/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = [":internal"]) 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | exports_files(["LICENSE"]) 6 | 7 | package_group( 8 | name = "internal", 9 | packages = [ 10 | "//resnet/...", 11 | ], 12 | ) 13 | 14 | filegroup( 15 | name = "py_srcs", 16 | data = glob([ 17 | "**/*.py", 18 | ]), 19 | ) 20 | 21 | py_library( 22 | name = "resnet_model", 23 | srcs = ["resnet_model.py"], 24 | ) 25 | 26 | py_binary( 27 | name = "resnet_main", 28 | srcs = [ 29 | "resnet_main.py", 30 | ], 31 | deps = [ 32 | ":cifar_input", 33 | ":resnet_model", 34 | ], 35 | ) 36 | 37 | py_library( 38 | name = "cifar_input", 39 | srcs = ["cifar_input.py"], 40 | ) 41 | -------------------------------------------------------------------------------- /resnet/README.md: -------------------------------------------------------------------------------- 1 | Reproduced ResNet on CIFAR-10 and CIFAR-100 dataset. 2 | 3 | contact: panyx0718 (xpan@google.com) 4 | 5 | Dataset: 6 | 7 | https://www.cs.toronto.edu/~kriz/cifar.html 8 | 9 | Related papers: 10 | 11 | Identity Mappings in Deep Residual Networks 12 | 13 | https://arxiv.org/pdf/1603.05027v2.pdf 14 | 15 | Deep Residual Learning for Image Recognition 16 | 17 | https://arxiv.org/pdf/1512.03385v1.pdf 18 | 19 | Wide Residual Networks 20 | 21 | https://arxiv.org/pdf/1605.07146v1.pdf 22 | 23 | Settings: 24 | 25 | * Random split 50k training set into 45k/5k train/eval split. 26 | * Pad to 36x36 and random crop. Horizontal flip. Per-image whitenting. 27 | * Momentum optimizer 0.9. 28 | * Learning rate schedule: 0.1 (40k), 0.01 (60k), 0.001 (>60k). 29 | * L2 weight decay: 0.002. 30 | * Batch size: 128. (28-10 wide and 1001 layer bottleneck use 64) 31 | 32 | Results: 33 | 34 | 35 | ![Precisions](g3doc/cifar_resnet.gif) 36 | 37 | 38 | ![Precisions Legends](g3doc/cifar_resnet_legends.gif) 39 | 40 | 41 | 42 | CIFAR-10 Model|Best Precision|Steps 43 | --------------|--------------|------ 44 | 32 layer|92.5%|~80k 45 | 110 layer|93.6%|~80k 46 | 164 layer bottleneck|94.5%|~80k 47 | 1001 layer bottleneck|94.9%|~80k 48 | 28-10 wide|95%|~90k 49 | 50 | CIFAR-100 Model|Best Precision|Steps 51 | ---------------|--------------|----- 52 | 32 layer|68.1%|~45k 53 | 110 layer|71.3%|~60k 54 | 164 layer bottleneck|75.7%|~50k 55 | 1001 layer bottleneck|78.2%|~70k 56 | 28-10 wide|78.3%|~70k 57 | 58 | Prerequisite: 59 | 60 | 1. Install TensorFlow, Bazel. 61 | 62 | 2. Download CIFAR-10/CIFAR-100 dataset. 63 | 64 | ```shell 65 | curl -o cifar-10-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz 66 | curl -o cifar-100-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz 67 | ``` 68 | 69 | How to run: 70 | 71 | ```shell 72 | # cd to the your workspace. 73 | # It contains an empty WORKSPACE file, resnet codes and cifar10 dataset. 74 | ls -R 75 | .: 76 | cifar10 resnet WORKSPACE 77 | 78 | ./cifar10: 79 | test.bin train.bin validation.bin 80 | 81 | ./resnet: 82 | BUILD cifar_input.py g3doc README.md resnet_main.py resnet_model.py 83 | 84 | # Build everything for GPU. 85 | bazel build -c opt --config=cuda resnet/... 86 | 87 | # Train the model. 88 | bazel-bin/resnet/resnet_main --train_data_path=cifar10/train.bin \ 89 | --log_root=/tmp/resnet_model \ 90 | --train_dir=/tmp/resnet_model/train \ 91 | --dataset='cifar10' \ 92 | --num_gpus=1 93 | 94 | # Evaluate the model. 95 | # Avoid running on the same GPU as the training job at the same time, 96 | # otherwise, you might run out of memory. 97 | bazel-bin/resnet/resnet_main --eval_data_path=cifar10/test.bin \ 98 | --log_root=/tmp/resnet_model \ 99 | --eval_dir=/tmp/resnet_model/test \ 100 | --mode=eval \ 101 | --dataset='cifar10' \ 102 | --num_gpus=0 103 | ``` 104 | -------------------------------------------------------------------------------- /resnet/g3doc/cifar_resnet.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/resnet/g3doc/cifar_resnet.gif -------------------------------------------------------------------------------- /resnet/g3doc/cifar_resnet_legends.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/resnet/g3doc/cifar_resnet_legends.gif -------------------------------------------------------------------------------- /slim/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import cifar10 22 | from datasets import flowers 23 | from datasets import imagenet 24 | from datasets import mnist 25 | 26 | datasets_map = { 27 | 'cifar10': cifar10, 28 | 'flowers': flowers, 29 | 'imagenet': imagenet, 30 | 'mnist': mnist, 31 | } 32 | 33 | 34 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 35 | """Given a dataset name and a split_name returns a Dataset. 36 | 37 | Args: 38 | name: String, the name of the dataset. 39 | split_name: A train/test split name. 40 | dataset_dir: The directory where the dataset files are stored. 41 | file_pattern: The file pattern to use for matching the dataset source files. 42 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 43 | reader defined by each dataset is used. 44 | 45 | Returns: 46 | A `Dataset` class. 47 | 48 | Raises: 49 | ValueError: If the dataset `name` is unknown. 50 | """ 51 | if name not in datasets_map: 52 | raise ValueError('Name of dataset unknown %s' % name) 53 | return datasets_map[name].get_split( 54 | split_name, 55 | dataset_dir, 56 | file_pattern, 57 | reader) 58 | -------------------------------------------------------------------------------- /slim/deployment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/download_and_convert_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts a particular dataset. 16 | 17 | Usage: 18 | ```shell 19 | 20 | $ python download_and_convert_data.py \ 21 | --dataset_name=mnist \ 22 | --dataset_dir=/tmp/mnist 23 | 24 | $ python download_and_convert_data.py \ 25 | --dataset_name=cifar10 \ 26 | --dataset_dir=/tmp/cifar10 27 | 28 | $ python download_and_convert_data.py \ 29 | --dataset_name=flowers \ 30 | --dataset_dir=/tmp/flowers 31 | ``` 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import tensorflow as tf 38 | 39 | from datasets import download_and_convert_cifar10 40 | from datasets import download_and_convert_flowers 41 | from datasets import download_and_convert_mnist 42 | 43 | FLAGS = tf.app.flags.FLAGS 44 | 45 | tf.app.flags.DEFINE_string( 46 | 'dataset_name', 47 | None, 48 | 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".') 49 | 50 | tf.app.flags.DEFINE_string( 51 | 'dataset_dir', 52 | None, 53 | 'The directory where the output TFRecords and temporary files are saved.') 54 | 55 | 56 | def main(_): 57 | if not FLAGS.dataset_name: 58 | raise ValueError('You must supply the dataset name with --dataset_name') 59 | if not FLAGS.dataset_dir: 60 | raise ValueError('You must supply the dataset directory with --dataset_dir') 61 | 62 | if FLAGS.dataset_name == 'cifar10': 63 | download_and_convert_cifar10.run(FLAGS.dataset_dir) 64 | elif FLAGS.dataset_name == 'flowers': 65 | download_and_convert_flowers.run(FLAGS.dataset_dir) 66 | elif FLAGS.dataset_name == 'mnist': 67 | download_and_convert_mnist.run(FLAGS.dataset_dir) 68 | else: 69 | raise ValueError( 70 | 'dataset_name [%s] was not recognized.' % FLAGS.dataset_dir) 71 | 72 | if __name__ == '__main__': 73 | tf.app.run() 74 | 75 | -------------------------------------------------------------------------------- /slim/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings inception_v1, inception_v2 and inception_v3 under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_v1 import inception_v1 25 | from nets.inception_v1 import inception_v1_arg_scope 26 | from nets.inception_v1 import inception_v1_base 27 | from nets.inception_v2 import inception_v2 28 | from nets.inception_v2 import inception_v2_arg_scope 29 | from nets.inception_v2 import inception_v2_base 30 | from nets.inception_v3 import inception_v3 31 | from nets.inception_v3 import inception_v3_arg_scope 32 | from nets.inception_v3 import inception_v3_base 33 | # pylint: enable=unused-import 34 | -------------------------------------------------------------------------------- /slim/nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFn(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in nets_factory.networks_map: 34 | with self.test_session(): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | if __name__ == '__main__': 46 | tf.test.main() 47 | -------------------------------------------------------------------------------- /slim/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.sub(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /slim/preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_resnet_v2': inception_preprocessing, 54 | 'lenet': lenet_preprocessing, 55 | 'resnet_v1_50': vgg_preprocessing, 56 | 'resnet_v1_101': vgg_preprocessing, 57 | 'resnet_v1_152': vgg_preprocessing, 58 | 'vgg': vgg_preprocessing, 59 | 'vgg_a': vgg_preprocessing, 60 | 'vgg_16': vgg_preprocessing, 61 | 'vgg_19': vgg_preprocessing, 62 | } 63 | 64 | if name not in preprocessing_fn_map: 65 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 66 | 67 | def preprocessing_fn(image, output_height, output_width, **kwargs): 68 | return preprocessing_fn_map[name].preprocess_image( 69 | image, output_height, output_width, is_training=is_training, **kwargs) 70 | 71 | return preprocessing_fn 72 | -------------------------------------------------------------------------------- /slim/scripts/finetune_inception_v1_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # 1. Downloads the Flowers dataset 5 | # 2. Fine-tunes an InceptionV1 model on the Flowers training set. 6 | # 3. Evaluates the model on the Flowers validation set. 7 | # 8 | # Usage: 9 | # cd slim 10 | # ./slim/scripts/finetune_inception_v1_on_flowers.sh 11 | 12 | # Where the pre-trained InceptionV1 checkpoint is saved to. 13 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 14 | 15 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 16 | TRAIN_DIR=/tmp/flowers-models/inception_v1 17 | 18 | # Where the dataset is saved to. 19 | DATASET_DIR=/tmp/flowers 20 | 21 | # Download the pre-trained checkpoint. 22 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 23 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 24 | fi 25 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt ]; then 26 | wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz 27 | tar -xvf inception_v1_2016_08_28.tar.gz 28 | mv inception_v1.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt 29 | rm inception_v1_2016_08_28.tar.gz 30 | fi 31 | 32 | # Download the dataset 33 | python download_and_convert_data.py \ 34 | --dataset_name=flowers \ 35 | --dataset_dir=${DATASET_DIR} 36 | 37 | # Fine-tune only the new layers for 2000 steps. 38 | python train_image_classifier.py \ 39 | --train_dir=${TRAIN_DIR} \ 40 | --dataset_name=flowers \ 41 | --dataset_split_name=train \ 42 | --dataset_dir=${DATASET_DIR} \ 43 | --model_name=inception_v1 \ 44 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt \ 45 | --checkpoint_exclude_scopes=InceptionV1/Logits \ 46 | --trainable_scopes=InceptionV1/Logits \ 47 | --max_number_of_steps=3000 \ 48 | --batch_size=32 \ 49 | --learning_rate=0.01 \ 50 | --save_interval_secs=60 \ 51 | --save_summaries_secs=60 \ 52 | --log_every_n_steps=100 \ 53 | --optimizer=rmsprop \ 54 | --weight_decay=0.00004 55 | 56 | # Run evaluation. 57 | python eval_image_classifier.py \ 58 | --checkpoint_path=${TRAIN_DIR} \ 59 | --eval_dir=${TRAIN_DIR} \ 60 | --dataset_name=flowers \ 61 | --dataset_split_name=validation \ 62 | --dataset_dir=${DATASET_DIR} \ 63 | --model_name=inception_v1 64 | 65 | # Fine-tune all the new layers for 1000 steps. 66 | python train_image_classifier.py \ 67 | --train_dir=${TRAIN_DIR}/all \ 68 | --dataset_name=flowers \ 69 | --dataset_split_name=train \ 70 | --dataset_dir=${DATASET_DIR} \ 71 | --checkpoint_path=${TRAIN_DIR} \ 72 | --model_name=inception_v1 \ 73 | --max_number_of_steps=1000 \ 74 | --batch_size=32 \ 75 | --learning_rate=0.001 \ 76 | --save_interval_secs=60 \ 77 | --save_summaries_secs=60 \ 78 | --log_every_n_steps=100 \ 79 | --optimizer=rmsprop \ 80 | --weight_decay=0.00004 81 | 82 | # Run evaluation. 83 | python eval_image_classifier.py \ 84 | --checkpoint_path=${TRAIN_DIR}/all \ 85 | --eval_dir=${TRAIN_DIR}/all \ 86 | --dataset_name=flowers \ 87 | --dataset_split_name=validation \ 88 | --dataset_dir=${DATASET_DIR} \ 89 | --model_name=inception_v1 90 | -------------------------------------------------------------------------------- /slim/scripts/finetune_inception_v3_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # 1. Downloads the Flowers dataset 5 | # 2. Fine-tunes an InceptionV3 model on the Flowers training set. 6 | # 3. Evaluates the model on the Flowers validation set. 7 | # 8 | # Usage: 9 | # cd slim 10 | # ./slim/scripts/finetune_inceptionv3_on_flowers.sh 11 | 12 | # Where the pre-trained InceptionV3 checkpoint is saved to. 13 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 14 | 15 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 16 | TRAIN_DIR=/tmp/flowers-models/inception_v3 17 | 18 | # Where the dataset is saved to. 19 | DATASET_DIR=/tmp/flowers 20 | 21 | # Download the pre-trained checkpoint. 22 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 23 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 24 | fi 25 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt ]; then 26 | wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 27 | tar -xvf inception_v3_2016_08_28.tar.gz 28 | mv inception_v3.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt 29 | rm inception_v3_2016_08_28.tar.gz 30 | fi 31 | 32 | # Download the dataset 33 | python download_and_convert_data.py \ 34 | --dataset_name=flowers \ 35 | --dataset_dir=${DATASET_DIR} 36 | 37 | # Fine-tune only the new layers for 1000 steps. 38 | python train_image_classifier.py \ 39 | --train_dir=${TRAIN_DIR} \ 40 | --dataset_name=flowers \ 41 | --dataset_split_name=train \ 42 | --dataset_dir=${DATASET_DIR} \ 43 | --model_name=inception_v3 \ 44 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt \ 45 | --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 46 | --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 47 | --max_number_of_steps=1000 \ 48 | --batch_size=32 \ 49 | --learning_rate=0.01 \ 50 | --learning_rate_decay_type=fixed \ 51 | --save_interval_secs=60 \ 52 | --save_summaries_secs=60 \ 53 | --log_every_n_steps=100 \ 54 | --optimizer=rmsprop \ 55 | --weight_decay=0.00004 56 | 57 | # Run evaluation. 58 | python eval_image_classifier.py \ 59 | --checkpoint_path=${TRAIN_DIR} \ 60 | --eval_dir=${TRAIN_DIR} \ 61 | --dataset_name=flowers \ 62 | --dataset_split_name=validation \ 63 | --dataset_dir=${DATASET_DIR} \ 64 | --model_name=inception_v3 65 | 66 | # Fine-tune all the new layers for 500 steps. 67 | python train_image_classifier.py \ 68 | --train_dir=${TRAIN_DIR}/all \ 69 | --dataset_name=flowers \ 70 | --dataset_split_name=train \ 71 | --dataset_dir=${DATASET_DIR} \ 72 | --model_name=inception_v3 \ 73 | --checkpoint_path=${TRAIN_DIR} \ 74 | --max_number_of_steps=500 \ 75 | --batch_size=32 \ 76 | --learning_rate=0.0001 \ 77 | --learning_rate_decay_type=fixed \ 78 | --save_interval_secs=60 \ 79 | --save_summaries_secs=60 \ 80 | --log_every_n_steps=10 \ 81 | --optimizer=rmsprop \ 82 | --weight_decay=0.00004 83 | 84 | # Run evaluation. 85 | python eval_image_classifier.py \ 86 | --checkpoint_path=${TRAIN_DIR}/all \ 87 | --eval_dir=${TRAIN_DIR}/all \ 88 | --dataset_name=flowers \ 89 | --dataset_split_name=validation \ 90 | --dataset_dir=${DATASET_DIR} \ 91 | --model_name=inception_v3 92 | -------------------------------------------------------------------------------- /slim/scripts/train_cifarnet_on_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # 1. Downloads the Cifar10 dataset 5 | # 2. Trains a CifarNet model on the Cifar10 training set. 6 | # 3. Evaluates the model on the Cifar10 testing set. 7 | # 8 | # Usage: 9 | # cd slim 10 | # ./scripts/train_cifar_net_on_mnist.sh 11 | 12 | # Where the checkpoint and logs will be saved to. 13 | TRAIN_DIR=/tmp/cifarnet-model 14 | 15 | # Where the dataset is saved to. 16 | DATASET_DIR=/tmp/cifar10 17 | 18 | # Download the dataset 19 | python download_and_convert_data.py \ 20 | --dataset_name=cifar10 \ 21 | --dataset_dir=${DATASET_DIR} 22 | 23 | # Run training. 24 | python train_image_classifier.py \ 25 | --train_dir=${TRAIN_DIR} \ 26 | --dataset_name=cifar10 \ 27 | --dataset_split_name=train \ 28 | --dataset_dir=${DATASET_DIR} \ 29 | --model_name=cifarnet \ 30 | --preprocessing_name=cifarnet \ 31 | --max_number_of_steps=100000 \ 32 | --batch_size=128 \ 33 | --save_interval_secs=120 \ 34 | --save_summaries_secs=120 \ 35 | --log_every_n_steps=100 \ 36 | --optimizer=sgd \ 37 | --learning_rate=0.1 \ 38 | --learning_rate_decay_factor=0.1 \ 39 | --num_epochs_per_decay=200 \ 40 | --weight_decay=0.004 41 | 42 | # Run evaluation. 43 | python eval_image_classifier.py \ 44 | --checkpoint_path=${TRAIN_DIR} \ 45 | --eval_dir=${TRAIN_DIR} \ 46 | --dataset_name=cifar10 \ 47 | --dataset_split_name=test \ 48 | --dataset_dir=${DATASET_DIR} \ 49 | --model_name=cifarnet 50 | -------------------------------------------------------------------------------- /slim/scripts/train_lenet_on_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # 1. Downloads the MNIST dataset 5 | # 2. Trains a LeNet model on the MNIST training set. 6 | # 3. Evaluates the model on the MNIST testing set. 7 | # 8 | # Usage: 9 | # cd slim 10 | # ./slim/scripts/train_lenet_on_mnist.sh 11 | 12 | # Where the checkpoint and logs will be saved to. 13 | TRAIN_DIR=/tmp/lenet-model 14 | 15 | # Where the dataset is saved to. 16 | DATASET_DIR=/tmp/mnist 17 | 18 | # Download the dataset 19 | python download_and_convert_data.py \ 20 | --dataset_name=mnist \ 21 | --dataset_dir=${DATASET_DIR} 22 | 23 | # Run training. 24 | python train_image_classifier.py \ 25 | --train_dir=${TRAIN_DIR} \ 26 | --dataset_name=mnist \ 27 | --dataset_split_name=train \ 28 | --dataset_dir=${DATASET_DIR} \ 29 | --model_name=lenet \ 30 | --preprocessing_name=lenet \ 31 | --max_number_of_steps=20000 \ 32 | --batch_size=50 \ 33 | --learning_rate=0.01 \ 34 | --save_interval_secs=60 \ 35 | --save_summaries_secs=60 \ 36 | --log_every_n_steps=100 \ 37 | --optimizer=sgd \ 38 | --learning_rate_decay_type=fixed \ 39 | --weight_decay=0 40 | 41 | # Run evaluation. 42 | python eval_image_classifier.py \ 43 | --checkpoint_path=${TRAIN_DIR} \ 44 | --eval_dir=${TRAIN_DIR} \ 45 | --dataset_name=mnist \ 46 | --dataset_split_name=test \ 47 | --dataset_dir=${DATASET_DIR} \ 48 | --model_name=lenet 49 | -------------------------------------------------------------------------------- /swivel/.gitignore: -------------------------------------------------------------------------------- 1 | *.an.tab 2 | *.pyc 3 | *.ws.tab 4 | MEN.tar.gz 5 | Mtruk.csv 6 | SimLex-999.zip 7 | analogy 8 | fastprep 9 | myz_naacl13_test_set.tgz 10 | questions-words.txt 11 | rw.zip 12 | ws353simrel.tar.gz 13 | -------------------------------------------------------------------------------- /swivel/eval.mk: -------------------------------------------------------------------------------- 1 | # -*- Mode: Makefile -*- 2 | # 3 | # Copyright 2016 Google Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # This makefile pulls down the evaluation datasets and formats them uniformly. 18 | # Word similarity evaluations are formatted to contain exactly three columns: 19 | # the two words being compared and the human judgement. 20 | # 21 | # Use wordsim.py and analogy to run the actual evaluations. 22 | 23 | CXXFLAGS=-std=c++11 -m64 -mavx -g -Ofast -Wall 24 | LDLIBS=-lpthread -lm 25 | 26 | WORDSIM_EVALS= ws353sim.ws.tab \ 27 | ws353rel.ws.tab \ 28 | men.ws.tab \ 29 | mturk.ws.tab \ 30 | rarewords.ws.tab \ 31 | simlex999.ws.tab \ 32 | $(NULL) 33 | 34 | ANALOGY_EVALS= mikolov.an.tab \ 35 | msr.an.tab \ 36 | $(NULL) 37 | 38 | all: $(WORDSIM_EVALS) $(ANALOGY_EVALS) analogy 39 | 40 | ws353sim.ws.tab: ws353simrel.tar.gz 41 | tar Oxfz $^ wordsim353_sim_rel/wordsim_similarity_goldstandard.txt > $@ 42 | 43 | ws353rel.ws.tab: ws353simrel.tar.gz 44 | tar Oxfz $^ wordsim353_sim_rel/wordsim_relatedness_goldstandard.txt > $@ 45 | 46 | men.ws.tab: MEN.tar.gz 47 | tar Oxfz $^ MEN/MEN_dataset_natural_form_full | tr ' ' '\t' > $@ 48 | 49 | mturk.ws.tab: Mtruk.csv 50 | cat $^ | tr -d '\r' | tr ',' '\t' > $@ 51 | 52 | rarewords.ws.tab: rw.zip 53 | unzip -p $^ rw/rw.txt | cut -f1-3 -d $$'\t' > $@ 54 | 55 | simlex999.ws.tab: SimLex-999.zip 56 | unzip -p $^ SimLex-999/SimLex-999.txt \ 57 | | tail -n +2 | cut -f1,2,4 -d $$'\t' > $@ 58 | 59 | mikolov.an.tab: questions-words.txt 60 | egrep -v -E '^:' $^ | tr '[A-Z] ' '[a-z]\t' > $@ 61 | 62 | msr.an.tab: myz_naacl13_test_set.tgz 63 | tar Oxfz $^ test_set/word_relationship.questions | tr ' ' '\t' > /tmp/q 64 | tar Oxfz $^ test_set/word_relationship.answers | cut -f2 -d ' ' > /tmp/a 65 | paste /tmp/q /tmp/a > $@ 66 | rm -f /tmp/q /tmp/a 67 | 68 | 69 | # wget commands to fetch the datasets. Please see the original datasets for 70 | # appropriate references if you use these. 71 | ws353simrel.tar.gz: 72 | wget http://alfonseca.org/pubs/ws353simrel.tar.gz 73 | 74 | MEN.tar.gz: 75 | wget http://clic.cimec.unitn.it/~elia.bruni/resources/MEN.tar.gz 76 | 77 | Mtruk.csv: 78 | wget http://tx.technion.ac.il/~kirar/files/Mtruk.csv 79 | 80 | rw.zip: 81 | wget http://www-nlp.stanford.edu/~lmthang/morphoNLM/rw.zip 82 | 83 | SimLex-999.zip: 84 | wget http://www.cl.cam.ac.uk/~fh295/SimLex-999.zip 85 | 86 | questions-words.txt: 87 | wget http://word2vec.googlecode.com/svn/trunk/questions-words.txt 88 | 89 | myz_naacl13_test_set.tgz: 90 | wget http://research.microsoft.com/en-us/um/people/gzweig/Pubs/myz_naacl13_test_set.tgz 91 | 92 | analogy: analogy.cc 93 | 94 | clean: 95 | rm -f *.ws.tab *.an.tab analogy *.pyc 96 | 97 | distclean: clean 98 | rm -f *.tgz *.tar.gz *.zip Mtruk.csv questions-words.txt 99 | -------------------------------------------------------------------------------- /swivel/fastprep.mk: -------------------------------------------------------------------------------- 1 | # -*- Mode: Makefile -*- 2 | 3 | # 4 | # Copyright 2016 Google Inc. All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | 20 | # This makefile builds "fastprep", a faster version of prep.py that can be used 21 | # to build training data for Swivel. Building "fastprep" is a bit more 22 | # involved: you'll need to pull and build the Tensorflow source, and then build 23 | # and install compatible protobuf software. We've tested this with Tensorflow 24 | # version 0.7. 25 | # 26 | # = Step 1. Pull and Build Tensorflow. = 27 | # 28 | # These instructions are somewhat abridged; for pre-requisites and the most 29 | # up-to-date instructions, refer to: 30 | # 31 | # 32 | # 33 | # To build the Tensorflow components required for "fastpret", you'll need to 34 | # install Bazel, Numpy, Swig, and Python development headers as described in at 35 | # the above URL. Run the "configure" script as appropriate for your 36 | # environment and then build the "build_pip_package" target: 37 | # 38 | # bazel build -c opt [--config=cuda] //tensorflow/tools/pip_package:build_pip_package 39 | # 40 | # This will generate the Tensorflow headers and libraries necessary for 41 | # "fastprep". 42 | # 43 | # 44 | # = Step 2. Build and Install Compatible Protobuf Libraries = 45 | # 46 | # "fastprep" also needs compatible protocol buffer libraries, which you can 47 | # build from the protobuf implementation included with the Tensorflow 48 | # distribution: 49 | # 50 | # cd ${TENSORFLOW_SRCDIR}/google/protobuf 51 | # ./autogen.sh 52 | # ./configure --prefix=${HOME} # ...or whatever 53 | # make 54 | # make install # ...or maybe "sudo make install" 55 | # 56 | # This will install the headers and libraries appropriately. 57 | # 58 | # 59 | # = Step 3. Build "fastprep". = 60 | # 61 | # Finally modify this file (if necessary) to update PB_DIR and TF_DIR to refer 62 | # to appropriate locations, and: 63 | # 64 | # make -f fastprep.mk 65 | # 66 | # If all goes well, you should have a program that is "flag compatible" with 67 | # "prep.py" and runs significantly faster. Use it to generate the co-occurrence 68 | # matrices and other files necessary to train a Swivel matrix. 69 | 70 | 71 | # The root directory where the Google Protobuf software is installed. 72 | # Alternative locations might be "/usr" or "/usr/local". 73 | PB_DIR=$(HOME) 74 | 75 | # Assuming you've got the Tensorflow source unpacked and built in ${HOME}/src: 76 | TF_DIR=$(HOME)/src/tensorflow 77 | 78 | PB_INCLUDE=$(PB_DIR)/include 79 | TF_INCLUDE=$(TF_DIR)/bazel-genfiles 80 | CXXFLAGS=-std=c++11 -m64 -mavx -g -Ofast -Wall -I$(TF_INCLUDE) -I$(PB_INCLUDE) 81 | 82 | PB_LIBDIR=$(PB_DIR)/lib 83 | TF_LIBDIR=$(TF_DIR)/bazel-bin/tensorflow/core 84 | LDFLAGS=-L$(TF_LIBDIR) -L$(PB_LIBDIR) 85 | LDLIBS=-lprotos_all_cc -lprotobuf -lpthread -lm 86 | 87 | fastprep: fastprep.cc 88 | -------------------------------------------------------------------------------- /swivel/nearest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2016 Google Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Simple tool for inspecting nearest neighbors and analogies.""" 18 | 19 | import re 20 | import sys 21 | from getopt import GetoptError, getopt 22 | 23 | from vecs import Vecs 24 | 25 | try: 26 | opts, args = getopt(sys.argv[1:], 'v:e:', ['vocab=', 'embeddings=']) 27 | except GetoptError, e: 28 | print >> sys.stderr, e 29 | sys.exit(2) 30 | 31 | opt_vocab = 'vocab.txt' 32 | opt_embeddings = None 33 | 34 | for o, a in opts: 35 | if o in ('-v', '--vocab'): 36 | opt_vocab = a 37 | if o in ('-e', '--embeddings'): 38 | opt_embeddings = a 39 | 40 | vecs = Vecs(opt_vocab, opt_embeddings) 41 | 42 | while True: 43 | sys.stdout.write('query> ') 44 | sys.stdout.flush() 45 | 46 | query = sys.stdin.readline().strip() 47 | if not query: 48 | break 49 | 50 | parts = re.split(r'\s+', query) 51 | 52 | if len(parts) == 1: 53 | res = vecs.neighbors(parts[0]) 54 | 55 | elif len(parts) == 3: 56 | vs = [vecs.lookup(w) for w in parts] 57 | if any(v is None for v in vs): 58 | print 'not in vocabulary: %s' % ( 59 | ', '.join(tok for tok, v in zip(parts, vs) if v is None)) 60 | 61 | continue 62 | 63 | res = vecs.neighbors(vs[2] - vs[0] + vs[1]) 64 | 65 | else: 66 | print 'use a single word to query neighbors, or three words for analogy' 67 | continue 68 | 69 | if not res: 70 | continue 71 | 72 | for word, sim in res[:20]: 73 | print '%0.4f: %s' % (sim, word) 74 | 75 | print 76 | -------------------------------------------------------------------------------- /swivel/text2bin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2016 Google Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Converts vectors from text to a binary format for quicker manipulation. 18 | 19 | Usage: 20 | 21 | text2bin.py -o -v vec1.txt [vec2.txt ...] 22 | 23 | Optiona: 24 | 25 | -o , --output 26 | The name of the file into which the binary vectors are written. 27 | 28 | -v , --vocab 29 | The name of the file into which the vocabulary is written. 30 | 31 | Description 32 | 33 | This program merges one or more whitespace separated vector files into a single 34 | binary vector file that can be used by downstream evaluation tools in this 35 | directory ("wordsim.py" and "analogy"). 36 | 37 | If more than one vector file is specified, then the files must be aligned 38 | row-wise (i.e., each line must correspond to the same embedding), and they must 39 | have the same number of columns (i.e., be the same dimension). 40 | 41 | """ 42 | 43 | from itertools import izip 44 | from getopt import GetoptError, getopt 45 | import os 46 | import struct 47 | import sys 48 | 49 | try: 50 | opts, args = getopt( 51 | sys.argv[1:], 'o:v:', ['output=', 'vocab=']) 52 | except GetoptError, e: 53 | print >> sys.stderr, e 54 | sys.exit(2) 55 | 56 | opt_output = 'vecs.bin' 57 | opt_vocab = 'vocab.txt' 58 | for o, a in opts: 59 | if o in ('-o', '--output'): 60 | opt_output = a 61 | if o in ('-v', '--vocab'): 62 | opt_vocab = a 63 | 64 | def go(fhs): 65 | fmt = None 66 | with open(opt_vocab, 'w') as vocab_out: 67 | with open(opt_output, 'w') as vecs_out: 68 | for lines in izip(*fhs): 69 | parts = [line.split() for line in lines] 70 | token = parts[0][0] 71 | if any(part[0] != token for part in parts[1:]): 72 | raise IOError('vector files must be aligned') 73 | 74 | print >> vocab_out, token 75 | 76 | vec = [sum(float(x) for x in xs) for xs in zip(*parts)[1:]] 77 | if not fmt: 78 | fmt = struct.Struct('%df' % len(vec)) 79 | 80 | vecs_out.write(fmt.pack(*vec)) 81 | 82 | if args: 83 | fhs = [open(filename) for filename in args] 84 | go(fhs) 85 | for fh in fhs: 86 | fh.close() 87 | else: 88 | go([sys.stdin]) 89 | -------------------------------------------------------------------------------- /swivel/vecs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import mmap 16 | import numpy as np 17 | import os 18 | import struct 19 | 20 | class Vecs(object): 21 | def __init__(self, vocab_filename, rows_filename, cols_filename=None): 22 | """Initializes the vectors from a text vocabulary and binary data.""" 23 | with open(vocab_filename, 'r') as lines: 24 | self.vocab = [line.split()[0] for line in lines] 25 | self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)} 26 | 27 | n = len(self.vocab) 28 | 29 | with open(rows_filename, 'r') as rows_fh: 30 | rows_fh.seek(0, os.SEEK_END) 31 | size = rows_fh.tell() 32 | 33 | # Make sure that the file size seems reasonable. 34 | if size % (4 * n) != 0: 35 | raise IOError( 36 | 'unexpected file size for binary vector file %s' % rows_filename) 37 | 38 | # Memory map the rows. 39 | dim = size / (4 * n) 40 | rows_mm = mmap.mmap(rows_fh.fileno(), 0, prot=mmap.PROT_READ) 41 | rows = np.matrix( 42 | np.frombuffer(rows_mm, dtype=np.float32).reshape(n, dim)) 43 | 44 | # If column vectors were specified, then open them and add them to the row 45 | # vectors. 46 | if cols_filename: 47 | with open(cols_filename, 'r') as cols_fh: 48 | cols_mm = mmap.mmap(cols_fh.fileno(), 0, prot=mmap.PROT_READ) 49 | cols_fh.seek(0, os.SEEK_END) 50 | if cols_fh.tell() != size: 51 | raise IOError('row and column vector files have different sizes') 52 | 53 | cols = np.matrix( 54 | np.frombuffer(cols_mm, dtype=np.float32).reshape(n, dim)) 55 | 56 | rows += cols 57 | cols_mm.close() 58 | 59 | # Normalize so that dot products are just cosine similarity. 60 | self.vecs = rows / np.linalg.norm(rows, axis=1).reshape(n, 1) 61 | rows_mm.close() 62 | 63 | def similarity(self, word1, word2): 64 | """Computes the similarity of two tokens.""" 65 | idx1 = self.word_to_idx.get(word1) 66 | idx2 = self.word_to_idx.get(word2) 67 | if not idx1 or not idx2: 68 | return None 69 | 70 | return float(self.vecs[idx1] * self.vecs[idx2].transpose()) 71 | 72 | def neighbors(self, query): 73 | """Returns the nearest neighbors to the query (a word or vector).""" 74 | if isinstance(query, basestring): 75 | idx = self.word_to_idx.get(query) 76 | if idx is None: 77 | return None 78 | 79 | query = self.vecs[idx] 80 | 81 | neighbors = self.vecs * query.transpose() 82 | 83 | return sorted( 84 | zip(self.vocab, neighbors.flat), 85 | key=lambda kv: kv[1], reverse=True) 86 | 87 | def lookup(self, word): 88 | """Returns the embedding for a token, or None if no embedding exists.""" 89 | idx = self.word_to_idx.get(word) 90 | return None if idx is None else self.vecs[idx] 91 | -------------------------------------------------------------------------------- /swivel/wordsim.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2016 Google Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Computes Spearman's rho with respect to human judgements. 18 | 19 | Given a set of row (and potentially column) embeddings, this computes Spearman's 20 | rho between the rank ordering of predicted word similarity and human judgements. 21 | 22 | Usage: 23 | 24 | wordim.py --embeddings= --vocab= eval1.tab eval2.tab ... 25 | 26 | Options: 27 | 28 | --embeddings=: the vectors to test 29 | --vocab=: the vocabulary file 30 | 31 | Evaluation files are assumed to be tab-separated files with exactly three 32 | columns. The first two columns contain the words, and the third column contains 33 | the scored human judgement. 34 | 35 | """ 36 | 37 | import scipy.stats 38 | import sys 39 | from getopt import GetoptError, getopt 40 | 41 | from vecs import Vecs 42 | 43 | try: 44 | opts, args = getopt(sys.argv[1:], '', ['embeddings=', 'vocab=']) 45 | except GetoptError, e: 46 | print >> sys.stderr, e 47 | sys.exit(2) 48 | 49 | opt_embeddings = None 50 | opt_vocab = None 51 | 52 | for o, a in opts: 53 | if o == '--embeddings': 54 | opt_embeddings = a 55 | if o == '--vocab': 56 | opt_vocab = a 57 | 58 | if not opt_vocab: 59 | print >> sys.stderr, 'please specify a vocabulary file with "--vocab"' 60 | sys.exit(2) 61 | 62 | if not opt_embeddings: 63 | print >> sys.stderr, 'please specify the embeddings with "--embeddings"' 64 | sys.exit(2) 65 | 66 | try: 67 | vecs = Vecs(opt_vocab, opt_embeddings) 68 | except IOError, e: 69 | print >> sys.stderr, e 70 | sys.exit(1) 71 | 72 | def evaluate(lines): 73 | acts, preds = [], [] 74 | 75 | with open(filename, 'r') as lines: 76 | for line in lines: 77 | w1, w2, act = line.strip().split('\t') 78 | pred = vecs.similarity(w1, w2) 79 | if pred is None: 80 | continue 81 | 82 | acts.append(float(act)) 83 | preds.append(pred) 84 | 85 | rho, _ = scipy.stats.spearmanr(acts, preds) 86 | return rho 87 | 88 | for filename in args: 89 | with open(filename, 'r') as lines: 90 | print '%0.3f %s' % (evaluate(lines), filename) 91 | -------------------------------------------------------------------------------- /syntaxnet/.gitignore: -------------------------------------------------------------------------------- 1 | /bazel-bin 2 | /bazel-genfiles 3 | /bazel-out 4 | /bazel-tensorflow 5 | /bazel-testlogs 6 | /bazel-tf 7 | /bazel-syntaxnet 8 | -------------------------------------------------------------------------------- /syntaxnet/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM java:8 2 | 3 | ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin 4 | 5 | RUN mkdir -p $SYNTAXNETDIR \ 6 | && cd $SYNTAXNETDIR \ 7 | && apt-get update \ 8 | && apt-get install git zlib1g-dev file swig python2.7 python-dev python-pip python-mock -y \ 9 | && pip install --upgrade pip \ 10 | && pip install -U protobuf==3.0.0 \ 11 | && pip install asciitree \ 12 | && pip install numpy \ 13 | && wget https://github.com/bazelbuild/bazel/releases/download/0.3.1/bazel-0.3.1-installer-linux-x86_64.sh \ 14 | && chmod +x bazel-0.3.1-installer-linux-x86_64.sh \ 15 | && ./bazel-0.3.1-installer-linux-x86_64.sh --user \ 16 | && git clone --recursive https://github.com/tensorflow/models.git \ 17 | && cd $SYNTAXNETDIR/models/syntaxnet/tensorflow \ 18 | && echo "\n\n\n\n" | ./configure \ 19 | && apt-get autoremove -y \ 20 | && apt-get clean 21 | 22 | RUN cd $SYNTAXNETDIR/models/syntaxnet \ 23 | && bazel test --genrule_strategy=standalone syntaxnet/... util/utf8/... 24 | 25 | WORKDIR $SYNTAXNETDIR/models/syntaxnet 26 | 27 | CMD [ "sh", "-c", "echo 'Bob brought the pizza to Alice.' | syntaxnet/demo.sh" ] 28 | 29 | # COMMANDS to build and run 30 | # =============================== 31 | # mkdir build && cp Dockerfile build/ && cd build 32 | # docker build -t syntaxnet . 33 | # docker run syntaxnet 34 | -------------------------------------------------------------------------------- /syntaxnet/WORKSPACE: -------------------------------------------------------------------------------- 1 | local_repository( 2 | name = "org_tensorflow", 3 | path = "tensorflow", 4 | ) 5 | 6 | load('@org_tensorflow//tensorflow:workspace.bzl', 'tf_workspace') 7 | tf_workspace() 8 | 9 | # Specify the minimum required Bazel version. 10 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "check_version") 11 | check_version("0.3.0") 12 | -------------------------------------------------------------------------------- /syntaxnet/beam_search_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/syntaxnet/beam_search_training.png -------------------------------------------------------------------------------- /syntaxnet/ff_nn_schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/syntaxnet/ff_nn_schematic.png -------------------------------------------------------------------------------- /syntaxnet/looping-parser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/syntaxnet/looping-parser.gif -------------------------------------------------------------------------------- /syntaxnet/sawman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/syntaxnet/sawman.png -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/base.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef SYNTAXNET_BASE_H_ 17 | #define SYNTAXNET_BASE_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include "tensorflow/core/lib/core/status.h" 25 | #include "tensorflow/core/lib/strings/strcat.h" 26 | #include "tensorflow/core/lib/strings/stringprintf.h" 27 | #include "tensorflow/core/platform/default/integral_types.h" 28 | #include "tensorflow/core/platform/mutex.h" 29 | #include "tensorflow/core/platform/protobuf.h" 30 | 31 | 32 | 33 | using tensorflow::int32; 34 | using tensorflow::int64; 35 | using tensorflow::uint64; 36 | using tensorflow::uint32; 37 | using tensorflow::uint32; 38 | using tensorflow::protobuf::TextFormat; 39 | using tensorflow::mutex_lock; 40 | using tensorflow::mutex; 41 | using std::map; 42 | using std::pair; 43 | using std::vector; 44 | using std::unordered_map; 45 | using std::unordered_set; 46 | typedef signed int char32; 47 | 48 | using tensorflow::StringPiece; 49 | using std::string; 50 | 51 | // namespace syntaxnet 52 | 53 | #endif // SYNTAXNET_BASE_H_ 54 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # A script that runs a tokenizer, a part-of-speech tagger and a dependency 18 | # parser on an English text file, with one sentence per line. 19 | # 20 | # Example usage: 21 | # echo "Parsey McParseface is my favorite parser!" | syntaxnet/demo.sh 22 | 23 | # To run on a conll formatted file, add the --conll command line argument. 24 | # 25 | 26 | PARSER_EVAL=bazel-bin/syntaxnet/parser_eval 27 | MODEL_DIR=syntaxnet/models/parsey_mcparseface 28 | [[ "$1" == "--conll" ]] && INPUT_FORMAT=stdin-conll || INPUT_FORMAT=stdin 29 | 30 | $PARSER_EVAL \ 31 | --input=$INPUT_FORMAT \ 32 | --output=stdout-conll \ 33 | --hidden_layer_sizes=64 \ 34 | --arg_prefix=brain_tagger \ 35 | --graph_builder=structured \ 36 | --task_context=$MODEL_DIR/context.pbtxt \ 37 | --model_path=$MODEL_DIR/tagger-params \ 38 | --slim_model \ 39 | --batch_size=1024 \ 40 | --alsologtostderr \ 41 | | \ 42 | $PARSER_EVAL \ 43 | --input=stdin-conll \ 44 | --output=stdout-conll \ 45 | --hidden_layer_sizes=512,512 \ 46 | --arg_prefix=brain_parser \ 47 | --graph_builder=structured \ 48 | --task_context=$MODEL_DIR/context.pbtxt \ 49 | --model_path=$MODEL_DIR/parser-params \ 50 | --slim_model \ 51 | --batch_size=1024 \ 52 | --alsologtostderr \ 53 | | \ 54 | bazel-bin/syntaxnet/conll2tree \ 55 | --task_context=$MODEL_DIR/context.pbtxt \ 56 | --alsologtostderr 57 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/dictionary.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffers for serializing string<=>index dictionaries. 2 | 3 | syntax = "proto2"; 4 | 5 | package syntaxnet; 6 | 7 | // Serializable representation of a string=>string pair. 8 | message StringToStringPair { 9 | // String representing the key. 10 | required string key = 1; 11 | 12 | // String representing the value. 13 | required string value = 2; 14 | } 15 | 16 | // Serializable representation of a string=>string mapping. 17 | message StringToStringMap { 18 | // Key=>value pairs. 19 | repeated StringToStringPair pair = 1; 20 | } 21 | 22 | // Affix table entry, for serialization of the affix tables. 23 | message AffixTableEntry { 24 | // Nested message for serializing a single affix. 25 | message AffixEntry { 26 | // The affix as a string. 27 | required string form = 1; 28 | 29 | // The length of the affix (this is non-trivial to compute due to UTF-8). 30 | required int32 length = 2; 31 | 32 | // The ID of the affix that is one character shorter, or -1 if none exists. 33 | required int32 shorter_id = 3; 34 | } 35 | 36 | // The type of affix table, as a string. 37 | required string type = 1; 38 | 39 | // The maximum affix length. 40 | required int32 max_length = 2; 41 | 42 | // The list of affixes, in order of affix ID. 43 | repeated AffixEntry affix = 3; 44 | } 45 | 46 | // A light-weight proto to store vectors in binary format. 47 | message TokenEmbedding { 48 | required bytes token = 1; // can be word or phrase, or URL, etc. 49 | 50 | // If available, raw count of this token in the training corpus. 51 | optional int64 count = 3; 52 | 53 | message Vector { 54 | repeated float values = 1 [packed = true]; 55 | } 56 | optional Vector vector = 2; 57 | }; 58 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/document_format.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "syntaxnet/document_format.h" 17 | 18 | namespace syntaxnet { 19 | 20 | // Component registry for document formatters. 21 | REGISTER_CLASS_REGISTRY("document format", DocumentFormat); 22 | 23 | } // namespace syntaxnet 24 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/document_format.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // An interface for document formats. 17 | 18 | #ifndef SYNTAXNET_DOCUMENT_FORMAT_H__ 19 | #define SYNTAXNET_DOCUMENT_FORMAT_H__ 20 | 21 | #include 22 | #include 23 | 24 | #include "syntaxnet/utils.h" 25 | #include "syntaxnet/registry.h" 26 | #include "syntaxnet/sentence.pb.h" 27 | #include "syntaxnet/task_context.h" 28 | #include "tensorflow/core/lib/io/buffered_inputstream.h" 29 | 30 | namespace syntaxnet { 31 | 32 | // A document format component converts a key/value pair from a record to one or 33 | // more documents. The record format is used for selecting the document format 34 | // component. A document format component can be registered with the 35 | // REGISTER_DOCUMENT_FORMAT macro. 36 | class DocumentFormat : public RegisterableClass { 37 | public: 38 | DocumentFormat() {} 39 | virtual ~DocumentFormat() {} 40 | 41 | virtual void Setup(TaskContext *context) {} 42 | 43 | // Reads a record from the given input buffer with format specific logic. 44 | // Returns false if no record could be read because we reached end of file. 45 | virtual bool ReadRecord(tensorflow::io::BufferedInputStream *buffer, 46 | string *record) = 0; 47 | 48 | // Converts a key/value pair to one or more documents. 49 | virtual void ConvertFromString(const string &key, const string &value, 50 | vector *documents) = 0; 51 | 52 | // Converts a document to a key/value pair. 53 | virtual void ConvertToString(const Sentence &document, 54 | string *key, string *value) = 0; 55 | 56 | private: 57 | TF_DISALLOW_COPY_AND_ASSIGN(DocumentFormat); 58 | }; 59 | 60 | #define REGISTER_DOCUMENT_FORMAT(type, component) \ 61 | REGISTER_CLASS_COMPONENT(DocumentFormat, type, component) 62 | 63 | } // namespace syntaxnet 64 | 65 | #endif // SYNTAXNET_DOCUMENT_FORMAT_H__ 66 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/feature_extractor.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffers for feature extractor. 2 | 3 | syntax = "proto2"; 4 | 5 | package syntaxnet; 6 | 7 | message Parameter { 8 | optional string name = 1; 9 | optional string value = 2; 10 | } 11 | 12 | // Descriptor for feature function. 13 | message FeatureFunctionDescriptor { 14 | // Feature function type. 15 | required string type = 1; 16 | 17 | // Feature function name. 18 | optional string name = 2; 19 | 20 | // Default argument for feature function. 21 | optional int32 argument = 3 [default = 0]; 22 | 23 | // Named parameters for feature descriptor. 24 | repeated Parameter parameter = 4; 25 | 26 | // Nested sub-feature function descriptors. 27 | repeated FeatureFunctionDescriptor feature = 7; 28 | }; 29 | 30 | // Descriptor for feature extractor. 31 | message FeatureExtractorDescriptor { 32 | // Top-level feature function for extractor. 33 | repeated FeatureFunctionDescriptor feature = 1; 34 | }; 35 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/load_parser_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Loads parser_ops shared library.""" 17 | 18 | import os.path 19 | import tensorflow as tf 20 | 21 | tf.load_op_library( 22 | os.path.join(tf.resource_loader.get_data_files_path(), 23 | 'parser_ops.so')) 24 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/models/parsey_mcparseface/fine-to-universal.map: -------------------------------------------------------------------------------- 1 | # . 2 | $ . 3 | '' . 4 | -LRB- . 5 | -RRB- . 6 | , . 7 | . . 8 | : . 9 | ADD X 10 | AFX PRT 11 | CC CONJ 12 | CD NUM 13 | DT DET 14 | EX DET 15 | FW X 16 | GW X 17 | HYPH . 18 | IN ADP 19 | JJ ADJ 20 | JJR ADJ 21 | JJS ADJ 22 | LS X 23 | MD VERB 24 | NFP . 25 | NN NOUN 26 | NNP NOUN 27 | NNPS NOUN 28 | NNS NOUN 29 | PDT DET 30 | POS PRT 31 | PRP PRON 32 | PRP$ PRON 33 | RB ADV 34 | RBR ADV 35 | RBS ADV 36 | RP PRT 37 | SYM X 38 | TO PRT 39 | UH X 40 | VB VERB 41 | VBD VERB 42 | VBG VERB 43 | VBN VERB 44 | VBP VERB 45 | VBZ VERB 46 | WDT DET 47 | WP PRON 48 | WP$ PRON 49 | WRB ADV 50 | `` . 51 | X X 52 | XX X 53 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/models/parsey_mcparseface/label-map: -------------------------------------------------------------------------------- 1 | 46 2 | punct 243160 3 | prep 194627 4 | pobj 186958 5 | det 170592 6 | nsubj 144821 7 | nn 144800 8 | amod 117242 9 | ROOT 90592 10 | dobj 88551 11 | aux 76523 12 | advmod 72893 13 | conj 59384 14 | cc 57532 15 | num 36350 16 | poss 35117 17 | dep 34986 18 | ccomp 29470 19 | cop 25991 20 | mark 25141 21 | xcomp 25111 22 | rcmod 16234 23 | auxpass 15740 24 | advcl 14996 25 | possessive 14866 26 | nsubjpass 14133 27 | pcomp 12488 28 | appos 11112 29 | partmod 11106 30 | neg 11090 31 | number 10658 32 | prt 7123 33 | quantmod 6653 34 | tmod 5418 35 | infmod 5134 36 | npadvmod 3213 37 | parataxis 3012 38 | mwe 2793 39 | expl 2712 40 | iobj 1642 41 | acomp 1632 42 | discourse 1381 43 | csubj 1225 44 | predet 1160 45 | preconj 749 46 | goeswith 146 47 | csubjpass 41 48 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/models/parsey_mcparseface/parser-params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/syntaxnet/syntaxnet/models/parsey_mcparseface/parser-params -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/models/parsey_mcparseface/prefix-table: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/syntaxnet/syntaxnet/models/parsey_mcparseface/prefix-table -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/models/parsey_mcparseface/suffix-table: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/syntaxnet/syntaxnet/models/parsey_mcparseface/suffix-table -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/models/parsey_mcparseface/tag-map: -------------------------------------------------------------------------------- 1 | 49 2 | NN 285194 3 | IN 228165 4 | DT 179147 5 | NNP 175147 6 | JJ 125667 7 | NNS 115732 8 | , 97481 9 | . 85938 10 | RB 78513 11 | VB 63952 12 | CC 57554 13 | VBD 56635 14 | CD 55674 15 | PRP 55244 16 | VBZ 48126 17 | VBN 44458 18 | VBG 34524 19 | VBP 33669 20 | TO 28772 21 | MD 22364 22 | PRP$ 20706 23 | HYPH 18526 24 | POS 14905 25 | `` 12193 26 | '' 12154 27 | WDT 10267 28 | : 8713 29 | $ 7993 30 | WP 7336 31 | RP 7335 32 | WRB 6634 33 | JJR 6295 34 | NNPS 5917 35 | -RRB- 3904 36 | -LRB- 3840 37 | JJS 3596 38 | RBR 3186 39 | EX 2733 40 | UH 1521 41 | RBS 1467 42 | PDT 1271 43 | FW 928 44 | NFP 844 45 | SYM 652 46 | ADD 476 47 | LS 392 48 | WP$ 332 49 | GW 184 50 | AFX 42 51 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/models/parsey_mcparseface/tagger-params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/syntaxnet/syntaxnet/models/parsey_mcparseface/tagger-params -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/models/parsey_universal/context-tokenize-zh.pbtxt: -------------------------------------------------------------------------------- 1 | Parameter { 2 | name: "brain_tokenizer_zh_embedding_dims" 3 | value: "32;32" 4 | } 5 | Parameter { 6 | name: "brain_tokenizer_zh_embedding_names" 7 | value: "chars;words" 8 | } 9 | Parameter { 10 | name: "brain_tokenizer_zh_features" 11 | value: "input.char " 12 | "input(1).char " 13 | "input(2).char " 14 | "input(3).char " 15 | "input(-1).char " 16 | "input(-2).char " 17 | "input(-3).char " 18 | "stack.char " 19 | "stack.offset(1).char " 20 | "stack.offset(-1).char " 21 | "stack(1).char " 22 | "stack(1).offset(1).char " 23 | "stack(1).offset(-1).char " 24 | "stack(2).char; " 25 | "last-word(1,min-freq=2) " 26 | "last-word(2,min-freq=2) " 27 | "last-word(3,min-freq=2)" 28 | } 29 | Parameter { 30 | name: "brain_tokenizer_zh_transition_system" 31 | value: "binary-segment-transitions" 32 | } 33 | input { 34 | name: "word-map" 35 | Part { 36 | file_pattern: "last-word-map" 37 | } 38 | } 39 | input { 40 | name: "char-map" 41 | Part { 42 | file_pattern: "char-map" 43 | } 44 | } 45 | input { 46 | name: "label-map" 47 | Part { 48 | file_pattern: "label-map" 49 | } 50 | } 51 | input { 52 | name: 'stdin-untoken' 53 | record_format: 'untokenized-text' 54 | Part { 55 | file_pattern: '-' 56 | } 57 | } 58 | input { 59 | name: 'stdout-conll' 60 | record_format: 'conll-sentence' 61 | Part { 62 | file_pattern: '-' 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/models/parsey_universal/parse.sh: -------------------------------------------------------------------------------- 1 | # A script that runs a morphological analyzer, a part-of-speech tagger and a 2 | # dependency parser on a text file, with one sentence per line. 3 | # 4 | # Example usage: 5 | # bazel build syntaxnet:parser_eval 6 | # cat sentences.txt | 7 | # syntaxnet/models/parsey_universal/parse.sh \ 8 | # $MODEL_DIRECTORY > output.conll 9 | # 10 | # To run on a conll formatted file, add the --conll command line argument: 11 | # cat sentences.conll | 12 | # syntaxnet/models/parsey_universal/parse.sh \ 13 | # --conll $MODEL_DIRECTORY > output.conll 14 | # 15 | # Models can be downloaded from 16 | # http://download.tensorflow.org/models/parsey_universal/.zip 17 | # for the languages listed at 18 | # https://github.com/tensorflow/models/blob/master/syntaxnet/universal.md 19 | # 20 | 21 | PARSER_EVAL=bazel-bin/syntaxnet/parser_eval 22 | CONTEXT=syntaxnet/models/parsey_universal/context.pbtxt 23 | if [[ "$1" == "--conll" ]]; then 24 | INPUT_FORMAT=stdin-conll 25 | shift 26 | else 27 | INPUT_FORMAT=stdin 28 | fi 29 | MODEL_DIR=$1 30 | 31 | $PARSER_EVAL \ 32 | --input=$INPUT_FORMAT \ 33 | --output=stdout-conll \ 34 | --hidden_layer_sizes=64 \ 35 | --arg_prefix=brain_morpher \ 36 | --graph_builder=structured \ 37 | --task_context=$CONTEXT \ 38 | --resource_dir=$MODEL_DIR \ 39 | --model_path=$MODEL_DIR/morpher-params \ 40 | --slim_model \ 41 | --batch_size=1024 \ 42 | --alsologtostderr \ 43 | | \ 44 | $PARSER_EVAL \ 45 | --input=stdin-conll \ 46 | --output=stdout-conll \ 47 | --hidden_layer_sizes=64 \ 48 | --arg_prefix=brain_tagger \ 49 | --graph_builder=structured \ 50 | --task_context=$CONTEXT \ 51 | --resource_dir=$MODEL_DIR \ 52 | --model_path=$MODEL_DIR/tagger-params \ 53 | --slim_model \ 54 | --batch_size=1024 \ 55 | --alsologtostderr \ 56 | | \ 57 | $PARSER_EVAL \ 58 | --input=stdin-conll \ 59 | --output=stdout-conll \ 60 | --hidden_layer_sizes=512,512 \ 61 | --arg_prefix=brain_parser \ 62 | --graph_builder=structured \ 63 | --task_context=$CONTEXT \ 64 | --resource_dir=$MODEL_DIR \ 65 | --model_path=$MODEL_DIR/parser-params \ 66 | --slim_model \ 67 | --batch_size=1024 \ 68 | --alsologtostderr 69 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/models/parsey_universal/tokenize.sh: -------------------------------------------------------------------------------- 1 | # A script that runs a tokenizer on a text file with one sentence per line. 2 | # 3 | # Example usage: 4 | # bazel build syntaxnet:parser_eval 5 | # cat untokenized-sentences.txt | 6 | # syntaxnet/models/parsey_universal/tokenize.sh \ 7 | # $MODEL_DIRECTORY > output.conll 8 | # 9 | # Models can be downloaded from 10 | # http://download.tensorflow.org/models/parsey_universal/.zip 11 | # for the languages listed at 12 | # https://github.com/tensorflow/models/blob/master/syntaxnet/universal.md 13 | # 14 | 15 | PARSER_EVAL=bazel-bin/syntaxnet/parser_eval 16 | CONTEXT=syntaxnet/models/parsey_universal/context.pbtxt 17 | INPUT_FORMAT=stdin-untoken 18 | MODEL_DIR=$1 19 | 20 | $PARSER_EVAL \ 21 | --input=$INPUT_FORMAT \ 22 | --output=stdin-untoken \ 23 | --hidden_layer_sizes=128,128 \ 24 | --arg_prefix=brain_tokenizer \ 25 | --graph_builder=greedy \ 26 | --task_context=$CONTEXT \ 27 | --resource_dir=$MODEL_DIR \ 28 | --model_path=$MODEL_DIR/tokenizer-params \ 29 | --batch_size=32 \ 30 | --alsologtostderr \ 31 | --slim_model 32 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/models/parsey_universal/tokenize_zh.sh: -------------------------------------------------------------------------------- 1 | # A script that runs a traditional Chinese tokenizer on a text file with one 2 | # sentence per line. 3 | # 4 | # Example usage: 5 | # bazel build syntaxnet:parser_eval 6 | # cat untokenized-sentences.txt | 7 | # syntaxnet/models/parsey_universal/tokenize_zh.sh \ 8 | # $MODEL_DIRECTORY > output.conll 9 | # 10 | # The traditional Chinese model can be downloaded from 11 | # http://download.tensorflow.org/models/parsey_universal/Chinese.zip 12 | # 13 | 14 | PARSER_EVAL=bazel-bin/syntaxnet/parser_eval 15 | CONTEXT=syntaxnet/models/parsey_universal/context-tokenize-zh.pbtxt 16 | INPUT_FORMAT=stdin-untoken 17 | MODEL_DIR=$1 18 | 19 | $PARSER_EVAL \ 20 | --input=$INPUT_FORMAT \ 21 | --output=stdin-untoken \ 22 | --hidden_layer_sizes=256,256 \ 23 | --arg_prefix=brain_tokenizer_zh \ 24 | --graph_builder=structured \ 25 | --task_context=$CONTEXT \ 26 | --resource_dir=$MODEL_DIR \ 27 | --model_path=$MODEL_DIR/tokenizer-params \ 28 | --batch_size=1024 \ 29 | --alsologtostderr \ 30 | --slim_model 31 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/morphology_label_set.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "syntaxnet/morphology_label_set.h" 17 | 18 | namespace syntaxnet { 19 | 20 | const char MorphologyLabelSet::kSeparator[] = "\t"; 21 | 22 | int MorphologyLabelSet::Add(const TokenMorphology &morph) { 23 | string repr = StringForMatch(morph); 24 | auto it = fast_lookup_.find(repr); 25 | if (it != fast_lookup_.end()) return it->second; 26 | fast_lookup_[repr] = label_set_.size(); 27 | label_set_.push_back(morph); 28 | return label_set_.size() - 1; 29 | } 30 | 31 | // Look up an existing TokenMorphology. If it is not present, return -1. 32 | int MorphologyLabelSet::LookupExisting(const TokenMorphology &morph) const { 33 | string repr = StringForMatch(morph); 34 | auto it = fast_lookup_.find(repr); 35 | if (it != fast_lookup_.end()) return it->second; 36 | return -1; 37 | } 38 | 39 | // Return the TokenMorphology at position i. The input i should be in the range 40 | // 0..size(). 41 | const TokenMorphology &MorphologyLabelSet::Lookup(int i) const { 42 | CHECK_GE(i, 0); 43 | CHECK_LT(i, label_set_.size()); 44 | return label_set_[i]; 45 | } 46 | 47 | void MorphologyLabelSet::Read(const string &filename) { 48 | ProtoRecordReader reader(filename); 49 | Read(&reader); 50 | } 51 | 52 | void MorphologyLabelSet::Read(ProtoRecordReader *reader) { 53 | TokenMorphology morph; 54 | while (reader->Read(&morph).ok()) { 55 | CHECK_EQ(-1, LookupExisting(morph)); 56 | Add(morph); 57 | } 58 | } 59 | 60 | void MorphologyLabelSet::Write(const string &filename) const { 61 | ProtoRecordWriter writer(filename); 62 | Write(&writer); 63 | } 64 | 65 | void MorphologyLabelSet::Write(ProtoRecordWriter *writer) const { 66 | for (const TokenMorphology &morph : label_set_) { 67 | writer->Write(morph); 68 | } 69 | } 70 | 71 | string MorphologyLabelSet::StringForMatch(const TokenMorphology &morph) const { 72 | vector attributes; 73 | for (const auto &a : morph.attribute()) { 74 | attributes.push_back( 75 | tensorflow::strings::StrCat(a.name(), kSeparator, a.value())); 76 | } 77 | std::sort(attributes.begin(), attributes.end()); 78 | return utils::Join(attributes, kSeparator); 79 | } 80 | 81 | string FullLabelFeatureType::GetFeatureValueName(FeatureValue value) const { 82 | const TokenMorphology &morph = label_set_->Lookup(value); 83 | vector attributes; 84 | for (const auto &a : morph.attribute()) { 85 | attributes.push_back(tensorflow::strings::StrCat(a.name(), ":", a.value())); 86 | } 87 | std::sort(attributes.begin(), attributes.end()); 88 | return utils::Join(attributes, ","); 89 | } 90 | 91 | } // namespace syntaxnet 92 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/parser_transitions.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "syntaxnet/parser_transitions.h" 17 | 18 | #include "syntaxnet/parser_state.h" 19 | 20 | namespace syntaxnet { 21 | 22 | // Transition system registry. 23 | REGISTER_CLASS_REGISTRY("transition system", ParserTransitionSystem); 24 | 25 | void ParserTransitionSystem::PerformAction(ParserAction action, 26 | ParserState *state) const { 27 | PerformActionWithoutHistory(action, state); 28 | } 29 | 30 | } // namespace syntaxnet 31 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/registry.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "syntaxnet/registry.h" 17 | 18 | namespace syntaxnet { 19 | 20 | // Global list of all component registries. 21 | RegistryMetadata *global_registry_list = nullptr; 22 | 23 | void RegistryMetadata::Register(RegistryMetadata *registry) { 24 | registry->set_link(global_registry_list); 25 | global_registry_list = registry; 26 | } 27 | 28 | } // namespace syntaxnet 29 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/segmenter_utils.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "syntaxnet/segmenter_utils.h" 17 | #include "util/utf8/unicodetext.h" 18 | #include "util/utf8/unilib.h" 19 | #include "util/utf8/unilib_utf8_utils.h" 20 | 21 | namespace syntaxnet { 22 | 23 | // Separators, code Zs from http://www.unicode.org/Public/UNIDATA/PropList.txt 24 | // NB: This list is not necessarily exhaustive. 25 | const std::unordered_set SegmenterUtils::kBreakChars({ 26 | 0x2028, // line separator 27 | 0x2029, // paragraph separator 28 | 0x0020, // space 29 | 0x00a0, // no-break space 30 | 0x1680, // Ogham space mark 31 | 0x180e, // Mongolian vowel separator 32 | 0x202f, // narrow no-break space 33 | 0x205f, // medium mathematical space 34 | 0x3000, // ideographic space 35 | 0xe5e5, // Google addition 36 | 0x2000, 0x2001, 0x2002, 0x2003, 0x2004, 0x2005, 0x2006, 0x2007, 0x2008, 37 | 0x2009, 0x200a 38 | }); 39 | 40 | void SegmenterUtils::GetUTF8Chars(const string &text, 41 | vector *chars) { 42 | const char *start = text.c_str(); 43 | const char *end = text.c_str() + text.size(); 44 | while (start < end) { 45 | int char_length = UniLib::OneCharLen(start); 46 | chars->emplace_back(start, char_length); 47 | start += char_length; 48 | } 49 | } 50 | 51 | void SegmenterUtils::SetCharsAsTokens( 52 | const string &text, 53 | const vector &chars, 54 | Sentence *sentence) { 55 | sentence->clear_token(); 56 | sentence->set_text(text); 57 | for (int i = 0; i < chars.size(); ++i) { 58 | Token *tok = sentence->add_token(); 59 | tok->set_word(chars[i].ToString()); // NOLINT 60 | int start_byte, end_byte; 61 | GetCharStartEndBytes(text, chars[i], &start_byte, &end_byte); 62 | tok->set_start(start_byte); 63 | tok->set_end(end_byte); 64 | } 65 | } 66 | 67 | bool SegmenterUtils::IsValidSegment(const Sentence &sentence, 68 | const Token &token) { 69 | // Check that the token is not empty, both by string and by bytes. 70 | if (token.word().empty()) return false; 71 | if (token.start() > token.end()) return false; 72 | 73 | // Check token boudaries inside of text. 74 | if (token.start() < 0) return false; 75 | if (token.end() >= sentence.text().size()) return false; 76 | 77 | // Check that token string is valid UTF8, by bytes. 78 | const char s = sentence.text()[token.start()]; 79 | const char e = sentence.text()[token.end() + 1]; 80 | if (UniLib::IsTrailByte(s)) return false; 81 | if (UniLib::IsTrailByte(e)) return false; 82 | return true; 83 | } 84 | 85 | } // namespace syntaxnet 86 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/sentence.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffer specification for document analysis. 2 | 3 | syntax = "proto2"; 4 | 5 | package syntaxnet; 6 | 7 | // A Sentence contains the raw text contents of a sentence, as well as an 8 | // analysis. 9 | message Sentence { 10 | // Identifier for document. 11 | optional string docid = 1; 12 | 13 | // Raw text contents of the sentence. 14 | optional string text = 2; 15 | 16 | // Tokenization of the sentence. 17 | repeated Token token = 3; 18 | 19 | extensions 1000 to max; 20 | } 21 | 22 | // A document token marks a span of bytes in the document text as a token 23 | // or word. 24 | message Token { 25 | // Token word form. 26 | required string word = 1; 27 | 28 | // Start position of token in text. 29 | required int32 start = 2; 30 | 31 | // End position of token in text. Gives index of last byte, not one past 32 | // the last byte. If token came from lexer, excludes any trailing HTML tags. 33 | required int32 end = 3; 34 | 35 | // Head of this token in the dependency tree: the id of the token which has an 36 | // arc going to this one. If it is the root token of a sentence, then it is 37 | // set to -1. 38 | optional int32 head = 4 [default = -1]; 39 | 40 | // Part-of-speech tag for token. 41 | optional string tag = 5; 42 | 43 | // Coarse-grained word category for token. 44 | optional string category = 6; 45 | 46 | // Label for dependency relation between this token and its head. 47 | optional string label = 7; 48 | 49 | // Break level for tokens that indicates how it was separated from the 50 | // previous token in the text. 51 | enum BreakLevel { 52 | NO_BREAK = 0; // No separation between tokens. 53 | SPACE_BREAK = 1; // Tokens separated by space. 54 | LINE_BREAK = 2; // Tokens separated by line break. 55 | SENTENCE_BREAK = 3; // Tokens separated by sentence break. 56 | } 57 | 58 | optional BreakLevel break_level = 8 [default = SPACE_BREAK]; 59 | 60 | extensions 1000 to max; 61 | } 62 | 63 | // Stores information about the morphology of a token. 64 | message TokenMorphology { 65 | extend Token { 66 | optional TokenMorphology morphology = 63949837; 67 | } 68 | 69 | // Morphology is represented by a set of attribute values. 70 | message Attribute { 71 | required string name = 1; 72 | required string value = 2; 73 | } 74 | // This attribute field is designated to hold a single disambiguated analysis. 75 | repeated Attribute attribute = 3; 76 | }; 77 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/sentence_batch.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "syntaxnet/sentence_batch.h" 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include "syntaxnet/task_context.h" 23 | 24 | namespace syntaxnet { 25 | 26 | void SentenceBatch::Init(TaskContext *context) { 27 | reader_.reset(new TextReader(*context->GetInput(input_name_), context)); 28 | size_ = 0; 29 | } 30 | 31 | bool SentenceBatch::AdvanceSentence(int index) { 32 | if (sentences_[index] == nullptr) ++size_; 33 | sentences_[index].reset(); 34 | std::unique_ptr sentence(reader_->Read()); 35 | if (sentence == nullptr) { 36 | --size_; 37 | return false; 38 | } 39 | 40 | // Preprocess the new sentence for the parser state. 41 | sentences_[index] = std::move(sentence); 42 | return true; 43 | } 44 | 45 | } // namespace syntaxnet 46 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/sentence_batch.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef SYNTAXNET_SENTENCE_BATCH_H_ 17 | #define SYNTAXNET_SENTENCE_BATCH_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "syntaxnet/embedding_feature_extractor.h" 25 | #include "syntaxnet/feature_extractor.h" 26 | #include "syntaxnet/parser_state.h" 27 | #include "syntaxnet/parser_transitions.h" 28 | #include "syntaxnet/sentence.pb.h" 29 | #include "syntaxnet/sparse.pb.h" 30 | #include "syntaxnet/task_context.h" 31 | #include "syntaxnet/task_spec.pb.h" 32 | #include "syntaxnet/term_frequency_map.h" 33 | 34 | namespace syntaxnet { 35 | 36 | // Helper class to manage generating batches of preprocessed ParserState objects 37 | // by reading in multiple sentences in parallel. 38 | class SentenceBatch { 39 | public: 40 | SentenceBatch(int batch_size, string input_name) 41 | : batch_size_(batch_size), 42 | input_name_(std::move(input_name)), 43 | sentences_(batch_size) {} 44 | 45 | // Initializes all resources and opens the corpus file. 46 | void Init(TaskContext *context); 47 | 48 | // Advances the index'th sentence in the batch to the next sentence. This will 49 | // create and preprocess a new ParserState for that element. Returns false if 50 | // EOF is reached (if EOF, also sets the state to be nullptr.) 51 | bool AdvanceSentence(int index); 52 | 53 | // Rewinds the corpus reader. 54 | void Rewind() { reader_->Reset(); } 55 | 56 | int size() const { return size_; } 57 | 58 | Sentence *sentence(int index) { return sentences_[index].get(); } 59 | 60 | private: 61 | // Running tally of non-nullptr states in the batch. 62 | int size_; 63 | 64 | // Maximum number of states in the batch. 65 | int batch_size_; 66 | 67 | // Input to read from the TaskContext. 68 | string input_name_; 69 | 70 | // Reader for the corpus. 71 | std::unique_ptr reader_; 72 | 73 | // Batch: Sentence objects. 74 | std::vector> sentences_; 75 | }; 76 | 77 | } // namespace syntaxnet 78 | 79 | #endif // SYNTAXNET_SENTENCE_BATCH_H_ 80 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/shared_store.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "syntaxnet/shared_store.h" 17 | 18 | #include 19 | 20 | #include "tensorflow/core/lib/strings/stringprintf.h" 21 | 22 | namespace syntaxnet { 23 | 24 | SharedStore::SharedObjectMap *SharedStore::shared_object_map_ = 25 | new SharedObjectMap; 26 | 27 | mutex SharedStore::shared_object_map_mutex_(tensorflow::LINKER_INITIALIZED); 28 | 29 | SharedStore::SharedObjectMap *SharedStore::shared_object_map() { 30 | return shared_object_map_; 31 | } 32 | 33 | bool SharedStore::Release(const void *object) { 34 | if (object == nullptr) { 35 | return true; 36 | } 37 | mutex_lock l(shared_object_map_mutex_); 38 | for (SharedObjectMap::iterator it = shared_object_map()->begin(); 39 | it != shared_object_map()->end(); ++it) { 40 | if (it->second.object == object) { 41 | // Check the invariant that reference counts are positive. A violation 42 | // likely implies memory corruption. 43 | CHECK_GE(it->second.refcount, 1); 44 | it->second.refcount--; 45 | if (it->second.refcount == 0) { 46 | it->second.delete_callback(); 47 | shared_object_map()->erase(it); 48 | } 49 | return true; 50 | } 51 | } 52 | return false; 53 | } 54 | 55 | void SharedStore::Clear() { 56 | mutex_lock l(shared_object_map_mutex_); 57 | for (SharedObjectMap::iterator it = shared_object_map()->begin(); 58 | it != shared_object_map()->end(); ++it) { 59 | it->second.delete_callback(); 60 | } 61 | shared_object_map()->clear(); 62 | } 63 | 64 | string SharedStoreUtils::CreateDefaultName() { return string(); } 65 | 66 | string SharedStoreUtils::ToString(const string &input) { 67 | return ToString(tensorflow::StringPiece(input)); 68 | } 69 | 70 | string SharedStoreUtils::ToString(const char *input) { 71 | return ToString(tensorflow::StringPiece(input)); 72 | } 73 | 74 | string SharedStoreUtils::ToString(tensorflow::StringPiece input) { 75 | return tensorflow::strings::StrCat("\"", utils::CEscape(input.ToString()), 76 | "\""); 77 | } 78 | 79 | string SharedStoreUtils::ToString(bool input) { 80 | return input ? "true" : "false"; 81 | } 82 | 83 | string SharedStoreUtils::ToString(float input) { 84 | return tensorflow::strings::Printf("%af", input); 85 | } 86 | 87 | string SharedStoreUtils::ToString(double input) { 88 | return tensorflow::strings::Printf("%a", input); 89 | } 90 | 91 | } // namespace syntaxnet 92 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/sparse.proto: -------------------------------------------------------------------------------- 1 | // Protocol for passing around sparse sets of features. 2 | 3 | syntax = "proto2"; 4 | 5 | package syntaxnet; 6 | 7 | // A sparse set of features. 8 | // 9 | // If using SparseStringToIdTransformer, description is required and id should 10 | // be omitted; otherwise, id is required and description optional. 11 | // 12 | // id, weight, and description fields are all aligned if present (ie, any of 13 | // these that are non-empty should have the same # items). If weight is omitted, 14 | // 1.0 is used. 15 | message SparseFeatures { 16 | repeated uint64 id = 1; 17 | repeated float weight = 2; 18 | repeated string description = 3; 19 | }; 20 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/task_context.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef SYNTAXNET_TASK_CONTEXT_H_ 17 | #define SYNTAXNET_TASK_CONTEXT_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "syntaxnet/task_spec.pb.h" 23 | #include "syntaxnet/utils.h" 24 | 25 | namespace syntaxnet { 26 | 27 | // A task context holds configuration information for a task. It is basically a 28 | // wrapper around a TaskSpec protocol buffer. 29 | class TaskContext { 30 | public: 31 | // Returns the underlying task specification protocol buffer for the context. 32 | const TaskSpec &spec() const { return spec_; } 33 | TaskSpec *mutable_spec() { return &spec_; } 34 | 35 | // Returns a named input descriptor for the task. A new input is created if 36 | // the task context does not already have an input with that name. 37 | TaskInput *GetInput(const string &name); 38 | TaskInput *GetInput(const string &name, const string &file_format, 39 | const string &record_format); 40 | 41 | // Sets task parameter. 42 | void SetParameter(const string &name, const string &value); 43 | 44 | // Returns task parameter. If the parameter is not in the task configuration 45 | // the (default) value of the corresponding command line flag is returned. 46 | string GetParameter(const string &name) const; 47 | int GetIntParameter(const string &name) const; 48 | int64 GetInt64Parameter(const string &name) const; 49 | bool GetBoolParameter(const string &name) const; 50 | double GetFloatParameter(const string &name) const; 51 | 52 | // Returns task parameter. If the parameter is not in the task configuration 53 | // the default value is returned. Parameters retrieved using these methods 54 | // don't need to be defined with a DEFINE_*() macro. 55 | string Get(const string &name, const string &defval) const; 56 | string Get(const string &name, const char *defval) const; 57 | int Get(const string &name, int defval) const; 58 | int64 Get(const string &name, int64 defval) const; 59 | double Get(const string &name, double defval) const; 60 | bool Get(const string &name, bool defval) const; 61 | 62 | // Returns input file name for a single-file task input. 63 | static string InputFile(const TaskInput &input); 64 | 65 | // Returns true if task input supports the file and record format. 66 | static bool Supports(const TaskInput &input, const string &file_format, 67 | const string &record_format); 68 | 69 | private: 70 | // Underlying task specification protocol buffer. 71 | TaskSpec spec_; 72 | 73 | // Vector of parameters required by this task. These must be specified in the 74 | // task rather than relying on default values. 75 | vector required_parameters_; 76 | }; 77 | 78 | } // namespace syntaxnet 79 | 80 | #endif // SYNTAXNET_TASK_CONTEXT_H_ 81 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/task_spec.proto: -------------------------------------------------------------------------------- 1 | // LINT: ALLOW_GROUPS 2 | // Protocol buffer specifications for task configuration. 3 | 4 | syntax = "proto2"; 5 | 6 | package syntaxnet; 7 | 8 | // Task input descriptor. 9 | message TaskInput { 10 | // Name of input resource. 11 | required string name = 1; 12 | 13 | // Name of stage responsible of creating this resource. 14 | optional string creator = 2; 15 | 16 | // File format for resource. 17 | repeated string file_format = 3; 18 | 19 | // Record format for resource. 20 | repeated string record_format = 4; 21 | 22 | // Is this resource multi-file? 23 | optional bool multi_file = 5 [default = false]; 24 | 25 | // An input can consist of multiple file sets. 26 | repeated group Part = 6 { 27 | // File pattern for file set. 28 | optional string file_pattern = 7; 29 | 30 | // File format for file set. 31 | optional string file_format = 8; 32 | 33 | // Record format for file set. 34 | optional string record_format = 9; 35 | } 36 | } 37 | 38 | // Task output descriptor. 39 | message TaskOutput { 40 | // Name of output resource. 41 | required string name = 1; 42 | 43 | // File format for output resource. 44 | optional string file_format = 2; 45 | 46 | // Record format for output resource. 47 | optional string record_format = 3; 48 | 49 | // Number of shards in output. If it is different from zero this output is 50 | // sharded. If the number of shards is set to -1 this means that the output is 51 | // sharded, but the number of shard is unknown. The files are then named 52 | // 'base-*-of-*'. 53 | optional int32 shards = 4 [default = 0]; 54 | 55 | // Base file name for output resource. If this is not set by the task 56 | // component it is set to a default value by the workflow engine. 57 | optional string file_base = 5; 58 | 59 | // Optional extension added to the file name. 60 | optional string file_extension = 6; 61 | } 62 | 63 | // A task specification is used for describing executing parameters. 64 | message TaskSpec { 65 | // Name of task. 66 | optional string task_name = 1; 67 | 68 | // Workflow task type. 69 | optional string task_type = 2; 70 | 71 | // Task parameters. 72 | repeated group Parameter = 3 { 73 | required string name = 4; 74 | optional string value = 5; 75 | } 76 | 77 | // Task inputs. 78 | repeated TaskInput input = 6; 79 | 80 | // Task outputs. 81 | repeated TaskOutput output = 7; 82 | } 83 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/test_main.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // A program with a main that is suitable for unittests, including those 17 | // that also define microbenchmarks. Based on whether the user specified 18 | // the --benchmark_filter flag which specifies which benchmarks to run, 19 | // we will either run benchmarks or run the gtest tests in the program. 20 | 21 | #include "tensorflow/core/platform/platform.h" 22 | #include "tensorflow/core/platform/types.h" 23 | 24 | #if defined(PLATFORM_GOOGLE) || defined(__ANDROID__) 25 | 26 | // main() is supplied by gunit_main 27 | #else 28 | #include "gtest/gtest.h" 29 | #include "tensorflow/core/lib/core/stringpiece.h" 30 | #include "tensorflow/core/platform/test_benchmark.h" 31 | 32 | GTEST_API_ int main(int argc, char **argv) { 33 | std::cout << "Running main() from test_main.cc\n"; 34 | 35 | testing::InitGoogleTest(&argc, argv); 36 | for (int i = 1; i < argc; i++) { 37 | if (tensorflow::StringPiece(argv[i]).starts_with("--benchmarks=")) { 38 | const char *pattern = argv[i] + strlen("--benchmarks="); 39 | tensorflow::testing::Benchmark::Run(pattern); 40 | return 0; 41 | } 42 | } 43 | return RUN_ALL_TESTS(); 44 | } 45 | #endif 46 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/testdata/context.pbtxt: -------------------------------------------------------------------------------- 1 | Parameter { 2 | name: 'brain_parser_embedding_dims' 3 | value: '8;8;8' 4 | } 5 | Parameter { 6 | name: 'brain_parser_features' 7 | value: 'input.token.word input(1).token.word input(2).token.word stack.token.word stack(1).token.word stack(2).token.word;input.tag input(1).tag input(2).tag stack.tag stack(1).tag stack(2).tag;stack.child(1).label stack.child(1).sibling(-1).label stack.child(-1).label stack.child(-1).sibling(1).label' 8 | } 9 | Parameter { 10 | name: 'brain_parser_embedding_names' 11 | value: 'words;tags;labels' 12 | } 13 | input { 14 | name: 'training-corpus' 15 | record_format: 'conll-sentence' 16 | Part { 17 | file_pattern: 'syntaxnet/testdata/mini-training-set' 18 | } 19 | } 20 | input { 21 | name: 'tuning-corpus' 22 | record_format: 'conll-sentence' 23 | Part { 24 | file_pattern: 'syntaxnet/testdata/mini-training-set' 25 | } 26 | } 27 | input { 28 | name: 'parsed-tuning-corpus' 29 | creator: 'brain_parser/greedy' 30 | record_format: 'conll-sentence' 31 | } 32 | input { 33 | name: 'label-map' 34 | file_format: 'text' 35 | Part { 36 | file_pattern: 'OUTPATH/label-map' 37 | } 38 | } 39 | input { 40 | name: 'word-map' 41 | Part { 42 | file_pattern: 'OUTPATH/word-map' 43 | } 44 | } 45 | input { 46 | name: 'lcword-map' 47 | Part { 48 | file_pattern: 'OUTPATH/lcword-map' 49 | } 50 | } 51 | input { 52 | name: 'tag-map' 53 | Part { 54 | file_pattern: 'OUTPATH/tag-map' 55 | } 56 | } 57 | input { 58 | name: 'category-map' 59 | Part { 60 | file_pattern: 'OUTPATH/category-map' 61 | } 62 | } 63 | input { 64 | name: 'char-map' 65 | Part { 66 | file_pattern: 'OUTPATH/char-map' 67 | } 68 | } 69 | input { 70 | name: 'prefix-table' 71 | Part { 72 | file_pattern: 'OUTPATH/prefix-table' 73 | } 74 | } 75 | input { 76 | name: 'suffix-table' 77 | Part { 78 | file_pattern: 'OUTPATH/suffix-table' 79 | } 80 | } 81 | input { 82 | name: 'tag-to-category' 83 | Part { 84 | file_pattern: 'OUTPATH/tag-to-category' 85 | } 86 | } 87 | input { 88 | name: 'stdout' 89 | record_format: 'conll-sentence' 90 | Part { 91 | file_pattern: '-' 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/testdata/document: -------------------------------------------------------------------------------- 1 | text : "I can not recall any disorder in currency markets since the 1974 guidelines were adopted ." 2 | token: { 3 | word : "I" 4 | start : 0 5 | end : 0 6 | head : 3 7 | tag : "PRP" 8 | category: "PRON" 9 | label : "nsubj" 10 | break_level : SENTENCE_BREAK 11 | } 12 | token: { 13 | word : "can" 14 | start : 2 15 | end : 4 16 | head : 3 17 | tag : "MD" 18 | category: "VERB" 19 | label : "aux" 20 | } 21 | token: { 22 | word : "not" 23 | start : 6 24 | end : 8 25 | head : 3 26 | tag : "RB" 27 | category: "ADV" 28 | label : "neg" 29 | } 30 | token: { 31 | word : "recall" 32 | start : 10 33 | end : 15 34 | tag : "VB" 35 | category: "VERB" 36 | label : "ROOT" 37 | } 38 | token: { 39 | word : "any" 40 | start : 17 41 | end : 19 42 | head : 5 43 | tag : "DT" 44 | category: "DET" 45 | label : "det" 46 | } 47 | token: { 48 | word : "disorder" 49 | start : 21 50 | end : 28 51 | head : 3 52 | tag : "NN" 53 | category: "NOUN" 54 | label : "dobj" 55 | } 56 | token: { 57 | word : "in" 58 | start : 30 59 | end : 31 60 | head : 5 61 | tag : "IN" 62 | category: "ADP" 63 | label : "prep" 64 | } 65 | token: { 66 | word : "currency" 67 | start : 33 68 | end : 40 69 | head : 8 70 | tag : "NN" 71 | category: "NOUN" 72 | label : "nn" 73 | } 74 | token: { 75 | word : "markets" 76 | start : 42 77 | end : 48 78 | head : 6 79 | tag : "NNS" 80 | category: "NOUN" 81 | label : "pobj" 82 | } 83 | token: { 84 | word : "since" 85 | start : 50 86 | end : 54 87 | head : 14 88 | tag : "IN" 89 | category: "ADP" 90 | label : "mark" 91 | } 92 | token: { 93 | word : "the" 94 | start : 56 95 | end : 58 96 | head : 12 97 | tag : "DT" 98 | category: "DET" 99 | label : "det" 100 | } 101 | token: { 102 | word : "1974" 103 | start : 60 104 | end : 63 105 | head : 12 106 | tag : "CD" 107 | category: "NUM" 108 | label : "num" 109 | } 110 | token: { 111 | word : "guidelines" 112 | start : 65 113 | end : 74 114 | head : 14 115 | tag : "NNS" 116 | category: "NOUN" 117 | label : "nsubjpass" 118 | } 119 | token: { 120 | word : "were" 121 | start : 76 122 | end : 79 123 | head : 14 124 | tag : "VBD" 125 | category: "VERB" 126 | label : "auxpass" 127 | } 128 | token: { 129 | word : "adopted" 130 | start : 81 131 | end : 87 132 | head : 3 133 | tag : "VBN" 134 | category: "VERB" 135 | label : "advcl" 136 | } 137 | token: { 138 | word : "." 139 | start : 89 140 | end : 89 141 | head : 3 142 | tag : "." 143 | category: "." 144 | label : "p" 145 | } 146 | -------------------------------------------------------------------------------- /syntaxnet/syntaxnet/workspace.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "syntaxnet/workspace.h" 17 | 18 | #include "tensorflow/core/lib/strings/strcat.h" 19 | 20 | namespace syntaxnet { 21 | 22 | string WorkspaceRegistry::DebugString() const { 23 | string str; 24 | for (auto &it : workspace_names_) { 25 | const string &type_name = workspace_types_.at(it.first); 26 | for (size_t index = 0; index < it.second.size(); ++index) { 27 | const string &workspace_name = it.second[index]; 28 | tensorflow::strings::StrAppend(&str, "\n ", type_name, " :: ", 29 | workspace_name); 30 | } 31 | } 32 | return str; 33 | } 34 | 35 | VectorIntWorkspace::VectorIntWorkspace(int size) : elements_(size) {} 36 | 37 | VectorIntWorkspace::VectorIntWorkspace(int size, int value) 38 | : elements_(size, value) {} 39 | 40 | VectorIntWorkspace::VectorIntWorkspace(const vector &elements) 41 | : elements_(elements) {} 42 | 43 | string VectorIntWorkspace::TypeName() { return "Vector"; } 44 | 45 | VectorVectorIntWorkspace::VectorVectorIntWorkspace(int size) 46 | : elements_(size) {} 47 | 48 | string VectorVectorIntWorkspace::TypeName() { return "VectorVector"; } 49 | 50 | } // namespace syntaxnet 51 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | 3 | cc_library( 4 | name = "utf", 5 | srcs = [ 6 | "rune.c", 7 | "runestrcat.c", 8 | "runestrchr.c", 9 | "runestrcmp.c", 10 | "runestrcpy.c", 11 | "runestrdup.c", 12 | "runestrecpy.c", 13 | "runestrlen.c", 14 | "runestrncat.c", 15 | "runestrncmp.c", 16 | "runestrncpy.c", 17 | "runestrrchr.c", 18 | "runestrstr.c", 19 | "runetype.c", 20 | "utfecpy.c", 21 | "utflen.c", 22 | "utfnlen.c", 23 | "utfrrune.c", 24 | "utfrune.c", 25 | "utfutf.c", 26 | ], 27 | hdrs = [ 28 | "runetypebody.c", 29 | "utf.h", 30 | "utfdef.h", 31 | ], 32 | includes = ["."], 33 | visibility = ["//visibility:public"], 34 | ) 35 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/README: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 1998-2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrcat.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | Rune* 20 | runestrcat(Rune *s1, const Rune *s2) 21 | { 22 | 23 | runestrcpy((Rune*)runestrchr(s1, 0), s2); 24 | return s1; 25 | } 26 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrchr.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | const 20 | Rune* 21 | runestrchr(const Rune *s, Rune c) 22 | { 23 | Rune c0 = c; 24 | Rune c1; 25 | 26 | if(c == 0) { 27 | while(*s++) 28 | ; 29 | return s-1; 30 | } 31 | 32 | while((c1 = *s++) != 0) 33 | if(c1 == c0) 34 | return s-1; 35 | return 0; 36 | } 37 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrcmp.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | int 20 | runestrcmp(const Rune *s1, const Rune *s2) 21 | { 22 | Rune c1, c2; 23 | 24 | for(;;) { 25 | c1 = *s1++; 26 | c2 = *s2++; 27 | if(c1 != c2) { 28 | if(c1 > c2) 29 | return 1; 30 | return -1; 31 | } 32 | if(c1 == 0) 33 | return 0; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrcpy.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | Rune* 20 | runestrcpy(Rune *s1, const Rune *s2) 21 | { 22 | Rune *os1; 23 | 24 | os1 = s1; 25 | while((*s1++ = *s2++) != 0) 26 | ; 27 | return os1; 28 | } 29 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrdup.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include 17 | #include "third_party/utf/utf.h" 18 | #include "third_party/utf/utfdef.h" 19 | 20 | Rune* 21 | runestrdup(const Rune *s) 22 | { 23 | Rune *ns; 24 | 25 | ns = (Rune*)malloc(sizeof(Rune)*(runestrlen(s) + 1)); 26 | if(ns == 0) 27 | return 0; 28 | 29 | return runestrcpy(ns, s); 30 | } 31 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrecpy.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | Rune* 20 | runestrecpy(Rune *s1, Rune *es1, const Rune *s2) 21 | { 22 | if(s1 >= es1) 23 | return s1; 24 | 25 | while((*s1++ = *s2++) != 0){ 26 | if(s1 == es1){ 27 | *--s1 = '\0'; 28 | break; 29 | } 30 | } 31 | return s1; 32 | } 33 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrlen.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | long 20 | runestrlen(const Rune *s) 21 | { 22 | 23 | return runestrchr(s, 0) - s; 24 | } 25 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrncat.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | Rune* 20 | runestrncat(Rune *s1, const Rune *s2, long n) 21 | { 22 | Rune *os1; 23 | 24 | os1 = s1; 25 | s1 = (Rune*)runestrchr(s1, 0); 26 | while((*s1++ = *s2++) != 0) 27 | if(--n < 0) { 28 | s1[-1] = 0; 29 | break; 30 | } 31 | return os1; 32 | } 33 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrncmp.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | int 20 | runestrncmp(const Rune *s1, const Rune *s2, long n) 21 | { 22 | Rune c1, c2; 23 | 24 | while(n > 0) { 25 | c1 = *s1++; 26 | c2 = *s2++; 27 | n--; 28 | if(c1 != c2) { 29 | if(c1 > c2) 30 | return 1; 31 | return -1; 32 | } 33 | if(c1 == 0) 34 | break; 35 | } 36 | return 0; 37 | } 38 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrncpy.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | Rune* 20 | runestrncpy(Rune *s1, const Rune *s2, long n) 21 | { 22 | int i; 23 | Rune *os1; 24 | 25 | os1 = s1; 26 | for(i = 0; i < n; i++) 27 | if((*s1++ = *s2++) == 0) { 28 | while(++i < n) 29 | *s1++ = 0; 30 | return os1; 31 | } 32 | return os1; 33 | } 34 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrrchr.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | const 20 | Rune* 21 | runestrrchr(const Rune *s, Rune c) 22 | { 23 | const Rune *r; 24 | 25 | if(c == 0) 26 | return runestrchr(s, 0); 27 | r = 0; 28 | while((s = runestrchr(s, c)) != 0) 29 | r = s++; 30 | return r; 31 | } 32 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runestrstr.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | /* 20 | * Return pointer to first occurrence of s2 in s1, 21 | * 0 if none 22 | */ 23 | const 24 | Rune* 25 | runestrstr(const Rune *s1, const Rune *s2) 26 | { 27 | const Rune *p, *pa, *pb; 28 | int c0, c; 29 | 30 | c0 = *s2; 31 | if(c0 == 0) 32 | return s1; 33 | s2++; 34 | for(p=runestrchr(s1, c0); p; p=runestrchr(p+1, c0)) { 35 | pa = p; 36 | for(pb=s2;; pb++) { 37 | c = *pb; 38 | if(c == 0) 39 | return p; 40 | if(c != *++pa) 41 | break; 42 | } 43 | } 44 | return 0; 45 | } 46 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/runetype.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include "third_party/utf/utf.h" 15 | #include "third_party/utf/utfdef.h" 16 | 17 | static 18 | Rune* 19 | rbsearch(Rune c, Rune *t, int n, int ne) 20 | { 21 | Rune *p; 22 | int m; 23 | 24 | while(n > 1) { 25 | m = n >> 1; 26 | p = t + m*ne; 27 | if(c >= p[0]) { 28 | t = p; 29 | n = n-m; 30 | } else 31 | n = m; 32 | } 33 | if(n && c >= t[0]) 34 | return t; 35 | return 0; 36 | } 37 | 38 | /* 39 | * The "ideographic" property is hard to extract from UnicodeData.txt, 40 | * so it is hard coded here. 41 | * 42 | * It is defined in the Unicode PropList.txt file, for example 43 | * PropList-3.0.0.txt. Unlike the UnicodeData.txt file, the format of 44 | * PropList changes between versions. This property appears relatively static; 45 | * it is the same in version 4.0.1, except that version defines some >16 bit 46 | * chars as ideographic as well: 20000..2a6d6, and 2f800..2Fa1d. 47 | */ 48 | static Rune __isideographicr[] = { 49 | 0x3006, 0x3007, /* 3006 not in Unicode 2, in 2.1 */ 50 | 0x3021, 0x3029, 51 | 0x3038, 0x303a, /* not in Unicode 2 or 2.1 */ 52 | 0x3400, 0x4db5, /* not in Unicode 2 or 2.1 */ 53 | 0x4e00, 0x9fbb, /* 0x9FA6..0x9FBB added for 4.1.0? */ 54 | 0xf900, 0xfa2d, 55 | 0x20000, 0x2A6D6, 56 | 0x2F800, 0x2FA1D, 57 | }; 58 | 59 | int 60 | isideographicrune(Rune c) 61 | { 62 | Rune *p; 63 | 64 | p = rbsearch(c, __isideographicr, nelem(__isideographicr)/2, 2); 65 | if(p && c >= p[0] && c <= p[1]) 66 | return 1; 67 | return 0; 68 | } 69 | 70 | #include "third_party/utf/runetypebody.c" 71 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/utfdef.h: -------------------------------------------------------------------------------- 1 | #define uchar _utfuchar 2 | #define ushort _utfushort 3 | #define uint _utfuint 4 | #define ulong _utfulong 5 | #define vlong _utfvlong 6 | #define uvlong _utfuvlong 7 | 8 | typedef unsigned char uchar; 9 | typedef unsigned short ushort; 10 | typedef unsigned int uint; 11 | typedef unsigned long ulong; 12 | 13 | #define nelem(x) (sizeof(x)/sizeof((x)[0])) 14 | #define nil ((void*)0) 15 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/utfecpy.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | char* 20 | utfecpy(char *to, char *e, const char *from) 21 | { 22 | char *end; 23 | 24 | if(to >= e) 25 | return to; 26 | end = (char*)memccpy(to, from, '\0', e - to); 27 | if(end == nil){ 28 | end = e-1; 29 | while(end>to && (*--end&0xC0)==0x80) 30 | ; 31 | *end = '\0'; 32 | }else{ 33 | end--; 34 | } 35 | return end; 36 | } 37 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/utflen.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | int 20 | utflen(const char *s) 21 | { 22 | int c; 23 | long n; 24 | Rune rune; 25 | 26 | n = 0; 27 | for(;;) { 28 | c = *(uchar*)s; 29 | if(c < Runeself) { 30 | if(c == 0) 31 | return n; 32 | s++; 33 | } else 34 | s += chartorune(&rune, s); 35 | n++; 36 | } 37 | return 0; 38 | } 39 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/utfnlen.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | int 20 | utfnlen(const char *s, long m) 21 | { 22 | int c; 23 | long n; 24 | Rune rune; 25 | const char *es; 26 | 27 | es = s + m; 28 | for(n = 0; s < es; n++) { 29 | c = *(uchar*)s; 30 | if(c < Runeself){ 31 | if(c == '\0') 32 | break; 33 | s++; 34 | continue; 35 | } 36 | if(!fullrune(s, es-s)) 37 | break; 38 | s += chartorune(&rune, s); 39 | } 40 | return n; 41 | } 42 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/utfrrune.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | const 20 | char* 21 | utfrrune(const char *s, Rune c) 22 | { 23 | long c1; 24 | Rune r; 25 | const char *s1; 26 | 27 | if(c < Runesync) /* not part of utf sequence */ 28 | return strrchr(s, c); 29 | 30 | s1 = 0; 31 | for(;;) { 32 | c1 = *(uchar*)s; 33 | if(c1 < Runeself) { /* one byte rune */ 34 | if(c1 == 0) 35 | return s1; 36 | if(c1 == c) 37 | s1 = s; 38 | s++; 39 | continue; 40 | } 41 | c1 = chartorune(&r, s); 42 | if(r == c) 43 | s1 = s; 44 | s += c1; 45 | } 46 | return 0; 47 | } 48 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/utfrune.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | const 20 | char* 21 | utfrune(const char *s, Rune c) 22 | { 23 | long c1; 24 | Rune r; 25 | int n; 26 | 27 | if(c < Runesync) /* not part of utf sequence */ 28 | return strchr(s, c); 29 | 30 | for(;;) { 31 | c1 = *(uchar*)s; 32 | if(c1 < Runeself) { /* one byte rune */ 33 | if(c1 == 0) 34 | return 0; 35 | if(c1 == c) 36 | return s; 37 | s++; 38 | continue; 39 | } 40 | n = chartorune(&r, s); 41 | if(r == c) 42 | return s; 43 | s += n; 44 | } 45 | return 0; 46 | } 47 | -------------------------------------------------------------------------------- /syntaxnet/third_party/utf/utfutf.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The authors of this software are Rob Pike and Ken Thompson. 3 | * Copyright (c) 2002 by Lucent Technologies. 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose without fee is hereby granted, provided that this entire notice 6 | * is included in all copies of any software which is or includes a copy 7 | * or modification of this software and in all copies of the supporting 8 | * documentation for such software. 9 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 10 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 11 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 12 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 13 | */ 14 | #include 15 | #include 16 | #include "third_party/utf/utf.h" 17 | #include "third_party/utf/utfdef.h" 18 | 19 | 20 | /* 21 | * Return pointer to first occurrence of s2 in s1, 22 | * 0 if none 23 | */ 24 | const 25 | char* 26 | utfutf(const char *s1, const char *s2) 27 | { 28 | const char *p; 29 | long f, n1, n2; 30 | Rune r; 31 | 32 | n1 = chartorune(&r, s2); 33 | f = r; 34 | if(f <= Runesync) /* represents self */ 35 | return strstr(s1, s2); 36 | 37 | n2 = strlen(s2); 38 | for(p=s1; (p=utfrune(p, f)) != 0; p+=n1) 39 | if(strncmp(p, s2, n2) == 0) 40 | return p; 41 | return 0; 42 | } 43 | -------------------------------------------------------------------------------- /syntaxnet/tools/bazel.rc: -------------------------------------------------------------------------------- 1 | build:cuda --crosstool_top=//third_party/gpus/crosstool 2 | 3 | build --define=use_fast_cpp_protos=true 4 | build --define=allow_oversize_protos=true 5 | build --copt -funsigned-char 6 | build -c opt 7 | 8 | build --spawn_strategy=standalone 9 | test --spawn_strategy=standalone 10 | run --spawn_strategy=standalone 11 | -------------------------------------------------------------------------------- /syntaxnet/util/utf8/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | 3 | # Requires --copt -funsigned-char when compiling (unsigned chars). 4 | 5 | cc_library( 6 | name = "unicodetext", 7 | srcs = [ 8 | "unicodetext.cc", 9 | "unilib.cc", 10 | ], 11 | hdrs = [ 12 | "unicodetext.h", 13 | "unilib.h", 14 | "unilib_utf8_utils.h", 15 | ], 16 | visibility = ["//visibility:public"], 17 | deps = [ 18 | "//syntaxnet:base", 19 | "//third_party/utf", 20 | ], 21 | ) 22 | 23 | cc_test( 24 | name = "unicodetext_unittest", 25 | srcs = [ 26 | "gtest_main.cc", 27 | "unicodetext_unittest.cc", 28 | ], 29 | deps = [ 30 | "@org_tensorflow//tensorflow/core:testlib", 31 | ":unicodetext", 32 | ], 33 | ) 34 | 35 | cc_binary( 36 | name = "unicodetext_main", 37 | srcs = ["unicodetext_main.cc"], 38 | deps = [":unicodetext"], 39 | ) 40 | -------------------------------------------------------------------------------- /syntaxnet/util/utf8/gtest_main.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2010 Google Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | // Author: sligocki@google.com (Shawn Ligocki) 18 | // 19 | // Build all tests with this main to run all tests. 20 | 21 | #include "gtest/gtest.h" 22 | 23 | int main(int argc, char **argv) { 24 | ::testing::InitGoogleTest(&argc, argv); 25 | return RUN_ALL_TESTS(); 26 | } 27 | -------------------------------------------------------------------------------- /syntaxnet/util/utf8/unicodetext_main.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2010 Google Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | // Author: sligocki@google.com (Shawn Ligocki) 18 | // 19 | // A basic main function to test that UnicodeText builds. 20 | 21 | #include 22 | #include 23 | 24 | #include 25 | 26 | #include "util/utf8/unicodetext.h" 27 | 28 | int main(int argc, char** argv) { 29 | if (argc > 1) { 30 | printf("Bytes:\n"); 31 | std::string bytes(argv[1]); 32 | for (std::string::const_iterator iter = bytes.begin(); 33 | iter < bytes.end(); ++iter) { 34 | printf(" 0x%02X\n", *iter); 35 | } 36 | 37 | printf("Unicode codepoints:\n"); 38 | UnicodeText text(UTF8ToUnicodeText(bytes)); 39 | for (UnicodeText::const_iterator iter = text.begin(); 40 | iter < text.end(); ++iter) { 41 | printf(" U+%X\n", *iter); 42 | } 43 | } 44 | return EXIT_SUCCESS; 45 | } 46 | -------------------------------------------------------------------------------- /syntaxnet/util/utf8/unilib.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2010 Google Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | // Author: sligocki@google.com (Shawn Ligocki) 18 | 19 | #include "util/utf8/unilib.h" 20 | 21 | #include "syntaxnet/base.h" 22 | #include "third_party/utf/utf.h" 23 | 24 | namespace UniLib { 25 | 26 | // Codepoints not allowed for interchange are: 27 | // C0 (ASCII) controls: U+0000 to U+001F excluding Space (SP, U+0020), 28 | // Horizontal Tab (HT, U+0009), Line-Feed (LF, U+000A), 29 | // Form Feed (FF, U+000C) and Carriage-Return (CR, U+000D) 30 | // C1 controls: U+007F to U+009F 31 | // Surrogates: U+D800 to U+DFFF 32 | // Non-characters: U+FDD0 to U+FDEF and U+xxFFFE to U+xxFFFF for all xx 33 | bool IsInterchangeValid(char32 c) { 34 | return !((c >= 0x00 && c <= 0x08) || c == 0x0B || (c >= 0x0E && c <= 0x1F) || 35 | (c >= 0x7F && c <= 0x9F) || 36 | (c >= 0xD800 && c <= 0xDFFF) || 37 | (c >= 0xFDD0 && c <= 0xFDEF) || (c&0xFFFE) == 0xFFFE); 38 | } 39 | 40 | int SpanInterchangeValid(const char* begin, int byte_length) { 41 | char32 rune; 42 | const char* p = begin; 43 | const char* end = begin + byte_length; 44 | while (p < end) { 45 | int bytes_consumed = charntorune(&rune, p, end - p); 46 | // We want to accept Runeerror == U+FFFD as a valid char, but it is used 47 | // by chartorune to indicate error. Luckily, the real codepoint is size 3 48 | // while errors return bytes_consumed <= 1. 49 | if ((rune == Runeerror && bytes_consumed <= 1) || 50 | !IsInterchangeValid(rune)) { 51 | break; // Found 52 | } 53 | p += bytes_consumed; 54 | } 55 | return p - begin; 56 | } 57 | 58 | } // namespace UniLib 59 | -------------------------------------------------------------------------------- /syntaxnet/util/utf8/unilib.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2010 Google Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | // Routines to do manipulation of Unicode characters or text 18 | // 19 | // The StructurallyValid routines accept buffers of arbitrary bytes. 20 | // For CoerceToStructurallyValid(), the input buffer and output buffers may 21 | // point to exactly the same memory. 22 | // 23 | // In all other cases, the UTF-8 string must be structurally valid and 24 | // have all codepoints in the range U+0000 to U+D7FF or U+E000 to U+10FFFF. 25 | // Debug builds take a fatal error for invalid UTF-8 input. 26 | // The input and output buffers may not overlap at all. 27 | // 28 | // The char32 routines are here only for convenience; they convert to UTF-8 29 | // internally and use the UTF-8 routines. 30 | 31 | #ifndef UTIL_UTF8_UNILIB_H__ 32 | #define UTIL_UTF8_UNILIB_H__ 33 | 34 | #include 35 | #include "syntaxnet/base.h" 36 | 37 | // We export OneCharLen, IsValidCodepoint, and IsTrailByte from here, 38 | // but they are defined in unilib_utf8_utils.h. 39 | //#include "util/utf8/public/unilib_utf8_utils.h" // IWYU pragma: export 40 | 41 | namespace UniLib { 42 | 43 | // Returns the length in bytes of the prefix of src that is all 44 | // interchange valid UTF-8 45 | int SpanInterchangeValid(const char* src, int byte_length); 46 | inline int SpanInterchangeValid(const std::string& src) { 47 | return SpanInterchangeValid(src.data(), src.size()); 48 | } 49 | 50 | // Returns true if the source is all interchange valid UTF-8 51 | // "Interchange valid" is a stronger than structurally valid -- 52 | // no C0 or C1 control codes (other than CR LF HT FF) and no non-characters. 53 | bool IsInterchangeValid(char32 codepoint); 54 | inline bool IsInterchangeValid(const char* src, int byte_length) { 55 | return (byte_length == SpanInterchangeValid(src, byte_length)); 56 | } 57 | inline bool IsInterchangeValid(const std::string& src) { 58 | return IsInterchangeValid(src.data(), src.size()); 59 | } 60 | 61 | } // namespace UniLib 62 | 63 | #endif // UTIL_UTF8_PUBLIC_UNILIB_H_ 64 | -------------------------------------------------------------------------------- /syntaxnet/util/utf8/unilib_utf8_utils.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2010 Google Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef UTIL_UTF8_PUBLIC_UNILIB_UTF8_UTILS_H_ 18 | #define UTIL_UTF8_PUBLIC_UNILIB_UTF8_UTILS_H_ 19 | 20 | // These definitions are self-contained and have no dependencies. 21 | // They are also exported from unilib.h for legacy reasons. 22 | 23 | #include "syntaxnet/base.h" 24 | #include "third_party/utf/utf.h" 25 | 26 | namespace UniLib { 27 | 28 | // Returns true if 'c' is in the range [0, 0xD800) or [0xE000, 0x10FFFF] 29 | // (i.e., is not a surrogate codepoint). See also 30 | // IsValidCodepoint(const char* src) in util/utf8/public/unilib.h. 31 | inline bool IsValidCodepoint(char32 c) { 32 | return (static_cast(c) < 0xD800) 33 | || (c >= 0xE000 && c <= 0x10FFFF); 34 | } 35 | 36 | // Returns true if 'str' is the start of a structurally valid UTF-8 37 | // sequence and is not a surrogate codepoint. Returns false if str.empty() 38 | // or if str.length() < UniLib::OneCharLen(str[0]). Otherwise, this function 39 | // will access 1-4 bytes of src, where n is UniLib::OneCharLen(src[0]). 40 | inline bool IsUTF8ValidCodepoint(StringPiece str) { 41 | char32 c; 42 | int consumed; 43 | // It's OK if str.length() > consumed. 44 | return !str.empty() 45 | && isvalidcharntorune(str.data(), str.size(), &c, &consumed) 46 | && IsValidCodepoint(c); 47 | } 48 | 49 | // Returns the length (number of bytes) of the Unicode code point 50 | // starting at src, based on inspecting just that one byte. This 51 | // requires that src point to a well-formed UTF-8 string; the result 52 | // is undefined otherwise. 53 | inline int OneCharLen(const char* src) { 54 | return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; 55 | } 56 | 57 | // Returns true if this byte is a trailing UTF-8 byte (10xx xxxx) 58 | inline bool IsTrailByte(char x) { 59 | // return (x & 0xC0) == 0x80; 60 | // Since trail bytes are always in [0x80, 0xBF], we can optimize: 61 | return static_cast(x) < -0x40; 62 | } 63 | 64 | } // namespace UniLib 65 | 66 | #endif // UTIL_UTF8_PUBLIC_UNILIB_UTF8_UTILS_H_ 67 | -------------------------------------------------------------------------------- /textsum/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = [":internal"]) 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | exports_files(["LICENSE"]) 6 | 7 | package_group( 8 | name = "internal", 9 | packages = [ 10 | "//textsum/...", 11 | ], 12 | ) 13 | 14 | py_library( 15 | name = "seq2seq_attention_model", 16 | srcs = ["seq2seq_attention_model.py"], 17 | deps = [ 18 | ":seq2seq_lib", 19 | ], 20 | ) 21 | 22 | py_library( 23 | name = "seq2seq_lib", 24 | srcs = ["seq2seq_lib.py"], 25 | ) 26 | 27 | py_binary( 28 | name = "seq2seq_attention", 29 | srcs = ["seq2seq_attention.py"], 30 | deps = [ 31 | ":batch_reader", 32 | ":data", 33 | ":seq2seq_attention_decode", 34 | ":seq2seq_attention_model", 35 | ], 36 | ) 37 | 38 | py_library( 39 | name = "batch_reader", 40 | srcs = ["batch_reader.py"], 41 | deps = [ 42 | ":data", 43 | ":seq2seq_attention_model", 44 | ], 45 | ) 46 | 47 | py_library( 48 | name = "beam_search", 49 | srcs = ["beam_search.py"], 50 | ) 51 | 52 | py_library( 53 | name = "seq2seq_attention_decode", 54 | srcs = ["seq2seq_attention_decode.py"], 55 | deps = [ 56 | ":beam_search", 57 | ":data", 58 | ], 59 | ) 60 | 61 | py_library( 62 | name = "data", 63 | srcs = ["data.py"], 64 | ) 65 | -------------------------------------------------------------------------------- /textsum/data/data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/tf-models/43dad80002d008f8c0d70718a9c88dce5a56eb21/textsum/data/data -------------------------------------------------------------------------------- /textsum/data_convert_example.py: -------------------------------------------------------------------------------- 1 | """Example of Converting TextSum model data. 2 | Usage: 3 | python data_convert_example.py --command binary_to_text --in_file data/data --out_file data/text_data 4 | python data_convert_example.py --command text_to_binary --in_file data/text_data --out_file data/binary_data 5 | python data_convert_example.py --command binary_to_text --in_file data/binary_data --out_file data/text_data2 6 | diff data/text_data2 data/text_data 7 | """ 8 | 9 | import struct 10 | import sys 11 | 12 | import tensorflow as tf 13 | from tensorflow.core.example import example_pb2 14 | 15 | FLAGS = tf.app.flags.FLAGS 16 | tf.app.flags.DEFINE_string('command', 'binary_to_text', 17 | 'Either binary_to_text or text_to_binary.' 18 | 'Specify FLAGS.in_file accordingly.') 19 | tf.app.flags.DEFINE_string('in_file', '', 'path to file') 20 | tf.app.flags.DEFINE_string('out_file', '', 'path to file') 21 | 22 | def _binary_to_text(): 23 | reader = open(FLAGS.in_file, 'rb') 24 | writer = open(FLAGS.out_file, 'w') 25 | while True: 26 | len_bytes = reader.read(8) 27 | if not len_bytes: 28 | sys.stderr.write('Done reading\n') 29 | return 30 | str_len = struct.unpack('q', len_bytes)[0] 31 | tf_example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 32 | tf_example = example_pb2.Example.FromString(tf_example_str) 33 | examples = [] 34 | for key in tf_example.features.feature: 35 | examples.append('%s=%s' % (key, tf_example.features.feature[key].bytes_list.value[0])) 36 | writer.write('%s\n' % '\t'.join(examples)) 37 | reader.close() 38 | writer.close() 39 | 40 | 41 | def _text_to_binary(): 42 | inputs = open(FLAGS.in_file, 'r').readlines() 43 | writer = open(FLAGS.out_file, 'wb') 44 | for inp in inputs: 45 | tf_example = example_pb2.Example() 46 | for feature in inp.strip().split('\t'): 47 | (k, v) = feature.split('=') 48 | tf_example.features.feature[k].bytes_list.value.extend([v]) 49 | tf_example_str = tf_example.SerializeToString() 50 | str_len = len(tf_example_str) 51 | writer.write(struct.pack('q', str_len)) 52 | writer.write(struct.pack('%ds' % str_len, tf_example_str)) 53 | writer.close() 54 | 55 | 56 | def main(unused_argv): 57 | assert FLAGS.command and FLAGS.in_file and FLAGS.out_file 58 | if FLAGS.command == 'binary_to_text': 59 | _binary_to_text() 60 | elif FLAGS.command == 'text_to_binary': 61 | _text_to_binary() 62 | 63 | 64 | if __name__ == '__main__': 65 | tf.app.run() 66 | -------------------------------------------------------------------------------- /transformer/README.md: -------------------------------------------------------------------------------- 1 | # Spatial Transformer Network 2 | 3 | The Spatial Transformer Network [1] allows the spatial manipulation of data within the network. 4 | 5 |
6 |

7 |
8 | 9 | ### API 10 | 11 | A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2]. 12 | 13 | #### How to use 14 | 15 |
16 |

17 |
18 | 19 | ```python 20 | transformer(U, theta, out_size) 21 | ``` 22 | 23 | #### Parameters 24 | 25 | U : float 26 | The output of a convolutional net should have the 27 | shape [num_batch, height, width, num_channels]. 28 | theta: float 29 | The output of the 30 | localisation network should be [num_batch, 6]. 31 | out_size: tuple of two ints 32 | The size of the output of the network 33 | 34 | 35 | #### Notes 36 | To initialize the network to the identity transform init ``theta`` to : 37 | 38 | ```python 39 | identity = np.array([[1., 0., 0.], 40 | [0., 1., 0.]]) 41 | identity = identity.flatten() 42 | theta = tf.Variable(initial_value=identity) 43 | ``` 44 | 45 | #### Experiments 46 | 47 |
48 |

49 |
50 | 51 | We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN. 52 | 53 | All experiments were run in Tensorflow 0.7. 54 | 55 | ### References 56 | 57 | [1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015) 58 | 59 | [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 60 | -------------------------------------------------------------------------------- /transformer/data/README.md: -------------------------------------------------------------------------------- 1 | ### How to get the data 2 | 3 | #### Cluttered MNIST 4 | 5 | The cluttered MNIST dataset can be found here [1] or can be generated via [2]. 6 | 7 | Settings used for `cluttered_mnist.py` : 8 | 9 | ```python 10 | 11 | ORG_SHP = [28, 28] 12 | OUT_SHP = [40, 40] 13 | NUM_DISTORTIONS = 8 14 | dist_size = (5, 5) 15 | 16 | ``` 17 | 18 | [1] https://github.com/daviddao/spatial-transformer-tensorflow 19 | 20 | [2] https://github.com/skaae/recurrent-spatial-transformer-code/blob/master/MNIST_SEQUENCE/create_mnist_sequence.py -------------------------------------------------------------------------------- /transformer/example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from scipy import ndimage 16 | import tensorflow as tf 17 | from spatial_transformer import transformer 18 | import numpy as np 19 | import matplotlib.pyplot as plt 20 | 21 | # %% Create a batch of three images (1600 x 1200) 22 | # %% Image retrieved from: 23 | # %% https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg 24 | im = ndimage.imread('cat.jpg') 25 | im = im / 255. 26 | im = im.reshape(1, 1200, 1600, 3) 27 | im = im.astype('float32') 28 | 29 | # %% Let the output size of the transformer be half the image size. 30 | out_size = (600, 800) 31 | 32 | # %% Simulate batch 33 | batch = np.append(im, im, axis=0) 34 | batch = np.append(batch, im, axis=0) 35 | num_batch = 3 36 | 37 | x = tf.placeholder(tf.float32, [None, 1200, 1600, 3]) 38 | x = tf.cast(batch, 'float32') 39 | 40 | # %% Create localisation network and convolutional layer 41 | with tf.variable_scope('spatial_transformer_0'): 42 | 43 | # %% Create a fully-connected layer with 6 output nodes 44 | n_fc = 6 45 | W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1') 46 | 47 | # %% Zoom into the image 48 | initial = np.array([[0.5, 0, 0], [0, 0.5, 0]]) 49 | initial = initial.astype('float32') 50 | initial = initial.flatten() 51 | 52 | b_fc1 = tf.Variable(initial_value=initial, name='b_fc1') 53 | h_fc1 = tf.matmul(tf.zeros([num_batch, 1200 * 1600 * 3]), W_fc1) + b_fc1 54 | h_trans = transformer(x, h_fc1, out_size) 55 | 56 | # %% Run session 57 | sess = tf.Session() 58 | sess.run(tf.initialize_all_variables()) 59 | y = sess.run(h_trans, feed_dict={x: batch}) 60 | 61 | # plt.imshow(y[0]) 62 | -------------------------------------------------------------------------------- /video_prediction/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 The TensorFlow Authors All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | # Example: 19 | # 20 | # download_dataset.sh datafiles.txt ./tmp 21 | # 22 | # will download all of the files listed in the file, datafiles.txt, into 23 | # a directory, "./tmp". 24 | # 25 | # Each line of the datafiles.txt file should contain the path from the 26 | # bucket root to a file. 27 | 28 | ARGC="$#" 29 | LISTING_FILE=push_datafiles.txt 30 | if [ "${ARGC}" -ge 1 ]; then 31 | LISTING_FILE=$1 32 | fi 33 | OUTPUT_DIR="./" 34 | if [ "${ARGC}" -ge 2 ]; then 35 | OUTPUT_DIR=$2 36 | fi 37 | 38 | echo "OUTPUT_DIR=$OUTPUT_DIR" 39 | 40 | mkdir "${OUTPUT_DIR}" 41 | 42 | function download_file { 43 | FILE=$1 44 | BUCKET="https://storage.googleapis.com/brain-robotics-data" 45 | URL="${BUCKET}/${FILE}" 46 | OUTPUT_FILE="${OUTPUT_DIR}/${FILE}" 47 | DIRECTORY=`dirname ${OUTPUT_FILE}` 48 | echo DIRECTORY=$DIRECTORY 49 | mkdir -p "${DIRECTORY}" 50 | curl --output ${OUTPUT_FILE} ${URL} 51 | } 52 | 53 | while read filename; do 54 | download_file $filename 55 | done <${LISTING_FILE} 56 | --------------------------------------------------------------------------------