├── .gitignore ├── README.md ├── bench.sh ├── model_conv.py ├── requirements.txt ├── run_onnx.py ├── run_sb3.py ├── run_tflite.py ├── tflite_benchmark.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.zip 2 | *.onnx 3 | *.pb 4 | model 5 | *.tflite 6 | *.log 7 | venv 8 | .idea 9 | *.swp 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Running SB3 developed agents on TFLite or Coral 2 | 3 | ## Introduction 4 | 5 | I've been using [Stable-Baselines3](https://stable-baselines3.readthedocs.io) 6 | to train agents against some custom Gyms, some of which require fairly 7 | large NNs in order to be effective. 8 | 9 | I want those agents to eventually be run on a pi or similar, so I need 10 | to export all the way to [TFLite](https://www.tensorflow.org/lite) 11 | and ideally a [Coral](https://coral.ai/). 12 | 13 | ## How to use 14 | 15 | ### Setup 16 | 17 | You will need to have configured the Coral system-wide stuff. 18 | 19 | Build a venv: 20 | 21 | ```shell 22 | python3 -m venv venv 23 | source venv/bin/activate 24 | python3 -m pip install -r requirements.txt 25 | ``` 26 | 27 | ### Running 28 | 29 | This comes with enough defaults to do cradle-to-grave demonstration, 30 | but all the pieces take command-line arguments so I can adjust to taste 31 | for my actual use case. 32 | 33 | ```shell 34 | # Train an agent with SB3 35 | python3 ./train.py 36 | 37 | # Convert model 38 | python3 ./model_conv.py 39 | 40 | # Run original SB3 model 41 | python3 ./run_sb3.py 42 | # Run the onnx model 43 | python3 ./run_onnx.py 44 | # Run the TFLite model 45 | python3 ./run_tflite.py 46 | # Run the Coral model ["edgetpu" in the name will attempt to load Coral] 47 | python3 ./run_tflite.py MountainCarContinuous-v0 model_quant_edgetpu 48 | ``` 49 | 50 | ## Benchmarking 51 | 52 | I was curious to explore how the Coral actually performs. bench.sh should 53 | reproduce a file with a variety of NN sizes, then benchmark them all. 54 | 55 | A few things about the benchmark: 56 | * For completeness, there's a non-quantised "edgetpu" file built; it 57 | should perform exactly the same as the CPU non-quantised one [because 58 | it can't run on the Coral] 59 | * The benchmark simply samples the observation space for pushing through 60 | TFLite, but doesn't actually execute the Gym. One can imagine perverse 61 | edge cases here. 62 | * This manufactures NNs, but they aren't trained to completion. One can 63 | imagine perverse edge cases here, too. 64 | * Simple fully-connected NNs such as these RL models enjoy may not be 65 | a great use case for the Coral 66 | * The bench.sh script creates some deliberately poorly-dimensioned NNs; 67 | either they cannot possibly fit on the Coral, or couldn't possibly 68 | be useful. 69 | 70 | ## Extras 71 | 72 | The full chain, implemented here, to go from SB3 (Torch) to Coral is: 73 | ``` 74 | Torch => ONNX => Tensorflow => TFLite (normal) => TFLite (quantised) => Coral 75 | ``` 76 | 77 | When this code quantises the network, it explicitly leaves the inputs and 78 | outputs as floats; this means there's some work that gets done on the CPU, 79 | but the observation and action spaces of a gym would mean that work needs 80 | doing, anyways. So although edgetpu\_compiler says that this may be less 81 | efficient when run on the actual device, it's actually not. 82 | 83 | The torch-to-ONNX step is a separate beast related to stable-baselines 3, that 84 | warrants discussion; you can find more information on the SB3 docs page, here: 85 | https://stable-baselines3.readthedocs.io/en/master/guide/export.html 86 | 87 | Cheers, 88 | Gary 89 | 90 | -------------------------------------------------------------------------------- /bench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | workdir=bench/ 4 | mkdir -p workdir 5 | 6 | gym=MountainCarContinuous-v0 7 | 8 | out_csv=${workdir}/bench.csv 9 | 10 | for n_nodes_per_layer in 64 128 256 512 1024 2700 11 | do 12 | for n_layers in 2 4 8 16 32 13 | do 14 | echo "Layers: ${n_layers} Width: ${n_nodes_per_layer}" 15 | modelprefix=bench_w${n_nodes_per_layer}xd${n_layers} 16 | python3 ./train.py ${gym} ${workdir}/${modelprefix} ${n_layers} ${n_nodes_per_layer} 17 | python3 ./model_conv.py ${gym} ${workdir}/${modelprefix} 18 | done 19 | done 20 | 21 | for modelfile in ${workdir}/*.tflite 22 | do 23 | echo "Benchmarking ${modelfile}" 24 | python3 ./tflite_benchmark.py ${gym} ${modelfile} ${out_csv} 25 | done 26 | 27 | -------------------------------------------------------------------------------- /model_conv.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys 3 | from os import system 4 | 5 | import gym 6 | import torch 7 | import torchsummary 8 | import onnx 9 | import onnx_tf.backend 10 | import tensorflow as tf 11 | 12 | from stable_baselines3 import SAC 13 | 14 | 15 | class OnnxablePolicy(torch.nn.Module): 16 | def __init__(self, actor): 17 | super(OnnxablePolicy, self).__init__() 18 | self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu) 19 | 20 | def forward(self, observation): 21 | # NOTE: You may have to process (normalize) observation in the correct 22 | # way before using this. See `common.preprocessing.preprocess_obs` 23 | return self.actor(observation) 24 | 25 | 26 | if __name__ == '__main__': 27 | env_name = 'MountainCarContinuous-v0' 28 | model_prefix = 'model' 29 | 30 | if len(sys.argv) < 3: 31 | print("Usage: " + str(sys.argv[0]) + " ") 32 | print(" Defaulting to env: " + env_name + ", model_prefix: " + model_prefix) 33 | else: 34 | env_name = sys.argv[1] 35 | model_prefix = sys.argv[2] 36 | 37 | model_save_file = model_prefix + '.zip' 38 | onnx_save_file = model_prefix + '.onnx' 39 | tflite_save_file = model_prefix + '.tflite' 40 | tflite_quant_save_file = model_prefix + '_quant.tflite' 41 | 42 | print('Creating gym to gather observation sample...') 43 | env = gym.make(env_name) 44 | obs = env.observation_space 45 | # Awkward reshape: https://github.com/onnx/onnx-tensorflow/issues/400 46 | dummy_input = torch.FloatTensor(obs.sample().reshape(1, -1)) 47 | 48 | print('Loading existing SB3 model...') 49 | model = SAC.load(model_save_file, env, verbose=True) 50 | 51 | print('Exporting to ONNX...') 52 | onnxable_model = OnnxablePolicy(model.policy.actor) 53 | model.policy.to("cpu") 54 | model.policy.eval() 55 | print(str(onnxable_model.actor)) 56 | # torchsummary.summary(model.policy.actor, input_size=len(dummy_input)) 57 | 58 | torch.onnx.export(onnxable_model, dummy_input, onnx_save_file, 59 | input_names=['input'], 60 | output_names=['output'], 61 | opset_version=9, verbose=True) 62 | 63 | print('Loading ONNX and checking...') 64 | onnx_model = onnx.load(onnx_save_file) 65 | onnx.checker.check_model(onnx_model) 66 | print(onnx.helper.printable_graph(onnx_model.graph)) 67 | 68 | print('Converting ONNX to TF...') 69 | tf_rep = onnx_tf.backend.prepare(onnx_model) 70 | tf_rep.export_graph(model_prefix) 71 | 72 | print('Converting TF to TFLite...') 73 | converter = tf.lite.TFLiteConverter.from_saved_model(model_prefix) 74 | tflite_model = converter.convert() 75 | with open(tflite_save_file, 'wb') as f: 76 | f.write(tflite_model) 77 | 78 | print('Converting TF to Quantised TFLite...') 79 | 80 | def representative_data_gen(): 81 | global obs 82 | for i in range(100000): 83 | yield [obs.sample().reshape(1, -1)] 84 | 85 | converter_quant = tf.lite.TFLiteConverter.from_saved_model(model_prefix) 86 | converter_quant.optimizations = [tf.lite.Optimize.DEFAULT] 87 | converter_quant.representative_dataset = representative_data_gen 88 | converter_quant.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] 89 | converter_quant.target_spec.supported_types = [tf.int8] 90 | # Just accept that observations and actions are inherently floaty, let Coral handle that on the CPU 91 | converter_quant.inference_input_type = tf.float32 92 | converter_quant.inference_output_type = tf.float32 93 | tflite_quant_model = converter_quant.convert() 94 | with open(tflite_quant_save_file, 'wb') as f: 95 | f.write(tflite_quant_model) 96 | 97 | print('Converting TFLite [nonquant] to Coral...') 98 | system('edgetpu_compiler --show_operations -o ' + os.path.dirname(model_prefix) + ' ' + tflite_save_file) 99 | 100 | print('Converting TFLite [quant] to Coral...') 101 | system('edgetpu_compiler --show_operations -o ' + os.path.dirname(model_prefix) + ' ' + tflite_quant_save_file) 102 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Create env: 2 | # python3 -m venv venv 3 | # source venv/bin/activate 4 | # python3 -m pip install -r requirements.txt 5 | # Due to some weird dependencies, Gym may be old: 6 | # python3 -m pip install -U gym 7 | 8 | pip 9 | gym 10 | pyglet 11 | onnx 12 | onnxruntime 13 | onnx-tf 14 | torch 15 | torchsummary 16 | stable-baselines3 17 | stable-baselines3[extra] 18 | tensorflow 19 | --extra-index-url https://google-coral.github.io/py-repo/ pycoral~=2.0 20 | --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime 21 | 22 | -------------------------------------------------------------------------------- /run_onnx.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gym 3 | import onnxruntime as ort 4 | 5 | if __name__ == '__main__': 6 | env_name = 'MountainCarContinuous-v0' 7 | model_prefix = 'model' 8 | if len(sys.argv) < 3: 9 | print("Usage: " + str(sys.argv[0]) + " ") 10 | print(" Defaulting to env: " + env_name + ", model_prefix: " + model_prefix) 11 | else: 12 | env_name = sys.argv[1] 13 | model_prefix = sys.argv[2] 14 | model_save_file = model_prefix + ".onnx" 15 | 16 | env = gym.make(env_name) 17 | obs = env.reset() 18 | 19 | ort_session = ort.InferenceSession(model_save_file) 20 | 21 | for i in range(100000): 22 | outputs = ort_session.run( 23 | None, 24 | {'input': obs.reshape([1, -1])} 25 | ) 26 | obs, reward, done, info = env.step(outputs[0]) 27 | env.render() 28 | if done: 29 | obs = env.reset() 30 | 31 | 32 | -------------------------------------------------------------------------------- /run_sb3.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gym 3 | 4 | from stable_baselines3 import SAC 5 | 6 | if __name__ == '__main__': 7 | env_name = 'MountainCarContinuous-v0' 8 | model_prefix = 'model' 9 | if len(sys.argv) < 3: 10 | print("Usage: " + str(sys.argv[0]) + " ") 11 | print(" Defaulting to env: " + env_name + ", model_prefix: " + model_prefix) 12 | else: 13 | env_name = sys.argv[1] 14 | model_prefix = sys.argv[2] 15 | model_save_file = model_prefix + ".zip" 16 | 17 | env = gym.make(env_name) 18 | obs = env.reset() 19 | 20 | model = SAC.load(model_save_file, env) 21 | 22 | for i in range(100000): 23 | action, _state = model.predict(obs, deterministic=True) 24 | obs, reward, done, info = env.step(action) 25 | env.render() 26 | if done: 27 | obs = env.reset() 28 | 29 | 30 | -------------------------------------------------------------------------------- /run_tflite.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gym 3 | import tflite_runtime.interpreter as tflite 4 | 5 | if __name__ == '__main__': 6 | env_name = 'MountainCarContinuous-v0' 7 | model_prefix = 'model_quant' 8 | if len(sys.argv) < 3: 9 | print("Usage: " + str(sys.argv[0]) + " ") 10 | print(" Defaulting to env: " + env_name + ", model_prefix: " + model_prefix) 11 | else: 12 | env_name = sys.argv[1] 13 | model_prefix = sys.argv[2] 14 | model_save_file = model_prefix + ".tflite" 15 | 16 | delegates = None 17 | if 'edgetpu' in model_save_file: 18 | delegates = [tflite.load_delegate('libedgetpu.so.1')] 19 | 20 | env = gym.make(env_name) 21 | obs = env.reset() 22 | 23 | interpreter = tflite.Interpreter(model_path=model_save_file, experimental_delegates=delegates) 24 | interpreter.allocate_tensors() 25 | 26 | # Get input and output tensors. 27 | input_details = interpreter.get_input_details() 28 | output_details = interpreter.get_output_details() 29 | 30 | for i in range(100000): 31 | 32 | input_data = obs.reshape(1, -1) 33 | interpreter.set_tensor(input_details[0]['index'], input_data) 34 | 35 | interpreter.invoke() 36 | output_data = interpreter.get_tensor(output_details[0]['index']) 37 | 38 | obs, reward, done, info = env.step(output_data) 39 | env.render() 40 | if done: 41 | obs = env.reset() 42 | 43 | 44 | -------------------------------------------------------------------------------- /tflite_benchmark.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys 3 | import gym 4 | import time 5 | import gym_rtam 6 | import socket 7 | import re 8 | import tflite_runtime.interpreter as tflite 9 | 10 | if __name__ == '__main__': 11 | if len(sys.argv) < 4: 12 | print("Usage: " + str(sys.argv[0]) + " ") 13 | exit(0) 14 | 15 | env_name = sys.argv[1] 16 | tflite_model = sys.argv[2] 17 | output_csv = sys.argv[3] 18 | 19 | dev = os.getenv("EDGETPU_DEVICE", ":0") 20 | device_description = "CPU" 21 | if 'edgetpu' in tflite_model: 22 | from pycoral.utils import edgetpu 23 | 24 | edge_tpus_available = edgetpu.list_edge_tpus() 25 | print("Coral TPUs available: {}, using {}".format(edge_tpus_available, dev)) 26 | interpreter = edgetpu.make_interpreter(tflite_model, device=dev) 27 | device_description = "TPU: " + dev 28 | else: 29 | interpreter = tflite.Interpreter(model_path=tflite_model) 30 | 31 | # Average over this many inferences 32 | bench_inference_cnt = 100000 33 | # Stop benchmarking if it takes longer then this 34 | max_bench_time_ns = 240 * 1e9 35 | 36 | env = gym.make(env_name) 37 | obs_space = env.observation_space 38 | 39 | interpreter.allocate_tensors() 40 | 41 | # Get input and output tensors. 42 | input_details = interpreter.get_input_details() 43 | output_details = interpreter.get_output_details() 44 | 45 | start_time_ns = time.time_ns() 46 | inference_cnt = 0 47 | for i in range(bench_inference_cnt): 48 | # Skip the actual simulation, just grab a random observation 49 | input_data = obs_space.sample().reshape(1, -1) 50 | interpreter.set_tensor(input_details[0]['index'], input_data) 51 | interpreter.invoke() 52 | output_data = interpreter.get_tensor(output_details[0]['index']) 53 | inference_cnt += 1 54 | time_now = time.time_ns() 55 | if time_now - start_time_ns > max_bench_time_ns: 56 | break 57 | 58 | add_header = not os.path.exists(output_csv) 59 | pat = re.compile(r'w(?P[0-9]+)xd(?P[0-9]+)') 60 | model_params = pat.search(tflite_model) 61 | n_nodes_per_layer = model_params.group("n_nodes_per_layer") if model_params is not None else '' 62 | n_hidden_layers = model_params.group("n_hidden_layers") if model_params is not None else '' 63 | with open(output_csv, 'a') as out_f: 64 | if add_header: 65 | out_f.write("env,file,file_size,inference_cnt,ms_per_inf,inf_per_s,hostname,execute_on," 66 | "is_quantised,n_nodes_per_layer,n_hidden_layers\n") 67 | model_size = os.path.getsize(tflite_model) 68 | ns_per_inference = (time_now - start_time_ns) / float(inference_cnt) 69 | ms_per_inference = ns_per_inference / 1e6 70 | out_f.write(f"{env_name},{tflite_model},{model_size},{inference_cnt}," 71 | f"{ms_per_inference},{1000/ms_per_inference}," 72 | f"{socket.gethostname()},{device_description},{'quant' in tflite_model}," 73 | f"{n_nodes_per_layer},{n_hidden_layers}\n") 74 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gym 3 | 4 | from stable_baselines3 import SAC 5 | 6 | # This is here so as to generate a model.zip file; I don't tune the parameters or even expect it to 7 | # generate an agent that works usefully, I just want a complete-and-intact saved model. 8 | # At the very least, increase number of timesteps to model.learn if you need something useful 9 | 10 | if __name__ == '__main__': 11 | env_name = 'MountainCarContinuous-v0' 12 | model_prefix = 'model' 13 | n_hidden_layers = 4 14 | n_nodes_per_layer = 64 15 | if len(sys.argv) < 3: 16 | print("Usage: " + str(sys.argv[0]) + " [ ]") 17 | print(" Defaulting to env: " + env_name + ", model_prefix: " + model_prefix) 18 | else: 19 | env_name = sys.argv[1] 20 | model_prefix = sys.argv[2] 21 | if len(sys.argv) >= 5: 22 | n_hidden_layers = int(sys.argv[3]) 23 | n_nodes_per_layer = int(sys.argv[4]) 24 | 25 | model_save_file = model_prefix + ".zip" 26 | env = gym.make(env_name) 27 | env.reset() 28 | env.render() 29 | 30 | nn = [n_nodes_per_layer for i in range(n_hidden_layers)] 31 | print("nn: {}".format(nn)) 32 | # "pi=[]" is an array of widths for the created policy/actor network, qf is for critic 33 | model = SAC('MlpPolicy', env, verbose=1, 34 | policy_kwargs=dict(net_arch=dict(pi=nn, qf=[64, 64])) 35 | ) 36 | model.learn(total_timesteps=250) 37 | # model.learn(total_timesteps=250_000) 38 | model.save(model_save_file) 39 | 40 | --------------------------------------------------------------------------------