├── .gitignore ├── 3dball ├── ppo │ ├── config.yaml │ ├── config_addition.yaml │ ├── main.py │ └── ppo.py └── ppo_sep_critic │ ├── config.yaml │ ├── config_addition.yaml │ ├── main.py │ └── ppo.py ├── README.md ├── algorithm ├── agent.py ├── ppo_base.py ├── ppo_main.py ├── ppo_sep_critic_base.py ├── ppo_sep_critic_main.py └── saver.py ├── mlagents └── envs │ ├── __init__.py │ ├── base_unity_environment.py │ ├── brain.py │ ├── communicator.py │ ├── communicator_objects │ ├── __init__.py │ ├── agent_action_proto_pb2.py │ ├── agent_info_proto_pb2.py │ ├── brain_parameters_proto_pb2.py │ ├── command_proto_pb2.py │ ├── custom_action_pb2.py │ ├── custom_observation_pb2.py │ ├── custom_reset_parameters_pb2.py │ ├── demonstration_meta_proto_pb2.py │ ├── engine_configuration_proto_pb2.py │ ├── environment_parameters_proto_pb2.py │ ├── header_pb2.py │ ├── resolution_proto_pb2.py │ ├── space_type_proto_pb2.py │ ├── unity_input_pb2.py │ ├── unity_message_pb2.py │ ├── unity_output_pb2.py │ ├── unity_rl_initialization_input_pb2.py │ ├── unity_rl_initialization_output_pb2.py │ ├── unity_rl_input_pb2.py │ ├── unity_rl_output_pb2.py │ ├── unity_to_external_pb2.py │ └── unity_to_external_pb2_grpc.py │ ├── environment.py │ ├── exception.py │ ├── mock_communicator.py │ ├── rpc_communicator.py │ ├── socket_communicator.py │ ├── subprocess_environment.py │ └── tests │ ├── __init__.py │ ├── test_envs.py │ ├── test_rpc_communicator.py │ └── test_subprocess_unity_environment.py ├── simple_boat ├── ppo │ ├── config.yaml │ ├── config_addition.yaml │ ├── main.py │ └── ppo.py └── ppo_sep_critic │ ├── config.yaml │ ├── config_addition.yaml │ ├── main.py │ └── ppo.py └── simple_roller ├── ppo ├── config.yaml ├── config_addition.yaml ├── main.py └── ppo.py └── ppo_sep_critic ├── config.yaml ├── config_addition.yaml ├── main.py └── ppo.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | .vscode 118 | models -------------------------------------------------------------------------------- /3dball/ppo/config.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: 3DBall 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 5000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # std: true 13 | # mix: true 14 | # aux_cumulative_ratio: 0.4 15 | # good_trans_ratio: 1 16 | # addition_objective: false 17 | 18 | ppo_config: 19 | # save_per_iter: 1000 20 | write_summary_graph: true 21 | 22 | # batch_size: 2048 23 | # epoch_size: 10 24 | 25 | # init_td_threshold: 0.0 26 | # td_threshold_decay_steps: 100 27 | # td_threshold_rate: 0.5 28 | 29 | # beta: 0.001 30 | epsilon: 0.1 31 | 32 | # init_lr: 0.00005 33 | # min_lr: 0.00001 34 | decay_steps: 50 35 | decay_rate: 0.7 36 | -------------------------------------------------------------------------------- /3dball/ppo/config_addition.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: 3DBall 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 5000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # std: true 13 | # mix: true 14 | # aux_cumulative_ratio: 0.4 15 | # good_trans_ratio: 1 16 | addition_objective: true 17 | 18 | ppo_config: 19 | # save_per_iter: 1000 20 | write_summary_graph: true 21 | 22 | # batch_size: 2048 23 | # epoch_size: 10 24 | 25 | # init_td_threshold: 0.0 26 | # td_threshold_decay_steps: 100 27 | # td_threshold_rate: 0.5 28 | 29 | # beta: 0.001 30 | epsilon: 0.02 31 | 32 | # init_lr: 0.00005 33 | # min_lr: 0.00001 34 | decay_steps: 50 35 | decay_rate: 0.7 36 | -------------------------------------------------------------------------------- /3dball/ppo/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | import numpy as np 5 | 6 | sys.path.append('../..') 7 | from algorithm.ppo_main import Main 8 | from algorithm.agent import Agent 9 | 10 | if __name__ == '__main__': 11 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s] - [%(name)s] - %(message)s') 12 | 13 | _log = logging.getLogger('tensorflow') 14 | _log.setLevel(logging.ERROR) 15 | 16 | logger = logging.getLogger('ppo') 17 | 18 | Main(sys.argv[1:], Agent) 19 | -------------------------------------------------------------------------------- /3dball/ppo/ppo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_probability as tfp 4 | 5 | import warnings 6 | warnings.filterwarnings("ignore", category=DeprecationWarning) 7 | 8 | initializer_helper = { 9 | 'kernel_initializer': tf.truncated_normal_initializer(0, .1), 10 | 'bias_initializer': tf.constant_initializer(.1) 11 | } 12 | 13 | 14 | class PPO_Sep_Custom(object): 15 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 16 | with tf.variable_scope(scope, reuse=reuse): 17 | policy, policy_variables = self._build_actor_net(s_inputs, 'actor', trainable) 18 | v, v_variables = self._build_critic_net(s_inputs, 'critic', trainable) 19 | 20 | return policy, v, policy_variables + v_variables 21 | 22 | def _build_critic_net(self, s_inputs, scope, trainable, reuse=False): 23 | with tf.variable_scope(scope, reuse=reuse): 24 | l = tf.layers.dense(s_inputs, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 25 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 26 | l = tf.layers.dense(l, 128, tf.nn.relu, trainable=trainable, **initializer_helper) 27 | l = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 28 | v = tf.layers.dense(l, 1, trainable=trainable, **initializer_helper) 29 | 30 | variables = tf.get_variable_scope().global_variables() 31 | 32 | return v, variables 33 | 34 | def _build_actor_net(self, s_inputs, scope, trainable, reuse=False): 35 | with tf.variable_scope(scope, reuse=reuse): 36 | l = tf.layers.dense(s_inputs, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 37 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 38 | l = tf.layers.dense(l, 128, tf.nn.relu, trainable=trainable, **initializer_helper) 39 | l = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 40 | 41 | mu = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 42 | mu = tf.layers.dense(mu, self.a_dim, tf.nn.tanh, trainable=trainable, **initializer_helper) 43 | sigma = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 44 | sigma = tf.layers.dense(sigma, self.a_dim, tf.nn.sigmoid, trainable=trainable, **initializer_helper) 45 | 46 | mu, sigma = mu, sigma + .1 47 | 48 | policy = tf.distributions.Normal(loc=mu, scale=sigma) 49 | 50 | variables = tf.get_variable_scope().global_variables() 51 | 52 | return policy, variables 53 | 54 | 55 | class PPO_Std_Custom(object): 56 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 57 | with tf.variable_scope(scope, reuse=reuse): 58 | l = tf.layers.dense(s_inputs, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 59 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 60 | l = tf.layers.dense(l, 128, tf.nn.relu, trainable=trainable, **initializer_helper) 61 | 62 | prob_l = tf.layers.dense(l, 128, tf.nn.relu, trainable=trainable, **initializer_helper) 63 | mu = tf.layers.dense(prob_l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 64 | mu = tf.layers.dense(mu, self.a_dim, tf.nn.tanh, trainable=trainable, **initializer_helper) 65 | sigma = tf.layers.dense(prob_l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 66 | sigma = tf.layers.dense(sigma, self.a_dim, tf.nn.sigmoid, trainable=trainable, **initializer_helper) 67 | mu, sigma = mu, sigma + .1 68 | 69 | policy = tf.distributions.Normal(loc=mu, scale=sigma) 70 | 71 | v_l = tf.layers.dense(l, 128, tf.nn.relu, trainable=trainable, **initializer_helper) 72 | v_l = tf.layers.dense(v_l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 73 | v = tf.layers.dense(v_l, 1, trainable=trainable, **initializer_helper) 74 | 75 | variables = tf.get_variable_scope().global_variables() 76 | 77 | return policy, v, variables 78 | -------------------------------------------------------------------------------- /3dball/ppo_sep_critic/config.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: 3DBall 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 5000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # mix: true 13 | # aux_cumulative_ratio: 0.4 14 | # good_trans_ratio: 1 15 | # addition_objective: false 16 | 17 | critic_config: 18 | # save_per_iter: 1000 19 | write_summary_graph: true 20 | 21 | # batch_size: 2048 22 | # epoch_size: 10 23 | 24 | # init_td_threshold: 0.01 25 | # td_threshold_decay_steps: 100 26 | # td_threshold_rate: 0.9 27 | 28 | # init_lr: 0.00005 29 | decay_steps: 50 30 | decay_rate: 0.7 31 | 32 | 33 | ppo_config: 34 | # save_per_iter: 1000 35 | write_summary_graph: true 36 | 37 | # batch_size: 2048 38 | # epoch_size: 10 39 | 40 | beta: 0.002 41 | epsilon: 0.1 42 | 43 | # init_lr: 0.00005 44 | decay_steps: 50 45 | decay_rate: 0.7 46 | -------------------------------------------------------------------------------- /3dball/ppo_sep_critic/config_addition.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: 3DBall 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 5000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # mix: true 13 | # aux_cumulative_ratio: 0.4 14 | # good_trans_ratio: 1 15 | addition_objective: true 16 | 17 | critic_config: 18 | # save_per_iter: 1000 19 | write_summary_graph: true 20 | 21 | # batch_size: 2048 22 | # epoch_size: 10 23 | 24 | init_td_threshold: 0.01 25 | # td_threshold_decay_steps: 100 26 | td_threshold_rate: 0.9 27 | 28 | # init_lr: 0.00005 29 | decay_steps: 50 30 | decay_rate: 0.7 31 | 32 | 33 | ppo_config: 34 | # save_per_iter: 1000 35 | write_summary_graph: true 36 | 37 | # batch_size: 2048 38 | # epoch_size: 10 39 | 40 | beta: 0.002 41 | epsilon: 0.02 42 | 43 | # init_lr: 0.00005 44 | decay_steps: 50 45 | decay_rate: 0.7 46 | -------------------------------------------------------------------------------- /3dball/ppo_sep_critic/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | import numpy as np 5 | 6 | sys.path.append('../..') 7 | from algorithm.ppo_sep_critic_main import Main 8 | from algorithm.agent import Agent 9 | 10 | if __name__ == '__main__': 11 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s] - [%(name)s] - %(message)s') 12 | 13 | _log = logging.getLogger('tensorflow') 14 | _log.setLevel(logging.ERROR) 15 | 16 | logger = logging.getLogger('ppo') 17 | 18 | Main(sys.argv[1:], Agent) 19 | -------------------------------------------------------------------------------- /3dball/ppo_sep_critic/ppo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_probability as tfp 4 | 5 | import warnings 6 | warnings.filterwarnings("ignore", category=DeprecationWarning) 7 | 8 | initializer_helper = { 9 | 'kernel_initializer': tf.truncated_normal_initializer(0, .1), 10 | 'bias_initializer': tf.constant_initializer(.1) 11 | } 12 | 13 | 14 | class Critic_Custom(object): 15 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 16 | with tf.variable_scope(scope): 17 | l = tf.layers.dense(self.pl_s, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 18 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 19 | l = tf.layers.dense(l, 128, tf.nn.relu, trainable=trainable, **initializer_helper) 20 | l = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 21 | v = tf.layers.dense(l, 1, trainable=trainable, **initializer_helper) 22 | 23 | return v 24 | 25 | 26 | class PPO_Custom(object): 27 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 28 | with tf.variable_scope(scope, reuse=reuse): 29 | l = tf.layers.dense(s_inputs, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 30 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 31 | l = tf.layers.dense(l, 128, tf.nn.relu, trainable=trainable, **initializer_helper) 32 | l = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 33 | 34 | mu = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 35 | mu = tf.layers.dense(mu, self.a_dim, tf.nn.tanh, trainable=trainable, **initializer_helper) 36 | sigma = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 37 | sigma = tf.layers.dense(sigma, self.a_dim, tf.nn.sigmoid, trainable=trainable, **initializer_helper) 38 | 39 | mu, sigma = mu, sigma + .1 40 | 41 | policy = tf.distributions.Normal(loc=mu, scale=sigma) 42 | 43 | variables = tf.get_variable_scope().global_variables() 44 | 45 | return policy, variables 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RL-PPO-with-Unity 2 | The implementation of PPO algorithm based on the Unity3d environment. 3 | 4 | We use [ml-agents](https://github.com/Unity-Technologies/ml-agents) to connect Unity game environment with learning algorithm based on python runtime. In the training and inference stage, we use [TensorFlow](https://github.com/tensorflow/tensorflow) to build our neural network. -------------------------------------------------------------------------------- /algorithm/agent.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import numpy as np 4 | 5 | 6 | class Agent(object): 7 | reward = 0 8 | ppo = None 9 | 10 | done = False 11 | _curr_cumulative_reward = 0 12 | 13 | def __init__(self, agent_id, gamma, lambda_): 14 | self.agent_id = agent_id 15 | self.gamma = gamma 16 | self.lambda_ = lambda_ 17 | 18 | self._tmp_trans = list() 19 | self.trajectories = list() 20 | self.good_trajectories = list() 21 | self.aux_trajectories = list() 22 | 23 | def add_transition(self, 24 | state, 25 | action, 26 | reward, 27 | local_done, 28 | max_reached, 29 | state_): 30 | self._curr_cumulative_reward += reward 31 | self._tmp_trans.append({ 32 | 'state': state, 33 | 'action': action, 34 | 'reward': np.array([reward]), 35 | 'local_done': local_done, 36 | 'max_reached': max_reached, 37 | 'state_': state_, 38 | 'cumulative_reward': self._curr_cumulative_reward 39 | }) 40 | 41 | if not self.done: 42 | self.reward += reward 43 | 44 | self._extra_log(state, 45 | action, 46 | reward, 47 | local_done, 48 | max_reached, 49 | state_) 50 | 51 | if local_done: 52 | self.done = True 53 | self.fill_reset_tmp_trans() 54 | 55 | def _extra_log(self, 56 | state, 57 | action, 58 | reward, 59 | local_done, 60 | max_reached, 61 | state_): 62 | pass 63 | 64 | def fill_reset_tmp_trans(self): 65 | if len(self._tmp_trans) != 0: 66 | self.trajectories.append(self._tmp_trans) 67 | self._curr_cumulative_reward = 0 68 | self._tmp_trans = list() 69 | 70 | def get_cumulative_rewards(self): 71 | return [t[-1]['cumulative_reward'] for t in self.trajectories] 72 | 73 | def get_trans_combined(self): 74 | return [] if len(self.trajectories) == 0 else \ 75 | reduce(lambda x, y: x + y, self.trajectories) 76 | 77 | def get_good_trans_combined(self): 78 | return [] if len(self.good_trajectories) == 0 else \ 79 | reduce(lambda x, y: x + y, self.good_trajectories) 80 | 81 | def get_aux_trans_combined(self): 82 | return [] if len(self.aux_trajectories) == 0 else \ 83 | reduce(lambda x, y: x + y, self.aux_trajectories) 84 | 85 | def compute_discounted_return(self): 86 | for trans in self.trajectories: 87 | if (not trans[-1]['max_reached']) and trans[-1]['local_done']: 88 | v_tmp = 0 89 | else: 90 | v_tmp = self.ppo.get_v(trans[-1]['state_'][np.newaxis, :])[0] 91 | for tran in trans[::-1]: 92 | v_tmp = tran['reward'] + self.gamma * v_tmp 93 | tran['discounted_return'] = v_tmp 94 | 95 | def compute_advantage(self): 96 | for trans in self.trajectories: 97 | if self.lambda_ == 1: 98 | s = np.array([t['state'] for t in trans]) 99 | v_s = self.ppo.get_v(s) 100 | for i, tran in enumerate(trans): 101 | tran['advantage'] = tran['discounted_return'] - v_s[i] 102 | else: 103 | s, r, s_, done, max_reached = [np.array(e) for e in zip(*[(t['state'], 104 | t['reward'], 105 | t['state_'], 106 | [t['local_done']], 107 | [t['max_reached']]) for t in trans])] 108 | v_s = self.ppo.get_v(s) 109 | v_s_ = self.ppo.get_v(s_) 110 | td_errors = r + self.gamma * v_s_ * (~(done ^ max_reached)) - v_s 111 | for i, td_error in enumerate(td_errors): 112 | trans[i]['td_error'] = td_error 113 | 114 | td_error_tmp = 0 115 | for tran in trans[::-1]: 116 | td_error_tmp = tran['td_error'] + self.gamma * self.lambda_ * td_error_tmp 117 | tran['advantage'] = td_error_tmp 118 | 119 | def compute_good_trans(self, aux_cumulative_reward): 120 | for trans in self.trajectories: 121 | if trans[-1]['local_done']: 122 | if trans[-1]['reward'] >= 1: 123 | self.good_trajectories.append(trans) 124 | elif trans[-1]['cumulative_reward'] >= aux_cumulative_reward: 125 | self.aux_trajectories.append(trans) 126 | -------------------------------------------------------------------------------- /algorithm/ppo_base.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from .saver import Saver 8 | 9 | 10 | class PPO_Base(object): 11 | def __init__(self, 12 | state_dim, 13 | action_dim, 14 | model_root_path, 15 | 16 | save_per_iter=1000, 17 | write_summary_graph=False, 18 | seed=None, 19 | std=True, 20 | addition_objective=False, 21 | 22 | batch_size=2048, 23 | epoch_size=10, # train K epochs 24 | 25 | init_td_threshold=0.0, 26 | td_threshold_decay_steps=100, 27 | td_threshold_rate=0.5, 28 | 29 | beta=0.001, # entropy coefficient 30 | epsilon=0.2, # clip bound 31 | 32 | init_lr=5e-5, 33 | min_lr=1e-5, 34 | decay_steps=50, 35 | decay_rate=0.9): 36 | 37 | self.graph = tf.Graph() 38 | gpu_options = tf.GPUOptions(allow_growth=True) 39 | self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options), 40 | graph=self.graph) 41 | 42 | self.s_dim = state_dim 43 | self.a_dim = action_dim 44 | self.save_per_iter = save_per_iter 45 | self.batch_size = batch_size 46 | self.epoch_size = epoch_size 47 | 48 | 49 | with self.graph.as_default(): 50 | if seed is not None: 51 | tf.random.set_random_seed(seed) 52 | 53 | self._build_model(std, addition_objective, 54 | init_td_threshold, td_threshold_decay_steps, td_threshold_rate, 55 | beta, epsilon, 56 | init_lr, min_lr, decay_steps, decay_rate) 57 | 58 | self.saver = Saver(f'{model_root_path}/model', self.sess) 59 | self.init_iteration = self.saver.restore_or_init() 60 | 61 | summary_path = f'{model_root_path}/log' 62 | if write_summary_graph: 63 | writer = tf.summary.FileWriter(summary_path, self.graph) 64 | writer.close() 65 | self.summary_writer = tf.summary.FileWriter(summary_path) 66 | 67 | def _build_model(self, std, addition_objective, 68 | init_td_threshold, td_threshold_decay_steps, td_threshold_rate, 69 | beta, epsilon, 70 | init_lr, min_lr, decay_steps, decay_rate): 71 | self.pl_s = tf.placeholder(tf.float32, shape=(None, self.s_dim), name='state') 72 | self.policy, self.v, policy_v_variables = self._build_net(self.pl_s, 'actor_critic', True) 73 | old_policy, old_v, old_policy_v_variables = self._build_net(self.pl_s, 'old_actor_critic', False) 74 | 75 | with tf.name_scope('objective_and_value_function_loss'): 76 | self.pl_a = tf.placeholder(tf.float32, shape=(None, self.a_dim), name='action') 77 | self.pl_advantage = tf.placeholder(tf.float32, shape=(None, 1), name='advantage') 78 | self.pl_discounted_r = tf.placeholder(tf.float32, shape=(None, 1), name='discounted_reward') 79 | 80 | self.policy_prob = self.policy.prob(self.pl_a) 81 | if addition_objective: 82 | ratio = self.policy_prob - old_policy.prob(self.pl_a) 83 | L_clip = tf.math.reduce_mean(tf.math.minimum( 84 | ratio * self.pl_advantage, # surrogate objective 85 | tf.clip_by_value(ratio, -epsilon, epsilon) * self.pl_advantage 86 | ), name='clipped_objective') 87 | else: 88 | ratio = self.policy_prob / old_policy.prob(self.pl_a) 89 | L_clip = tf.math.reduce_mean(tf.math.minimum( 90 | ratio * self.pl_advantage, # surrogate objective 91 | tf.clip_by_value(ratio, 1. - epsilon, 1. + epsilon) * self.pl_advantage 92 | ), name='clipped_objective') 93 | 94 | L_vf = tf.reduce_mean(tf.square(self.pl_discounted_r - self.v), name='value_function_loss') 95 | S = tf.reduce_mean(self.policy.entropy(), name='entropy') 96 | 97 | self.choose_action_op = tf.squeeze(self.policy.sample(1), axis=0) 98 | 99 | with tf.name_scope('optimizer'): 100 | self.global_iter = tf.get_variable('global_iter', shape=(), initializer=tf.constant_initializer(0), trainable=False) 101 | self.lr = tf.math.maximum(tf.train.exponential_decay(learning_rate=init_lr, 102 | global_step=self.global_iter, 103 | decay_steps=decay_steps, 104 | decay_rate=decay_rate, 105 | staircase=True), min_lr) 106 | 107 | self.td_threshold = tf.train.exponential_decay(init_td_threshold, 108 | global_step=self.global_iter, 109 | decay_steps=td_threshold_decay_steps, 110 | decay_rate=td_threshold_rate, 111 | staircase=True) 112 | if std: 113 | L = L_clip - L_vf + beta * S 114 | self.train_op = tf.train.AdamOptimizer(self.lr).minimize(-L) 115 | else: 116 | L = L_clip + beta * S 117 | self.train_op = [tf.train.AdamOptimizer(self.lr).minimize(-L), 118 | tf.train.AdamOptimizer(self.lr).minimize(L_vf)] 119 | 120 | self.update_variables_op = [tf.assign(r, v) for r, v in 121 | zip(old_policy_v_variables, policy_v_variables)] 122 | 123 | tf.summary.scalar('loss/value_function', L_vf) 124 | tf.summary.scalar('loss/-entropy', S) 125 | tf.summary.scalar('loss/lr', self.lr) 126 | self.summaries = tf.summary.merge_all() 127 | 128 | def _build_net(self, s_inputs, trainable): 129 | # return policy, v, variables 130 | raise Exception('PPO_Base._build_net not implemented') 131 | 132 | def get_td_error(self, s, r, s_, done): 133 | assert len(s.shape) == 2 134 | assert len(r.shape) == 2 135 | assert len(s_.shape) == 2 136 | assert len(done.shape) == 2 137 | 138 | return self.sess.run(self.td_error, { 139 | self.pl_s: s, 140 | self.pl_r: r, 141 | self.pl_s_: s_, 142 | self.pl_done: done 143 | }) 144 | 145 | def get_v(self, s): 146 | assert len(s.shape) == 2 147 | 148 | return self.sess.run(self.v, { 149 | self.pl_s: s 150 | }) 151 | 152 | def choose_action(self, s): 153 | assert len(s.shape) == 2 154 | 155 | a = self.sess.run(self.choose_action_op, { 156 | self.pl_s: s 157 | }) 158 | 159 | return np.clip(a, -1, 1) 160 | 161 | def get_policy(self, s): 162 | assert len(s.shape) == 2 163 | 164 | return self.sess.run([self.policy.loc, self.policy.scale], { 165 | self.pl_s: s 166 | }) 167 | 168 | def write_constant_summaries(self, constant_summaries, iteration): 169 | if self.summary_writer is not None: 170 | summaries = tf.Summary(value=[tf.Summary.Value(tag=i['tag'], 171 | simple_value=i['simple_value']) 172 | for i in constant_summaries]) 173 | self.summary_writer.add_summary(summaries, iteration + self.init_iteration) 174 | 175 | def get_not_zero_prob_bool_mask(self, s, a): 176 | policy_prob = self.sess.run(self.policy_prob, { 177 | self.pl_s: s, 178 | self.pl_a: a 179 | }) 180 | bool_mask = ~np.any(policy_prob <= 1.e-7, axis=1) 181 | return bool_mask 182 | 183 | def train(self, s, a, adv, discounted_r, iteration): 184 | assert len(s.shape) == 2 185 | assert len(a.shape) == 2 186 | assert len(adv.shape) == 2 187 | assert len(discounted_r.shape) == 2 188 | assert s.shape[0] == a.shape[0] == adv.shape[0] == discounted_r.shape[0] 189 | 190 | global_iter = iteration + self.init_iteration 191 | self.global_iter.load(global_iter, self.sess) 192 | 193 | td_error = np.square(self.get_v(s) - discounted_r) 194 | bool_mask = np.all(td_error > self.sess.run(self.td_threshold), axis=1) 195 | if not np.all(bool_mask == False): 196 | s, a, adv, discounted_r = s[bool_mask], a[bool_mask], adv[bool_mask], discounted_r[bool_mask] 197 | 198 | self.sess.run(self.update_variables_op) # TODO 199 | 200 | if iteration % self.save_per_iter == 0: 201 | self.saver.save(global_iter) 202 | 203 | if self.summary_writer is not None: 204 | summaries = self.sess.run(self.summaries, { 205 | self.pl_s: s, 206 | self.pl_a: a, 207 | self.pl_advantage: adv, 208 | self.pl_discounted_r: discounted_r 209 | }) 210 | self.summary_writer.add_summary(summaries, global_iter) 211 | 212 | for i in range(0, s.shape[0], self.batch_size): 213 | _s, _a, _adv, _discounted_r = (s[i:i + self.batch_size], 214 | a[i:i + self.batch_size], 215 | adv[i:i + self.batch_size], 216 | discounted_r[i:i + self.batch_size]) 217 | for _ in range(self.epoch_size): 218 | self.sess.run(self.train_op, { 219 | self.pl_s: _s, 220 | self.pl_a: _a, 221 | self.pl_advantage: _adv, 222 | self.pl_discounted_r: _discounted_r 223 | }) 224 | 225 | def dispose(self): 226 | self.summary_writer.close() 227 | self.sess.close() 228 | -------------------------------------------------------------------------------- /algorithm/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | 6 | class Saver(object): 7 | def __init__(self, model_path, sess, var_list=None): 8 | self.model_path = model_path 9 | self.sess = sess 10 | 11 | # create model path if not exists 12 | is_exists = os.path.exists(model_path) 13 | if not is_exists: 14 | os.makedirs(model_path) 15 | 16 | if var_list is None: 17 | self.saver = tf.train.Saver(max_to_keep=10) 18 | else: 19 | self.saver = tf.train.Saver(var_list, max_to_keep=10) 20 | 21 | def restore_or_init(self, step=None): 22 | last_step = 0 23 | ckpt = tf.train.get_checkpoint_state(self.model_path) 24 | if ckpt is None: 25 | self.sess.run(tf.global_variables_initializer()) 26 | else: 27 | if step is None: 28 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 29 | last_step = int(ckpt.model_checkpoint_path.split('-')[1].split('.')[0]) 30 | else: 31 | for c in ckpt.all_model_checkpoint_paths: 32 | if f'model-{step}' in c: 33 | self.saver.restore(self.sess, c) 34 | last_step = step 35 | break 36 | else: 37 | paths = ', '.join(ckpt.all_model_checkpoint_paths) 38 | raise Exception(f'No checkpoint step [{step}], available paths are [{paths}]') 39 | return last_step 40 | 41 | def save_graph(self, model_name=None): 42 | if model_name is None: 43 | model_name = 'raw_graph_def.pb' 44 | tf.train.write_graph(sess.graph_def, self.model_path, model_name, as_text=False) 45 | 46 | def save(self, step=None): 47 | if step is None: 48 | self.saver.save(self.sess, f'{self.model_path}/model.ckpt') 49 | else: 50 | self.saver.save(self.sess, f'{self.model_path}/model-{int(step)}.ckpt') 51 | -------------------------------------------------------------------------------- /mlagents/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .brain import * 2 | from .environment import * 3 | from .exception import * 4 | -------------------------------------------------------------------------------- /mlagents/envs/base_unity_environment.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict 3 | 4 | from mlagents.envs import AllBrainInfo, BrainParameters 5 | 6 | 7 | class BaseUnityEnvironment(ABC): 8 | @abstractmethod 9 | def step( 10 | self, vector_action=None, memory=None, text_action=None, value=None 11 | ) -> AllBrainInfo: 12 | pass 13 | 14 | @abstractmethod 15 | def reset(self, config=None, train_mode=True) -> AllBrainInfo: 16 | pass 17 | 18 | @property 19 | @abstractmethod 20 | def global_done(self): 21 | pass 22 | 23 | @property 24 | @abstractmethod 25 | def external_brains(self) -> Dict[str, BrainParameters]: 26 | pass 27 | 28 | @property 29 | @abstractmethod 30 | def reset_parameters(self) -> Dict[str, str]: 31 | pass 32 | 33 | @abstractmethod 34 | def close(self): 35 | pass 36 | -------------------------------------------------------------------------------- /mlagents/envs/brain.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import io 4 | 5 | from typing import Dict, List, Optional 6 | from PIL import Image 7 | 8 | logger = logging.getLogger("mlagents.envs") 9 | 10 | 11 | class BrainInfo: 12 | def __init__( 13 | self, 14 | visual_observation, 15 | vector_observation, 16 | text_observations, 17 | memory=None, 18 | reward=None, 19 | agents=None, 20 | local_done=None, 21 | vector_action=None, 22 | text_action=None, 23 | max_reached=None, 24 | action_mask=None, 25 | custom_observations=None, 26 | ): 27 | """ 28 | Describes experience at current step of all agents linked to a brain. 29 | """ 30 | self.visual_observations = visual_observation 31 | self.vector_observations = vector_observation 32 | self.text_observations = text_observations 33 | self.memories = memory 34 | self.rewards = reward 35 | self.local_done = local_done 36 | self.max_reached = max_reached 37 | self.agents = agents 38 | self.previous_vector_actions = vector_action 39 | self.previous_text_actions = text_action 40 | self.action_masks = action_mask 41 | self.custom_observations = custom_observations 42 | 43 | def merge(self, other): 44 | for i in range(len(self.visual_observations)): 45 | self.visual_observations[i].extend(other.visual_observations[i]) 46 | self.vector_observations = np.append( 47 | self.vector_observations, other.vector_observations, axis=0 48 | ) 49 | self.text_observations.extend(other.text_observations) 50 | self.memories = self.merge_memories( 51 | self.memories, other.memories, self.agents, other.agents 52 | ) 53 | self.rewards = safe_concat_lists(self.rewards, other.rewards) 54 | self.local_done = safe_concat_lists(self.local_done, other.local_done) 55 | self.max_reached = safe_concat_lists(self.max_reached, other.max_reached) 56 | self.agents = safe_concat_lists(self.agents, other.agents) 57 | self.previous_vector_actions = safe_concat_np_ndarray( 58 | self.previous_vector_actions, other.previous_vector_actions 59 | ) 60 | self.previous_text_actions = safe_concat_lists( 61 | self.previous_text_actions, other.previous_text_actions 62 | ) 63 | self.action_masks = safe_concat_np_ndarray( 64 | self.action_masks, other.action_masks 65 | ) 66 | self.custom_observations = safe_concat_lists( 67 | self.custom_observations, other.custom_observations 68 | ) 69 | 70 | @staticmethod 71 | def merge_memories(m1, m2, agents1, agents2): 72 | if len(m1) == 0 and len(m2) != 0: 73 | m1 = np.zeros((len(agents1), m2.shape[1])) 74 | elif len(m2) == 0 and len(m1) != 0: 75 | m2 = np.zeros((len(agents2), m1.shape[1])) 76 | elif m2.shape[1] > m1.shape[1]: 77 | new_m1 = np.zeros((m1.shape[0], m2.shape[1])) 78 | new_m1[0 : m1.shape[0], 0 : m1.shape[1]] = m1 79 | return np.append(new_m1, m2, axis=0) 80 | elif m1.shape[1] > m2.shape[1]: 81 | new_m2 = np.zeros((m2.shape[0], m1.shape[1])) 82 | new_m2[0 : m2.shape[0], 0 : m2.shape[1]] = m2 83 | return np.append(m1, new_m2, axis=0) 84 | return np.append(m1, m2, axis=0) 85 | 86 | @staticmethod 87 | def process_pixels(image_bytes, gray_scale): 88 | """ 89 | Converts byte array observation image into numpy array, re-sizes it, 90 | and optionally converts it to grey scale 91 | :param gray_scale: Whether to convert the image to grayscale. 92 | :param image_bytes: input byte array corresponding to image 93 | :return: processed numpy array of observation from environment 94 | """ 95 | s = bytearray(image_bytes) 96 | image = Image.open(io.BytesIO(s)) 97 | s = np.array(image) / 255.0 98 | if gray_scale: 99 | s = np.mean(s, axis=2) 100 | s = np.reshape(s, [s.shape[0], s.shape[1], 1]) 101 | return s 102 | 103 | @staticmethod 104 | def from_agent_proto(agent_info_list, brain_params): 105 | """ 106 | Converts list of agent infos to BrainInfo. 107 | """ 108 | vis_obs = [] 109 | for i in range(brain_params.number_visual_observations): 110 | obs = [ 111 | BrainInfo.process_pixels( 112 | x.visual_observations[i], 113 | brain_params.camera_resolutions[i]["blackAndWhite"], 114 | ) 115 | for x in agent_info_list 116 | ] 117 | vis_obs += [obs] 118 | if len(agent_info_list) == 0: 119 | memory_size = 0 120 | else: 121 | memory_size = max([len(x.memories) for x in agent_info_list]) 122 | if memory_size == 0: 123 | memory = np.zeros((0, 0)) 124 | else: 125 | [ 126 | x.memories.extend([0] * (memory_size - len(x.memories))) 127 | for x in agent_info_list 128 | ] 129 | memory = np.array([list(x.memories) for x in agent_info_list]) 130 | total_num_actions = sum(brain_params.vector_action_space_size) 131 | mask_actions = np.ones((len(agent_info_list), total_num_actions)) 132 | for agent_index, agent_info in enumerate(agent_info_list): 133 | if agent_info.action_mask is not None: 134 | if len(agent_info.action_mask) == total_num_actions: 135 | mask_actions[agent_index, :] = [ 136 | 0 if agent_info.action_mask[k] else 1 137 | for k in range(total_num_actions) 138 | ] 139 | if any([np.isnan(x.reward) for x in agent_info_list]): 140 | logger.warning( 141 | "An agent had a NaN reward for brain " + brain_params.brain_name 142 | ) 143 | if any([np.isnan(x.stacked_vector_observation).any() for x in agent_info_list]): 144 | logger.warning( 145 | "An agent had a NaN observation for brain " + brain_params.brain_name 146 | ) 147 | 148 | if len(agent_info_list) == 0: 149 | vector_obs = np.zeros( 150 | ( 151 | 0, 152 | brain_params.vector_observation_space_size 153 | * brain_params.num_stacked_vector_observations, 154 | ) 155 | ) 156 | else: 157 | vector_obs = np.nan_to_num( 158 | np.array([x.stacked_vector_observation for x in agent_info_list]) 159 | ) 160 | brain_info = BrainInfo( 161 | visual_observation=vis_obs, 162 | vector_observation=vector_obs, 163 | text_observations=[x.text_observation for x in agent_info_list], 164 | memory=memory, 165 | reward=[x.reward if not np.isnan(x.reward) else 0 for x in agent_info_list], 166 | agents=[x.id for x in agent_info_list], 167 | local_done=[x.done for x in agent_info_list], 168 | vector_action=np.array([x.stored_vector_actions for x in agent_info_list]), 169 | text_action=[list(x.stored_text_actions) for x in agent_info_list], 170 | max_reached=[x.max_step_reached for x in agent_info_list], 171 | custom_observations=[x.custom_observation for x in agent_info_list], 172 | action_mask=mask_actions, 173 | ) 174 | return brain_info 175 | 176 | 177 | def safe_concat_lists(l1: Optional[List], l2: Optional[List]): 178 | if l1 is None and l2 is None: 179 | return None 180 | if l1 is None and l2 is not None: 181 | return l2.copy() 182 | if l1 is not None and l2 is None: 183 | return l1.copy() 184 | else: 185 | copy = l1.copy() 186 | copy.extend(l2) 187 | return copy 188 | 189 | 190 | def safe_concat_np_ndarray(a1: Optional[np.ndarray], a2: Optional[np.ndarray]): 191 | if a1 is not None and a1.size != 0: 192 | if a2 is not None and a2.size != 0: 193 | return np.append(a1, a2, axis=0) 194 | else: 195 | return a1.copy() 196 | elif a2 is not None and a2.size != 0: 197 | return a2.copy() 198 | return None 199 | 200 | 201 | # Renaming of dictionary of brain name to BrainInfo for clarity 202 | AllBrainInfo = Dict[str, BrainInfo] 203 | 204 | 205 | class BrainParameters: 206 | def __init__( 207 | self, 208 | brain_name: str, 209 | vector_observation_space_size: int, 210 | num_stacked_vector_observations: int, 211 | camera_resolutions: List[Dict], 212 | vector_action_space_size: List[int], 213 | vector_action_descriptions: List[str], 214 | vector_action_space_type: int, 215 | ): 216 | """ 217 | Contains all brain-specific parameters. 218 | """ 219 | self.brain_name = brain_name 220 | self.vector_observation_space_size = vector_observation_space_size 221 | self.num_stacked_vector_observations = num_stacked_vector_observations 222 | self.number_visual_observations = len(camera_resolutions) 223 | self.camera_resolutions = camera_resolutions 224 | self.vector_action_space_size = vector_action_space_size 225 | self.vector_action_descriptions = vector_action_descriptions 226 | self.vector_action_space_type = ["discrete", "continuous"][ 227 | vector_action_space_type 228 | ] 229 | 230 | def __str__(self): 231 | return """Unity brain name: {} 232 | Number of Visual Observations (per agent): {} 233 | Vector Observation space size (per agent): {} 234 | Number of stacked Vector Observation: {} 235 | Vector Action space type: {} 236 | Vector Action space size (per agent): {} 237 | Vector Action descriptions: {}""".format( 238 | self.brain_name, 239 | str(self.number_visual_observations), 240 | str(self.vector_observation_space_size), 241 | str(self.num_stacked_vector_observations), 242 | self.vector_action_space_type, 243 | str(self.vector_action_space_size), 244 | ", ".join(self.vector_action_descriptions), 245 | ) 246 | 247 | @staticmethod 248 | def from_proto(brain_param_proto): 249 | """ 250 | Converts brain parameter proto to BrainParameter object. 251 | :param brain_param_proto: protobuf object. 252 | :return: BrainParameter object. 253 | """ 254 | resolution = [ 255 | {"height": x.height, "width": x.width, "blackAndWhite": x.gray_scale} 256 | for x in brain_param_proto.camera_resolutions 257 | ] 258 | brain_params = BrainParameters( 259 | brain_param_proto.brain_name, 260 | brain_param_proto.vector_observation_size, 261 | brain_param_proto.num_stacked_vector_observations, 262 | resolution, 263 | list(brain_param_proto.vector_action_size), 264 | list(brain_param_proto.vector_action_descriptions), 265 | brain_param_proto.vector_action_space_type, 266 | ) 267 | return brain_params 268 | -------------------------------------------------------------------------------- /mlagents/envs/communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .communicator_objects import UnityOutput, UnityInput 4 | 5 | logger = logging.getLogger("mlagents.envs") 6 | 7 | 8 | class Communicator(object): 9 | def __init__(self, worker_id=0, base_port=5005): 10 | """ 11 | Python side of the communication. Must be used in pair with the right Unity Communicator equivalent. 12 | 13 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 14 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 15 | """ 16 | 17 | def initialize(self, inputs: UnityInput) -> UnityOutput: 18 | """ 19 | Used to exchange initialization parameters between Python and the Environment 20 | :param inputs: The initialization input that will be sent to the environment. 21 | :return: UnityOutput: The initialization output sent by Unity 22 | """ 23 | 24 | def exchange(self, inputs: UnityInput) -> UnityOutput: 25 | """ 26 | Used to send an input and receive an output from the Environment 27 | :param inputs: The UnityInput that needs to be sent the Environment 28 | :return: The UnityOutputs generated by the Environment 29 | """ 30 | 31 | def close(self): 32 | """ 33 | Sends a shutdown signal to the unity environment, and closes the connection. 34 | """ 35 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent_action_proto_pb2 import * 2 | from .agent_info_proto_pb2 import * 3 | from .brain_parameters_proto_pb2 import * 4 | from .command_proto_pb2 import * 5 | from .custom_action_pb2 import * 6 | from .custom_observation_pb2 import * 7 | from .custom_reset_parameters_pb2 import * 8 | from .demonstration_meta_proto_pb2 import * 9 | from .engine_configuration_proto_pb2 import * 10 | from .environment_parameters_proto_pb2 import * 11 | from .header_pb2 import * 12 | from .resolution_proto_pb2 import * 13 | from .space_type_proto_pb2 import * 14 | from .unity_input_pb2 import * 15 | from .unity_message_pb2 import * 16 | from .unity_output_pb2 import * 17 | from .unity_rl_initialization_input_pb2 import * 18 | from .unity_rl_initialization_output_pb2 import * 19 | from .unity_rl_input_pb2 import * 20 | from .unity_rl_output_pb2 import * 21 | from .unity_to_external_pb2 import * 22 | from .unity_to_external_pb2_grpc import * 23 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/agent_action_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/agent_action_proto.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | from mlagents.envs.communicator_objects import ( 19 | custom_action_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2, 20 | ) 21 | 22 | 23 | DESCRIPTOR = _descriptor.FileDescriptor( 24 | name="mlagents/envs/communicator_objects/agent_action_proto.proto", 25 | package="communicator_objects", 26 | syntax="proto3", 27 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 28 | serialized_pb=_b( 29 | '\n;mlagents/envs/communicator_objects/agent_action_proto.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents/envs/communicator_objects/custom_action.proto"\x9c\x01\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x12\r\n\x05value\x18\x04 \x01(\x02\x12\x39\n\rcustom_action\x18\x05 \x01(\x0b\x32".communicator_objects.CustomActionB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 30 | ), 31 | dependencies=[ 32 | mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2.DESCRIPTOR 33 | ], 34 | ) 35 | 36 | 37 | _AGENTACTIONPROTO = _descriptor.Descriptor( 38 | name="AgentActionProto", 39 | full_name="communicator_objects.AgentActionProto", 40 | filename=None, 41 | file=DESCRIPTOR, 42 | containing_type=None, 43 | fields=[ 44 | _descriptor.FieldDescriptor( 45 | name="vector_actions", 46 | full_name="communicator_objects.AgentActionProto.vector_actions", 47 | index=0, 48 | number=1, 49 | type=2, 50 | cpp_type=6, 51 | label=3, 52 | has_default_value=False, 53 | default_value=[], 54 | message_type=None, 55 | enum_type=None, 56 | containing_type=None, 57 | is_extension=False, 58 | extension_scope=None, 59 | serialized_options=None, 60 | file=DESCRIPTOR, 61 | ), 62 | _descriptor.FieldDescriptor( 63 | name="text_actions", 64 | full_name="communicator_objects.AgentActionProto.text_actions", 65 | index=1, 66 | number=2, 67 | type=9, 68 | cpp_type=9, 69 | label=1, 70 | has_default_value=False, 71 | default_value=_b("").decode("utf-8"), 72 | message_type=None, 73 | enum_type=None, 74 | containing_type=None, 75 | is_extension=False, 76 | extension_scope=None, 77 | serialized_options=None, 78 | file=DESCRIPTOR, 79 | ), 80 | _descriptor.FieldDescriptor( 81 | name="memories", 82 | full_name="communicator_objects.AgentActionProto.memories", 83 | index=2, 84 | number=3, 85 | type=2, 86 | cpp_type=6, 87 | label=3, 88 | has_default_value=False, 89 | default_value=[], 90 | message_type=None, 91 | enum_type=None, 92 | containing_type=None, 93 | is_extension=False, 94 | extension_scope=None, 95 | serialized_options=None, 96 | file=DESCRIPTOR, 97 | ), 98 | _descriptor.FieldDescriptor( 99 | name="value", 100 | full_name="communicator_objects.AgentActionProto.value", 101 | index=3, 102 | number=4, 103 | type=2, 104 | cpp_type=6, 105 | label=1, 106 | has_default_value=False, 107 | default_value=float(0), 108 | message_type=None, 109 | enum_type=None, 110 | containing_type=None, 111 | is_extension=False, 112 | extension_scope=None, 113 | serialized_options=None, 114 | file=DESCRIPTOR, 115 | ), 116 | _descriptor.FieldDescriptor( 117 | name="custom_action", 118 | full_name="communicator_objects.AgentActionProto.custom_action", 119 | index=4, 120 | number=5, 121 | type=11, 122 | cpp_type=10, 123 | label=1, 124 | has_default_value=False, 125 | default_value=None, 126 | message_type=None, 127 | enum_type=None, 128 | containing_type=None, 129 | is_extension=False, 130 | extension_scope=None, 131 | serialized_options=None, 132 | file=DESCRIPTOR, 133 | ), 134 | ], 135 | extensions=[], 136 | nested_types=[], 137 | enum_types=[], 138 | serialized_options=None, 139 | is_extendable=False, 140 | syntax="proto3", 141 | extension_ranges=[], 142 | oneofs=[], 143 | serialized_start=142, 144 | serialized_end=298, 145 | ) 146 | 147 | _AGENTACTIONPROTO.fields_by_name[ 148 | "custom_action" 149 | ].message_type = ( 150 | mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2._CUSTOMACTION 151 | ) 152 | DESCRIPTOR.message_types_by_name["AgentActionProto"] = _AGENTACTIONPROTO 153 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 154 | 155 | AgentActionProto = _reflection.GeneratedProtocolMessageType( 156 | "AgentActionProto", 157 | (_message.Message,), 158 | dict( 159 | DESCRIPTOR=_AGENTACTIONPROTO, 160 | __module__="mlagents.envs.communicator_objects.agent_action_proto_pb2" 161 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentActionProto) 162 | ), 163 | ) 164 | _sym_db.RegisterMessage(AgentActionProto) 165 | 166 | 167 | DESCRIPTOR._options = None 168 | # @@protoc_insertion_point(module_scope) 169 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/agent_info_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/agent_info_proto.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | from mlagents.envs.communicator_objects import ( 19 | custom_observation_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2, 20 | ) 21 | 22 | 23 | DESCRIPTOR = _descriptor.FileDescriptor( 24 | name="mlagents/envs/communicator_objects/agent_info_proto.proto", 25 | package="communicator_objects", 26 | syntax="proto3", 27 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 28 | serialized_pb=_b( 29 | '\n9mlagents/envs/communicator_objects/agent_info_proto.proto\x12\x14\x63ommunicator_objects\x1a;mlagents/envs/communicator_objects/custom_observation.proto"\xd7\x02\n\x0e\x41gentInfoProto\x12"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x1b\n\x13visual_observations\x18\x02 \x03(\x0c\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x10\n\x08memories\x18\x06 \x03(\x02\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12\x43\n\x12\x63ustom_observation\x18\x0c \x01(\x0b\x32\'.communicator_objects.CustomObservationB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 30 | ), 31 | dependencies=[ 32 | mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2.DESCRIPTOR 33 | ], 34 | ) 35 | 36 | 37 | _AGENTINFOPROTO = _descriptor.Descriptor( 38 | name="AgentInfoProto", 39 | full_name="communicator_objects.AgentInfoProto", 40 | filename=None, 41 | file=DESCRIPTOR, 42 | containing_type=None, 43 | fields=[ 44 | _descriptor.FieldDescriptor( 45 | name="stacked_vector_observation", 46 | full_name="communicator_objects.AgentInfoProto.stacked_vector_observation", 47 | index=0, 48 | number=1, 49 | type=2, 50 | cpp_type=6, 51 | label=3, 52 | has_default_value=False, 53 | default_value=[], 54 | message_type=None, 55 | enum_type=None, 56 | containing_type=None, 57 | is_extension=False, 58 | extension_scope=None, 59 | serialized_options=None, 60 | file=DESCRIPTOR, 61 | ), 62 | _descriptor.FieldDescriptor( 63 | name="visual_observations", 64 | full_name="communicator_objects.AgentInfoProto.visual_observations", 65 | index=1, 66 | number=2, 67 | type=12, 68 | cpp_type=9, 69 | label=3, 70 | has_default_value=False, 71 | default_value=[], 72 | message_type=None, 73 | enum_type=None, 74 | containing_type=None, 75 | is_extension=False, 76 | extension_scope=None, 77 | serialized_options=None, 78 | file=DESCRIPTOR, 79 | ), 80 | _descriptor.FieldDescriptor( 81 | name="text_observation", 82 | full_name="communicator_objects.AgentInfoProto.text_observation", 83 | index=2, 84 | number=3, 85 | type=9, 86 | cpp_type=9, 87 | label=1, 88 | has_default_value=False, 89 | default_value=_b("").decode("utf-8"), 90 | message_type=None, 91 | enum_type=None, 92 | containing_type=None, 93 | is_extension=False, 94 | extension_scope=None, 95 | serialized_options=None, 96 | file=DESCRIPTOR, 97 | ), 98 | _descriptor.FieldDescriptor( 99 | name="stored_vector_actions", 100 | full_name="communicator_objects.AgentInfoProto.stored_vector_actions", 101 | index=3, 102 | number=4, 103 | type=2, 104 | cpp_type=6, 105 | label=3, 106 | has_default_value=False, 107 | default_value=[], 108 | message_type=None, 109 | enum_type=None, 110 | containing_type=None, 111 | is_extension=False, 112 | extension_scope=None, 113 | serialized_options=None, 114 | file=DESCRIPTOR, 115 | ), 116 | _descriptor.FieldDescriptor( 117 | name="stored_text_actions", 118 | full_name="communicator_objects.AgentInfoProto.stored_text_actions", 119 | index=4, 120 | number=5, 121 | type=9, 122 | cpp_type=9, 123 | label=1, 124 | has_default_value=False, 125 | default_value=_b("").decode("utf-8"), 126 | message_type=None, 127 | enum_type=None, 128 | containing_type=None, 129 | is_extension=False, 130 | extension_scope=None, 131 | serialized_options=None, 132 | file=DESCRIPTOR, 133 | ), 134 | _descriptor.FieldDescriptor( 135 | name="memories", 136 | full_name="communicator_objects.AgentInfoProto.memories", 137 | index=5, 138 | number=6, 139 | type=2, 140 | cpp_type=6, 141 | label=3, 142 | has_default_value=False, 143 | default_value=[], 144 | message_type=None, 145 | enum_type=None, 146 | containing_type=None, 147 | is_extension=False, 148 | extension_scope=None, 149 | serialized_options=None, 150 | file=DESCRIPTOR, 151 | ), 152 | _descriptor.FieldDescriptor( 153 | name="reward", 154 | full_name="communicator_objects.AgentInfoProto.reward", 155 | index=6, 156 | number=7, 157 | type=2, 158 | cpp_type=6, 159 | label=1, 160 | has_default_value=False, 161 | default_value=float(0), 162 | message_type=None, 163 | enum_type=None, 164 | containing_type=None, 165 | is_extension=False, 166 | extension_scope=None, 167 | serialized_options=None, 168 | file=DESCRIPTOR, 169 | ), 170 | _descriptor.FieldDescriptor( 171 | name="done", 172 | full_name="communicator_objects.AgentInfoProto.done", 173 | index=7, 174 | number=8, 175 | type=8, 176 | cpp_type=7, 177 | label=1, 178 | has_default_value=False, 179 | default_value=False, 180 | message_type=None, 181 | enum_type=None, 182 | containing_type=None, 183 | is_extension=False, 184 | extension_scope=None, 185 | serialized_options=None, 186 | file=DESCRIPTOR, 187 | ), 188 | _descriptor.FieldDescriptor( 189 | name="max_step_reached", 190 | full_name="communicator_objects.AgentInfoProto.max_step_reached", 191 | index=8, 192 | number=9, 193 | type=8, 194 | cpp_type=7, 195 | label=1, 196 | has_default_value=False, 197 | default_value=False, 198 | message_type=None, 199 | enum_type=None, 200 | containing_type=None, 201 | is_extension=False, 202 | extension_scope=None, 203 | serialized_options=None, 204 | file=DESCRIPTOR, 205 | ), 206 | _descriptor.FieldDescriptor( 207 | name="id", 208 | full_name="communicator_objects.AgentInfoProto.id", 209 | index=9, 210 | number=10, 211 | type=5, 212 | cpp_type=1, 213 | label=1, 214 | has_default_value=False, 215 | default_value=0, 216 | message_type=None, 217 | enum_type=None, 218 | containing_type=None, 219 | is_extension=False, 220 | extension_scope=None, 221 | serialized_options=None, 222 | file=DESCRIPTOR, 223 | ), 224 | _descriptor.FieldDescriptor( 225 | name="action_mask", 226 | full_name="communicator_objects.AgentInfoProto.action_mask", 227 | index=10, 228 | number=11, 229 | type=8, 230 | cpp_type=7, 231 | label=3, 232 | has_default_value=False, 233 | default_value=[], 234 | message_type=None, 235 | enum_type=None, 236 | containing_type=None, 237 | is_extension=False, 238 | extension_scope=None, 239 | serialized_options=None, 240 | file=DESCRIPTOR, 241 | ), 242 | _descriptor.FieldDescriptor( 243 | name="custom_observation", 244 | full_name="communicator_objects.AgentInfoProto.custom_observation", 245 | index=11, 246 | number=12, 247 | type=11, 248 | cpp_type=10, 249 | label=1, 250 | has_default_value=False, 251 | default_value=None, 252 | message_type=None, 253 | enum_type=None, 254 | containing_type=None, 255 | is_extension=False, 256 | extension_scope=None, 257 | serialized_options=None, 258 | file=DESCRIPTOR, 259 | ), 260 | ], 261 | extensions=[], 262 | nested_types=[], 263 | enum_types=[], 264 | serialized_options=None, 265 | is_extendable=False, 266 | syntax="proto3", 267 | extension_ranges=[], 268 | oneofs=[], 269 | serialized_start=145, 270 | serialized_end=488, 271 | ) 272 | 273 | _AGENTINFOPROTO.fields_by_name[ 274 | "custom_observation" 275 | ].message_type = ( 276 | mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2._CUSTOMOBSERVATION 277 | ) 278 | DESCRIPTOR.message_types_by_name["AgentInfoProto"] = _AGENTINFOPROTO 279 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 280 | 281 | AgentInfoProto = _reflection.GeneratedProtocolMessageType( 282 | "AgentInfoProto", 283 | (_message.Message,), 284 | dict( 285 | DESCRIPTOR=_AGENTINFOPROTO, 286 | __module__="mlagents.envs.communicator_objects.agent_info_proto_pb2" 287 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentInfoProto) 288 | ), 289 | ) 290 | _sym_db.RegisterMessage(AgentInfoProto) 291 | 292 | 293 | DESCRIPTOR._options = None 294 | # @@protoc_insertion_point(module_scope) 295 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/brain_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/brain_parameters_proto.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | from mlagents.envs.communicator_objects import ( 19 | resolution_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2, 20 | ) 21 | from mlagents.envs.communicator_objects import ( 22 | space_type_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_space__type__proto__pb2, 23 | ) 24 | 25 | 26 | DESCRIPTOR = _descriptor.FileDescriptor( 27 | name="mlagents/envs/communicator_objects/brain_parameters_proto.proto", 28 | package="communicator_objects", 29 | syntax="proto3", 30 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 31 | serialized_pb=_b( 32 | '\n?mlagents/envs/communicator_objects/brain_parameters_proto.proto\x12\x14\x63ommunicator_objects\x1a\x39mlagents/envs/communicator_objects/resolution_proto.proto\x1a\x39mlagents/envs/communicator_objects/space_type_proto.proto"\xd4\x02\n\x14\x42rainParametersProto\x12\x1f\n\x17vector_observation_size\x18\x01 \x01(\x05\x12\'\n\x1fnum_stacked_vector_observations\x18\x02 \x01(\x05\x12\x1a\n\x12vector_action_size\x18\x03 \x03(\x05\x12\x41\n\x12\x63\x61mera_resolutions\x18\x04 \x03(\x0b\x32%.communicator_objects.ResolutionProto\x12"\n\x1avector_action_descriptions\x18\x05 \x03(\t\x12\x46\n\x18vector_action_space_type\x18\x06 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12\x12\n\nbrain_name\x18\x07 \x01(\t\x12\x13\n\x0bis_training\x18\x08 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 33 | ), 34 | dependencies=[ 35 | mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR, 36 | mlagents_dot_envs_dot_communicator__objects_dot_space__type__proto__pb2.DESCRIPTOR, 37 | ], 38 | ) 39 | 40 | 41 | _BRAINPARAMETERSPROTO = _descriptor.Descriptor( 42 | name="BrainParametersProto", 43 | full_name="communicator_objects.BrainParametersProto", 44 | filename=None, 45 | file=DESCRIPTOR, 46 | containing_type=None, 47 | fields=[ 48 | _descriptor.FieldDescriptor( 49 | name="vector_observation_size", 50 | full_name="communicator_objects.BrainParametersProto.vector_observation_size", 51 | index=0, 52 | number=1, 53 | type=5, 54 | cpp_type=1, 55 | label=1, 56 | has_default_value=False, 57 | default_value=0, 58 | message_type=None, 59 | enum_type=None, 60 | containing_type=None, 61 | is_extension=False, 62 | extension_scope=None, 63 | serialized_options=None, 64 | file=DESCRIPTOR, 65 | ), 66 | _descriptor.FieldDescriptor( 67 | name="num_stacked_vector_observations", 68 | full_name="communicator_objects.BrainParametersProto.num_stacked_vector_observations", 69 | index=1, 70 | number=2, 71 | type=5, 72 | cpp_type=1, 73 | label=1, 74 | has_default_value=False, 75 | default_value=0, 76 | message_type=None, 77 | enum_type=None, 78 | containing_type=None, 79 | is_extension=False, 80 | extension_scope=None, 81 | serialized_options=None, 82 | file=DESCRIPTOR, 83 | ), 84 | _descriptor.FieldDescriptor( 85 | name="vector_action_size", 86 | full_name="communicator_objects.BrainParametersProto.vector_action_size", 87 | index=2, 88 | number=3, 89 | type=5, 90 | cpp_type=1, 91 | label=3, 92 | has_default_value=False, 93 | default_value=[], 94 | message_type=None, 95 | enum_type=None, 96 | containing_type=None, 97 | is_extension=False, 98 | extension_scope=None, 99 | serialized_options=None, 100 | file=DESCRIPTOR, 101 | ), 102 | _descriptor.FieldDescriptor( 103 | name="camera_resolutions", 104 | full_name="communicator_objects.BrainParametersProto.camera_resolutions", 105 | index=3, 106 | number=4, 107 | type=11, 108 | cpp_type=10, 109 | label=3, 110 | has_default_value=False, 111 | default_value=[], 112 | message_type=None, 113 | enum_type=None, 114 | containing_type=None, 115 | is_extension=False, 116 | extension_scope=None, 117 | serialized_options=None, 118 | file=DESCRIPTOR, 119 | ), 120 | _descriptor.FieldDescriptor( 121 | name="vector_action_descriptions", 122 | full_name="communicator_objects.BrainParametersProto.vector_action_descriptions", 123 | index=4, 124 | number=5, 125 | type=9, 126 | cpp_type=9, 127 | label=3, 128 | has_default_value=False, 129 | default_value=[], 130 | message_type=None, 131 | enum_type=None, 132 | containing_type=None, 133 | is_extension=False, 134 | extension_scope=None, 135 | serialized_options=None, 136 | file=DESCRIPTOR, 137 | ), 138 | _descriptor.FieldDescriptor( 139 | name="vector_action_space_type", 140 | full_name="communicator_objects.BrainParametersProto.vector_action_space_type", 141 | index=5, 142 | number=6, 143 | type=14, 144 | cpp_type=8, 145 | label=1, 146 | has_default_value=False, 147 | default_value=0, 148 | message_type=None, 149 | enum_type=None, 150 | containing_type=None, 151 | is_extension=False, 152 | extension_scope=None, 153 | serialized_options=None, 154 | file=DESCRIPTOR, 155 | ), 156 | _descriptor.FieldDescriptor( 157 | name="brain_name", 158 | full_name="communicator_objects.BrainParametersProto.brain_name", 159 | index=6, 160 | number=7, 161 | type=9, 162 | cpp_type=9, 163 | label=1, 164 | has_default_value=False, 165 | default_value=_b("").decode("utf-8"), 166 | message_type=None, 167 | enum_type=None, 168 | containing_type=None, 169 | is_extension=False, 170 | extension_scope=None, 171 | serialized_options=None, 172 | file=DESCRIPTOR, 173 | ), 174 | _descriptor.FieldDescriptor( 175 | name="is_training", 176 | full_name="communicator_objects.BrainParametersProto.is_training", 177 | index=7, 178 | number=8, 179 | type=8, 180 | cpp_type=7, 181 | label=1, 182 | has_default_value=False, 183 | default_value=False, 184 | message_type=None, 185 | enum_type=None, 186 | containing_type=None, 187 | is_extension=False, 188 | extension_scope=None, 189 | serialized_options=None, 190 | file=DESCRIPTOR, 191 | ), 192 | ], 193 | extensions=[], 194 | nested_types=[], 195 | enum_types=[], 196 | serialized_options=None, 197 | is_extendable=False, 198 | syntax="proto3", 199 | extension_ranges=[], 200 | oneofs=[], 201 | serialized_start=208, 202 | serialized_end=548, 203 | ) 204 | 205 | _BRAINPARAMETERSPROTO.fields_by_name[ 206 | "camera_resolutions" 207 | ].message_type = ( 208 | mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2._RESOLUTIONPROTO 209 | ) 210 | _BRAINPARAMETERSPROTO.fields_by_name[ 211 | "vector_action_space_type" 212 | ].enum_type = ( 213 | mlagents_dot_envs_dot_communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO 214 | ) 215 | DESCRIPTOR.message_types_by_name["BrainParametersProto"] = _BRAINPARAMETERSPROTO 216 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 217 | 218 | BrainParametersProto = _reflection.GeneratedProtocolMessageType( 219 | "BrainParametersProto", 220 | (_message.Message,), 221 | dict( 222 | DESCRIPTOR=_BRAINPARAMETERSPROTO, 223 | __module__="mlagents.envs.communicator_objects.brain_parameters_proto_pb2" 224 | # @@protoc_insertion_point(class_scope:communicator_objects.BrainParametersProto) 225 | ), 226 | ) 227 | _sym_db.RegisterMessage(BrainParametersProto) 228 | 229 | 230 | DESCRIPTOR._options = None 231 | # @@protoc_insertion_point(module_scope) 232 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/command_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/command_proto.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf.internal import enum_type_wrapper 9 | from google.protobuf import descriptor as _descriptor 10 | from google.protobuf import message as _message 11 | from google.protobuf import reflection as _reflection 12 | from google.protobuf import symbol_database as _symbol_database 13 | 14 | # @@protoc_insertion_point(imports) 15 | 16 | _sym_db = _symbol_database.Default() 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name="mlagents/envs/communicator_objects/command_proto.proto", 21 | package="communicator_objects", 22 | syntax="proto3", 23 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 24 | serialized_pb=_b( 25 | "\n6mlagents/envs/communicator_objects/command_proto.proto\x12\x14\x63ommunicator_objects*-\n\x0c\x43ommandProto\x12\x08\n\x04STEP\x10\x00\x12\t\n\x05RESET\x10\x01\x12\x08\n\x04QUIT\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3" 26 | ), 27 | ) 28 | 29 | _COMMANDPROTO = _descriptor.EnumDescriptor( 30 | name="CommandProto", 31 | full_name="communicator_objects.CommandProto", 32 | filename=None, 33 | file=DESCRIPTOR, 34 | values=[ 35 | _descriptor.EnumValueDescriptor( 36 | name="STEP", index=0, number=0, serialized_options=None, type=None 37 | ), 38 | _descriptor.EnumValueDescriptor( 39 | name="RESET", index=1, number=1, serialized_options=None, type=None 40 | ), 41 | _descriptor.EnumValueDescriptor( 42 | name="QUIT", index=2, number=2, serialized_options=None, type=None 43 | ), 44 | ], 45 | containing_type=None, 46 | serialized_options=None, 47 | serialized_start=80, 48 | serialized_end=125, 49 | ) 50 | _sym_db.RegisterEnumDescriptor(_COMMANDPROTO) 51 | 52 | CommandProto = enum_type_wrapper.EnumTypeWrapper(_COMMANDPROTO) 53 | STEP = 0 54 | RESET = 1 55 | QUIT = 2 56 | 57 | 58 | DESCRIPTOR.enum_types_by_name["CommandProto"] = _COMMANDPROTO 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | 62 | DESCRIPTOR._options = None 63 | # @@protoc_insertion_point(module_scope) 64 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/custom_action_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/custom_action.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name="mlagents/envs/communicator_objects/custom_action.proto", 20 | package="communicator_objects", 21 | syntax="proto3", 22 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 23 | serialized_pb=_b( 24 | '\n6mlagents/envs/communicator_objects/custom_action.proto\x12\x14\x63ommunicator_objects"\x0e\n\x0c\x43ustomActionB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 25 | ), 26 | ) 27 | 28 | 29 | _CUSTOMACTION = _descriptor.Descriptor( 30 | name="CustomAction", 31 | full_name="communicator_objects.CustomAction", 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[], 36 | extensions=[], 37 | nested_types=[], 38 | enum_types=[], 39 | serialized_options=None, 40 | is_extendable=False, 41 | syntax="proto3", 42 | extension_ranges=[], 43 | oneofs=[], 44 | serialized_start=80, 45 | serialized_end=94, 46 | ) 47 | 48 | DESCRIPTOR.message_types_by_name["CustomAction"] = _CUSTOMACTION 49 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 50 | 51 | CustomAction = _reflection.GeneratedProtocolMessageType( 52 | "CustomAction", 53 | (_message.Message,), 54 | dict( 55 | DESCRIPTOR=_CUSTOMACTION, 56 | __module__="mlagents.envs.communicator_objects.custom_action_pb2" 57 | # @@protoc_insertion_point(class_scope:communicator_objects.CustomAction) 58 | ), 59 | ) 60 | _sym_db.RegisterMessage(CustomAction) 61 | 62 | 63 | DESCRIPTOR._options = None 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/custom_observation_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/custom_observation.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name="mlagents/envs/communicator_objects/custom_observation.proto", 20 | package="communicator_objects", 21 | syntax="proto3", 22 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 23 | serialized_pb=_b( 24 | '\n;mlagents/envs/communicator_objects/custom_observation.proto\x12\x14\x63ommunicator_objects"\x13\n\x11\x43ustomObservationB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 25 | ), 26 | ) 27 | 28 | 29 | _CUSTOMOBSERVATION = _descriptor.Descriptor( 30 | name="CustomObservation", 31 | full_name="communicator_objects.CustomObservation", 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[], 36 | extensions=[], 37 | nested_types=[], 38 | enum_types=[], 39 | serialized_options=None, 40 | is_extendable=False, 41 | syntax="proto3", 42 | extension_ranges=[], 43 | oneofs=[], 44 | serialized_start=85, 45 | serialized_end=104, 46 | ) 47 | 48 | DESCRIPTOR.message_types_by_name["CustomObservation"] = _CUSTOMOBSERVATION 49 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 50 | 51 | CustomObservation = _reflection.GeneratedProtocolMessageType( 52 | "CustomObservation", 53 | (_message.Message,), 54 | dict( 55 | DESCRIPTOR=_CUSTOMOBSERVATION, 56 | __module__="mlagents.envs.communicator_objects.custom_observation_pb2" 57 | # @@protoc_insertion_point(class_scope:communicator_objects.CustomObservation) 58 | ), 59 | ) 60 | _sym_db.RegisterMessage(CustomObservation) 61 | 62 | 63 | DESCRIPTOR._options = None 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/custom_reset_parameters_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/custom_reset_parameters.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name="mlagents/envs/communicator_objects/custom_reset_parameters.proto", 20 | package="communicator_objects", 21 | syntax="proto3", 22 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 23 | serialized_pb=_b( 24 | '\n@mlagents/envs/communicator_objects/custom_reset_parameters.proto\x12\x14\x63ommunicator_objects"\x17\n\x15\x43ustomResetParametersB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 25 | ), 26 | ) 27 | 28 | 29 | _CUSTOMRESETPARAMETERS = _descriptor.Descriptor( 30 | name="CustomResetParameters", 31 | full_name="communicator_objects.CustomResetParameters", 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[], 36 | extensions=[], 37 | nested_types=[], 38 | enum_types=[], 39 | serialized_options=None, 40 | is_extendable=False, 41 | syntax="proto3", 42 | extension_ranges=[], 43 | oneofs=[], 44 | serialized_start=90, 45 | serialized_end=113, 46 | ) 47 | 48 | DESCRIPTOR.message_types_by_name["CustomResetParameters"] = _CUSTOMRESETPARAMETERS 49 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 50 | 51 | CustomResetParameters = _reflection.GeneratedProtocolMessageType( 52 | "CustomResetParameters", 53 | (_message.Message,), 54 | dict( 55 | DESCRIPTOR=_CUSTOMRESETPARAMETERS, 56 | __module__="mlagents.envs.communicator_objects.custom_reset_parameters_pb2" 57 | # @@protoc_insertion_point(class_scope:communicator_objects.CustomResetParameters) 58 | ), 59 | ) 60 | _sym_db.RegisterMessage(CustomResetParameters) 61 | 62 | 63 | DESCRIPTOR._options = None 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/demonstration_meta_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/demonstration_meta_proto.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name="mlagents/envs/communicator_objects/demonstration_meta_proto.proto", 20 | package="communicator_objects", 21 | syntax="proto3", 22 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 23 | serialized_pb=_b( 24 | '\nAmlagents/envs/communicator_objects/demonstration_meta_proto.proto\x12\x14\x63ommunicator_objects"\x8d\x01\n\x16\x44\x65monstrationMetaProto\x12\x13\n\x0b\x61pi_version\x18\x01 \x01(\x05\x12\x1a\n\x12\x64\x65monstration_name\x18\x02 \x01(\t\x12\x14\n\x0cnumber_steps\x18\x03 \x01(\x05\x12\x17\n\x0fnumber_episodes\x18\x04 \x01(\x05\x12\x13\n\x0bmean_reward\x18\x05 \x01(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 25 | ), 26 | ) 27 | 28 | 29 | _DEMONSTRATIONMETAPROTO = _descriptor.Descriptor( 30 | name="DemonstrationMetaProto", 31 | full_name="communicator_objects.DemonstrationMetaProto", 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name="api_version", 38 | full_name="communicator_objects.DemonstrationMetaProto.api_version", 39 | index=0, 40 | number=1, 41 | type=5, 42 | cpp_type=1, 43 | label=1, 44 | has_default_value=False, 45 | default_value=0, 46 | message_type=None, 47 | enum_type=None, 48 | containing_type=None, 49 | is_extension=False, 50 | extension_scope=None, 51 | serialized_options=None, 52 | file=DESCRIPTOR, 53 | ), 54 | _descriptor.FieldDescriptor( 55 | name="demonstration_name", 56 | full_name="communicator_objects.DemonstrationMetaProto.demonstration_name", 57 | index=1, 58 | number=2, 59 | type=9, 60 | cpp_type=9, 61 | label=1, 62 | has_default_value=False, 63 | default_value=_b("").decode("utf-8"), 64 | message_type=None, 65 | enum_type=None, 66 | containing_type=None, 67 | is_extension=False, 68 | extension_scope=None, 69 | serialized_options=None, 70 | file=DESCRIPTOR, 71 | ), 72 | _descriptor.FieldDescriptor( 73 | name="number_steps", 74 | full_name="communicator_objects.DemonstrationMetaProto.number_steps", 75 | index=2, 76 | number=3, 77 | type=5, 78 | cpp_type=1, 79 | label=1, 80 | has_default_value=False, 81 | default_value=0, 82 | message_type=None, 83 | enum_type=None, 84 | containing_type=None, 85 | is_extension=False, 86 | extension_scope=None, 87 | serialized_options=None, 88 | file=DESCRIPTOR, 89 | ), 90 | _descriptor.FieldDescriptor( 91 | name="number_episodes", 92 | full_name="communicator_objects.DemonstrationMetaProto.number_episodes", 93 | index=3, 94 | number=4, 95 | type=5, 96 | cpp_type=1, 97 | label=1, 98 | has_default_value=False, 99 | default_value=0, 100 | message_type=None, 101 | enum_type=None, 102 | containing_type=None, 103 | is_extension=False, 104 | extension_scope=None, 105 | serialized_options=None, 106 | file=DESCRIPTOR, 107 | ), 108 | _descriptor.FieldDescriptor( 109 | name="mean_reward", 110 | full_name="communicator_objects.DemonstrationMetaProto.mean_reward", 111 | index=4, 112 | number=5, 113 | type=2, 114 | cpp_type=6, 115 | label=1, 116 | has_default_value=False, 117 | default_value=float(0), 118 | message_type=None, 119 | enum_type=None, 120 | containing_type=None, 121 | is_extension=False, 122 | extension_scope=None, 123 | serialized_options=None, 124 | file=DESCRIPTOR, 125 | ), 126 | ], 127 | extensions=[], 128 | nested_types=[], 129 | enum_types=[], 130 | serialized_options=None, 131 | is_extendable=False, 132 | syntax="proto3", 133 | extension_ranges=[], 134 | oneofs=[], 135 | serialized_start=92, 136 | serialized_end=233, 137 | ) 138 | 139 | DESCRIPTOR.message_types_by_name["DemonstrationMetaProto"] = _DEMONSTRATIONMETAPROTO 140 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 141 | 142 | DemonstrationMetaProto = _reflection.GeneratedProtocolMessageType( 143 | "DemonstrationMetaProto", 144 | (_message.Message,), 145 | dict( 146 | DESCRIPTOR=_DEMONSTRATIONMETAPROTO, 147 | __module__="mlagents.envs.communicator_objects.demonstration_meta_proto_pb2" 148 | # @@protoc_insertion_point(class_scope:communicator_objects.DemonstrationMetaProto) 149 | ), 150 | ) 151 | _sym_db.RegisterMessage(DemonstrationMetaProto) 152 | 153 | 154 | DESCRIPTOR._options = None 155 | # @@protoc_insertion_point(module_scope) 156 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/engine_configuration_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/engine_configuration_proto.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name="mlagents/envs/communicator_objects/engine_configuration_proto.proto", 20 | package="communicator_objects", 21 | syntax="proto3", 22 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 23 | serialized_pb=_b( 24 | '\nCmlagents/envs/communicator_objects/engine_configuration_proto.proto\x12\x14\x63ommunicator_objects"\x95\x01\n\x18\x45ngineConfigurationProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x15\n\rquality_level\x18\x03 \x01(\x05\x12\x12\n\ntime_scale\x18\x04 \x01(\x02\x12\x19\n\x11target_frame_rate\x18\x05 \x01(\x05\x12\x14\n\x0cshow_monitor\x18\x06 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 25 | ), 26 | ) 27 | 28 | 29 | _ENGINECONFIGURATIONPROTO = _descriptor.Descriptor( 30 | name="EngineConfigurationProto", 31 | full_name="communicator_objects.EngineConfigurationProto", 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name="width", 38 | full_name="communicator_objects.EngineConfigurationProto.width", 39 | index=0, 40 | number=1, 41 | type=5, 42 | cpp_type=1, 43 | label=1, 44 | has_default_value=False, 45 | default_value=0, 46 | message_type=None, 47 | enum_type=None, 48 | containing_type=None, 49 | is_extension=False, 50 | extension_scope=None, 51 | serialized_options=None, 52 | file=DESCRIPTOR, 53 | ), 54 | _descriptor.FieldDescriptor( 55 | name="height", 56 | full_name="communicator_objects.EngineConfigurationProto.height", 57 | index=1, 58 | number=2, 59 | type=5, 60 | cpp_type=1, 61 | label=1, 62 | has_default_value=False, 63 | default_value=0, 64 | message_type=None, 65 | enum_type=None, 66 | containing_type=None, 67 | is_extension=False, 68 | extension_scope=None, 69 | serialized_options=None, 70 | file=DESCRIPTOR, 71 | ), 72 | _descriptor.FieldDescriptor( 73 | name="quality_level", 74 | full_name="communicator_objects.EngineConfigurationProto.quality_level", 75 | index=2, 76 | number=3, 77 | type=5, 78 | cpp_type=1, 79 | label=1, 80 | has_default_value=False, 81 | default_value=0, 82 | message_type=None, 83 | enum_type=None, 84 | containing_type=None, 85 | is_extension=False, 86 | extension_scope=None, 87 | serialized_options=None, 88 | file=DESCRIPTOR, 89 | ), 90 | _descriptor.FieldDescriptor( 91 | name="time_scale", 92 | full_name="communicator_objects.EngineConfigurationProto.time_scale", 93 | index=3, 94 | number=4, 95 | type=2, 96 | cpp_type=6, 97 | label=1, 98 | has_default_value=False, 99 | default_value=float(0), 100 | message_type=None, 101 | enum_type=None, 102 | containing_type=None, 103 | is_extension=False, 104 | extension_scope=None, 105 | serialized_options=None, 106 | file=DESCRIPTOR, 107 | ), 108 | _descriptor.FieldDescriptor( 109 | name="target_frame_rate", 110 | full_name="communicator_objects.EngineConfigurationProto.target_frame_rate", 111 | index=4, 112 | number=5, 113 | type=5, 114 | cpp_type=1, 115 | label=1, 116 | has_default_value=False, 117 | default_value=0, 118 | message_type=None, 119 | enum_type=None, 120 | containing_type=None, 121 | is_extension=False, 122 | extension_scope=None, 123 | serialized_options=None, 124 | file=DESCRIPTOR, 125 | ), 126 | _descriptor.FieldDescriptor( 127 | name="show_monitor", 128 | full_name="communicator_objects.EngineConfigurationProto.show_monitor", 129 | index=5, 130 | number=6, 131 | type=8, 132 | cpp_type=7, 133 | label=1, 134 | has_default_value=False, 135 | default_value=False, 136 | message_type=None, 137 | enum_type=None, 138 | containing_type=None, 139 | is_extension=False, 140 | extension_scope=None, 141 | serialized_options=None, 142 | file=DESCRIPTOR, 143 | ), 144 | ], 145 | extensions=[], 146 | nested_types=[], 147 | enum_types=[], 148 | serialized_options=None, 149 | is_extendable=False, 150 | syntax="proto3", 151 | extension_ranges=[], 152 | oneofs=[], 153 | serialized_start=94, 154 | serialized_end=243, 155 | ) 156 | 157 | DESCRIPTOR.message_types_by_name["EngineConfigurationProto"] = _ENGINECONFIGURATIONPROTO 158 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 159 | 160 | EngineConfigurationProto = _reflection.GeneratedProtocolMessageType( 161 | "EngineConfigurationProto", 162 | (_message.Message,), 163 | dict( 164 | DESCRIPTOR=_ENGINECONFIGURATIONPROTO, 165 | __module__="mlagents.envs.communicator_objects.engine_configuration_proto_pb2" 166 | # @@protoc_insertion_point(class_scope:communicator_objects.EngineConfigurationProto) 167 | ), 168 | ) 169 | _sym_db.RegisterMessage(EngineConfigurationProto) 170 | 171 | 172 | DESCRIPTOR._options = None 173 | # @@protoc_insertion_point(module_scope) 174 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/environment_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/environment_parameters_proto.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | from mlagents.envs.communicator_objects import ( 19 | custom_reset_parameters_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_custom__reset__parameters__pb2, 20 | ) 21 | 22 | 23 | DESCRIPTOR = _descriptor.FileDescriptor( 24 | name="mlagents/envs/communicator_objects/environment_parameters_proto.proto", 25 | package="communicator_objects", 26 | syntax="proto3", 27 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 28 | serialized_pb=_b( 29 | '\nEmlagents/envs/communicator_objects/environment_parameters_proto.proto\x12\x14\x63ommunicator_objects\x1a@mlagents/envs/communicator_objects/custom_reset_parameters.proto"\x83\x02\n\x1a\x45nvironmentParametersProto\x12_\n\x10\x66loat_parameters\x18\x01 \x03(\x0b\x32\x45.communicator_objects.EnvironmentParametersProto.FloatParametersEntry\x12L\n\x17\x63ustom_reset_parameters\x18\x02 \x01(\x0b\x32+.communicator_objects.CustomResetParameters\x1a\x36\n\x14\x46loatParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 30 | ), 31 | dependencies=[ 32 | mlagents_dot_envs_dot_communicator__objects_dot_custom__reset__parameters__pb2.DESCRIPTOR 33 | ], 34 | ) 35 | 36 | 37 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY = _descriptor.Descriptor( 38 | name="FloatParametersEntry", 39 | full_name="communicator_objects.EnvironmentParametersProto.FloatParametersEntry", 40 | filename=None, 41 | file=DESCRIPTOR, 42 | containing_type=None, 43 | fields=[ 44 | _descriptor.FieldDescriptor( 45 | name="key", 46 | full_name="communicator_objects.EnvironmentParametersProto.FloatParametersEntry.key", 47 | index=0, 48 | number=1, 49 | type=9, 50 | cpp_type=9, 51 | label=1, 52 | has_default_value=False, 53 | default_value=_b("").decode("utf-8"), 54 | message_type=None, 55 | enum_type=None, 56 | containing_type=None, 57 | is_extension=False, 58 | extension_scope=None, 59 | serialized_options=None, 60 | file=DESCRIPTOR, 61 | ), 62 | _descriptor.FieldDescriptor( 63 | name="value", 64 | full_name="communicator_objects.EnvironmentParametersProto.FloatParametersEntry.value", 65 | index=1, 66 | number=2, 67 | type=2, 68 | cpp_type=6, 69 | label=1, 70 | has_default_value=False, 71 | default_value=float(0), 72 | message_type=None, 73 | enum_type=None, 74 | containing_type=None, 75 | is_extension=False, 76 | extension_scope=None, 77 | serialized_options=None, 78 | file=DESCRIPTOR, 79 | ), 80 | ], 81 | extensions=[], 82 | nested_types=[], 83 | enum_types=[], 84 | serialized_options=_b("8\001"), 85 | is_extendable=False, 86 | syntax="proto3", 87 | extension_ranges=[], 88 | oneofs=[], 89 | serialized_start=367, 90 | serialized_end=421, 91 | ) 92 | 93 | _ENVIRONMENTPARAMETERSPROTO = _descriptor.Descriptor( 94 | name="EnvironmentParametersProto", 95 | full_name="communicator_objects.EnvironmentParametersProto", 96 | filename=None, 97 | file=DESCRIPTOR, 98 | containing_type=None, 99 | fields=[ 100 | _descriptor.FieldDescriptor( 101 | name="float_parameters", 102 | full_name="communicator_objects.EnvironmentParametersProto.float_parameters", 103 | index=0, 104 | number=1, 105 | type=11, 106 | cpp_type=10, 107 | label=3, 108 | has_default_value=False, 109 | default_value=[], 110 | message_type=None, 111 | enum_type=None, 112 | containing_type=None, 113 | is_extension=False, 114 | extension_scope=None, 115 | serialized_options=None, 116 | file=DESCRIPTOR, 117 | ), 118 | _descriptor.FieldDescriptor( 119 | name="custom_reset_parameters", 120 | full_name="communicator_objects.EnvironmentParametersProto.custom_reset_parameters", 121 | index=1, 122 | number=2, 123 | type=11, 124 | cpp_type=10, 125 | label=1, 126 | has_default_value=False, 127 | default_value=None, 128 | message_type=None, 129 | enum_type=None, 130 | containing_type=None, 131 | is_extension=False, 132 | extension_scope=None, 133 | serialized_options=None, 134 | file=DESCRIPTOR, 135 | ), 136 | ], 137 | extensions=[], 138 | nested_types=[_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY], 139 | enum_types=[], 140 | serialized_options=None, 141 | is_extendable=False, 142 | syntax="proto3", 143 | extension_ranges=[], 144 | oneofs=[], 145 | serialized_start=162, 146 | serialized_end=421, 147 | ) 148 | 149 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.containing_type = ( 150 | _ENVIRONMENTPARAMETERSPROTO 151 | ) 152 | _ENVIRONMENTPARAMETERSPROTO.fields_by_name[ 153 | "float_parameters" 154 | ].message_type = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY 155 | _ENVIRONMENTPARAMETERSPROTO.fields_by_name[ 156 | "custom_reset_parameters" 157 | ].message_type = ( 158 | mlagents_dot_envs_dot_communicator__objects_dot_custom__reset__parameters__pb2._CUSTOMRESETPARAMETERS 159 | ) 160 | DESCRIPTOR.message_types_by_name[ 161 | "EnvironmentParametersProto" 162 | ] = _ENVIRONMENTPARAMETERSPROTO 163 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 164 | 165 | EnvironmentParametersProto = _reflection.GeneratedProtocolMessageType( 166 | "EnvironmentParametersProto", 167 | (_message.Message,), 168 | dict( 169 | FloatParametersEntry=_reflection.GeneratedProtocolMessageType( 170 | "FloatParametersEntry", 171 | (_message.Message,), 172 | dict( 173 | DESCRIPTOR=_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, 174 | __module__="mlagents.envs.communicator_objects.environment_parameters_proto_pb2" 175 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto.FloatParametersEntry) 176 | ), 177 | ), 178 | DESCRIPTOR=_ENVIRONMENTPARAMETERSPROTO, 179 | __module__="mlagents.envs.communicator_objects.environment_parameters_proto_pb2" 180 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto) 181 | ), 182 | ) 183 | _sym_db.RegisterMessage(EnvironmentParametersProto) 184 | _sym_db.RegisterMessage(EnvironmentParametersProto.FloatParametersEntry) 185 | 186 | 187 | DESCRIPTOR._options = None 188 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY._options = None 189 | # @@protoc_insertion_point(module_scope) 190 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/header_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/header.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name="mlagents/envs/communicator_objects/header.proto", 20 | package="communicator_objects", 21 | syntax="proto3", 22 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 23 | serialized_pb=_b( 24 | '\n/mlagents/envs/communicator_objects/header.proto\x12\x14\x63ommunicator_objects")\n\x06Header\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 25 | ), 26 | ) 27 | 28 | 29 | _HEADER = _descriptor.Descriptor( 30 | name="Header", 31 | full_name="communicator_objects.Header", 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name="status", 38 | full_name="communicator_objects.Header.status", 39 | index=0, 40 | number=1, 41 | type=5, 42 | cpp_type=1, 43 | label=1, 44 | has_default_value=False, 45 | default_value=0, 46 | message_type=None, 47 | enum_type=None, 48 | containing_type=None, 49 | is_extension=False, 50 | extension_scope=None, 51 | serialized_options=None, 52 | file=DESCRIPTOR, 53 | ), 54 | _descriptor.FieldDescriptor( 55 | name="message", 56 | full_name="communicator_objects.Header.message", 57 | index=1, 58 | number=2, 59 | type=9, 60 | cpp_type=9, 61 | label=1, 62 | has_default_value=False, 63 | default_value=_b("").decode("utf-8"), 64 | message_type=None, 65 | enum_type=None, 66 | containing_type=None, 67 | is_extension=False, 68 | extension_scope=None, 69 | serialized_options=None, 70 | file=DESCRIPTOR, 71 | ), 72 | ], 73 | extensions=[], 74 | nested_types=[], 75 | enum_types=[], 76 | serialized_options=None, 77 | is_extendable=False, 78 | syntax="proto3", 79 | extension_ranges=[], 80 | oneofs=[], 81 | serialized_start=73, 82 | serialized_end=114, 83 | ) 84 | 85 | DESCRIPTOR.message_types_by_name["Header"] = _HEADER 86 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 87 | 88 | Header = _reflection.GeneratedProtocolMessageType( 89 | "Header", 90 | (_message.Message,), 91 | dict( 92 | DESCRIPTOR=_HEADER, 93 | __module__="mlagents.envs.communicator_objects.header_pb2" 94 | # @@protoc_insertion_point(class_scope:communicator_objects.Header) 95 | ), 96 | ) 97 | _sym_db.RegisterMessage(Header) 98 | 99 | 100 | DESCRIPTOR._options = None 101 | # @@protoc_insertion_point(module_scope) 102 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/resolution_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/resolution_proto.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name="mlagents/envs/communicator_objects/resolution_proto.proto", 20 | package="communicator_objects", 21 | syntax="proto3", 22 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 23 | serialized_pb=_b( 24 | '\n9mlagents/envs/communicator_objects/resolution_proto.proto\x12\x14\x63ommunicator_objects"D\n\x0fResolutionProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x12\n\ngray_scale\x18\x03 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 25 | ), 26 | ) 27 | 28 | 29 | _RESOLUTIONPROTO = _descriptor.Descriptor( 30 | name="ResolutionProto", 31 | full_name="communicator_objects.ResolutionProto", 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name="width", 38 | full_name="communicator_objects.ResolutionProto.width", 39 | index=0, 40 | number=1, 41 | type=5, 42 | cpp_type=1, 43 | label=1, 44 | has_default_value=False, 45 | default_value=0, 46 | message_type=None, 47 | enum_type=None, 48 | containing_type=None, 49 | is_extension=False, 50 | extension_scope=None, 51 | serialized_options=None, 52 | file=DESCRIPTOR, 53 | ), 54 | _descriptor.FieldDescriptor( 55 | name="height", 56 | full_name="communicator_objects.ResolutionProto.height", 57 | index=1, 58 | number=2, 59 | type=5, 60 | cpp_type=1, 61 | label=1, 62 | has_default_value=False, 63 | default_value=0, 64 | message_type=None, 65 | enum_type=None, 66 | containing_type=None, 67 | is_extension=False, 68 | extension_scope=None, 69 | serialized_options=None, 70 | file=DESCRIPTOR, 71 | ), 72 | _descriptor.FieldDescriptor( 73 | name="gray_scale", 74 | full_name="communicator_objects.ResolutionProto.gray_scale", 75 | index=2, 76 | number=3, 77 | type=8, 78 | cpp_type=7, 79 | label=1, 80 | has_default_value=False, 81 | default_value=False, 82 | message_type=None, 83 | enum_type=None, 84 | containing_type=None, 85 | is_extension=False, 86 | extension_scope=None, 87 | serialized_options=None, 88 | file=DESCRIPTOR, 89 | ), 90 | ], 91 | extensions=[], 92 | nested_types=[], 93 | enum_types=[], 94 | serialized_options=None, 95 | is_extendable=False, 96 | syntax="proto3", 97 | extension_ranges=[], 98 | oneofs=[], 99 | serialized_start=83, 100 | serialized_end=151, 101 | ) 102 | 103 | DESCRIPTOR.message_types_by_name["ResolutionProto"] = _RESOLUTIONPROTO 104 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 105 | 106 | ResolutionProto = _reflection.GeneratedProtocolMessageType( 107 | "ResolutionProto", 108 | (_message.Message,), 109 | dict( 110 | DESCRIPTOR=_RESOLUTIONPROTO, 111 | __module__="mlagents.envs.communicator_objects.resolution_proto_pb2" 112 | # @@protoc_insertion_point(class_scope:communicator_objects.ResolutionProto) 113 | ), 114 | ) 115 | _sym_db.RegisterMessage(ResolutionProto) 116 | 117 | 118 | DESCRIPTOR._options = None 119 | # @@protoc_insertion_point(module_scope) 120 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/space_type_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/space_type_proto.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf.internal import enum_type_wrapper 9 | from google.protobuf import descriptor as _descriptor 10 | from google.protobuf import message as _message 11 | from google.protobuf import reflection as _reflection 12 | from google.protobuf import symbol_database as _symbol_database 13 | 14 | # @@protoc_insertion_point(imports) 15 | 16 | _sym_db = _symbol_database.Default() 17 | 18 | 19 | from mlagents.envs.communicator_objects import ( 20 | resolution_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2, 21 | ) 22 | 23 | 24 | DESCRIPTOR = _descriptor.FileDescriptor( 25 | name="mlagents/envs/communicator_objects/space_type_proto.proto", 26 | package="communicator_objects", 27 | syntax="proto3", 28 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 29 | serialized_pb=_b( 30 | "\n9mlagents/envs/communicator_objects/space_type_proto.proto\x12\x14\x63ommunicator_objects\x1a\x39mlagents/envs/communicator_objects/resolution_proto.proto*.\n\x0eSpaceTypeProto\x12\x0c\n\x08\x64iscrete\x10\x00\x12\x0e\n\ncontinuous\x10\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3" 31 | ), 32 | dependencies=[ 33 | mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR 34 | ], 35 | ) 36 | 37 | _SPACETYPEPROTO = _descriptor.EnumDescriptor( 38 | name="SpaceTypeProto", 39 | full_name="communicator_objects.SpaceTypeProto", 40 | filename=None, 41 | file=DESCRIPTOR, 42 | values=[ 43 | _descriptor.EnumValueDescriptor( 44 | name="discrete", index=0, number=0, serialized_options=None, type=None 45 | ), 46 | _descriptor.EnumValueDescriptor( 47 | name="continuous", index=1, number=1, serialized_options=None, type=None 48 | ), 49 | ], 50 | containing_type=None, 51 | serialized_options=None, 52 | serialized_start=142, 53 | serialized_end=188, 54 | ) 55 | _sym_db.RegisterEnumDescriptor(_SPACETYPEPROTO) 56 | 57 | SpaceTypeProto = enum_type_wrapper.EnumTypeWrapper(_SPACETYPEPROTO) 58 | discrete = 0 59 | continuous = 1 60 | 61 | 62 | DESCRIPTOR.enum_types_by_name["SpaceTypeProto"] = _SPACETYPEPROTO 63 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 64 | 65 | 66 | DESCRIPTOR._options = None 67 | # @@protoc_insertion_point(module_scope) 68 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_input_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_input.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | from mlagents.envs.communicator_objects import ( 19 | unity_rl_input_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__input__pb2, 20 | ) 21 | from mlagents.envs.communicator_objects import ( 22 | unity_rl_initialization_input_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__input__pb2, 23 | ) 24 | 25 | 26 | DESCRIPTOR = _descriptor.FileDescriptor( 27 | name="mlagents/envs/communicator_objects/unity_input.proto", 28 | package="communicator_objects", 29 | syntax="proto3", 30 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 31 | serialized_pb=_b( 32 | '\n4mlagents/envs/communicator_objects/unity_input.proto\x12\x14\x63ommunicator_objects\x1a\x37mlagents/envs/communicator_objects/unity_rl_input.proto\x1a\x46mlagents/envs/communicator_objects/unity_rl_initialization_input.proto"\x95\x01\n\nUnityInput\x12\x34\n\x08rl_input\x18\x01 \x01(\x0b\x32".communicator_objects.UnityRLInput\x12Q\n\x17rl_initialization_input\x18\x02 \x01(\x0b\x32\x30.communicator_objects.UnityRLInitializationInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 33 | ), 34 | dependencies=[ 35 | mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__input__pb2.DESCRIPTOR, 36 | mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__input__pb2.DESCRIPTOR, 37 | ], 38 | ) 39 | 40 | 41 | _UNITYINPUT = _descriptor.Descriptor( 42 | name="UnityInput", 43 | full_name="communicator_objects.UnityInput", 44 | filename=None, 45 | file=DESCRIPTOR, 46 | containing_type=None, 47 | fields=[ 48 | _descriptor.FieldDescriptor( 49 | name="rl_input", 50 | full_name="communicator_objects.UnityInput.rl_input", 51 | index=0, 52 | number=1, 53 | type=11, 54 | cpp_type=10, 55 | label=1, 56 | has_default_value=False, 57 | default_value=None, 58 | message_type=None, 59 | enum_type=None, 60 | containing_type=None, 61 | is_extension=False, 62 | extension_scope=None, 63 | serialized_options=None, 64 | file=DESCRIPTOR, 65 | ), 66 | _descriptor.FieldDescriptor( 67 | name="rl_initialization_input", 68 | full_name="communicator_objects.UnityInput.rl_initialization_input", 69 | index=1, 70 | number=2, 71 | type=11, 72 | cpp_type=10, 73 | label=1, 74 | has_default_value=False, 75 | default_value=None, 76 | message_type=None, 77 | enum_type=None, 78 | containing_type=None, 79 | is_extension=False, 80 | extension_scope=None, 81 | serialized_options=None, 82 | file=DESCRIPTOR, 83 | ), 84 | ], 85 | extensions=[], 86 | nested_types=[], 87 | enum_types=[], 88 | serialized_options=None, 89 | is_extendable=False, 90 | syntax="proto3", 91 | extension_ranges=[], 92 | oneofs=[], 93 | serialized_start=208, 94 | serialized_end=357, 95 | ) 96 | 97 | _UNITYINPUT.fields_by_name[ 98 | "rl_input" 99 | ].message_type = ( 100 | mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__input__pb2._UNITYRLINPUT 101 | ) 102 | _UNITYINPUT.fields_by_name[ 103 | "rl_initialization_input" 104 | ].message_type = ( 105 | mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__input__pb2._UNITYRLINITIALIZATIONINPUT 106 | ) 107 | DESCRIPTOR.message_types_by_name["UnityInput"] = _UNITYINPUT 108 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 109 | 110 | UnityInput = _reflection.GeneratedProtocolMessageType( 111 | "UnityInput", 112 | (_message.Message,), 113 | dict( 114 | DESCRIPTOR=_UNITYINPUT, 115 | __module__="mlagents.envs.communicator_objects.unity_input_pb2" 116 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityInput) 117 | ), 118 | ) 119 | _sym_db.RegisterMessage(UnityInput) 120 | 121 | 122 | DESCRIPTOR._options = None 123 | # @@protoc_insertion_point(module_scope) 124 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_message_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_message.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | from mlagents.envs.communicator_objects import ( 19 | unity_output_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__output__pb2, 20 | ) 21 | from mlagents.envs.communicator_objects import ( 22 | unity_input_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__input__pb2, 23 | ) 24 | from mlagents.envs.communicator_objects import ( 25 | header_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_header__pb2, 26 | ) 27 | 28 | 29 | DESCRIPTOR = _descriptor.FileDescriptor( 30 | name="mlagents/envs/communicator_objects/unity_message.proto", 31 | package="communicator_objects", 32 | syntax="proto3", 33 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 34 | serialized_pb=_b( 35 | '\n6mlagents/envs/communicator_objects/unity_message.proto\x12\x14\x63ommunicator_objects\x1a\x35mlagents/envs/communicator_objects/unity_output.proto\x1a\x34mlagents/envs/communicator_objects/unity_input.proto\x1a/mlagents/envs/communicator_objects/header.proto"\xac\x01\n\x0cUnityMessage\x12,\n\x06header\x18\x01 \x01(\x0b\x32\x1c.communicator_objects.Header\x12\x37\n\x0cunity_output\x18\x02 \x01(\x0b\x32!.communicator_objects.UnityOutput\x12\x35\n\x0bunity_input\x18\x03 \x01(\x0b\x32 .communicator_objects.UnityInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 36 | ), 37 | dependencies=[ 38 | mlagents_dot_envs_dot_communicator__objects_dot_unity__output__pb2.DESCRIPTOR, 39 | mlagents_dot_envs_dot_communicator__objects_dot_unity__input__pb2.DESCRIPTOR, 40 | mlagents_dot_envs_dot_communicator__objects_dot_header__pb2.DESCRIPTOR, 41 | ], 42 | ) 43 | 44 | 45 | _UNITYMESSAGE = _descriptor.Descriptor( 46 | name="UnityMessage", 47 | full_name="communicator_objects.UnityMessage", 48 | filename=None, 49 | file=DESCRIPTOR, 50 | containing_type=None, 51 | fields=[ 52 | _descriptor.FieldDescriptor( 53 | name="header", 54 | full_name="communicator_objects.UnityMessage.header", 55 | index=0, 56 | number=1, 57 | type=11, 58 | cpp_type=10, 59 | label=1, 60 | has_default_value=False, 61 | default_value=None, 62 | message_type=None, 63 | enum_type=None, 64 | containing_type=None, 65 | is_extension=False, 66 | extension_scope=None, 67 | serialized_options=None, 68 | file=DESCRIPTOR, 69 | ), 70 | _descriptor.FieldDescriptor( 71 | name="unity_output", 72 | full_name="communicator_objects.UnityMessage.unity_output", 73 | index=1, 74 | number=2, 75 | type=11, 76 | cpp_type=10, 77 | label=1, 78 | has_default_value=False, 79 | default_value=None, 80 | message_type=None, 81 | enum_type=None, 82 | containing_type=None, 83 | is_extension=False, 84 | extension_scope=None, 85 | serialized_options=None, 86 | file=DESCRIPTOR, 87 | ), 88 | _descriptor.FieldDescriptor( 89 | name="unity_input", 90 | full_name="communicator_objects.UnityMessage.unity_input", 91 | index=2, 92 | number=3, 93 | type=11, 94 | cpp_type=10, 95 | label=1, 96 | has_default_value=False, 97 | default_value=None, 98 | message_type=None, 99 | enum_type=None, 100 | containing_type=None, 101 | is_extension=False, 102 | extension_scope=None, 103 | serialized_options=None, 104 | file=DESCRIPTOR, 105 | ), 106 | ], 107 | extensions=[], 108 | nested_types=[], 109 | enum_types=[], 110 | serialized_options=None, 111 | is_extendable=False, 112 | syntax="proto3", 113 | extension_ranges=[], 114 | oneofs=[], 115 | serialized_start=239, 116 | serialized_end=411, 117 | ) 118 | 119 | _UNITYMESSAGE.fields_by_name[ 120 | "header" 121 | ].message_type = mlagents_dot_envs_dot_communicator__objects_dot_header__pb2._HEADER 122 | _UNITYMESSAGE.fields_by_name[ 123 | "unity_output" 124 | ].message_type = ( 125 | mlagents_dot_envs_dot_communicator__objects_dot_unity__output__pb2._UNITYOUTPUT 126 | ) 127 | _UNITYMESSAGE.fields_by_name[ 128 | "unity_input" 129 | ].message_type = ( 130 | mlagents_dot_envs_dot_communicator__objects_dot_unity__input__pb2._UNITYINPUT 131 | ) 132 | DESCRIPTOR.message_types_by_name["UnityMessage"] = _UNITYMESSAGE 133 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 134 | 135 | UnityMessage = _reflection.GeneratedProtocolMessageType( 136 | "UnityMessage", 137 | (_message.Message,), 138 | dict( 139 | DESCRIPTOR=_UNITYMESSAGE, 140 | __module__="mlagents.envs.communicator_objects.unity_message_pb2" 141 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityMessage) 142 | ), 143 | ) 144 | _sym_db.RegisterMessage(UnityMessage) 145 | 146 | 147 | DESCRIPTOR._options = None 148 | # @@protoc_insertion_point(module_scope) 149 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_output_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_output.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | from mlagents.envs.communicator_objects import ( 19 | unity_rl_output_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__output__pb2, 20 | ) 21 | from mlagents.envs.communicator_objects import ( 22 | unity_rl_initialization_output_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__output__pb2, 23 | ) 24 | 25 | 26 | DESCRIPTOR = _descriptor.FileDescriptor( 27 | name="mlagents/envs/communicator_objects/unity_output.proto", 28 | package="communicator_objects", 29 | syntax="proto3", 30 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 31 | serialized_pb=_b( 32 | '\n5mlagents/envs/communicator_objects/unity_output.proto\x12\x14\x63ommunicator_objects\x1a\x38mlagents/envs/communicator_objects/unity_rl_output.proto\x1aGmlagents/envs/communicator_objects/unity_rl_initialization_output.proto"\x9a\x01\n\x0bUnityOutput\x12\x36\n\trl_output\x18\x01 \x01(\x0b\x32#.communicator_objects.UnityRLOutput\x12S\n\x18rl_initialization_output\x18\x02 \x01(\x0b\x32\x31.communicator_objects.UnityRLInitializationOutputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 33 | ), 34 | dependencies=[ 35 | mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__output__pb2.DESCRIPTOR, 36 | mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__output__pb2.DESCRIPTOR, 37 | ], 38 | ) 39 | 40 | 41 | _UNITYOUTPUT = _descriptor.Descriptor( 42 | name="UnityOutput", 43 | full_name="communicator_objects.UnityOutput", 44 | filename=None, 45 | file=DESCRIPTOR, 46 | containing_type=None, 47 | fields=[ 48 | _descriptor.FieldDescriptor( 49 | name="rl_output", 50 | full_name="communicator_objects.UnityOutput.rl_output", 51 | index=0, 52 | number=1, 53 | type=11, 54 | cpp_type=10, 55 | label=1, 56 | has_default_value=False, 57 | default_value=None, 58 | message_type=None, 59 | enum_type=None, 60 | containing_type=None, 61 | is_extension=False, 62 | extension_scope=None, 63 | serialized_options=None, 64 | file=DESCRIPTOR, 65 | ), 66 | _descriptor.FieldDescriptor( 67 | name="rl_initialization_output", 68 | full_name="communicator_objects.UnityOutput.rl_initialization_output", 69 | index=1, 70 | number=2, 71 | type=11, 72 | cpp_type=10, 73 | label=1, 74 | has_default_value=False, 75 | default_value=None, 76 | message_type=None, 77 | enum_type=None, 78 | containing_type=None, 79 | is_extension=False, 80 | extension_scope=None, 81 | serialized_options=None, 82 | file=DESCRIPTOR, 83 | ), 84 | ], 85 | extensions=[], 86 | nested_types=[], 87 | enum_types=[], 88 | serialized_options=None, 89 | is_extendable=False, 90 | syntax="proto3", 91 | extension_ranges=[], 92 | oneofs=[], 93 | serialized_start=211, 94 | serialized_end=365, 95 | ) 96 | 97 | _UNITYOUTPUT.fields_by_name[ 98 | "rl_output" 99 | ].message_type = ( 100 | mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__output__pb2._UNITYRLOUTPUT 101 | ) 102 | _UNITYOUTPUT.fields_by_name[ 103 | "rl_initialization_output" 104 | ].message_type = ( 105 | mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__output__pb2._UNITYRLINITIALIZATIONOUTPUT 106 | ) 107 | DESCRIPTOR.message_types_by_name["UnityOutput"] = _UNITYOUTPUT 108 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 109 | 110 | UnityOutput = _reflection.GeneratedProtocolMessageType( 111 | "UnityOutput", 112 | (_message.Message,), 113 | dict( 114 | DESCRIPTOR=_UNITYOUTPUT, 115 | __module__="mlagents.envs.communicator_objects.unity_output_pb2" 116 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityOutput) 117 | ), 118 | ) 119 | _sym_db.RegisterMessage(UnityOutput) 120 | 121 | 122 | DESCRIPTOR._options = None 123 | # @@protoc_insertion_point(module_scope) 124 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_rl_initialization_input_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_rl_initialization_input.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name="mlagents/envs/communicator_objects/unity_rl_initialization_input.proto", 20 | package="communicator_objects", 21 | syntax="proto3", 22 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 23 | serialized_pb=_b( 24 | '\nFmlagents/envs/communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects"*\n\x1aUnityRLInitializationInput\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 25 | ), 26 | ) 27 | 28 | 29 | _UNITYRLINITIALIZATIONINPUT = _descriptor.Descriptor( 30 | name="UnityRLInitializationInput", 31 | full_name="communicator_objects.UnityRLInitializationInput", 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name="seed", 38 | full_name="communicator_objects.UnityRLInitializationInput.seed", 39 | index=0, 40 | number=1, 41 | type=5, 42 | cpp_type=1, 43 | label=1, 44 | has_default_value=False, 45 | default_value=0, 46 | message_type=None, 47 | enum_type=None, 48 | containing_type=None, 49 | is_extension=False, 50 | extension_scope=None, 51 | serialized_options=None, 52 | file=DESCRIPTOR, 53 | ) 54 | ], 55 | extensions=[], 56 | nested_types=[], 57 | enum_types=[], 58 | serialized_options=None, 59 | is_extendable=False, 60 | syntax="proto3", 61 | extension_ranges=[], 62 | oneofs=[], 63 | serialized_start=96, 64 | serialized_end=138, 65 | ) 66 | 67 | DESCRIPTOR.message_types_by_name[ 68 | "UnityRLInitializationInput" 69 | ] = _UNITYRLINITIALIZATIONINPUT 70 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 71 | 72 | UnityRLInitializationInput = _reflection.GeneratedProtocolMessageType( 73 | "UnityRLInitializationInput", 74 | (_message.Message,), 75 | dict( 76 | DESCRIPTOR=_UNITYRLINITIALIZATIONINPUT, 77 | __module__="mlagents.envs.communicator_objects.unity_rl_initialization_input_pb2" 78 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationInput) 79 | ), 80 | ) 81 | _sym_db.RegisterMessage(UnityRLInitializationInput) 82 | 83 | 84 | DESCRIPTOR._options = None 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_rl_initialization_output_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_rl_initialization_output.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | from mlagents.envs.communicator_objects import ( 19 | brain_parameters_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_brain__parameters__proto__pb2, 20 | ) 21 | from mlagents.envs.communicator_objects import ( 22 | environment_parameters_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2, 23 | ) 24 | 25 | 26 | DESCRIPTOR = _descriptor.FileDescriptor( 27 | name="mlagents/envs/communicator_objects/unity_rl_initialization_output.proto", 28 | package="communicator_objects", 29 | syntax="proto3", 30 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 31 | serialized_pb=_b( 32 | '\nGmlagents/envs/communicator_objects/unity_rl_initialization_output.proto\x12\x14\x63ommunicator_objects\x1a?mlagents/envs/communicator_objects/brain_parameters_proto.proto\x1a\x45mlagents/envs/communicator_objects/environment_parameters_proto.proto"\xe6\x01\n\x1bUnityRLInitializationOutput\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x10\n\x08log_path\x18\x03 \x01(\t\x12\x44\n\x10\x62rain_parameters\x18\x05 \x03(\x0b\x32*.communicator_objects.BrainParametersProto\x12P\n\x16\x65nvironment_parameters\x18\x06 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 33 | ), 34 | dependencies=[ 35 | mlagents_dot_envs_dot_communicator__objects_dot_brain__parameters__proto__pb2.DESCRIPTOR, 36 | mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR, 37 | ], 38 | ) 39 | 40 | 41 | _UNITYRLINITIALIZATIONOUTPUT = _descriptor.Descriptor( 42 | name="UnityRLInitializationOutput", 43 | full_name="communicator_objects.UnityRLInitializationOutput", 44 | filename=None, 45 | file=DESCRIPTOR, 46 | containing_type=None, 47 | fields=[ 48 | _descriptor.FieldDescriptor( 49 | name="name", 50 | full_name="communicator_objects.UnityRLInitializationOutput.name", 51 | index=0, 52 | number=1, 53 | type=9, 54 | cpp_type=9, 55 | label=1, 56 | has_default_value=False, 57 | default_value=_b("").decode("utf-8"), 58 | message_type=None, 59 | enum_type=None, 60 | containing_type=None, 61 | is_extension=False, 62 | extension_scope=None, 63 | serialized_options=None, 64 | file=DESCRIPTOR, 65 | ), 66 | _descriptor.FieldDescriptor( 67 | name="version", 68 | full_name="communicator_objects.UnityRLInitializationOutput.version", 69 | index=1, 70 | number=2, 71 | type=9, 72 | cpp_type=9, 73 | label=1, 74 | has_default_value=False, 75 | default_value=_b("").decode("utf-8"), 76 | message_type=None, 77 | enum_type=None, 78 | containing_type=None, 79 | is_extension=False, 80 | extension_scope=None, 81 | serialized_options=None, 82 | file=DESCRIPTOR, 83 | ), 84 | _descriptor.FieldDescriptor( 85 | name="log_path", 86 | full_name="communicator_objects.UnityRLInitializationOutput.log_path", 87 | index=2, 88 | number=3, 89 | type=9, 90 | cpp_type=9, 91 | label=1, 92 | has_default_value=False, 93 | default_value=_b("").decode("utf-8"), 94 | message_type=None, 95 | enum_type=None, 96 | containing_type=None, 97 | is_extension=False, 98 | extension_scope=None, 99 | serialized_options=None, 100 | file=DESCRIPTOR, 101 | ), 102 | _descriptor.FieldDescriptor( 103 | name="brain_parameters", 104 | full_name="communicator_objects.UnityRLInitializationOutput.brain_parameters", 105 | index=3, 106 | number=5, 107 | type=11, 108 | cpp_type=10, 109 | label=3, 110 | has_default_value=False, 111 | default_value=[], 112 | message_type=None, 113 | enum_type=None, 114 | containing_type=None, 115 | is_extension=False, 116 | extension_scope=None, 117 | serialized_options=None, 118 | file=DESCRIPTOR, 119 | ), 120 | _descriptor.FieldDescriptor( 121 | name="environment_parameters", 122 | full_name="communicator_objects.UnityRLInitializationOutput.environment_parameters", 123 | index=4, 124 | number=6, 125 | type=11, 126 | cpp_type=10, 127 | label=1, 128 | has_default_value=False, 129 | default_value=None, 130 | message_type=None, 131 | enum_type=None, 132 | containing_type=None, 133 | is_extension=False, 134 | extension_scope=None, 135 | serialized_options=None, 136 | file=DESCRIPTOR, 137 | ), 138 | ], 139 | extensions=[], 140 | nested_types=[], 141 | enum_types=[], 142 | serialized_options=None, 143 | is_extendable=False, 144 | syntax="proto3", 145 | extension_ranges=[], 146 | oneofs=[], 147 | serialized_start=234, 148 | serialized_end=464, 149 | ) 150 | 151 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name[ 152 | "brain_parameters" 153 | ].message_type = ( 154 | mlagents_dot_envs_dot_communicator__objects_dot_brain__parameters__proto__pb2._BRAINPARAMETERSPROTO 155 | ) 156 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name[ 157 | "environment_parameters" 158 | ].message_type = ( 159 | mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO 160 | ) 161 | DESCRIPTOR.message_types_by_name[ 162 | "UnityRLInitializationOutput" 163 | ] = _UNITYRLINITIALIZATIONOUTPUT 164 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 165 | 166 | UnityRLInitializationOutput = _reflection.GeneratedProtocolMessageType( 167 | "UnityRLInitializationOutput", 168 | (_message.Message,), 169 | dict( 170 | DESCRIPTOR=_UNITYRLINITIALIZATIONOUTPUT, 171 | __module__="mlagents.envs.communicator_objects.unity_rl_initialization_output_pb2" 172 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationOutput) 173 | ), 174 | ) 175 | _sym_db.RegisterMessage(UnityRLInitializationOutput) 176 | 177 | 178 | DESCRIPTOR._options = None 179 | # @@protoc_insertion_point(module_scope) 180 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_rl_output_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_rl_output.proto 4 | 5 | import sys 6 | 7 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | 13 | # @@protoc_insertion_point(imports) 14 | 15 | _sym_db = _symbol_database.Default() 16 | 17 | 18 | from mlagents.envs.communicator_objects import ( 19 | agent_info_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_agent__info__proto__pb2, 20 | ) 21 | 22 | 23 | DESCRIPTOR = _descriptor.FileDescriptor( 24 | name="mlagents/envs/communicator_objects/unity_rl_output.proto", 25 | package="communicator_objects", 26 | syntax="proto3", 27 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 28 | serialized_pb=_b( 29 | '\n8mlagents/envs/communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a\x39mlagents/envs/communicator_objects/agent_info_proto.proto"\xa3\x02\n\rUnityRLOutput\x12\x13\n\x0bglobal_done\x18\x01 \x01(\x08\x12G\n\nagentInfos\x18\x02 \x03(\x0b\x32\x33.communicator_objects.UnityRLOutput.AgentInfosEntry\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1ai\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x45\n\x05value\x18\x02 \x01(\x0b\x32\x36.communicator_objects.UnityRLOutput.ListAgentInfoProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 30 | ), 31 | dependencies=[ 32 | mlagents_dot_envs_dot_communicator__objects_dot_agent__info__proto__pb2.DESCRIPTOR 33 | ], 34 | ) 35 | 36 | 37 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO = _descriptor.Descriptor( 38 | name="ListAgentInfoProto", 39 | full_name="communicator_objects.UnityRLOutput.ListAgentInfoProto", 40 | filename=None, 41 | file=DESCRIPTOR, 42 | containing_type=None, 43 | fields=[ 44 | _descriptor.FieldDescriptor( 45 | name="value", 46 | full_name="communicator_objects.UnityRLOutput.ListAgentInfoProto.value", 47 | index=0, 48 | number=1, 49 | type=11, 50 | cpp_type=10, 51 | label=3, 52 | has_default_value=False, 53 | default_value=[], 54 | message_type=None, 55 | enum_type=None, 56 | containing_type=None, 57 | is_extension=False, 58 | extension_scope=None, 59 | serialized_options=None, 60 | file=DESCRIPTOR, 61 | ) 62 | ], 63 | extensions=[], 64 | nested_types=[], 65 | enum_types=[], 66 | serialized_options=None, 67 | is_extendable=False, 68 | syntax="proto3", 69 | extension_ranges=[], 70 | oneofs=[], 71 | serialized_start=253, 72 | serialized_end=326, 73 | ) 74 | 75 | _UNITYRLOUTPUT_AGENTINFOSENTRY = _descriptor.Descriptor( 76 | name="AgentInfosEntry", 77 | full_name="communicator_objects.UnityRLOutput.AgentInfosEntry", 78 | filename=None, 79 | file=DESCRIPTOR, 80 | containing_type=None, 81 | fields=[ 82 | _descriptor.FieldDescriptor( 83 | name="key", 84 | full_name="communicator_objects.UnityRLOutput.AgentInfosEntry.key", 85 | index=0, 86 | number=1, 87 | type=9, 88 | cpp_type=9, 89 | label=1, 90 | has_default_value=False, 91 | default_value=_b("").decode("utf-8"), 92 | message_type=None, 93 | enum_type=None, 94 | containing_type=None, 95 | is_extension=False, 96 | extension_scope=None, 97 | serialized_options=None, 98 | file=DESCRIPTOR, 99 | ), 100 | _descriptor.FieldDescriptor( 101 | name="value", 102 | full_name="communicator_objects.UnityRLOutput.AgentInfosEntry.value", 103 | index=1, 104 | number=2, 105 | type=11, 106 | cpp_type=10, 107 | label=1, 108 | has_default_value=False, 109 | default_value=None, 110 | message_type=None, 111 | enum_type=None, 112 | containing_type=None, 113 | is_extension=False, 114 | extension_scope=None, 115 | serialized_options=None, 116 | file=DESCRIPTOR, 117 | ), 118 | ], 119 | extensions=[], 120 | nested_types=[], 121 | enum_types=[], 122 | serialized_options=_b("8\001"), 123 | is_extendable=False, 124 | syntax="proto3", 125 | extension_ranges=[], 126 | oneofs=[], 127 | serialized_start=328, 128 | serialized_end=433, 129 | ) 130 | 131 | _UNITYRLOUTPUT = _descriptor.Descriptor( 132 | name="UnityRLOutput", 133 | full_name="communicator_objects.UnityRLOutput", 134 | filename=None, 135 | file=DESCRIPTOR, 136 | containing_type=None, 137 | fields=[ 138 | _descriptor.FieldDescriptor( 139 | name="global_done", 140 | full_name="communicator_objects.UnityRLOutput.global_done", 141 | index=0, 142 | number=1, 143 | type=8, 144 | cpp_type=7, 145 | label=1, 146 | has_default_value=False, 147 | default_value=False, 148 | message_type=None, 149 | enum_type=None, 150 | containing_type=None, 151 | is_extension=False, 152 | extension_scope=None, 153 | serialized_options=None, 154 | file=DESCRIPTOR, 155 | ), 156 | _descriptor.FieldDescriptor( 157 | name="agentInfos", 158 | full_name="communicator_objects.UnityRLOutput.agentInfos", 159 | index=1, 160 | number=2, 161 | type=11, 162 | cpp_type=10, 163 | label=3, 164 | has_default_value=False, 165 | default_value=[], 166 | message_type=None, 167 | enum_type=None, 168 | containing_type=None, 169 | is_extension=False, 170 | extension_scope=None, 171 | serialized_options=None, 172 | file=DESCRIPTOR, 173 | ), 174 | ], 175 | extensions=[], 176 | nested_types=[_UNITYRLOUTPUT_LISTAGENTINFOPROTO, _UNITYRLOUTPUT_AGENTINFOSENTRY], 177 | enum_types=[], 178 | serialized_options=None, 179 | is_extendable=False, 180 | syntax="proto3", 181 | extension_ranges=[], 182 | oneofs=[], 183 | serialized_start=142, 184 | serialized_end=433, 185 | ) 186 | 187 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.fields_by_name[ 188 | "value" 189 | ].message_type = ( 190 | mlagents_dot_envs_dot_communicator__objects_dot_agent__info__proto__pb2._AGENTINFOPROTO 191 | ) 192 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.containing_type = _UNITYRLOUTPUT 193 | _UNITYRLOUTPUT_AGENTINFOSENTRY.fields_by_name[ 194 | "value" 195 | ].message_type = _UNITYRLOUTPUT_LISTAGENTINFOPROTO 196 | _UNITYRLOUTPUT_AGENTINFOSENTRY.containing_type = _UNITYRLOUTPUT 197 | _UNITYRLOUTPUT.fields_by_name[ 198 | "agentInfos" 199 | ].message_type = _UNITYRLOUTPUT_AGENTINFOSENTRY 200 | DESCRIPTOR.message_types_by_name["UnityRLOutput"] = _UNITYRLOUTPUT 201 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 202 | 203 | UnityRLOutput = _reflection.GeneratedProtocolMessageType( 204 | "UnityRLOutput", 205 | (_message.Message,), 206 | dict( 207 | ListAgentInfoProto=_reflection.GeneratedProtocolMessageType( 208 | "ListAgentInfoProto", 209 | (_message.Message,), 210 | dict( 211 | DESCRIPTOR=_UNITYRLOUTPUT_LISTAGENTINFOPROTO, 212 | __module__="mlagents.envs.communicator_objects.unity_rl_output_pb2" 213 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.ListAgentInfoProto) 214 | ), 215 | ), 216 | AgentInfosEntry=_reflection.GeneratedProtocolMessageType( 217 | "AgentInfosEntry", 218 | (_message.Message,), 219 | dict( 220 | DESCRIPTOR=_UNITYRLOUTPUT_AGENTINFOSENTRY, 221 | __module__="mlagents.envs.communicator_objects.unity_rl_output_pb2" 222 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.AgentInfosEntry) 223 | ), 224 | ), 225 | DESCRIPTOR=_UNITYRLOUTPUT, 226 | __module__="mlagents.envs.communicator_objects.unity_rl_output_pb2" 227 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput) 228 | ), 229 | ) 230 | _sym_db.RegisterMessage(UnityRLOutput) 231 | _sym_db.RegisterMessage(UnityRLOutput.ListAgentInfoProto) 232 | _sym_db.RegisterMessage(UnityRLOutput.AgentInfosEntry) 233 | 234 | 235 | DESCRIPTOR._options = None 236 | _UNITYRLOUTPUT_AGENTINFOSENTRY._options = None 237 | # @@protoc_insertion_point(module_scope) 238 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_to_external_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/unity_to_external.proto 3 | 4 | import sys 5 | 6 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | from mlagents.envs.communicator_objects import ( 18 | unity_message_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2, 19 | ) 20 | 21 | 22 | DESCRIPTOR = _descriptor.FileDescriptor( 23 | name="mlagents/envs/communicator_objects/unity_to_external.proto", 24 | package="communicator_objects", 25 | syntax="proto3", 26 | serialized_options=_b("\252\002\034MLAgents.CommunicatorObjects"), 27 | serialized_pb=_b( 28 | '\n:mlagents/envs/communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents/envs/communicator_objects/unity_message.proto2g\n\x0fUnityToExternal\x12T\n\x08\x45xchange\x12".communicator_objects.UnityMessage\x1a".communicator_objects.UnityMessage"\x00\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3' 29 | ), 30 | dependencies=[ 31 | mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.DESCRIPTOR 32 | ], 33 | ) 34 | 35 | 36 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 37 | 38 | 39 | DESCRIPTOR._options = None 40 | 41 | _UNITYTOEXTERNAL = _descriptor.ServiceDescriptor( 42 | name="UnityToExternal", 43 | full_name="communicator_objects.UnityToExternal", 44 | file=DESCRIPTOR, 45 | index=0, 46 | serialized_options=None, 47 | serialized_start=140, 48 | serialized_end=243, 49 | methods=[ 50 | _descriptor.MethodDescriptor( 51 | name="Exchange", 52 | full_name="communicator_objects.UnityToExternal.Exchange", 53 | index=0, 54 | containing_service=None, 55 | input_type=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 56 | output_type=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 57 | serialized_options=None, 58 | ) 59 | ], 60 | ) 61 | _sym_db.RegisterServiceDescriptor(_UNITYTOEXTERNAL) 62 | 63 | DESCRIPTOR.services_by_name["UnityToExternal"] = _UNITYTOEXTERNAL 64 | 65 | # @@protoc_insertion_point(module_scope) 66 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_to_external_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | import grpc 3 | 4 | from mlagents.envs.communicator_objects import ( 5 | unity_message_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2, 6 | ) 7 | 8 | 9 | class UnityToExternalStub(object): 10 | # missing associated documentation comment in .proto file 11 | pass 12 | 13 | def __init__(self, channel): 14 | """Constructor. 15 | 16 | Args: 17 | channel: A grpc.Channel. 18 | """ 19 | self.Exchange = channel.unary_unary( 20 | "/communicator_objects.UnityToExternal/Exchange", 21 | request_serializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 22 | response_deserializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 23 | ) 24 | 25 | 26 | class UnityToExternalServicer(object): 27 | # missing associated documentation comment in .proto file 28 | pass 29 | 30 | def Exchange(self, request, context): 31 | """Sends the academy parameters 32 | """ 33 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 34 | context.set_details("Method not implemented!") 35 | raise NotImplementedError("Method not implemented!") 36 | 37 | 38 | def add_UnityToExternalServicer_to_server(servicer, server): 39 | rpc_method_handlers = { 40 | "Exchange": grpc.unary_unary_rpc_method_handler( 41 | servicer.Exchange, 42 | request_deserializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 43 | response_serializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 44 | ) 45 | } 46 | generic_handler = grpc.method_handlers_generic_handler( 47 | "communicator_objects.UnityToExternal", rpc_method_handlers 48 | ) 49 | server.add_generic_rpc_handlers((generic_handler,)) 50 | -------------------------------------------------------------------------------- /mlagents/envs/exception.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger("mlagents.envs") 4 | 5 | 6 | class UnityException(Exception): 7 | """ 8 | Any error related to ml-agents environment. 9 | """ 10 | 11 | pass 12 | 13 | 14 | class UnityEnvironmentException(UnityException): 15 | """ 16 | Related to errors starting and closing environment. 17 | """ 18 | 19 | pass 20 | 21 | 22 | class UnityActionException(UnityException): 23 | """ 24 | Related to errors with sending actions. 25 | """ 26 | 27 | pass 28 | 29 | 30 | class UnityTimeOutException(UnityException): 31 | """ 32 | Related to errors with communication timeouts. 33 | """ 34 | 35 | def __init__(self, message, log_file_path=None): 36 | if log_file_path is not None: 37 | try: 38 | with open(log_file_path, "r") as f: 39 | printing = False 40 | unity_error = "\n" 41 | for l in f: 42 | l = l.strip() 43 | if (l == "Exception") or (l == "Error"): 44 | printing = True 45 | unity_error += "----------------------\n" 46 | if l == "": 47 | printing = False 48 | if printing: 49 | unity_error += l + "\n" 50 | logger.info(unity_error) 51 | logger.error( 52 | "An error might have occured in the environment. " 53 | "You can check the logfile for more information at {}".format( 54 | log_file_path 55 | ) 56 | ) 57 | except: 58 | logger.error( 59 | "An error might have occured in the environment. " 60 | "No UnitySDK.log file could be found." 61 | ) 62 | super(UnityTimeOutException, self).__init__(message) 63 | 64 | 65 | class UnityWorkerInUseException(UnityException): 66 | """ 67 | This error occurs when the port for a certain worker ID is already reserved. 68 | """ 69 | 70 | MESSAGE_TEMPLATE = ( 71 | "Couldn't start socket communication because worker number {} is still in use. " 72 | "You may need to manually close a previously opened environment " 73 | "or use a different worker number." 74 | ) 75 | 76 | def __init__(self, worker_id): 77 | message = self.MESSAGE_TEMPLATE.format(str(worker_id)) 78 | super(UnityWorkerInUseException, self).__init__(message) 79 | -------------------------------------------------------------------------------- /mlagents/envs/mock_communicator.py: -------------------------------------------------------------------------------- 1 | from .communicator import Communicator 2 | from .communicator_objects import ( 3 | UnityOutput, 4 | UnityInput, 5 | ResolutionProto, 6 | BrainParametersProto, 7 | UnityRLInitializationOutput, 8 | AgentInfoProto, 9 | UnityRLOutput, 10 | ) 11 | 12 | 13 | class MockCommunicator(Communicator): 14 | def __init__( 15 | self, 16 | discrete_action=False, 17 | visual_inputs=0, 18 | stack=True, 19 | num_agents=3, 20 | brain_name="RealFakeBrain", 21 | vec_obs_size=3, 22 | ): 23 | """ 24 | Python side of the grpc communication. Python is the client and Unity the server 25 | 26 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 27 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 28 | """ 29 | self.is_discrete = discrete_action 30 | self.steps = 0 31 | self.visual_inputs = visual_inputs 32 | self.has_been_closed = False 33 | self.num_agents = num_agents 34 | self.brain_name = brain_name 35 | self.vec_obs_size = vec_obs_size 36 | if stack: 37 | self.num_stacks = 2 38 | else: 39 | self.num_stacks = 1 40 | 41 | def initialize(self, inputs: UnityInput) -> UnityOutput: 42 | resolutions = [ 43 | ResolutionProto(width=30, height=40, gray_scale=False) 44 | for i in range(self.visual_inputs) 45 | ] 46 | bp = BrainParametersProto( 47 | vector_observation_size=self.vec_obs_size, 48 | num_stacked_vector_observations=self.num_stacks, 49 | vector_action_size=[2], 50 | camera_resolutions=resolutions, 51 | vector_action_descriptions=["", ""], 52 | vector_action_space_type=int(not self.is_discrete), 53 | brain_name=self.brain_name, 54 | is_training=True, 55 | ) 56 | rl_init = UnityRLInitializationOutput( 57 | name="RealFakeAcademy", version="API-8", log_path="", brain_parameters=[bp] 58 | ) 59 | return UnityOutput(rl_initialization_output=rl_init) 60 | 61 | def exchange(self, inputs: UnityInput) -> UnityOutput: 62 | dict_agent_info = {} 63 | if self.is_discrete: 64 | vector_action = [1] 65 | else: 66 | vector_action = [1, 2] 67 | list_agent_info = [] 68 | if self.num_stacks == 1: 69 | observation = [1, 2, 3] 70 | else: 71 | observation = [1, 2, 3, 1, 2, 3] 72 | 73 | for i in range(self.num_agents): 74 | list_agent_info.append( 75 | AgentInfoProto( 76 | stacked_vector_observation=observation, 77 | reward=1, 78 | stored_vector_actions=vector_action, 79 | stored_text_actions="", 80 | text_observation="", 81 | memories=[], 82 | done=(i == 2), 83 | max_step_reached=False, 84 | id=i, 85 | ) 86 | ) 87 | dict_agent_info["RealFakeBrain"] = UnityRLOutput.ListAgentInfoProto( 88 | value=list_agent_info 89 | ) 90 | global_done = False 91 | try: 92 | fake_brain = inputs.rl_input.agent_actions["RealFakeBrain"] 93 | global_done = fake_brain.value[0].vector_actions[0] == -1 94 | except: 95 | pass 96 | result = UnityRLOutput(global_done=global_done, agentInfos=dict_agent_info) 97 | return UnityOutput(rl_output=result) 98 | 99 | def close(self): 100 | """ 101 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 102 | """ 103 | self.has_been_closed = True 104 | -------------------------------------------------------------------------------- /mlagents/envs/rpc_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import grpc 3 | 4 | import socket 5 | from multiprocessing import Pipe 6 | from concurrent.futures import ThreadPoolExecutor 7 | 8 | from .communicator import Communicator 9 | from .communicator_objects import ( 10 | UnityToExternalServicer, 11 | add_UnityToExternalServicer_to_server, 12 | ) 13 | from .communicator_objects import UnityMessage, UnityInput, UnityOutput 14 | from .exception import UnityTimeOutException, UnityWorkerInUseException 15 | 16 | logger = logging.getLogger("mlagents.envs") 17 | 18 | 19 | class UnityToExternalServicerImplementation(UnityToExternalServicer): 20 | def __init__(self): 21 | self.parent_conn, self.child_conn = Pipe() 22 | 23 | def Initialize(self, request, context): 24 | self.child_conn.send(request) 25 | return self.child_conn.recv() 26 | 27 | def Exchange(self, request, context): 28 | self.child_conn.send(request) 29 | return self.child_conn.recv() 30 | 31 | 32 | class RpcCommunicator(Communicator): 33 | def __init__(self, worker_id=0, base_port=5005, timeout_wait=30): 34 | """ 35 | Python side of the grpc communication. Python is the server and Unity the client 36 | 37 | 38 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 39 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 40 | """ 41 | self.port = base_port + worker_id 42 | self.worker_id = worker_id 43 | self.timeout_wait = timeout_wait 44 | self.server = None 45 | self.unity_to_external = None 46 | self.is_open = False 47 | self.create_server() 48 | 49 | def create_server(self): 50 | """ 51 | Creates the GRPC server. 52 | """ 53 | self.check_port(self.port) 54 | 55 | try: 56 | # Establish communication grpc 57 | self.server = grpc.server(ThreadPoolExecutor(max_workers=10)) 58 | self.unity_to_external = UnityToExternalServicerImplementation() 59 | add_UnityToExternalServicer_to_server(self.unity_to_external, self.server) 60 | # Using unspecified address, which means that grpc is communicating on all IPs 61 | # This is so that the docker container can connect. 62 | self.server.add_insecure_port("[::]:" + str(self.port)) 63 | self.server.start() 64 | self.is_open = True 65 | except: 66 | raise UnityWorkerInUseException(self.worker_id) 67 | 68 | def check_port(self, port): 69 | """ 70 | Attempts to bind to the requested communicator port, checking if it is already in use. 71 | """ 72 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 73 | try: 74 | s.bind(("localhost", port)) 75 | except socket.error: 76 | raise UnityWorkerInUseException(self.worker_id) 77 | finally: 78 | s.close() 79 | 80 | def initialize(self, inputs: UnityInput) -> UnityOutput: 81 | if not self.unity_to_external.parent_conn.poll(self.timeout_wait): 82 | raise UnityTimeOutException( 83 | "The Unity environment took too long to respond. Make sure that :\n" 84 | "\t The environment does not need user interaction to launch\n" 85 | "\t The Academy's Broadcast Hub is configured correctly\n" 86 | "\t The Agents are linked to the appropriate Brains\n" 87 | "\t The environment and the Python interface have compatible versions." 88 | ) 89 | aca_param = self.unity_to_external.parent_conn.recv().unity_output 90 | message = UnityMessage() 91 | message.header.status = 200 92 | message.unity_input.CopyFrom(inputs) 93 | self.unity_to_external.parent_conn.send(message) 94 | self.unity_to_external.parent_conn.recv() 95 | return aca_param 96 | 97 | def exchange(self, inputs: UnityInput) -> UnityOutput: 98 | message = UnityMessage() 99 | message.header.status = 200 100 | message.unity_input.CopyFrom(inputs) 101 | self.unity_to_external.parent_conn.send(message) 102 | output = self.unity_to_external.parent_conn.recv() 103 | if output.header.status != 200: 104 | return None 105 | return output.unity_output 106 | 107 | def close(self): 108 | """ 109 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 110 | """ 111 | if self.is_open: 112 | message_input = UnityMessage() 113 | message_input.header.status = 400 114 | self.unity_to_external.parent_conn.send(message_input) 115 | self.unity_to_external.parent_conn.close() 116 | self.server.stop(False) 117 | self.is_open = False 118 | -------------------------------------------------------------------------------- /mlagents/envs/socket_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import socket 3 | import struct 4 | 5 | from .communicator import Communicator 6 | from .communicator_objects import UnityMessage, UnityOutput, UnityInput 7 | from .exception import UnityTimeOutException 8 | 9 | 10 | logger = logging.getLogger("mlagents.envs") 11 | 12 | 13 | class SocketCommunicator(Communicator): 14 | def __init__(self, worker_id=0, base_port=5005): 15 | """ 16 | Python side of the socket communication 17 | 18 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 19 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 20 | """ 21 | 22 | self.port = base_port + worker_id 23 | self._buffer_size = 12000 24 | self.worker_id = worker_id 25 | self._socket = None 26 | self._conn = None 27 | 28 | def initialize(self, inputs: UnityInput) -> UnityOutput: 29 | try: 30 | # Establish communication socket 31 | self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 32 | self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 33 | self._socket.bind(("localhost", self.port)) 34 | except: 35 | raise UnityTimeOutException( 36 | "Couldn't start socket communication because worker number {} is still in use. " 37 | "You may need to manually close a previously opened environment " 38 | "or use a different worker number.".format(str(self.worker_id)) 39 | ) 40 | try: 41 | self._socket.settimeout(30) 42 | self._socket.listen(1) 43 | self._conn, _ = self._socket.accept() 44 | self._conn.settimeout(30) 45 | except: 46 | raise UnityTimeOutException( 47 | "The Unity environment took too long to respond. Make sure that :\n" 48 | "\t The environment does not need user interaction to launch\n" 49 | "\t The Academy's Broadcast Hub is configured correctly\n" 50 | "\t The Agents are linked to the appropriate Brains\n" 51 | "\t The environment and the Python interface have compatible versions." 52 | ) 53 | message = UnityMessage() 54 | message.header.status = 200 55 | message.unity_input.CopyFrom(inputs) 56 | self._communicator_send(message.SerializeToString()) 57 | initialization_output = UnityMessage() 58 | initialization_output.ParseFromString(self._communicator_receive()) 59 | return initialization_output.unity_output 60 | 61 | def _communicator_receive(self): 62 | try: 63 | s = self._conn.recv(self._buffer_size) 64 | message_length = struct.unpack("I", bytearray(s[:4]))[0] 65 | s = s[4:] 66 | while len(s) != message_length: 67 | s += self._conn.recv(self._buffer_size) 68 | except socket.timeout as e: 69 | raise UnityTimeOutException("The environment took too long to respond.") 70 | return s 71 | 72 | def _communicator_send(self, message): 73 | self._conn.send(struct.pack("I", len(message)) + message) 74 | 75 | def exchange(self, inputs: UnityInput) -> UnityOutput: 76 | message = UnityMessage() 77 | message.header.status = 200 78 | message.unity_input.CopyFrom(inputs) 79 | self._communicator_send(message.SerializeToString()) 80 | outputs = UnityMessage() 81 | outputs.ParseFromString(self._communicator_receive()) 82 | if outputs.header.status != 200: 83 | return None 84 | return outputs.unity_output 85 | 86 | def close(self): 87 | """ 88 | Sends a shutdown signal to the unity environment, and closes the socket connection. 89 | """ 90 | if self._socket is not None and self._conn is not None: 91 | message_input = UnityMessage() 92 | message_input.header.status = 400 93 | self._communicator_send(message_input.SerializeToString()) 94 | if self._socket is not None: 95 | self._socket.close() 96 | self._socket = None 97 | if self._socket is not None: 98 | self._conn.close() 99 | self._conn = None 100 | -------------------------------------------------------------------------------- /mlagents/envs/subprocess_environment.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import copy 3 | import numpy as np 4 | import cloudpickle 5 | 6 | from mlagents.envs import UnityEnvironment 7 | from multiprocessing import Process, Pipe 8 | from multiprocessing.connection import Connection 9 | from mlagents.envs.base_unity_environment import BaseUnityEnvironment 10 | from mlagents.envs import AllBrainInfo, UnityEnvironmentException 11 | 12 | 13 | class EnvironmentCommand(NamedTuple): 14 | name: str 15 | payload: Any = None 16 | 17 | 18 | class EnvironmentResponse(NamedTuple): 19 | name: str 20 | worker_id: int 21 | payload: Any 22 | 23 | 24 | class UnityEnvWorker(NamedTuple): 25 | process: Process 26 | worker_id: int 27 | conn: Connection 28 | 29 | def send(self, name: str, payload=None): 30 | try: 31 | cmd = EnvironmentCommand(name, payload) 32 | self.conn.send(cmd) 33 | except (BrokenPipeError, EOFError): 34 | raise KeyboardInterrupt 35 | 36 | def recv(self) -> EnvironmentResponse: 37 | try: 38 | response: EnvironmentResponse = self.conn.recv() 39 | return response 40 | except (BrokenPipeError, EOFError): 41 | raise KeyboardInterrupt 42 | 43 | def close(self): 44 | try: 45 | self.conn.send(EnvironmentCommand("close")) 46 | except (BrokenPipeError, EOFError): 47 | pass 48 | self.process.join() 49 | 50 | 51 | def worker(parent_conn: Connection, pickled_env_factory: str, worker_id: int): 52 | env_factory: Callable[[int], UnityEnvironment] = cloudpickle.loads( 53 | pickled_env_factory 54 | ) 55 | env = env_factory(worker_id) 56 | 57 | def _send_response(cmd_name, payload): 58 | parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload)) 59 | 60 | try: 61 | while True: 62 | cmd: EnvironmentCommand = parent_conn.recv() 63 | if cmd.name == "step": 64 | vector_action, memory, text_action, value = cmd.payload 65 | if env.global_done: 66 | all_brain_info = env.reset() 67 | else: 68 | all_brain_info = env.step(vector_action, memory, text_action, value) 69 | _send_response("step", all_brain_info) 70 | elif cmd.name == "external_brains": 71 | _send_response("external_brains", env.external_brains) 72 | elif cmd.name == "reset_parameters": 73 | _send_response("reset_parameters", env.reset_parameters) 74 | elif cmd.name == "reset": 75 | all_brain_info = env.reset(cmd.payload[0], cmd.payload[1]) 76 | _send_response("reset", all_brain_info) 77 | elif cmd.name == "global_done": 78 | _send_response("global_done", env.global_done) 79 | elif cmd.name == "close": 80 | break 81 | except KeyboardInterrupt: 82 | print("UnityEnvironment worker: keyboard interrupt") 83 | finally: 84 | env.close() 85 | 86 | 87 | class SubprocessUnityEnvironment(BaseUnityEnvironment): 88 | def __init__( 89 | self, env_factory: Callable[[int], BaseUnityEnvironment], n_env: int = 1 90 | ): 91 | self.envs = [] 92 | self.env_agent_counts = {} 93 | self.waiting = False 94 | for worker_id in range(n_env): 95 | self.envs.append(self.create_worker(worker_id, env_factory)) 96 | 97 | @staticmethod 98 | def create_worker( 99 | worker_id: int, env_factory: Callable[[int], BaseUnityEnvironment] 100 | ) -> UnityEnvWorker: 101 | parent_conn, child_conn = Pipe() 102 | 103 | # Need to use cloudpickle for the env factory function since function objects aren't picklable 104 | # on Windows as of Python 3.6. 105 | pickled_env_factory = cloudpickle.dumps(env_factory) 106 | child_process = Process( 107 | target=worker, args=(child_conn, pickled_env_factory, worker_id) 108 | ) 109 | child_process.start() 110 | return UnityEnvWorker(child_process, worker_id, parent_conn) 111 | 112 | def step_async( 113 | self, vector_action, memory=None, text_action=None, value=None 114 | ) -> None: 115 | if self.waiting: 116 | raise UnityEnvironmentException( 117 | "Tried to take an environment step bore previous step has completed." 118 | ) 119 | 120 | agent_counts_cum = {} 121 | for brain_name in self.env_agent_counts.keys(): 122 | agent_counts_cum[brain_name] = np.cumsum(self.env_agent_counts[brain_name]) 123 | 124 | # Split the actions provided by the previous set of agent counts, and send the step 125 | # commands to the workers. 126 | for worker_id, env in enumerate(self.envs): 127 | env_actions = {} 128 | env_memory = {} 129 | env_text_action = {} 130 | env_value = {} 131 | for brain_name in self.env_agent_counts.keys(): 132 | start_ind = 0 133 | if worker_id > 0: 134 | start_ind = agent_counts_cum[brain_name][worker_id - 1] 135 | end_ind = agent_counts_cum[brain_name][worker_id] 136 | if vector_action.get(brain_name) is not None: 137 | env_actions[brain_name] = vector_action[brain_name][ 138 | start_ind:end_ind 139 | ] 140 | if memory and memory.get(brain_name) is not None: 141 | env_memory[brain_name] = memory[brain_name][start_ind:end_ind] 142 | if text_action and text_action.get(brain_name) is not None: 143 | env_text_action[brain_name] = text_action[brain_name][ 144 | start_ind:end_ind 145 | ] 146 | if value and value.get(brain_name) is not None: 147 | env_value[brain_name] = value[brain_name][start_ind:end_ind] 148 | 149 | env.send("step", (env_actions, env_memory, env_text_action, env_value)) 150 | self.waiting = True 151 | 152 | def step_await(self) -> AllBrainInfo: 153 | if not self.waiting: 154 | raise UnityEnvironmentException( 155 | "Tried to await an environment step, but no async step was taken." 156 | ) 157 | 158 | steps = [self.envs[i].recv() for i in range(len(self.envs))] 159 | self._get_agent_counts(map(lambda s: s.payload, steps)) 160 | combined_brain_info = self._merge_step_info(steps) 161 | self.waiting = False 162 | return combined_brain_info 163 | 164 | def step( 165 | self, vector_action=None, memory=None, text_action=None, value=None 166 | ) -> AllBrainInfo: 167 | self.step_async(vector_action, memory, text_action, value) 168 | return self.step_await() 169 | 170 | def reset(self, config=None, train_mode=True) -> AllBrainInfo: 171 | self._broadcast_message("reset", (config, train_mode)) 172 | reset_results = [self.envs[i].recv() for i in range(len(self.envs))] 173 | self._get_agent_counts(map(lambda r: r.payload, reset_results)) 174 | 175 | return self._merge_step_info(reset_results) 176 | 177 | @property 178 | def global_done(self): 179 | self._broadcast_message("global_done") 180 | dones: List[EnvironmentResponse] = [ 181 | self.envs[i].recv().payload for i in range(len(self.envs)) 182 | ] 183 | return all(dones) 184 | 185 | @property 186 | def external_brains(self): 187 | self.envs[0].send("external_brains") 188 | return self.envs[0].recv().payload 189 | 190 | @property 191 | def reset_parameters(self): 192 | self.envs[0].send("reset_parameters") 193 | return self.envs[0].recv().payload 194 | 195 | def close(self): 196 | for env in self.envs: 197 | env.close() 198 | 199 | def _get_agent_counts(self, step_list: Iterable[AllBrainInfo]): 200 | for i, step in enumerate(step_list): 201 | for brain_name, brain_info in step.items(): 202 | if brain_name not in self.env_agent_counts.keys(): 203 | self.env_agent_counts[brain_name] = [0] * len(self.envs) 204 | self.env_agent_counts[brain_name][i] = len(brain_info.agents) 205 | 206 | @staticmethod 207 | def _merge_step_info(env_steps: List[EnvironmentResponse]) -> AllBrainInfo: 208 | accumulated_brain_info: AllBrainInfo = None 209 | for env_step in env_steps: 210 | all_brain_info: AllBrainInfo = env_step.payload 211 | for brain_name, brain_info in all_brain_info.items(): 212 | for i in range(len(brain_info.agents)): 213 | brain_info.agents[i] = ( 214 | str(env_step.worker_id) + "-" + str(brain_info.agents[i]) 215 | ) 216 | if accumulated_brain_info: 217 | accumulated_brain_info[brain_name].merge(brain_info) 218 | if not accumulated_brain_info: 219 | accumulated_brain_info = copy.deepcopy(all_brain_info) 220 | return accumulated_brain_info 221 | 222 | def _broadcast_message(self, name: str, payload=None): 223 | for env in self.envs: 224 | env.send(name, payload) 225 | -------------------------------------------------------------------------------- /mlagents/envs/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueFisher/RL-PPO-with-Unity/7a51787c35dbeb39dfc8835f7eed189eb28e0e59/mlagents/envs/tests/__init__.py -------------------------------------------------------------------------------- /mlagents/envs/tests/test_envs.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | import pytest 3 | 4 | import numpy as np 5 | 6 | from mlagents.envs import ( 7 | UnityEnvironment, 8 | UnityEnvironmentException, 9 | UnityActionException, 10 | BrainInfo, 11 | ) 12 | from mlagents.envs.mock_communicator import MockCommunicator 13 | 14 | 15 | @mock.patch("mlagents.envs.UnityEnvironment.get_communicator") 16 | def test_handles_bad_filename(get_communicator): 17 | with pytest.raises(UnityEnvironmentException): 18 | UnityEnvironment(" ") 19 | 20 | 21 | @mock.patch("mlagents.envs.UnityEnvironment.executable_launcher") 22 | @mock.patch("mlagents.envs.UnityEnvironment.get_communicator") 23 | def test_initialization(mock_communicator, mock_launcher): 24 | mock_communicator.return_value = MockCommunicator( 25 | discrete_action=False, visual_inputs=0 26 | ) 27 | env = UnityEnvironment(" ") 28 | with pytest.raises(UnityActionException): 29 | env.step([0]) 30 | assert env.brain_names[0] == "RealFakeBrain" 31 | env.close() 32 | 33 | 34 | @mock.patch("mlagents.envs.UnityEnvironment.executable_launcher") 35 | @mock.patch("mlagents.envs.UnityEnvironment.get_communicator") 36 | def test_reset(mock_communicator, mock_launcher): 37 | mock_communicator.return_value = MockCommunicator( 38 | discrete_action=False, visual_inputs=0 39 | ) 40 | env = UnityEnvironment(" ") 41 | brain = env.brains["RealFakeBrain"] 42 | brain_info = env.reset() 43 | env.close() 44 | assert not env.global_done 45 | assert isinstance(brain_info, dict) 46 | assert isinstance(brain_info["RealFakeBrain"], BrainInfo) 47 | assert isinstance(brain_info["RealFakeBrain"].visual_observations, list) 48 | assert isinstance(brain_info["RealFakeBrain"].vector_observations, np.ndarray) 49 | assert ( 50 | len(brain_info["RealFakeBrain"].visual_observations) 51 | == brain.number_visual_observations 52 | ) 53 | assert len(brain_info["RealFakeBrain"].vector_observations) == len( 54 | brain_info["RealFakeBrain"].agents 55 | ) 56 | assert ( 57 | len(brain_info["RealFakeBrain"].vector_observations[0]) 58 | == brain.vector_observation_space_size * brain.num_stacked_vector_observations 59 | ) 60 | 61 | 62 | @mock.patch("mlagents.envs.UnityEnvironment.executable_launcher") 63 | @mock.patch("mlagents.envs.UnityEnvironment.get_communicator") 64 | def test_step(mock_communicator, mock_launcher): 65 | mock_communicator.return_value = MockCommunicator( 66 | discrete_action=False, visual_inputs=0 67 | ) 68 | env = UnityEnvironment(" ") 69 | brain = env.brains["RealFakeBrain"] 70 | brain_info = env.reset() 71 | brain_info = env.step( 72 | [0] 73 | * brain.vector_action_space_size[0] 74 | * len(brain_info["RealFakeBrain"].agents) 75 | ) 76 | with pytest.raises(UnityActionException): 77 | env.step([0]) 78 | brain_info = env.step( 79 | [-1] 80 | * brain.vector_action_space_size[0] 81 | * len(brain_info["RealFakeBrain"].agents) 82 | ) 83 | with pytest.raises(UnityActionException): 84 | env.step( 85 | [0] 86 | * brain.vector_action_space_size[0] 87 | * len(brain_info["RealFakeBrain"].agents) 88 | ) 89 | env.close() 90 | assert env.global_done 91 | assert isinstance(brain_info, dict) 92 | assert isinstance(brain_info["RealFakeBrain"], BrainInfo) 93 | assert isinstance(brain_info["RealFakeBrain"].visual_observations, list) 94 | assert isinstance(brain_info["RealFakeBrain"].vector_observations, np.ndarray) 95 | assert ( 96 | len(brain_info["RealFakeBrain"].visual_observations) 97 | == brain.number_visual_observations 98 | ) 99 | assert len(brain_info["RealFakeBrain"].vector_observations) == len( 100 | brain_info["RealFakeBrain"].agents 101 | ) 102 | assert ( 103 | len(brain_info["RealFakeBrain"].vector_observations[0]) 104 | == brain.vector_observation_space_size * brain.num_stacked_vector_observations 105 | ) 106 | 107 | print("\n\n\n\n\n\n\n" + str(brain_info["RealFakeBrain"].local_done)) 108 | assert not brain_info["RealFakeBrain"].local_done[0] 109 | assert brain_info["RealFakeBrain"].local_done[2] 110 | 111 | 112 | @mock.patch("mlagents.envs.UnityEnvironment.executable_launcher") 113 | @mock.patch("mlagents.envs.UnityEnvironment.get_communicator") 114 | def test_close(mock_communicator, mock_launcher): 115 | comm = MockCommunicator(discrete_action=False, visual_inputs=0) 116 | mock_communicator.return_value = comm 117 | env = UnityEnvironment(" ") 118 | assert env._loaded 119 | env.close() 120 | assert not env._loaded 121 | assert comm.has_been_closed 122 | 123 | 124 | if __name__ == "__main__": 125 | pytest.main() 126 | -------------------------------------------------------------------------------- /mlagents/envs/tests/test_rpc_communicator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from mlagents.envs import RpcCommunicator 4 | from mlagents.envs import UnityWorkerInUseException 5 | 6 | 7 | def test_rpc_communicator_checks_port_on_create(): 8 | first_comm = RpcCommunicator() 9 | with pytest.raises(UnityWorkerInUseException): 10 | second_comm = RpcCommunicator() 11 | second_comm.close() 12 | first_comm.close() 13 | 14 | 15 | def test_rpc_communicator_close(): 16 | # Ensures it is possible to open a new RPC Communicators 17 | # after closing one on the same worker_id 18 | first_comm = RpcCommunicator() 19 | first_comm.close() 20 | second_comm = RpcCommunicator() 21 | second_comm.close() 22 | 23 | 24 | def test_rpc_communicator_create_multiple_workers(): 25 | # Ensures multiple RPC communicators can be created with 26 | # different worker_ids without causing an error. 27 | first_comm = RpcCommunicator() 28 | second_comm = RpcCommunicator(worker_id=1) 29 | first_comm.close() 30 | second_comm.close() 31 | -------------------------------------------------------------------------------- /mlagents/envs/tests/test_subprocess_unity_environment.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | from unittest.mock import Mock, MagicMock 3 | import unittest 4 | 5 | from mlagents.envs.subprocess_environment import * 6 | from mlagents.envs import UnityEnvironmentException, BrainInfo 7 | 8 | 9 | def mock_env_factory(worker_id: int): 10 | return mock.create_autospec(spec=BaseUnityEnvironment) 11 | 12 | 13 | class MockEnvWorker: 14 | def __init__(self, worker_id): 15 | self.worker_id = worker_id 16 | self.process = None 17 | self.conn = None 18 | self.send = MagicMock() 19 | self.recv = MagicMock() 20 | 21 | 22 | class SubprocessEnvironmentTest(unittest.TestCase): 23 | def test_environments_are_created(self): 24 | SubprocessUnityEnvironment.create_worker = MagicMock() 25 | env = SubprocessUnityEnvironment(mock_env_factory, 2) 26 | # Creates two processes 27 | self.assertEqual( 28 | env.create_worker.call_args_list, 29 | [mock.call(0, mock_env_factory), mock.call(1, mock_env_factory)], 30 | ) 31 | self.assertEqual(len(env.envs), 2) 32 | 33 | def test_step_async_fails_when_waiting(self): 34 | env = SubprocessUnityEnvironment(mock_env_factory, 0) 35 | env.waiting = True 36 | with self.assertRaises(UnityEnvironmentException): 37 | env.step_async(vector_action=[]) 38 | 39 | @staticmethod 40 | def test_step_async_splits_input_by_agent_count(): 41 | env = SubprocessUnityEnvironment(mock_env_factory, 0) 42 | env.env_agent_counts = {"MockBrain": [1, 3, 5]} 43 | env.envs = [MockEnvWorker(0), MockEnvWorker(1), MockEnvWorker(2)] 44 | env_0_actions = [[1.0, 2.0]] 45 | env_1_actions = [[3.0, 4.0]] * 3 46 | env_2_actions = [[5.0, 6.0]] * 5 47 | vector_action = {"MockBrain": env_0_actions + env_1_actions + env_2_actions} 48 | env.step_async(vector_action=vector_action) 49 | env.envs[0].send.assert_called_with( 50 | "step", ({"MockBrain": env_0_actions}, {}, {}, {}) 51 | ) 52 | env.envs[1].send.assert_called_with( 53 | "step", ({"MockBrain": env_1_actions}, {}, {}, {}) 54 | ) 55 | env.envs[2].send.assert_called_with( 56 | "step", ({"MockBrain": env_2_actions}, {}, {}, {}) 57 | ) 58 | 59 | def test_step_async_sets_waiting(self): 60 | env = SubprocessUnityEnvironment(mock_env_factory, 0) 61 | env.step_async(vector_action=[]) 62 | self.assertTrue(env.waiting) 63 | 64 | def test_step_await_fails_if_not_waiting(self): 65 | env = SubprocessUnityEnvironment(mock_env_factory, 0) 66 | with self.assertRaises(UnityEnvironmentException): 67 | env.step_await() 68 | 69 | def test_step_await_combines_brain_info(self): 70 | all_brain_info_env0 = { 71 | "MockBrain": BrainInfo( 72 | [], [[1.0, 2.0], [1.0, 2.0]], [], agents=[1, 2], memory=np.zeros((0, 0)) 73 | ) 74 | } 75 | all_brain_info_env1 = { 76 | "MockBrain": BrainInfo( 77 | [], [[3.0, 4.0]], [], agents=[3], memory=np.zeros((0, 0)) 78 | ) 79 | } 80 | env_worker_0 = MockEnvWorker(0) 81 | env_worker_0.recv.return_value = EnvironmentResponse( 82 | "step", 0, all_brain_info_env0 83 | ) 84 | env_worker_1 = MockEnvWorker(1) 85 | env_worker_1.recv.return_value = EnvironmentResponse( 86 | "step", 1, all_brain_info_env1 87 | ) 88 | env = SubprocessUnityEnvironment(mock_env_factory, 0) 89 | env.envs = [env_worker_0, env_worker_1] 90 | env.waiting = True 91 | combined_braininfo = env.step_await()["MockBrain"] 92 | self.assertEqual( 93 | combined_braininfo.vector_observations.tolist(), 94 | [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0]], 95 | ) 96 | self.assertEqual(combined_braininfo.agents, ["0-1", "0-2", "1-3"]) 97 | 98 | def test_step_resets_on_global_done(self): 99 | env_mock = Mock() 100 | env_mock.reset = Mock(return_value="reset_data") 101 | env_mock.global_done = True 102 | 103 | def mock_global_done_env_factory(worker_id: int): 104 | return env_mock 105 | 106 | mock_parent_connection = Mock() 107 | step_command = EnvironmentCommand("step", (None, None, None, None)) 108 | close_command = EnvironmentCommand("close") 109 | mock_parent_connection.recv = Mock() 110 | mock_parent_connection.recv.side_effect = [step_command, close_command] 111 | mock_parent_connection.send = Mock() 112 | 113 | worker( 114 | mock_parent_connection, cloudpickle.dumps(mock_global_done_env_factory), 0 115 | ) 116 | 117 | # recv called twice to get step and close command 118 | self.assertEqual(mock_parent_connection.recv.call_count, 2) 119 | 120 | # worker returns the data from the reset 121 | mock_parent_connection.send.assert_called_with( 122 | EnvironmentResponse("step", 0, "reset_data") 123 | ) 124 | -------------------------------------------------------------------------------- /simple_boat/ppo/config.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: SimpleBoat 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 5000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # std: true 13 | # mix: true 14 | # aux_cumulative_ratio: 0.4 15 | # good_trans_ratio: 1 16 | # addition_objective: false 17 | 18 | ppo_config: 19 | # save_per_iter: 1000 20 | write_summary_graph: true 21 | 22 | # batch_size: 2048 23 | # epoch_size: 10 24 | 25 | # init_td_threshold: 0.0 26 | # td_threshold_decay_steps: 100 27 | # td_threshold_rate: 0.5 28 | 29 | beta: 0.002 30 | # epsilon: 0.2 31 | 32 | # init_lr: 0.00005 33 | # min_lr: 0.00001 34 | # decay_steps: 500 35 | decay_rate: 1 36 | -------------------------------------------------------------------------------- /simple_boat/ppo/config_addition.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: SimpleBoat 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 10000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # std: true 13 | # mix: true 14 | # aux_cumulative_ratio: 0.4 15 | # good_trans_ratio: 1 16 | addition_objective: true 17 | 18 | ppo_config: 19 | # save_per_iter: 1000 20 | write_summary_graph: true 21 | 22 | # batch_size: 2048 23 | # epoch_size: 10 24 | 25 | # init_td_threshold: 0.0 26 | # td_threshold_decay_steps: 100 27 | # td_threshold_rate: 0.5 28 | 29 | beta: 0.002 30 | epsilon: 0.05 31 | 32 | # init_lr: 0.00005 33 | # min_lr: 0.00001 34 | # decay_steps: 500 35 | decay_rate: 1 36 | -------------------------------------------------------------------------------- /simple_boat/ppo/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | import numpy as np 5 | 6 | sys.path.append('../..') 7 | from algorithm.ppo_main import Main 8 | from algorithm.agent import Agent 9 | 10 | if __name__ == '__main__': 11 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s] - [%(name)s] - %(message)s') 12 | 13 | _log = logging.getLogger('tensorflow') 14 | _log.setLevel(logging.ERROR) 15 | 16 | logger = logging.getLogger('ppo') 17 | 18 | class AgentHitted(Agent): 19 | hitted = 0 20 | hitted_real = 0 21 | 22 | def _extra_log(self, 23 | state, 24 | action, 25 | reward, 26 | local_done, 27 | max_reached, 28 | state_): 29 | 30 | if not self.done and reward >= 1: 31 | self.hitted_real += 1 32 | if reward >= 1: 33 | self.hitted += 1 34 | 35 | class MainHitted(Main): 36 | def _log_episode_summaries(self, ppo, iteration, agents): 37 | rewards = np.array([a.reward for a in agents]) 38 | hitted = sum([a.hitted for a in agents]) 39 | hitted_real = sum([a.hitted_real for a in agents]) 40 | 41 | ppo.write_constant_summaries([ 42 | {'tag': 'reward/mean', 'simple_value': rewards.mean()}, 43 | {'tag': 'reward/max', 'simple_value': rewards.max()}, 44 | {'tag': 'reward/min', 'simple_value': rewards.min()}, 45 | {'tag': 'reward/hitted', 'simple_value': hitted}, 46 | {'tag': 'reward/hitted_real', 'simple_value': hitted_real} 47 | ], iteration) 48 | 49 | def _log_episode_info(self, ppo_i, iteration, agents): 50 | rewards = [a.reward for a in agents] 51 | hitted = sum([a.hitted for a in agents]) 52 | hitted_real = sum([a.hitted_real for a in agents]) 53 | 54 | rewards_sorted = ", ".join([f"{i:.1f}" for i in sorted(rewards)]) 55 | logger.info(f'{ppo_i}, iter {iteration}, rewards {rewards_sorted}, hitted {hitted}, hitted_real {hitted_real}') 56 | 57 | MainHitted(sys.argv[1:], AgentHitted) 58 | -------------------------------------------------------------------------------- /simple_boat/ppo/ppo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_probability as tfp 4 | 5 | import warnings 6 | warnings.filterwarnings("ignore", category=DeprecationWarning) 7 | 8 | initializer_helper = { 9 | 'kernel_initializer': tf.truncated_normal_initializer(0, .1), 10 | 'bias_initializer': tf.constant_initializer(.1) 11 | } 12 | 13 | 14 | class PPO_Sep_Custom(object): 15 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 16 | with tf.variable_scope(scope, reuse=reuse): 17 | policy, policy_variables = self._build_actor_net(s_inputs, 'actor', trainable) 18 | v, v_variables = self._build_critic_net(s_inputs, 'critic', trainable) 19 | 20 | return policy, v, policy_variables + v_variables 21 | 22 | def _build_critic_net(self, s_inputs, scope, trainable, reuse=False): 23 | with tf.variable_scope(scope, reuse=reuse): 24 | l = tf.layers.dense(s_inputs, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 25 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 26 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 27 | v = tf.layers.dense(l, 1, trainable=trainable, **initializer_helper) 28 | 29 | variables = tf.get_variable_scope().global_variables() 30 | 31 | return v, variables 32 | 33 | def _build_actor_net(self, s_inputs, scope, trainable, reuse=False): 34 | with tf.variable_scope(scope, reuse=reuse): 35 | l = tf.layers.dense(s_inputs, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 36 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 37 | 38 | mu = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 39 | mu = tf.layers.dense(mu, self.a_dim, tf.nn.tanh, trainable=trainable, **initializer_helper) 40 | sigma = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 41 | sigma = tf.layers.dense(sigma, self.a_dim, tf.nn.sigmoid, trainable=trainable, **initializer_helper) 42 | 43 | mu, sigma = mu, sigma + .1 44 | 45 | policy = tf.distributions.Normal(loc=mu, scale=sigma) 46 | 47 | variables = tf.get_variable_scope().global_variables() 48 | 49 | return policy, variables 50 | 51 | 52 | class PPO_Std_Custom(object): 53 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 54 | with tf.variable_scope(scope, reuse=reuse): 55 | l = tf.layers.dense(s_inputs, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 56 | l = tf.layers.dense(l, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 57 | 58 | prob_l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 59 | mu = tf.layers.dense(prob_l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 60 | mu = tf.layers.dense(mu, self.a_dim, tf.nn.tanh, trainable=trainable, **initializer_helper) 61 | sigma = tf.layers.dense(prob_l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 62 | sigma = tf.layers.dense(sigma, self.a_dim, tf.nn.sigmoid, trainable=trainable, **initializer_helper) 63 | mu, sigma = mu, sigma + .1 64 | 65 | policy = tf.distributions.Normal(loc=mu, scale=sigma) 66 | 67 | v_l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 68 | v_l = tf.layers.dense(v_l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 69 | v = tf.layers.dense(v_l, 1, trainable=trainable, **initializer_helper) 70 | 71 | variables = tf.get_variable_scope().global_variables() 72 | 73 | return policy, v, variables -------------------------------------------------------------------------------- /simple_boat/ppo_sep_critic/config.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: SimpleBoat 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 5000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # mix: true 13 | # aux_cumulative_ratio: 0.4 14 | # good_trans_ratio: 1 15 | # addition_objective: false 16 | 17 | critic_config: 18 | # save_per_iter: 1000 19 | write_summary_graph: true 20 | 21 | # batch_size: 2048 22 | # epoch_size: 10 23 | 24 | init_td_threshold: 0.01 25 | # td_threshold_decay_steps: 100 26 | td_threshold_rate: 0.9 27 | 28 | # init_lr: 0.00005 29 | decay_steps: 100 30 | decay_rate: 0.9 31 | 32 | 33 | ppo_config: 34 | # save_per_iter: 1000 35 | write_summary_graph: true 36 | 37 | # batch_size: 2048 38 | # epoch_size: 10 39 | 40 | beta: 0.002 41 | epsilon: 0.3 42 | 43 | # init_lr: 0.00005 44 | decay_steps: 100 45 | decay_rate: 0.9 46 | -------------------------------------------------------------------------------- /simple_boat/ppo_sep_critic/config_addition.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: SimpleBoat 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 5000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # mix: true 13 | # aux_cumulative_ratio: 0.4 14 | # good_trans_ratio: 1 15 | addition_objective: true 16 | 17 | critic_config: 18 | # save_per_iter: 1000 19 | write_summary_graph: true 20 | 21 | # batch_size: 2048 22 | # epoch_size: 10 23 | 24 | init_td_threshold: 0.01 25 | # td_threshold_decay_steps: 100 26 | td_threshold_rate: 0.9 27 | 28 | # init_lr: 0.00005 29 | decay_steps: 100 30 | decay_rate: 0.9 31 | 32 | 33 | ppo_config: 34 | # save_per_iter: 1000 35 | write_summary_graph: true 36 | 37 | # batch_size: 2048 38 | # epoch_size: 10 39 | 40 | beta: 0.002 41 | epsilon: 0.02 42 | 43 | # init_lr: 0.00005 44 | decay_steps: 100 45 | decay_rate: 0.9 46 | -------------------------------------------------------------------------------- /simple_boat/ppo_sep_critic/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | import numpy as np 5 | 6 | sys.path.append('../..') 7 | from algorithm.ppo_sep_critic_main import Main 8 | from algorithm.agent import Agent 9 | 10 | if __name__ == '__main__': 11 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s] - [%(name)s] - %(message)s') 12 | 13 | _log = logging.getLogger('tensorflow') 14 | _log.setLevel(logging.ERROR) 15 | 16 | logger = logging.getLogger('ppo') 17 | 18 | class AgentHitted(Agent): 19 | hitted = 0 20 | hitted_real = 0 21 | 22 | def _extra_log(self, 23 | state, 24 | action, 25 | reward, 26 | local_done, 27 | max_reached, 28 | state_): 29 | 30 | if not self.done and reward >= 1: 31 | self.hitted_real += 1 32 | if reward >= 1: 33 | self.hitted += 1 34 | 35 | class MainHitted(Main): 36 | def _log_episode_summaries(self, ppo, iteration, agents): 37 | rewards = np.array([a.reward for a in agents]) 38 | hitted = sum([a.hitted for a in agents]) 39 | hitted_real = sum([a.hitted_real for a in agents]) 40 | 41 | ppo.write_constant_summaries([ 42 | {'tag': 'reward/mean', 'simple_value': rewards.mean()}, 43 | {'tag': 'reward/max', 'simple_value': rewards.max()}, 44 | {'tag': 'reward/min', 'simple_value': rewards.min()}, 45 | {'tag': 'reward/hitted', 'simple_value': hitted}, 46 | {'tag': 'reward/hitted_real', 'simple_value': hitted_real} 47 | ], iteration) 48 | 49 | def _log_episode_info(self, ppo_i, iteration, agents): 50 | rewards = [a.reward for a in agents] 51 | hitted = sum([a.hitted for a in agents]) 52 | hitted_real = sum([a.hitted_real for a in agents]) 53 | 54 | rewards_sorted = ", ".join([f"{i:.1f}" for i in sorted(rewards)]) 55 | logger.info(f'{ppo_i}, iter {iteration}, rewards {rewards_sorted}, hitted {hitted}, hitted_real {hitted_real}') 56 | 57 | MainHitted(sys.argv[1:], AgentHitted) 58 | -------------------------------------------------------------------------------- /simple_boat/ppo_sep_critic/ppo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_probability as tfp 4 | 5 | import warnings 6 | warnings.filterwarnings("ignore", category=DeprecationWarning) 7 | 8 | initializer_helper = { 9 | 'kernel_initializer': tf.truncated_normal_initializer(0, .1), 10 | 'bias_initializer': tf.constant_initializer(.1) 11 | } 12 | 13 | 14 | class Critic_Custom(object): 15 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 16 | with tf.variable_scope(scope): 17 | l = tf.layers.dense(self.pl_s, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 18 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 19 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 20 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 21 | v = tf.layers.dense(l, 1, trainable=trainable, **initializer_helper) 22 | 23 | return v 24 | 25 | 26 | class PPO_Custom(object): 27 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 28 | with tf.variable_scope(scope, reuse=reuse): 29 | l = tf.layers.dense(s_inputs, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 30 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 31 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 32 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 33 | 34 | mu = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 35 | mu = tf.layers.dense(mu, self.a_dim, tf.nn.tanh, trainable=trainable, **initializer_helper) 36 | sigma = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 37 | sigma = tf.layers.dense(sigma, self.a_dim, tf.nn.sigmoid, trainable=trainable, **initializer_helper) 38 | 39 | mu, sigma = mu, sigma + .1 40 | 41 | policy = tf.distributions.Normal(loc=mu, scale=sigma) 42 | 43 | variables = tf.get_variable_scope().global_variables() 44 | 45 | return policy, variables 46 | -------------------------------------------------------------------------------- /simple_roller/ppo/config.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: ContinousSimpleRoller 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 3000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # std: true 13 | # mix: true 14 | # aux_cumulative_ratio: 0.4 15 | # good_trans_ratio: 1 16 | # addition_objective: false 17 | 18 | ppo_config: 19 | # save_per_iter: 1000 20 | write_summary_graph: true 21 | 22 | # batch_size: 2048 23 | # epoch_size: 10 24 | 25 | # init_td_threshold: 0.0 26 | # td_threshold_decay_steps: 100 27 | # td_threshold_rate: 0.5 28 | 29 | # beta: 0.001 30 | # epsilon: 0.2 31 | 32 | # init_lr: 0.00005 33 | # min_lr: 0.00001 34 | decay_steps: 500 35 | decay_rate: 0.7 36 | -------------------------------------------------------------------------------- /simple_roller/ppo/config_addition.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: ContinousSimpleRoller 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 3000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # std: true 13 | # mix: true 14 | # aux_cumulative_ratio: 0.4 15 | # good_trans_ratio: 1 16 | addition_objective: true 17 | 18 | ppo_config: 19 | # save_per_iter: 1000 20 | write_summary_graph: true 21 | 22 | # batch_size: 2048 23 | # epoch_size: 10 24 | 25 | # init_td_threshold: 0.0 26 | # td_threshold_decay_steps: 100 27 | # td_threshold_rate: 0.5 28 | 29 | # beta: 0.001 30 | epsilon: 0.05 31 | 32 | # init_lr: 0.00005 33 | # min_lr: 0.00001 34 | decay_steps: 500 35 | decay_rate: 0.7 36 | -------------------------------------------------------------------------------- /simple_roller/ppo/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | import numpy as np 5 | 6 | sys.path.append('../..') 7 | from algorithm.ppo_main import Main 8 | from algorithm.agent import Agent 9 | 10 | if __name__ == '__main__': 11 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s] - [%(name)s] - %(message)s') 12 | 13 | _log = logging.getLogger('tensorflow') 14 | _log.setLevel(logging.ERROR) 15 | 16 | logger = logging.getLogger('ppo') 17 | 18 | class AgentHitted(Agent): 19 | hitted = 0 20 | hitted_real = 0 21 | 22 | def _extra_log(self, 23 | state, 24 | action, 25 | reward, 26 | local_done, 27 | max_reached, 28 | state_): 29 | 30 | if not self.done and reward >= 1: 31 | self.hitted_real += 1 32 | if reward >= 1: 33 | self.hitted += 1 34 | 35 | class MainHitted(Main): 36 | def _log_episode_summaries(self, ppo, iteration, agents): 37 | rewards = np.array([a.reward for a in agents]) 38 | hitted = sum([a.hitted for a in agents]) 39 | hitted_real = sum([a.hitted_real for a in agents]) 40 | 41 | ppo.write_constant_summaries([ 42 | {'tag': 'reward/mean', 'simple_value': rewards.mean()}, 43 | {'tag': 'reward/max', 'simple_value': rewards.max()}, 44 | {'tag': 'reward/min', 'simple_value': rewards.min()}, 45 | {'tag': 'reward/hitted', 'simple_value': hitted}, 46 | {'tag': 'reward/hitted_real', 'simple_value': hitted_real} 47 | ], iteration) 48 | 49 | def _log_episode_info(self, ppo_i, iteration, agents): 50 | rewards = [a.reward for a in agents] 51 | hitted = sum([a.hitted for a in agents]) 52 | hitted_real = sum([a.hitted_real for a in agents]) 53 | 54 | rewards_sorted = ", ".join([f"{i:.1f}" for i in sorted(rewards)]) 55 | logger.info(f'{ppo_i}, iter {iteration}, rewards {rewards_sorted}, hitted {hitted}, hitted_real {hitted_real}') 56 | 57 | MainHitted(sys.argv[1:], AgentHitted) 58 | -------------------------------------------------------------------------------- /simple_roller/ppo/ppo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_probability as tfp 4 | 5 | import warnings 6 | warnings.filterwarnings("ignore", category=DeprecationWarning) 7 | 8 | initializer_helper = { 9 | 'kernel_initializer': tf.truncated_normal_initializer(0, .1), 10 | 'bias_initializer': tf.constant_initializer(.1) 11 | } 12 | 13 | 14 | class PPO_Sep_Custom(object): 15 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 16 | with tf.variable_scope(scope, reuse=reuse): 17 | policy, policy_variables = self._build_actor_net(s_inputs, 'actor', trainable) 18 | v, v_variables = self._build_critic_net(s_inputs, 'critic', trainable) 19 | 20 | return policy, v, policy_variables + v_variables 21 | 22 | def _build_critic_net(self, s_inputs, scope, trainable, reuse=False): 23 | with tf.variable_scope(scope, reuse=reuse): 24 | l = tf.layers.dense(s_inputs, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 25 | l = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 26 | l = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 27 | v = tf.layers.dense(l, 1, trainable=trainable, **initializer_helper) 28 | 29 | variables = tf.get_variable_scope().global_variables() 30 | 31 | return v, variables 32 | 33 | def _build_actor_net(self, s_inputs, scope, trainable, reuse=False): 34 | with tf.variable_scope(scope, reuse=reuse): 35 | l = tf.layers.dense(s_inputs, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 36 | 37 | mu = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 38 | mu = tf.layers.dense(mu, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 39 | mu = tf.layers.dense(mu, self.a_dim, tf.nn.tanh, trainable=trainable, **initializer_helper) 40 | sigma = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 41 | sigma = tf.layers.dense(sigma, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 42 | sigma = tf.layers.dense(sigma, self.a_dim, tf.nn.sigmoid, trainable=trainable, **initializer_helper) 43 | 44 | mu, sigma = mu, sigma + .1 45 | 46 | policy = tf.distributions.Normal(loc=mu, scale=sigma) 47 | 48 | variables = tf.get_variable_scope().global_variables() 49 | 50 | return policy, variables 51 | 52 | 53 | class PPO_Std_Custom(object): 54 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 55 | with tf.variable_scope(scope, reuse=reuse): 56 | l = tf.layers.dense(s_inputs, 64, tf.nn.relu, trainable=trainable, **initializer_helper) 57 | l = tf.layers.dense(l, 64, tf.nn.relu, trainable=trainable, **initializer_helper) 58 | 59 | prob_l = tf.layers.dense(l, 64, tf.nn.relu, trainable=trainable, **initializer_helper) 60 | mu = tf.layers.dense(prob_l, 64, tf.nn.relu, trainable=trainable, **initializer_helper) 61 | mu = tf.layers.dense(mu, self.a_dim, tf.nn.tanh, trainable=trainable, **initializer_helper) 62 | sigma = tf.layers.dense(prob_l, 64, tf.nn.relu, trainable=trainable, **initializer_helper) 63 | sigma = tf.layers.dense(sigma, self.a_dim, tf.nn.sigmoid, trainable=trainable, **initializer_helper) 64 | mu, sigma = mu, sigma + .1 65 | 66 | policy = tf.distributions.Normal(loc=mu, scale=sigma) 67 | 68 | v_l = tf.layers.dense(l, 64, tf.nn.relu, trainable=trainable, **initializer_helper) 69 | v_l = tf.layers.dense(v_l, 64, tf.nn.relu, trainable=trainable, **initializer_helper) 70 | v_l = tf.layers.dense(v_l, 64, tf.nn.relu, trainable=trainable, **initializer_helper) 71 | v = tf.layers.dense(v_l, 1, trainable=trainable, **initializer_helper) 72 | 73 | variables = tf.get_variable_scope().global_variables() 74 | 75 | return policy, v, variables 76 | -------------------------------------------------------------------------------- /simple_roller/ppo_sep_critic/config.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: ContinousSimpleRoller 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 3000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # mix: true 13 | # aux_cumulative_ratio: 0.4 14 | # good_trans_ratio: 1 15 | # addition_objective: false 16 | 17 | critic_config: 18 | # save_per_iter: 1000 19 | write_summary_graph: true 20 | 21 | # batch_size: 2048 22 | # epoch_size: 10 23 | 24 | init_td_threshold: 0.01 25 | # td_threshold_decay_steps: 100 26 | # td_threshold_rate: 0.5 27 | 28 | # init_lr: 0.00005 29 | decay_steps: 70 30 | decay_rate: 0.7 31 | 32 | 33 | ppo_config: 34 | # save_per_iter: 1000 35 | write_summary_graph: true 36 | 37 | # batch_size: 2048 38 | # epoch_size: 10 39 | 40 | # beta: 0.001 41 | # epsilon: 0.2 42 | 43 | # init_lr: 0.00005 44 | decay_steps: 70 45 | decay_rate: 0.7 46 | -------------------------------------------------------------------------------- /simple_roller/ppo_sep_critic/config_addition.yaml: -------------------------------------------------------------------------------- 1 | build_path: 2 | win32: C:\Users\Fisher\Documents\Unity\build-RL-Envs\RL-Envs.exe 3 | scene: ContinousSimpleRoller 4 | 5 | # lambda: 1 6 | # gamma: 0.99 7 | max_iter: 3000 8 | # policies_num: 1 9 | # agents_num_p_policy: 1 10 | # reset_on_iteration: true 11 | seed: 100 12 | # mix: true 13 | # aux_cumulative_ratio: 0.4 14 | # good_trans_ratio: 1 15 | addition_objective: true 16 | 17 | critic_config: 18 | # save_per_iter: 1000 19 | write_summary_graph: true 20 | 21 | # batch_size: 2048 22 | # epoch_size: 10 23 | 24 | init_td_threshold: 0.01 25 | # td_threshold_decay_steps: 100 26 | # td_threshold_rate: 0.5 27 | 28 | # init_lr: 0.00005 29 | decay_steps: 70 30 | decay_rate: 0.7 31 | 32 | 33 | ppo_config: 34 | # save_per_iter: 1000 35 | write_summary_graph: true 36 | 37 | # batch_size: 2048 38 | # epoch_size: 10 39 | 40 | # beta: 0.001 41 | epsilon: 0.02 42 | 43 | # init_lr: 0.00005 44 | decay_steps: 70 45 | decay_rate: 0.7 46 | -------------------------------------------------------------------------------- /simple_roller/ppo_sep_critic/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | import numpy as np 5 | 6 | sys.path.append('../..') 7 | from algorithm.ppo_sep_critic_main import Main 8 | from algorithm.agent import Agent 9 | 10 | if __name__ == '__main__': 11 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s] - [%(name)s] - %(message)s') 12 | 13 | _log = logging.getLogger('tensorflow') 14 | _log.setLevel(logging.ERROR) 15 | 16 | logger = logging.getLogger('ppo') 17 | 18 | class AgentHitted(Agent): 19 | hitted = 0 20 | hitted_real = 0 21 | 22 | def _extra_log(self, 23 | state, 24 | action, 25 | reward, 26 | local_done, 27 | max_reached, 28 | state_): 29 | 30 | if not self.done and reward >= 1: 31 | self.hitted_real += 1 32 | if reward >= 1: 33 | self.hitted += 1 34 | 35 | class MainHitted(Main): 36 | def _log_episode_summaries(self, ppo, iteration, agents): 37 | rewards = np.array([a.reward for a in agents]) 38 | hitted = sum([a.hitted for a in agents]) 39 | hitted_real = sum([a.hitted_real for a in agents]) 40 | 41 | ppo.write_constant_summaries([ 42 | {'tag': 'reward/mean', 'simple_value': rewards.mean()}, 43 | {'tag': 'reward/max', 'simple_value': rewards.max()}, 44 | {'tag': 'reward/min', 'simple_value': rewards.min()}, 45 | {'tag': 'reward/hitted', 'simple_value': hitted}, 46 | {'tag': 'reward/hitted_real', 'simple_value': hitted_real} 47 | ], iteration) 48 | 49 | def _log_episode_info(self, ppo_i, iteration, agents): 50 | rewards = [a.reward for a in agents] 51 | hitted = sum([a.hitted for a in agents]) 52 | hitted_real = sum([a.hitted_real for a in agents]) 53 | 54 | rewards_sorted = ", ".join([f"{i:.1f}" for i in sorted(rewards)]) 55 | logger.info(f'{ppo_i}, iter {iteration}, rewards {rewards_sorted}, hitted {hitted}, hitted_real {hitted_real}') 56 | 57 | MainHitted(sys.argv[1:], AgentHitted) 58 | -------------------------------------------------------------------------------- /simple_roller/ppo_sep_critic/ppo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_probability as tfp 4 | 5 | import warnings 6 | warnings.filterwarnings("ignore", category=DeprecationWarning) 7 | 8 | initializer_helper = { 9 | 'kernel_initializer': tf.truncated_normal_initializer(0, .1), 10 | 'bias_initializer': tf.constant_initializer(.1) 11 | } 12 | 13 | 14 | class Critic_Custom(object): 15 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 16 | with tf.variable_scope(scope): 17 | l = tf.layers.dense(self.pl_s, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 18 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 19 | l = tf.layers.dense(l, 128, tf.nn.relu, trainable=trainable, **initializer_helper) 20 | l = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 21 | v = tf.layers.dense(l, 1, trainable=trainable, **initializer_helper) 22 | 23 | return v 24 | 25 | 26 | class PPO_Custom(object): 27 | def _build_net(self, s_inputs, scope, trainable, reuse=False): 28 | with tf.variable_scope(scope, reuse=reuse): 29 | l = tf.layers.dense(s_inputs, 512, tf.nn.relu, trainable=trainable, **initializer_helper) 30 | l = tf.layers.dense(l, 256, tf.nn.relu, trainable=trainable, **initializer_helper) 31 | l = tf.layers.dense(l, 128, tf.nn.relu, trainable=trainable, **initializer_helper) 32 | l = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 33 | 34 | mu = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 35 | mu = tf.layers.dense(mu, self.a_dim, tf.nn.tanh, trainable=trainable, **initializer_helper) 36 | sigma = tf.layers.dense(l, 32, tf.nn.relu, trainable=trainable, **initializer_helper) 37 | sigma = tf.layers.dense(sigma, self.a_dim, tf.nn.sigmoid, trainable=trainable, **initializer_helper) 38 | 39 | mu, sigma = mu, sigma + .1 40 | 41 | policy = tf.distributions.Normal(loc=mu, scale=sigma) 42 | 43 | variables = tf.get_variable_scope().global_variables() 44 | 45 | return policy, variables 46 | --------------------------------------------------------------------------------