├── .gitignore ├── LICENSE ├── README.md ├── fake_display.sh ├── gym_examples ├── test_cart_pole.py ├── test_mountain_car.py └── test_pacman.py ├── gym_models ├── README.md └── cartpole_model │ └── 00000001 │ ├── checkpoint │ ├── export.data-00000-of-00001 │ ├── export.index │ └── export.meta ├── play_game.py └── python_predict_client ├── README.md ├── gym_agent.py ├── model_pb2.py ├── predict_client.py ├── predict_pb2.py └── prediction_service_pb2.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Q 2 | 3 | ## Introduction 4 | 5 | The deep reinforcement learning example with TensorFlow. 6 | 7 | It's based on gym and Q-learning algorithm. It provides the trainable example with native TensorFlow APIs and you can use it for all `gym` games. 8 | 9 | ## Usage 10 | 11 | ### CartPole 12 | 13 | ``` 14 | ./play_game.py 15 | ``` 16 | 17 | ### MountainCar 18 | 19 | ``` 20 | ./play_game.py --mode train --gym_env MountainCar-v0 --checkpoint ./checkpoint_mountain 21 | ``` 22 | 23 | ### Pacman 24 | 25 | ``` 26 | ./play_game.py --mode train --gym_env MsPacman-v0 --checkpoint ./checkpoint_pacman --model cnn 27 | ``` 28 | 29 | ## Test 30 | 31 | ``` 32 | ./play_game.py --mode untrained 33 | ``` 34 | 35 | ``` 36 | ./play_game.py --mode inference 37 | ``` 38 | -------------------------------------------------------------------------------- /fake_display.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | set -e 5 | 6 | sudo apt-get install -y xvfb 7 | 8 | xvfb-run -s "-screen 0 1400x900x24" bash 9 | -------------------------------------------------------------------------------- /gym_examples/test_cart_pole.py: -------------------------------------------------------------------------------- 1 | import gym 2 | env = gym.make('CartPole-v0') 3 | env.reset() 4 | for _ in range(1000): 5 | env.render() 6 | env.step(env.action_space.sample()) # take a random action 7 | -------------------------------------------------------------------------------- /gym_examples/test_mountain_car.py: -------------------------------------------------------------------------------- 1 | import gym 2 | env = gym.make('MountainCar-v0') 3 | env.reset() 4 | for _ in range(1000): 5 | env.render() 6 | env.step(env.action_space.sample()) # take a random action 7 | -------------------------------------------------------------------------------- /gym_examples/test_pacman.py: -------------------------------------------------------------------------------- 1 | import gym 2 | env = gym.make('MsPacman-v0') 3 | env.reset() 4 | for _ in range(1000): 5 | env.render() 6 | env.step(env.action_space.sample()) # take a random action 7 | -------------------------------------------------------------------------------- /gym_models/README.md: -------------------------------------------------------------------------------- 1 | # Gym Models 2 | 3 | ## Introduction 4 | 5 | We provide the model zoo for gym models. You can access the trained models with generic gym agent. 6 | 7 | ## Start server 8 | 9 | ``` 10 | nohup ./tensorflow_model_server --port=9001 --model_name=cartpole --model_base_path=./cartpole_model/ & 11 | ``` 12 | 13 | ## Play CartPole 14 | 15 | ``` 16 | ./gym_agent.py --host 139.162.72.39 --port 9001 --model_name cartpole --gym_env CartPole-v1 17 | ``` 18 | -------------------------------------------------------------------------------- /gym_models/cartpole_model/00000001/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "export" 2 | all_model_checkpoint_paths: "export" 3 | -------------------------------------------------------------------------------- /gym_models/cartpole_model/00000001/export.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobegit3hub/deep_q/2cd7cc726f3cb6b3316fdc2b8c4c1800c7de6222/gym_models/cartpole_model/00000001/export.data-00000-of-00001 -------------------------------------------------------------------------------- /gym_models/cartpole_model/00000001/export.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobegit3hub/deep_q/2cd7cc726f3cb6b3316fdc2b8c4c1800c7de6222/gym_models/cartpole_model/00000001/export.index -------------------------------------------------------------------------------- /gym_models/cartpole_model/00000001/export.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobegit3hub/deep_q/2cd7cc726f3cb6b3316fdc2b8c4c1800c7de6222/gym_models/cartpole_model/00000001/export.meta -------------------------------------------------------------------------------- /play_game.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from collections import deque 4 | import gym 5 | import numpy as np 6 | import os 7 | import random 8 | import tensorflow as tf 9 | from tensorflow.contrib.session_bundle import exporter 10 | import time 11 | 12 | # Define parameters 13 | flags = tf.app.flags 14 | FLAGS = flags.FLAGS 15 | flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.') 16 | flags.DEFINE_integer('episode_number', 100, 17 | 'Number of episode to run trainer.') 18 | flags.DEFINE_integer('episode_step_number', 10000, 19 | 'Number of steps for each episode.') 20 | flags.DEFINE_integer("batch_size", 32, "The batch size for training") 21 | flags.DEFINE_string("checkpoint_dir", "./checkpoint/", 22 | "indicates the checkpoint dirctory") 23 | flags.DEFINE_string("tensorboard_dir", "./tensorboard/", 24 | "indicates training output") 25 | flags.DEFINE_string("optimizer", "adam", "optimizer to train") 26 | flags.DEFINE_integer('episode_to_validate', 1, 27 | 'Steps to validate and print loss') 28 | flags.DEFINE_string("model", "dnn", "The model to train, dnn or cnn") 29 | flags.DEFINE_boolean("enable_bn", False, "Enable batch normalization or not") 30 | flags.DEFINE_float("bn_epsilon", 0.001, "The epsilon of batch normalization") 31 | flags.DEFINE_string("mode", "train", "Opetion mode: train, inference") 32 | flags.DEFINE_string("gym_env", "CartPole-v0", 33 | "The gym env, like 'CartPole-v0' or 'MountainCar-v0'") 34 | flags.DEFINE_float("discount_factor", 0.9, "Discount factor for Q-learning") 35 | flags.DEFINE_integer("experience_replay_size", 10000, "Relay buffer size") 36 | flags.DEFINE_float("exploration_exploitation_epsilon", 0.5, 37 | "The epsilon to select action") 38 | flags.DEFINE_boolean("render_game", True, "Render the gym in window or not") 39 | flags.DEFINE_float("render_sleep_time", 0.0, 40 | "Sleep time when render each frame") 41 | flags.DEFINE_string("model_path", "./model/", "The output path of the model") 42 | flags.DEFINE_integer("export_version", 1, "The version number of the model") 43 | 44 | 45 | def main(): 46 | print("Start playing game") 47 | 48 | # Initial Gym environement 49 | env = gym.make(FLAGS.gym_env) 50 | experience_replay_queue = deque() 51 | action_number = env.action_space.n 52 | # The shape of CarPole is [4, 0], Pacman is [210, 160, 3] 53 | state_number = env.observation_space.shape[0] 54 | if len(env.observation_space.shape) >= 3: 55 | state_number2 = env.observation_space.shape[1] 56 | state_number3 = env.observation_space.shape[2] 57 | else: 58 | state_number2 = env.observation_space.shape[0] 59 | state_number3 = env.observation_space.shape[0] 60 | 61 | # Define dnn model 62 | def dnn_inference(inputs, is_train=True): 63 | # The inputs is [BATCH_SIZE, state_number], outputs is [BATCH_SIZE, action_number] 64 | hidden1_unit_number = 20 65 | with tf.variable_scope("fc1"): 66 | weights = tf.get_variable("weight", 67 | [state_number, hidden1_unit_number], 68 | initializer=tf.random_normal_initializer()) 69 | bias = tf.get_variable("bias", 70 | [hidden1_unit_number], 71 | initializer=tf.random_normal_initializer()) 72 | layer = tf.add(tf.matmul(inputs, weights), bias) 73 | 74 | if FLAGS.enable_bn and is_train: 75 | mean, var = tf.nn.moments(layer, axes=[0]) 76 | scale = tf.get_variable("scale", 77 | hidden1_unit_number, 78 | initializer=tf.random_normal_initializer()) 79 | shift = tf.get_variable("shift", 80 | hidden1_unit_number, 81 | initializer=tf.random_normal_initializer()) 82 | layer = tf.nn.batch_normalization(layer, mean, var, shift, scale, 83 | FLAGS.bn_epsilon) 84 | 85 | layer = tf.nn.relu(layer) 86 | 87 | with tf.variable_scope("fc2"): 88 | weights = tf.get_variable("weight", 89 | [hidden1_unit_number, action_number], 90 | initializer=tf.random_normal_initializer()) 91 | bias = tf.get_variable("bias", 92 | [action_number], 93 | initializer=tf.random_normal_initializer()) 94 | layer = tf.add(tf.matmul(layer, weights), bias) 95 | 96 | return layer 97 | 98 | # Define cnn model 99 | def cnn_inference(inputs, is_train=True): 100 | LABEL_SIZE = action_number 101 | 102 | # The inputs is [BATCH_SIZE, 210, 160, 3], outputs is [BATCH_SIZE, action_number] 103 | with tf.variable_scope("conv1"): 104 | weights = tf.get_variable("weights", 105 | [3, 3, 3, 32], 106 | initializer=tf.random_normal_initializer()) 107 | bias = tf.get_variable("bias", 108 | [32], 109 | initializer=tf.random_normal_initializer()) 110 | 111 | # Should not use polling 112 | layer = tf.nn.conv2d(inputs, 113 | weights, 114 | strides=[1, 1, 1, 1], 115 | padding="SAME") 116 | layer = tf.nn.bias_add(layer, bias) 117 | layer = tf.nn.relu(layer) 118 | 119 | # The inputs is [BATCH_SIZE, 210, 160, 32], outputs is [BATCH_SIZE, 210, 160, 64] 120 | with tf.variable_scope("conv2"): 121 | weights = tf.get_variable("weights", 122 | [3, 3, 32, 64], 123 | initializer=tf.random_normal_initializer()) 124 | bias = tf.get_variable("bias", 125 | [64], 126 | initializer=tf.random_normal_initializer()) 127 | 128 | layer = tf.nn.conv2d(layer, 129 | weights, 130 | strides=[1, 1, 1, 1], 131 | padding="SAME") 132 | layer = tf.nn.bias_add(layer, bias) 133 | layer = tf.nn.relu(layer) 134 | 135 | # Reshape for full-connect network 136 | layer = tf.reshape(layer, [-1, 210 * 160 * 64]) 137 | 138 | # Full connected layer result: [BATCH_SIZE, LABEL_SIZE] 139 | with tf.variable_scope("fc1"): 140 | weights = tf.get_variable("weights", 141 | [210 * 160 * 64, LABEL_SIZE], 142 | initializer=tf.random_normal_initializer()) 143 | bias = tf.get_variable("bias", 144 | [LABEL_SIZE], 145 | initializer=tf.random_normal_initializer()) 146 | layer = tf.add(tf.matmul(layer, weights), bias) 147 | 148 | return layer 149 | 150 | # Define train op 151 | model = FLAGS.model 152 | print("Use the model: {}".format(model)) 153 | if model == "dnn": 154 | states_placeholder = tf.placeholder(tf.float32, [None, state_number]) 155 | inference = dnn_inference 156 | elif model == "cnn": 157 | states_placeholder = tf.placeholder(tf.float32, 158 | [None, state_number, state_number2, 159 | state_number3]) 160 | inference = cnn_inference 161 | else: 162 | print("Unknow model, exit now") 163 | exit(1) 164 | 165 | logit = inference(states_placeholder, True) 166 | actions_placeholder = tf.placeholder(tf.float32, [None, action_number]) 167 | predict_rewords = tf.reduce_sum( 168 | tf.multiply(logit, actions_placeholder), 169 | reduction_indices=1) 170 | rewards_placeholder = tf.placeholder(tf.float32, [None]) 171 | loss = tf.reduce_mean(tf.square(rewards_placeholder - predict_rewords)) 172 | 173 | learning_rate = FLAGS.learning_rate 174 | print("Use the optimizer: {}".format(FLAGS.optimizer)) 175 | if FLAGS.optimizer == "sgd": 176 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 177 | elif FLAGS.optimizer == "adadelta": 178 | optimizer = tf.train.AdadeltaOptimizer(learning_rate) 179 | elif FLAGS.optimizer == "adagrad": 180 | optimizer = tf.train.AdagradOptimizer(learning_rate) 181 | elif FLAGS.optimizer == "adam": 182 | optimizer = tf.train.AdamOptimizer(learning_rate) 183 | elif FLAGS.optimizer == "ftrl": 184 | optimizer = tf.train.FtrlOptimizer(learning_rate) 185 | elif FLAGS.optimizer == "rmsprop": 186 | optimizer = tf.train.RMSPropOptimizer(learning_rate) 187 | else: 188 | print("Unknow optimizer: {}, exit now".format(FLAGS.optimizer)) 189 | exit(1) 190 | 191 | global_step = tf.Variable(0, name="global_step", trainable=False) 192 | train_op = optimizer.minimize(loss, global_step=global_step) 193 | # Get the action with most rewoard when giving the state 194 | batch_best_actions = tf.argmax(logit, 1) 195 | best_action = batch_best_actions[0] 196 | batch_best_q = tf.reduce_max(logit, 1) 197 | best_q = batch_best_q[0] 198 | 199 | if not os.path.exists(FLAGS.checkpoint_dir): 200 | os.makedirs(FLAGS.checkpoint_dir) 201 | checkpoint_file = FLAGS.checkpoint_dir + "/checkpoint.ckpt" 202 | init_op = tf.global_variables_initializer() 203 | saver = tf.train.Saver() 204 | tf.summary.scalar("loss", loss) 205 | 206 | # Create session 207 | with tf.Session() as sess: 208 | summary_op = tf.summary.merge_all() 209 | writer = tf.summary.FileWriter(FLAGS.tensorboard_dir, sess.graph) 210 | sess.run(init_op) 211 | 212 | if FLAGS.mode == "train": 213 | # Restore from checkpoint if it exists 214 | ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 215 | if ckpt and ckpt.model_checkpoint_path: 216 | print("Restore model from the file {}".format( 217 | ckpt.model_checkpoint_path)) 218 | saver.restore(sess, ckpt.model_checkpoint_path) 219 | 220 | for episode in range(FLAGS.episode_number): 221 | # Start new epoisode to train 222 | state = env.reset() 223 | loss_value = -1 224 | 225 | for step in xrange(FLAGS.episode_step_number): 226 | # Get action from exploration and exploitation 227 | if random.random() <= FLAGS.exploration_exploitation_epsilon: 228 | action = random.randint(0, action_number - 1) 229 | else: 230 | action = sess.run(best_action, 231 | feed_dict={states_placeholder: [state]}) 232 | 233 | # Run this action on this state 234 | next_state, reward, done, _ = env.step(action) 235 | 236 | # Get new state add to replay experience queue 237 | one_hot_action = np.zeros(action_number) 238 | one_hot_action[action] = 1 239 | experience_replay_queue.append((state, one_hot_action, reward, 240 | next_state, done)) 241 | if len(experience_replay_queue) > FLAGS.experience_replay_size: 242 | experience_replay_queue.popleft() 243 | 244 | # Get enough data to train with batch 245 | if len(experience_replay_queue) > FLAGS.batch_size: 246 | 247 | # Get batch experience replay to train 248 | batch_data = random.sample(experience_replay_queue, 249 | FLAGS.batch_size) 250 | batch_states = [] 251 | batch_actions = [] 252 | batch_rewards = [] 253 | batch_next_states = [] 254 | expected_rewards = [] 255 | for experience_replay in batch_data: 256 | batch_states.append(experience_replay[0]) 257 | batch_actions.append(experience_replay[1]) 258 | batch_rewards.append(experience_replay[2]) 259 | batch_next_states.append(experience_replay[3]) 260 | 261 | # Get expected reword 262 | done = experience_replay[4] 263 | if done: 264 | expected_rewards.append(experience_replay[2]) 265 | else: 266 | # TODO: need to optimizer and compute within TensorFlow 267 | next_best_q = sess.run( 268 | best_q, 269 | feed_dict={states_placeholder: [experience_replay[3]]}) 270 | expected_rewards.append(experience_replay[2] + 271 | FLAGS.discount_factor * next_best_q) 272 | 273 | _, loss_value, step = sess.run( 274 | [train_op, loss, global_step], 275 | feed_dict={ 276 | rewards_placeholder: expected_rewards, 277 | actions_placeholder: batch_actions, 278 | states_placeholder: batch_states 279 | }) 280 | 281 | else: 282 | print("Add more data to train with batch") 283 | 284 | state = next_state 285 | if done: 286 | break 287 | 288 | # Validate for some episode 289 | if episode % FLAGS.episode_to_validate == 0: 290 | print("Episode: {}, global step: {}, the loss: {}".format( 291 | episode, step, loss_value)) 292 | 293 | state = env.reset() 294 | total_reward = 0 295 | 296 | for i in xrange(FLAGS.episode_step_number): 297 | if FLAGS.render_game: 298 | time.sleep(FLAGS.render_sleep_time) 299 | env.render() 300 | 301 | action = sess.run(best_action, 302 | feed_dict={states_placeholder: [state]}) 303 | state, reward, done, _ = env.step(action) 304 | total_reward += reward 305 | if done: 306 | break 307 | 308 | print("Eposide: {}, total reward: {}".format(episode, total_reward)) 309 | saver.save(sess, checkpoint_file, global_step=step) 310 | 311 | # End of training process 312 | model_exporter = exporter.Exporter(saver) 313 | model_exporter.init(sess.graph.as_graph_def(), 314 | named_graph_signatures={ 315 | 'inputs': exporter.generic_signature({ 316 | "states": states_placeholder 317 | }), 318 | 'outputs': exporter.generic_signature({ 319 | "actions": batch_best_actions 320 | }) 321 | }) 322 | model_exporter.export(FLAGS.model_path, 323 | tf.constant(FLAGS.export_version), sess) 324 | print "Done exporting!" 325 | 326 | elif FLAGS.mode == "untrained": 327 | total_reward = 0 328 | state = env.reset() 329 | 330 | for i in xrange(FLAGS.episode_step_number): 331 | if FLAGS.render_game: 332 | time.sleep(FLAGS.render_sleep_time) 333 | env.render() 334 | action = env.action_space.sample() 335 | next_state, reward, done, _ = env.step(action) 336 | total_reward += reward 337 | 338 | if done: 339 | print("End of untrained because of done, reword: {}".format( 340 | total_reward)) 341 | break 342 | 343 | elif FLAGS.mode == "inference": 344 | # Restore from checkpoint if it exists 345 | ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 346 | if ckpt and ckpt.model_checkpoint_path: 347 | print("Restore model from the file {}".format( 348 | ckpt.model_checkpoint_path)) 349 | saver.restore(sess, ckpt.model_checkpoint_path) 350 | else: 351 | print("Model not found, exit now") 352 | exit(0) 353 | 354 | total_reward = 0 355 | state = env.reset() 356 | 357 | index = 1 358 | while True: 359 | time.sleep(FLAGS.render_sleep_time) 360 | if FLAGS.render_game: 361 | env.render() 362 | 363 | action = sess.run(best_action, feed_dict={states_placeholder: [state]}) 364 | next_state, reward, done, _ = env.step(action) 365 | state = next_state 366 | total_reward += reward 367 | 368 | if done: 369 | print("End of inference because of done, reword: {}".format( 370 | total_reward)) 371 | break 372 | else: 373 | if total_reward > index * 100: 374 | print("Not done yet, current reword: {}".format(total_reward)) 375 | index += 1 376 | 377 | else: 378 | print("Unknown mode: {}".format(FLAGS.mode)) 379 | 380 | print("End of playing game") 381 | 382 | 383 | if __name__ == "__main__": 384 | main() 385 | -------------------------------------------------------------------------------- /python_predict_client/README.md: -------------------------------------------------------------------------------- 1 | # Generic Gym Agent 2 | 3 | ## Introduction 4 | 5 | It is the generic gym agent for any gym environments. 6 | 7 | You can export TensorFlow models and run with TensorFlow serving. Use `gym_agent.py` to play games with the trained models. 8 | 9 | ## Run server 10 | 11 | ``` 12 | ./tensorflow_model_server --port=9000 --model_name=cartpole --model_base_path=/home/tobe/code/deep_q/model 13 | ``` 14 | 15 | ## Predict client 16 | 17 | ``` 18 | ./predict_client.py --host 127.0.0.1 --port 9000 --model_name cartpole 19 | ``` 20 | 21 | ## Gym agent 22 | 23 | ``` 24 | ./gym_agent.py --host 127.0.0.1 --port 9000 --model_name cartpole --render_game False --gym_env CartPole-v0 25 | ``` 26 | -------------------------------------------------------------------------------- /python_predict_client/gym_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from grpc.beta import implementations 4 | import gym 5 | import numpy 6 | import tensorflow as tf 7 | import time 8 | 9 | import predict_pb2 10 | import prediction_service_pb2 11 | 12 | flags = tf.app.flags 13 | FLAGS = flags.FLAGS 14 | flags.DEFINE_string("host", "127.0.0.1", "gRPC server host") 15 | flags.DEFINE_integer("port", 9000, "gRPC server port") 16 | flags.DEFINE_string("model_name", "deep_q", "TensorFlow model name") 17 | flags.DEFINE_integer("model_version", 1, "TensorFlow model version") 18 | flags.DEFINE_float("request_timeout", 10.0, "Timeout of gRPC request") 19 | flags.DEFINE_string("gym_env", "CartPole-v0", 20 | "The gym env, like 'CartPole-v0' or 'MountainCar-v0'") 21 | flags.DEFINE_boolean("render_game", True, "Render the gym in window or not") 22 | 23 | 24 | def main(): 25 | host = FLAGS.host 26 | port = FLAGS.port 27 | model_name = FLAGS.model_name 28 | model_version = FLAGS.model_version 29 | request_timeout = FLAGS.request_timeout 30 | 31 | # Create gRPC client and request 32 | channel = implementations.insecure_channel(host, port) 33 | stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) 34 | request = predict_pb2.PredictRequest() 35 | request.model_spec.name = model_name 36 | if model_version > 0: 37 | request.model_spec.version.value = model_version 38 | 39 | env = gym.make(FLAGS.gym_env) 40 | state = env.reset() 41 | total_reward = 0 42 | 43 | while True: 44 | if FLAGS.render_game: 45 | time.sleep(0.1) 46 | env.render() 47 | 48 | # Generate inference data 49 | features = numpy.asarray([state]) 50 | features_tensor_proto = tf.contrib.util.make_tensor_proto(features, 51 | dtype=tf.float32) 52 | request.inputs['states'].CopyFrom(features_tensor_proto) 53 | 54 | # Send request 55 | result = stub.Predict(request, request_timeout) 56 | action = int(result.outputs.get("actions").int64_val[0]) 57 | 58 | next_state, reward, done, info = env.step(action) 59 | total_reward += reward 60 | state = next_state 61 | 62 | if done: 63 | print("End of the game, reward: {}".format(total_reward)) 64 | break 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /python_predict_client/model_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: model.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from google.protobuf import wrappers_pb2 as google_dot_protobuf_dot_wrappers__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='model.proto', 21 | package='tensorflow.serving', 22 | syntax='proto3', 23 | serialized_pb=_b('\n\x0bmodel.proto\x12\x12tensorflow.serving\x1a\x1egoogle/protobuf/wrappers.proto\"_\n\tModelSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12,\n\x07version\x18\x02 \x01(\x0b\x32\x1b.google.protobuf.Int64Value\x12\x16\n\x0esignature_name\x18\x03 \x01(\tB\x03\xf8\x01\x01\x62\x06proto3') 24 | , 25 | dependencies=[google_dot_protobuf_dot_wrappers__pb2.DESCRIPTOR,]) 26 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 27 | 28 | 29 | 30 | 31 | _MODELSPEC = _descriptor.Descriptor( 32 | name='ModelSpec', 33 | full_name='tensorflow.serving.ModelSpec', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='name', full_name='tensorflow.serving.ModelSpec.name', index=0, 40 | number=1, type=9, cpp_type=9, label=1, 41 | has_default_value=False, default_value=_b("").decode('utf-8'), 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='version', full_name='tensorflow.serving.ModelSpec.version', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | _descriptor.FieldDescriptor( 53 | name='signature_name', full_name='tensorflow.serving.ModelSpec.signature_name', index=2, 54 | number=3, type=9, cpp_type=9, label=1, 55 | has_default_value=False, default_value=_b("").decode('utf-8'), 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None), 59 | ], 60 | extensions=[ 61 | ], 62 | nested_types=[], 63 | enum_types=[ 64 | ], 65 | options=None, 66 | is_extendable=False, 67 | syntax='proto3', 68 | extension_ranges=[], 69 | oneofs=[ 70 | ], 71 | serialized_start=67, 72 | serialized_end=162, 73 | ) 74 | 75 | _MODELSPEC.fields_by_name['version'].message_type = google_dot_protobuf_dot_wrappers__pb2._INT64VALUE 76 | DESCRIPTOR.message_types_by_name['ModelSpec'] = _MODELSPEC 77 | 78 | ModelSpec = _reflection.GeneratedProtocolMessageType('ModelSpec', (_message.Message,), dict( 79 | DESCRIPTOR = _MODELSPEC, 80 | __module__ = 'model_pb2' 81 | # @@protoc_insertion_point(class_scope:tensorflow.serving.ModelSpec) 82 | )) 83 | _sym_db.RegisterMessage(ModelSpec) 84 | 85 | 86 | DESCRIPTOR.has_options = True 87 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) 88 | import grpc 89 | from grpc.beta import implementations as beta_implementations 90 | from grpc.beta import interfaces as beta_interfaces 91 | from grpc.framework.common import cardinality 92 | from grpc.framework.interfaces.face import utilities as face_utilities 93 | # @@protoc_insertion_point(module_scope) 94 | -------------------------------------------------------------------------------- /python_predict_client/predict_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import numpy 4 | 5 | from grpc.beta import implementations 6 | import tensorflow as tf 7 | 8 | import predict_pb2 9 | import prediction_service_pb2 10 | 11 | tf.app.flags.DEFINE_string("host", "127.0.0.1", "gRPC server host") 12 | tf.app.flags.DEFINE_integer("port", 9000, "gRPC server port") 13 | tf.app.flags.DEFINE_string("model_name", "deep_q", "TensorFlow model name") 14 | tf.app.flags.DEFINE_integer("model_version", 1, "TensorFlow model version") 15 | tf.app.flags.DEFINE_float("request_timeout", 10.0, "Timeout of gRPC request") 16 | FLAGS = tf.app.flags.FLAGS 17 | 18 | 19 | def main(): 20 | host = FLAGS.host 21 | port = FLAGS.port 22 | model_name = FLAGS.model_name 23 | model_version = FLAGS.model_version 24 | request_timeout = FLAGS.request_timeout 25 | 26 | # Generate inference data 27 | features = numpy.asarray( 28 | [[1, 2, 3, 4], [5, 6, 7, 8]]) 29 | features_tensor_proto = tf.contrib.util.make_tensor_proto(features, 30 | dtype=tf.float32) 31 | 32 | # Create gRPC client and request 33 | channel = implementations.insecure_channel(host, port) 34 | stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) 35 | request = predict_pb2.PredictRequest() 36 | request.model_spec.name = model_name 37 | if model_version > 0: 38 | request.model_spec.version.value = model_version 39 | request.inputs['state'].CopyFrom(features_tensor_proto) 40 | 41 | # Send request 42 | result = stub.Predict(request, request_timeout) 43 | print(result) 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /python_predict_client/predict_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: predict.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2 17 | import model_pb2 as model__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='predict.proto', 22 | package='tensorflow.serving', 23 | syntax='proto3', 24 | serialized_pb=_b('\n\rpredict.proto\x12\x12tensorflow.serving\x1a&tensorflow/core/framework/tensor.proto\x1a\x0bmodel.proto\"\xe2\x01\n\x0ePredictRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12>\n\x06inputs\x18\x02 \x03(\x0b\x32..tensorflow.serving.PredictRequest.InputsEntry\x12\x15\n\routput_filter\x18\x03 \x03(\t\x1a\x46\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\"\x9d\x01\n\x0fPredictResponse\x12\x41\n\x07outputs\x18\x01 \x03(\x0b\x32\x30.tensorflow.serving.PredictResponse.OutputsEntry\x1aG\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3') 25 | , 26 | dependencies=[tensorflow_dot_core_dot_framework_dot_tensor__pb2.DESCRIPTOR,model__pb2.DESCRIPTOR,]) 27 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 28 | 29 | 30 | 31 | 32 | _PREDICTREQUEST_INPUTSENTRY = _descriptor.Descriptor( 33 | name='InputsEntry', 34 | full_name='tensorflow.serving.PredictRequest.InputsEntry', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='key', full_name='tensorflow.serving.PredictRequest.InputsEntry.key', index=0, 41 | number=1, type=9, cpp_type=9, label=1, 42 | has_default_value=False, default_value=_b("").decode('utf-8'), 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None), 46 | _descriptor.FieldDescriptor( 47 | name='value', full_name='tensorflow.serving.PredictRequest.InputsEntry.value', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None), 53 | ], 54 | extensions=[ 55 | ], 56 | nested_types=[], 57 | enum_types=[ 58 | ], 59 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 60 | is_extendable=False, 61 | syntax='proto3', 62 | extension_ranges=[], 63 | oneofs=[ 64 | ], 65 | serialized_start=247, 66 | serialized_end=317, 67 | ) 68 | 69 | _PREDICTREQUEST = _descriptor.Descriptor( 70 | name='PredictRequest', 71 | full_name='tensorflow.serving.PredictRequest', 72 | filename=None, 73 | file=DESCRIPTOR, 74 | containing_type=None, 75 | fields=[ 76 | _descriptor.FieldDescriptor( 77 | name='model_spec', full_name='tensorflow.serving.PredictRequest.model_spec', index=0, 78 | number=1, type=11, cpp_type=10, label=1, 79 | has_default_value=False, default_value=None, 80 | message_type=None, enum_type=None, containing_type=None, 81 | is_extension=False, extension_scope=None, 82 | options=None), 83 | _descriptor.FieldDescriptor( 84 | name='inputs', full_name='tensorflow.serving.PredictRequest.inputs', index=1, 85 | number=2, type=11, cpp_type=10, label=3, 86 | has_default_value=False, default_value=[], 87 | message_type=None, enum_type=None, containing_type=None, 88 | is_extension=False, extension_scope=None, 89 | options=None), 90 | _descriptor.FieldDescriptor( 91 | name='output_filter', full_name='tensorflow.serving.PredictRequest.output_filter', index=2, 92 | number=3, type=9, cpp_type=9, label=3, 93 | has_default_value=False, default_value=[], 94 | message_type=None, enum_type=None, containing_type=None, 95 | is_extension=False, extension_scope=None, 96 | options=None), 97 | ], 98 | extensions=[ 99 | ], 100 | nested_types=[_PREDICTREQUEST_INPUTSENTRY, ], 101 | enum_types=[ 102 | ], 103 | options=None, 104 | is_extendable=False, 105 | syntax='proto3', 106 | extension_ranges=[], 107 | oneofs=[ 108 | ], 109 | serialized_start=91, 110 | serialized_end=317, 111 | ) 112 | 113 | 114 | _PREDICTRESPONSE_OUTPUTSENTRY = _descriptor.Descriptor( 115 | name='OutputsEntry', 116 | full_name='tensorflow.serving.PredictResponse.OutputsEntry', 117 | filename=None, 118 | file=DESCRIPTOR, 119 | containing_type=None, 120 | fields=[ 121 | _descriptor.FieldDescriptor( 122 | name='key', full_name='tensorflow.serving.PredictResponse.OutputsEntry.key', index=0, 123 | number=1, type=9, cpp_type=9, label=1, 124 | has_default_value=False, default_value=_b("").decode('utf-8'), 125 | message_type=None, enum_type=None, containing_type=None, 126 | is_extension=False, extension_scope=None, 127 | options=None), 128 | _descriptor.FieldDescriptor( 129 | name='value', full_name='tensorflow.serving.PredictResponse.OutputsEntry.value', index=1, 130 | number=2, type=11, cpp_type=10, label=1, 131 | has_default_value=False, default_value=None, 132 | message_type=None, enum_type=None, containing_type=None, 133 | is_extension=False, extension_scope=None, 134 | options=None), 135 | ], 136 | extensions=[ 137 | ], 138 | nested_types=[], 139 | enum_types=[ 140 | ], 141 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 142 | is_extendable=False, 143 | syntax='proto3', 144 | extension_ranges=[], 145 | oneofs=[ 146 | ], 147 | serialized_start=406, 148 | serialized_end=477, 149 | ) 150 | 151 | _PREDICTRESPONSE = _descriptor.Descriptor( 152 | name='PredictResponse', 153 | full_name='tensorflow.serving.PredictResponse', 154 | filename=None, 155 | file=DESCRIPTOR, 156 | containing_type=None, 157 | fields=[ 158 | _descriptor.FieldDescriptor( 159 | name='outputs', full_name='tensorflow.serving.PredictResponse.outputs', index=0, 160 | number=1, type=11, cpp_type=10, label=3, 161 | has_default_value=False, default_value=[], 162 | message_type=None, enum_type=None, containing_type=None, 163 | is_extension=False, extension_scope=None, 164 | options=None), 165 | ], 166 | extensions=[ 167 | ], 168 | nested_types=[_PREDICTRESPONSE_OUTPUTSENTRY, ], 169 | enum_types=[ 170 | ], 171 | options=None, 172 | is_extendable=False, 173 | syntax='proto3', 174 | extension_ranges=[], 175 | oneofs=[ 176 | ], 177 | serialized_start=320, 178 | serialized_end=477, 179 | ) 180 | 181 | _PREDICTREQUEST_INPUTSENTRY.fields_by_name['value'].message_type = tensorflow_dot_core_dot_framework_dot_tensor__pb2._TENSORPROTO 182 | _PREDICTREQUEST_INPUTSENTRY.containing_type = _PREDICTREQUEST 183 | _PREDICTREQUEST.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC 184 | _PREDICTREQUEST.fields_by_name['inputs'].message_type = _PREDICTREQUEST_INPUTSENTRY 185 | _PREDICTRESPONSE_OUTPUTSENTRY.fields_by_name['value'].message_type = tensorflow_dot_core_dot_framework_dot_tensor__pb2._TENSORPROTO 186 | _PREDICTRESPONSE_OUTPUTSENTRY.containing_type = _PREDICTRESPONSE 187 | _PREDICTRESPONSE.fields_by_name['outputs'].message_type = _PREDICTRESPONSE_OUTPUTSENTRY 188 | DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST 189 | DESCRIPTOR.message_types_by_name['PredictResponse'] = _PREDICTRESPONSE 190 | 191 | PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), dict( 192 | 193 | InputsEntry = _reflection.GeneratedProtocolMessageType('InputsEntry', (_message.Message,), dict( 194 | DESCRIPTOR = _PREDICTREQUEST_INPUTSENTRY, 195 | __module__ = 'predict_pb2' 196 | # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest.InputsEntry) 197 | )) 198 | , 199 | DESCRIPTOR = _PREDICTREQUEST, 200 | __module__ = 'predict_pb2' 201 | # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest) 202 | )) 203 | _sym_db.RegisterMessage(PredictRequest) 204 | _sym_db.RegisterMessage(PredictRequest.InputsEntry) 205 | 206 | PredictResponse = _reflection.GeneratedProtocolMessageType('PredictResponse', (_message.Message,), dict( 207 | 208 | OutputsEntry = _reflection.GeneratedProtocolMessageType('OutputsEntry', (_message.Message,), dict( 209 | DESCRIPTOR = _PREDICTRESPONSE_OUTPUTSENTRY, 210 | __module__ = 'predict_pb2' 211 | # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse.OutputsEntry) 212 | )) 213 | , 214 | DESCRIPTOR = _PREDICTRESPONSE, 215 | __module__ = 'predict_pb2' 216 | # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse) 217 | )) 218 | _sym_db.RegisterMessage(PredictResponse) 219 | _sym_db.RegisterMessage(PredictResponse.OutputsEntry) 220 | 221 | 222 | DESCRIPTOR.has_options = True 223 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) 224 | _PREDICTREQUEST_INPUTSENTRY.has_options = True 225 | _PREDICTREQUEST_INPUTSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 226 | _PREDICTRESPONSE_OUTPUTSENTRY.has_options = True 227 | _PREDICTRESPONSE_OUTPUTSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 228 | import grpc 229 | from grpc.beta import implementations as beta_implementations 230 | from grpc.beta import interfaces as beta_interfaces 231 | from grpc.framework.common import cardinality 232 | from grpc.framework.interfaces.face import utilities as face_utilities 233 | # @@protoc_insertion_point(module_scope) 234 | -------------------------------------------------------------------------------- /python_predict_client/prediction_service_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: prediction_service.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | import predict_pb2 as predict__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='prediction_service.proto', 21 | package='tensorflow.serving', 22 | syntax='proto3', 23 | serialized_pb=_b('\n\x18prediction_service.proto\x12\x12tensorflow.serving\x1a\rpredict.proto2g\n\x11PredictionService\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponseB\x03\xf8\x01\x01\x62\x06proto3') 24 | , 25 | dependencies=[predict__pb2.DESCRIPTOR,]) 26 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 27 | 28 | 29 | 30 | 31 | 32 | DESCRIPTOR.has_options = True 33 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) 34 | import grpc 35 | from grpc.beta import implementations as beta_implementations 36 | from grpc.beta import interfaces as beta_interfaces 37 | from grpc.framework.common import cardinality 38 | from grpc.framework.interfaces.face import utilities as face_utilities 39 | 40 | 41 | class PredictionServiceStub(object): 42 | """PredictionService provides access to machine-learned models loaded by 43 | model_servers. 44 | """ 45 | 46 | def __init__(self, channel): 47 | """Constructor. 48 | 49 | Args: 50 | channel: A grpc.Channel. 51 | """ 52 | self.Predict = channel.unary_unary( 53 | '/tensorflow.serving.PredictionService/Predict', 54 | request_serializer=predict__pb2.PredictRequest.SerializeToString, 55 | response_deserializer=predict__pb2.PredictResponse.FromString, 56 | ) 57 | 58 | 59 | class PredictionServiceServicer(object): 60 | """PredictionService provides access to machine-learned models loaded by 61 | model_servers. 62 | """ 63 | 64 | def Predict(self, request, context): 65 | """Predict -- provides access to loaded TensorFlow model. 66 | """ 67 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 68 | context.set_details('Method not implemented!') 69 | raise NotImplementedError('Method not implemented!') 70 | 71 | 72 | def add_PredictionServiceServicer_to_server(servicer, server): 73 | rpc_method_handlers = { 74 | 'Predict': grpc.unary_unary_rpc_method_handler( 75 | servicer.Predict, 76 | request_deserializer=predict__pb2.PredictRequest.FromString, 77 | response_serializer=predict__pb2.PredictResponse.SerializeToString, 78 | ), 79 | } 80 | generic_handler = grpc.method_handlers_generic_handler( 81 | 'tensorflow.serving.PredictionService', rpc_method_handlers) 82 | server.add_generic_rpc_handlers((generic_handler,)) 83 | 84 | 85 | class BetaPredictionServiceServicer(object): 86 | """PredictionService provides access to machine-learned models loaded by 87 | model_servers. 88 | """ 89 | def Predict(self, request, context): 90 | """Predict -- provides access to loaded TensorFlow model. 91 | """ 92 | context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) 93 | 94 | 95 | class BetaPredictionServiceStub(object): 96 | """PredictionService provides access to machine-learned models loaded by 97 | model_servers. 98 | """ 99 | def Predict(self, request, timeout, metadata=None, with_call=False, protocol_options=None): 100 | """Predict -- provides access to loaded TensorFlow model. 101 | """ 102 | raise NotImplementedError() 103 | Predict.future = None 104 | 105 | 106 | def beta_create_PredictionService_server(servicer, pool=None, pool_size=None, default_timeout=None, maximum_timeout=None): 107 | request_deserializers = { 108 | ('tensorflow.serving.PredictionService', 'Predict'): predict__pb2.PredictRequest.FromString, 109 | } 110 | response_serializers = { 111 | ('tensorflow.serving.PredictionService', 'Predict'): predict__pb2.PredictResponse.SerializeToString, 112 | } 113 | method_implementations = { 114 | ('tensorflow.serving.PredictionService', 'Predict'): face_utilities.unary_unary_inline(servicer.Predict), 115 | } 116 | server_options = beta_implementations.server_options(request_deserializers=request_deserializers, response_serializers=response_serializers, thread_pool=pool, thread_pool_size=pool_size, default_timeout=default_timeout, maximum_timeout=maximum_timeout) 117 | return beta_implementations.server(method_implementations, options=server_options) 118 | 119 | 120 | def beta_create_PredictionService_stub(channel, host=None, metadata_transformer=None, pool=None, pool_size=None): 121 | request_serializers = { 122 | ('tensorflow.serving.PredictionService', 'Predict'): predict__pb2.PredictRequest.SerializeToString, 123 | } 124 | response_deserializers = { 125 | ('tensorflow.serving.PredictionService', 'Predict'): predict__pb2.PredictResponse.FromString, 126 | } 127 | cardinalities = { 128 | 'Predict': cardinality.Cardinality.UNARY_UNARY, 129 | } 130 | stub_options = beta_implementations.stub_options(host=host, metadata_transformer=metadata_transformer, request_serializers=request_serializers, response_deserializers=response_deserializers, thread_pool=pool, thread_pool_size=pool_size) 131 | return beta_implementations.dynamic_stub(channel, 'tensorflow.serving.PredictionService', cardinalities, options=stub_options) 132 | # @@protoc_insertion_point(module_scope) 133 | --------------------------------------------------------------------------------