├── .gitignore ├── Dockerfile ├── LICENSE ├── NOTICE ├── README.md ├── examples ├── predict.py └── speedup.py ├── requirements.txt ├── setup.py └── threaded_estimator ├── __init__.py ├── iris_data.py ├── models.py └── tests ├── __init__.py └── test_flower_estimator.py /.gitignore: -------------------------------------------------------------------------------- 1 | .cache 2 | trained_models 3 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6.2 2 | 3 | WORKDIR /eai 4 | 5 | COPY ./requirements.txt ./requirements.txt 6 | RUN pip install -r requirements.txt 7 | 8 | COPY ./threaded_estimator ./threaded_estimator 9 | COPY ./setup.py ./setup.py 10 | RUN pip install . 11 | 12 | CMD bash -c "pytest -s --full-trace /eai/threaded_estimator/tests" 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 ServiceNow 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2021 ServiceNow, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | *ServiceNow completed its acquisition of Element AI on January 8, 2021. All references to Element AI in the materials that are part of this project should refer to ServiceNow.* 2 | 3 | # Multithreaded-estimators 4 | 5 | Code demonstrating how to use multithreading to speedup inference for Tensorflow estimators. 6 | 7 | ## Installation 8 | 9 | A Dockerfile is provided. First build the image from the root directory: 10 | 11 | ``` 12 | docker build . -t threaded 13 | ``` 14 | 15 | Then run the tests: 16 | 17 | ``` 18 | docker run threaded 19 | ``` 20 | 21 | ## License 22 | 23 | This code is released under an Apache 2 license. See [the license in full](LICENSE). 24 | -------------------------------------------------------------------------------- /examples/predict.py: -------------------------------------------------------------------------------- 1 | from threaded_estimator.models import FlowerClassifier 2 | import tensorflow as tf 3 | 4 | tf.logging.set_verbosity(tf.logging.INFO) 5 | fc = FlowerClassifier(model_path='../trained_models') 6 | 7 | fc.train(steps=1000) 8 | 9 | predict_x = { 10 | 'SepalLength': [5.1], 11 | 'SepalWidth': [3.3], 12 | 'PetalLength': [1.7], 13 | 'PetalWidth': [0.5], 14 | } 15 | 16 | p1 = list(fc.predict(predict_x)) 17 | # INFO:tensorflow:Restoring parameters from ./trained_models/model.ckpt-5000 18 | 19 | p2 = list(fc.predict(predict_x)) 20 | # INFO:tensorflow:Restoring parameters from ./trained_models/model.ckpt-5000 21 | 22 | 23 | -------------------------------------------------------------------------------- /examples/speedup.py: -------------------------------------------------------------------------------- 1 | import time 2 | import tensorflow as tf 3 | from threaded_estimator.models import FlowerClassifier, FlowerClassifierThreaded 4 | 5 | 6 | tf.logging.set_verbosity(tf.logging.INFO) 7 | 8 | predict_x = { 9 | 'SepalLength': [5.1], 10 | 'SepalWidth': [3.3], 11 | 'PetalLength': [1.7], 12 | 'PetalWidth': [0.5], 13 | } 14 | 15 | fe_threaded = FlowerClassifierThreaded(threaded=True) 16 | fe_unthreaded = FlowerClassifier() 17 | 18 | n_epochs = 100 19 | 20 | print('starting unthreaded') 21 | t1 = time.time() 22 | for _ in range(n_epochs): 23 | predictions = list(fe_unthreaded.predict(features=predict_x)) 24 | 25 | print('starting threaded') 26 | t2 = time.time() 27 | for _ in range(n_epochs): 28 | predictions = list(fe_threaded.predict(features=predict_x)) 29 | 30 | t3 = time.time() 31 | 32 | unthreaded_time = (t2 - t1) 33 | threaded_time = (t3 - t2) 34 | 35 | assert unthreaded_time > threaded_time 36 | 37 | print(f'Threaded time was {threaded_time}s;\n' 38 | f'Unthreaded time was {unthreaded_time}s; \n' 39 | f'Threaded was {unthreaded_time/threaded_time} times faster!') 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.5.0 2 | numpy==1.14.2 3 | pandas==0.22.0 4 | pytest==3.2.1 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='multithreading-estimators', 5 | version='', 6 | packages=['threaded_estimator'], 7 | url='', 8 | license='', 9 | author='archydeberker', 10 | author_email='archy@elementai.com', 11 | description='', 12 | requires=['tensorflow'], 13 | ) 14 | -------------------------------------------------------------------------------- /threaded_estimator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/multithreaded-estimators/7161a21af00adce86ee2f360cbdf6729db32fda1/threaded_estimator/__init__.py -------------------------------------------------------------------------------- /threaded_estimator/iris_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import tensorflow as tf 3 | 4 | TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv" 5 | TEST_URL = "http://download.tensorflow.org/data/iris_test.csv" 6 | 7 | CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 8 | 'PetalLength', 'PetalWidth', 'Species'] 9 | SPECIES = ['Setosa', 'Versicolor', 'Virginica'] 10 | 11 | def maybe_download(): 12 | train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL) 13 | test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL) 14 | 15 | return train_path, test_path 16 | 17 | def load_data(y_name='Species'): 18 | """Returns the iris dataset as (train_x, train_y), (test_x, test_y).""" 19 | train_path, test_path = maybe_download() 20 | 21 | train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0) 22 | train_x, train_y = train, train.pop(y_name) 23 | 24 | test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0) 25 | test_x, test_y = test, test.pop(y_name) 26 | 27 | return (train_x, train_y), (test_x, test_y) 28 | 29 | 30 | # The remainder of this file contains a simple example of a csv parser, 31 | # implemented using a the `Dataset` class. 32 | 33 | # `tf.parse_csv` sets the types of the outputs to match the examples given in 34 | # the `record_defaults` argument. 35 | CSV_TYPES = [[0.0], [0.0], [0.0], [0.0], [0]] 36 | 37 | def _parse_line(line): 38 | # Decode the line into its fields 39 | fields = tf.decode_csv(line, record_defaults=CSV_TYPES) 40 | 41 | # Pack the result into a dictionary 42 | features = dict(zip(CSV_COLUMN_NAMES, fields)) 43 | 44 | # Separate the label from the features 45 | label = features.pop('Species') 46 | 47 | return features, label 48 | 49 | 50 | def csv_input_fn(csv_path, batch_size): 51 | # Create a dataset containing the text lines. 52 | dataset = tf.data.TextLineDataset(csv_path).skip(1) 53 | 54 | # Parse each line. 55 | dataset = dataset.map(_parse_line) 56 | 57 | # Shuffle, repeat, and batch the examples. 58 | dataset = dataset.shuffle(1000).repeat().batch(batch_size) 59 | 60 | # Return the dataset. 61 | return dataset -------------------------------------------------------------------------------- /threaded_estimator/models.py: -------------------------------------------------------------------------------- 1 | """ Module to expose trained models for inference.""" 2 | 3 | import tensorflow as tf 4 | 5 | from tensorflow.contrib.learn import RunConfig 6 | from queue import Queue 7 | from threading import Thread 8 | 9 | from threaded_estimator import iris_data 10 | 11 | 12 | class FlowerClassifier: 13 | """ A light wrapper to handle training and inference with the Iris classifier here: 14 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/learn/iris.py 15 | 16 | Based upon a simple implementation of the canned TF DNNClassifier Estimator, see the tutorial here: 17 | https://www.tensorflow.org/versions/r1.3/get_started/estimator#construct_a_deep_neural_network_classifier 18 | 19 | """ 20 | 21 | def __init__(self, model_path='./trained_models/', 22 | verbose=False): 23 | """ 24 | Parameters 25 | ---------- 26 | model_path: str 27 | Location from which to load the model. 28 | 29 | verbose: Bool 30 | Whether to print various messages. 31 | """ 32 | 33 | (self.train_x, self.train_y), (self.test_x, self.test_y) = iris_data.load_data() 34 | self.batch_size = 32 35 | 36 | self.model_path = model_path 37 | 38 | self.estimator = self.load_estimator() 39 | 40 | self.verbose = verbose 41 | 42 | def predict(self, features): 43 | """ 44 | Vanilla prediction function. Returns a generator 45 | 46 | Intended for single-shot usage. 47 | 48 | Parameters 49 | ---------- 50 | features: dict 51 | dict of input features, containing keys 'SepalLength' 52 | 'SepalWidth' 53 | 'PetalLength' 54 | 'PetalWidth' 55 | 56 | Returns 57 | ------- 58 | predictions: generator 59 | Yields dictionaries containing 'probs' 60 | 'outputs' 61 | 'predicted_class' 62 | 63 | """ 64 | 65 | return self.estimator.predict(input_fn=lambda: self.predict_input_fn(features)) 66 | 67 | def load_estimator(self): 68 | """ 69 | 70 | Returns 71 | ------- 72 | estimator 73 | A tf.estimator.DNNClassifier 74 | 75 | """ 76 | 77 | # Feature columns describe how to use the input. 78 | my_feature_columns = [] 79 | for key in self.train_x.keys(): 80 | my_feature_columns.append(tf.feature_column.numeric_column(key=key)) 81 | 82 | run_config = RunConfig() 83 | run_config = run_config.replace(model_dir=self.model_path) 84 | 85 | return tf.estimator.DNNClassifier( 86 | feature_columns=my_feature_columns, 87 | # Two hidden layers of 10 nodes each. 88 | hidden_units=[10, 10], 89 | # The model must choose between 3 classes. 90 | n_classes=3, 91 | # Use runconfig to load model, 92 | config=run_config, 93 | model_dir=self.model_path) 94 | 95 | def train(self, steps): 96 | """ 97 | Parameters 98 | ---------- 99 | steps: int 100 | Number of steps to train for. 101 | 102 | """ 103 | self.estimator.train( 104 | input_fn=lambda: self.train_input_fn(self.train_x, self.train_y), 105 | steps=steps, 106 | ) 107 | 108 | def train_input_fn(self, features, labels): 109 | """ 110 | For background on the data, see iris_data.py 111 | 112 | Parameters 113 | ---------- 114 | features: pandas dataframe 115 | With columns 'SepalLength' 116 | 'SepalWidth' 117 | 'PetalLength' 118 | 'PetalWidth' 119 | 120 | labels: array 121 | Flower names 122 | 123 | Returns 124 | ------- 125 | dataset: generator 126 | Yields batches of size self.batch_size 127 | 128 | """ 129 | 130 | # Convert the inputs to a Dataset. 131 | dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) 132 | 133 | # Shuffle, repeat, and batch the examples. 134 | dataset = dataset.shuffle(1000).repeat().batch(self.batch_size) 135 | 136 | # Return the dataset. 137 | return dataset 138 | 139 | def predict_input_fn(self, features): 140 | """ 141 | 142 | Parameters 143 | ---------- 144 | features: pandas dataframe or dict 145 | with columns or keys 'SepalLength' 146 | 'SepalWidth' 147 | 'PetalLength' 148 | 'PetalWidth' 149 | 150 | Returns 151 | ------- 152 | dataset: generator 153 | Yields batches of size self.batch_size 154 | 155 | """ 156 | 157 | if self.verbose: 158 | print("Standard predict_input_n call") 159 | 160 | # Convert the inputs to a Dataset. 161 | dataset = tf.data.Dataset.from_tensor_slices(dict(features)) 162 | 163 | # Batch the examples 164 | assert self.batch_size is not None, "batch_size must not be None" 165 | dataset = dataset.batch(self.batch_size) 166 | 167 | # Return the batched dataset. This is an iterator returning batches. 168 | return dataset 169 | 170 | 171 | class FlowerClassifierThreaded(FlowerClassifier): 172 | 173 | def __init__(self, model_path='./trained_models/', 174 | threaded=True, 175 | verbose=False): 176 | """ 177 | Parameters 178 | ---------- 179 | model_path: str 180 | Location from which to load the model. 181 | 182 | threaded: Boolean [True] 183 | Whether to use multi-threaded execution for inference. 184 | If False, the model will use a new generator for each sample that is passed to it, and reload the entire 185 | model each time. 186 | 187 | 188 | """ 189 | 190 | super(FlowerClassifierThreaded, self).__init__(model_path=model_path, 191 | verbose=verbose) 192 | 193 | self.input_queue = Queue(maxsize=1) 194 | self.output_queue = Queue(maxsize=1) 195 | 196 | self.threaded = threaded 197 | 198 | if self.threaded: 199 | # We set the generator thread as daemon 200 | # (see https://docs.python.org/3/library/threading.html#threading.Thread.daemon) 201 | # This means that when all other threads are dead, 202 | # this thread will not prevent the Python program from exiting 203 | self.prediction_thread = Thread(target=self.predict_from_queue, daemon=True) 204 | self.prediction_thread.start() 205 | 206 | def generate_from_queue(self): 207 | """ Generator which yields items from the input queue. 208 | This lives within our 'prediction thread'. 209 | 210 | """ 211 | 212 | while True: 213 | if self.verbose: 214 | print('Yielding from input queue') 215 | yield self.input_queue.get() 216 | 217 | def predict_from_queue(self): 218 | """ Adds a prediction from the model to the output_queue. 219 | 220 | This lives within our 'prediction thread'. 221 | 222 | Note: estimators accept generators as inputs and return generators as output. Here, we are 223 | iterating through the output generator, which will be populated in lock-step with the input 224 | generator. 225 | 226 | """ 227 | 228 | for i in self.estimator.predict(input_fn=self.queued_predict_input_fn): 229 | if self.verbose: 230 | print('Putting in output queue') 231 | self.output_queue.put(i) 232 | 233 | def predict(self, features): 234 | """ 235 | Overwrites .predict in FlowerClassifierBasic. 236 | 237 | Calls either the vanilla or multi-threaded prediction methods based upon self.threaded. 238 | 239 | Parameters 240 | ---------- 241 | features: dict 242 | dict of input features, containing keys 'SepalLength' 243 | 'SepalWidth' 244 | 'PetalLength' 245 | 'PetalWidth' 246 | 247 | Returns 248 | ------- 249 | predictions: dict 250 | Dictionary containing 'probs' 251 | 'outputs' 252 | 'predicted_class' 253 | 254 | """ 255 | 256 | # Get predictions dictionary 257 | if self.threaded: 258 | features = dict(features) 259 | self.input_queue.put(features) 260 | predictions = self.output_queue.get() # The latest predictions generator 261 | else: 262 | predictions = self.estimator.predict(input_fn=lambda: self.predict_input_fn(features)) 263 | 264 | # TODO: list vs. generator vs. dict handling 265 | return predictions 266 | 267 | def queued_predict_input_fn(self): 268 | """ 269 | Queued version of the `predict_input_fn` in FlowerClassifier. 270 | 271 | Instead of yielding a dataset from data as a parameter, we construct a Dataset from a generator, 272 | which yields from the input queue. 273 | 274 | """ 275 | 276 | if self.verbose: 277 | print("QUEUED INPUT FUNCTION CALLED") 278 | 279 | # Fetch the inputs from the input queue 280 | dataset = tf.data.Dataset.from_generator(self.generate_from_queue, 281 | output_types={'SepalLength': tf.float32, 282 | 'SepalWidth': tf.float32, 283 | 'PetalLength': tf.float32, 284 | 'PetalWidth': tf.float32}) 285 | 286 | return dataset 287 | -------------------------------------------------------------------------------- /threaded_estimator/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/multithreaded-estimators/7161a21af00adce86ee2f360cbdf6729db32fda1/threaded_estimator/tests/__init__.py -------------------------------------------------------------------------------- /threaded_estimator/tests/test_flower_estimator.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from threaded_estimator import models 4 | import numpy as np 5 | import pytest 6 | import tensorflow as tf 7 | 8 | 9 | def test_iris_estimator_trains(): 10 | fe = models.FlowerClassifierThreaded(threaded=False) 11 | fe.train(steps=10) 12 | 13 | predict_x = { 14 | 'SepalLength': [5.1], 15 | 'SepalWidth': [3.3], 16 | 'PetalLength': [1.7], 17 | 'PetalWidth': [0.5], 18 | } 19 | 20 | 21 | def test_normal_input_fn(): 22 | fe = models.FlowerClassifierThreaded(threaded=False) 23 | ds = fe.predict_input_fn(predict_x) 24 | value = ds.make_one_shot_iterator().get_next() 25 | 26 | with tf.Session() as sess: 27 | features = sess.run(value) 28 | 29 | assert isinstance(features, dict) 30 | 31 | 32 | def test_predictions_change_with_training(): 33 | fe = models.FlowerClassifierThreaded(threaded=False) 34 | predictions1 = list(fe.predict(features=predict_x)) 35 | fe.train(steps=100) 36 | predictions2 = list(fe.predict(features=predict_x)) 37 | 38 | with pytest.raises(AssertionError): 39 | np.testing.assert_array_equal(predictions1[0]['logits'], 40 | predictions2[0]['logits']) 41 | 42 | 43 | @pytest.mark.parametrize('threaded', [False, True]) 44 | def test_iris_estimator_predict_deterministic(threaded): 45 | fe = models.FlowerClassifierThreaded(threaded=threaded) 46 | predictions1 = fe.predict(features=predict_x) 47 | predictions2 = fe.predict(features=predict_x) 48 | 49 | if not threaded: 50 | predictions1 = list(predictions1)[0] 51 | predictions2 = list(predictions2)[0] 52 | 53 | print(threaded, predictions1) 54 | print(threaded, predictions2) 55 | np.testing.assert_array_equal(predictions1['logits'], 56 | predictions2['logits']) 57 | 58 | 59 | def test_threaded_faster_than_non_threaded(): 60 | 61 | fe_threaded = models.FlowerClassifierThreaded(threaded=True) 62 | fe_unthreaded = models.FlowerClassifier() 63 | 64 | n_epochs = 100 65 | 66 | print('starting unthreaded') 67 | t1 = time.time() 68 | for _ in range(n_epochs): 69 | predictions = list(fe_unthreaded.predict(features=predict_x)) 70 | 71 | print('starting threaded') 72 | t2 = time.time() 73 | for _ in range(n_epochs): 74 | predictions = list(fe_threaded.predict(features=predict_x)) 75 | 76 | t3 = time.time() 77 | 78 | unthreaded_time = (t2-t1) 79 | threaded_time = (t3-t2) 80 | 81 | assert unthreaded_time > threaded_time 82 | 83 | print(f'Threaded time was {threaded_time}; s\n' 84 | f'Unthreaded time was {unthreaded_time}; s\n' 85 | f'Threaded was {unthreaded_time/threaded_time} times faster!') 86 | 87 | 88 | 89 | 90 | --------------------------------------------------------------------------------