├── LICENSE ├── README.md ├── Report.pdf ├── ddpgActor_Model.pth ├── ddpgCritic_Model.pth ├── ddpg_agent.py ├── media ├── ReacherUnityEnvironment.png ├── exampleTrainingScoresGraph.png └── reachertask.gif ├── model.py ├── python ├── Basics.ipynb ├── README.md ├── communicator_objects │ ├── __init__.py │ ├── agent_action_proto_pb2.py │ ├── agent_info_proto_pb2.py │ ├── brain_parameters_proto_pb2.py │ ├── brain_type_proto_pb2.py │ ├── command_proto_pb2.py │ ├── engine_configuration_proto_pb2.py │ ├── environment_parameters_proto_pb2.py │ ├── header_pb2.py │ ├── resolution_proto_pb2.py │ ├── space_type_proto_pb2.py │ ├── unity_input_pb2.py │ ├── unity_message_pb2.py │ ├── unity_output_pb2.py │ ├── unity_rl_initialization_input_pb2.py │ ├── unity_rl_initialization_output_pb2.py │ ├── unity_rl_input_pb2.py │ ├── unity_rl_output_pb2.py │ ├── unity_to_external_pb2.py │ └── unity_to_external_pb2_grpc.py ├── curricula │ ├── push.json │ ├── test.json │ └── wall.json ├── learn.py ├── requirements.txt ├── setup.py ├── tests │ ├── __init__.py │ ├── mock_communicator.py │ ├── test_bc.py │ ├── test_ppo.py │ ├── test_unityagents.py │ └── test_unitytrainers.py ├── trainer_config.yaml ├── unityagents │ ├── __init__.py │ ├── brain.py │ ├── communicator.py │ ├── curriculum.py │ ├── environment.py │ ├── exception.py │ ├── rpc_communicator.py │ └── socket_communicator.py └── unitytrainers │ ├── __init__.py │ ├── bc │ ├── __init__.py │ ├── models.py │ └── trainer.py │ ├── buffer.py │ ├── models.py │ ├── ppo │ ├── __init__.py │ ├── models.py │ └── trainer.py │ ├── trainer.py │ └── trainer_controller.py ├── replay_buffer.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # DDPG (Actor-Critic) Reinforcement Learning using PyTorch and Unity ML-Agents 2 | A simple example of how to implement vector based DDPG using PyTorch and a ML-Agents environment. 3 | 4 | The repository includes the following files: 5 | - ddpg_agent.py -> ddpg-agent implementation 6 | - replay_buffer.py -> ddpg-agent's replay buffer implementation 7 | - model.py -> example PyTorch Actor and Critic neural networks 8 | - train.py -> initializes and implements the training processes for a DDPG-agent. 9 | - test.py -> tests a trained DDPG-agent 10 | 11 | The repository also includes links to the Mac/Linux/Windows versions of a simple Unity environment, *Reacher*, for testing. 12 | This Unity application and testing environment was developed using ML-Agents Beta v0.4. The version of the Reacher environment employed for this project was developed for the Udacity Deep Reinforcement Nanodegree course. For more information about this course visit: https://www.udacity.com/course/deep-reinforcement-learning-nanodegree--nd893 13 | 14 | The files in the python/. directory are the ML-Agents toolkit files and dependencies required to run the Reacher environment. 15 | For more information about the Unity ML-Agents Toolkit visit: https://github.com/Unity-Technologies/ml-agents 16 | 17 | ## Example Unity Environment - Reacher 18 | The example uses a modified version of the Unity ML-Agents Reacher Example Environment. 19 | The environment includes In this environment, a double-jointed arm can move to target locations. 20 | A reward of +0.1 is provided for each step that the agent's hand is in the goal location. 21 | Thus, the goal of your agent is to maintain its position at the target location for as many 22 | time steps as possible. The environment uses multiple unity agents to increase training time. 23 | 24 | ### Multiagent Traning 25 | The Reacher environment contains multiple unity agents to increase training time. The training agent collects observations and learns from the experiences of all of the unity agents simultaneously. The Reacher environment example employed here has 20 unity agents (i.e., 20 double-jointed arms). 26 | 27 | ![Trained DDPG-Agent Reacher Unity Agents Following Tragets](media/reachertask.gif) 28 | 29 | ### State and Action Space 30 | The observation space consists of 33 variables corresponding to 31 | position, rotation, velocity, and angular velocities of the arm. 32 | Each action is a vector with four numbers, corresponding to torque 33 | applicable to two joints. Every entry in the action vector should 34 | be a number between -1 and 1. 35 | 36 | ## Installation and Dependencies 37 | 1. Anaconda Python 3.6: Download and installation instructions here: https://www.anaconda.com/download/ 38 | 39 | 2. Create (and activate) a new conda (virtual) environment with Python 3.6. 40 | - Linux or Mac: 41 | 42 | `conda create --name yourenvnamehere python=3.6` 43 | 44 | `source activate yourenvnamehere` 45 | 46 | - Windows: 47 | 48 | `conda create --name yourenvnamehere python=3.6` 49 | 50 | `activate yourenvnamehere` 51 | 52 | 3. Download and save this GitHub repository. 53 | 54 | 4. To install required dependencies (torch, ML-Agents trainers (v.4), etc...) 55 | - Naviagte to where you downloaded and saved this GitHub repository (e.g., *yourpath/thisgithubrepository*) 56 | - Change to the '.python/' subdirectory and run from the command line: 57 | 58 | `pip3 install .` 59 | 60 | - Note: depdeing on your system setup, you may have to install PyTorch separatetly. 61 | 62 | ## Download the Unity Environment 63 | For this example project, you will not need to install Unity - this is because you can use a version of the Reacher's unity environment that is already built (compiled) as a standalone application. 64 | 65 | Download the relevant environment zip file from one of the links below. You need only to download the environment that matches your operating system: 66 | 67 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher_Linux.zip) 68 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher.app.zip) 69 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher_Windows_x86.zip) 70 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher_Windows_x86_64.zip) 71 | 72 | After you have downloaded the relevant zip file, navigate to where you downloaded and saved this GitHub repository and place the file in the main folder of the repository, then unzip (or decompress) the file. 73 | 74 | NOTE: The Reacher environment is similar to, but not identical to the Reacher environment on the Unity ML-Agents GitHub page. 75 | 76 | ## Training 77 | - activate the conda environment you created above 78 | - change the directory to the 'yourpath/thisgithubrepository' directory. 79 | - open `train.py`, find STEP 2 (lines 47 to 54) and set the relevant version of Reacher to match your operating system. 80 | - run the following command: 81 | 82 | `python train.py` 83 | 84 | - training will complete once the agent reaches *solved_score* in `train.py`. 85 | - after training a *ddpgActor_Model_datetime.path* and *ddpgCritic_Model_datetime.path* files will be saved with the trained model weights 86 | - a *ddpgAgent_scores_datetime.csv* file will also be saved with the scores received during training. You can use this file to plot or assess training performance (see below figure). 87 | - It is recommended that you train multiple agents and test different hyperparameter settings in `train.py` and `ddpg_agent.py`. 88 | - For more information about the DDPG training algorithm and the training hyperparameters see the included `Report.pdf` file. 89 | 90 | ![Example of agent performance (score) as a function of training episodes](media/exampleTrainingScoresGraph.png) 91 | 92 | 93 | ## Testing 94 | - activate the conda environment you created above 95 | - change the directory to the 'yourpath/thisgithubrepository' directory. 96 | - run the following command: 97 | 98 | `python test.py` 99 | 100 | - An example model weights files are included in the repository (*ddpgActor_Model.pth* and *ddpgCritic_Model.pth*). 101 | - A different model weights file can be tested by changing the model file name defined in `test.py`. 102 | -------------------------------------------------------------------------------- /Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkiwilabs/DDPG-using-PyTorch-and-ML-Agents/e68f85e20bbaccf89b92b072124db87fee9182ae/Report.pdf -------------------------------------------------------------------------------- /ddpgActor_Model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkiwilabs/DDPG-using-PyTorch-and-ML-Agents/e68f85e20bbaccf89b92b072124db87fee9182ae/ddpgActor_Model.pth -------------------------------------------------------------------------------- /ddpgCritic_Model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkiwilabs/DDPG-using-PyTorch-and-ML-Agents/e68f85e20bbaccf89b92b072124db87fee9182ae/ddpgCritic_Model.pth -------------------------------------------------------------------------------- /ddpg_agent.py: -------------------------------------------------------------------------------- 1 | """" 2 | Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL) 3 | Code Expanded and Adapted from Code provided by Udacity DRL Team, 2018. 4 | """ 5 | 6 | import numpy as np 7 | import random 8 | import copy 9 | from collections import namedtuple, deque 10 | 11 | from model import Actor, Critic 12 | from replay_buffer import ReplayBuffer 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | 18 | 19 | BUFFER_SIZE = int(1e5) # replay buffer size 20 | BATCH_SIZE = 128 # minibatch size 21 | GAMMA = 0.99 # discount factor 22 | TAU = 1e-3 # for soft update of target parameters 23 | LR_ACTOR = 1e-4 # learning rate of the actor 24 | LR_CRITIC = 1e-4 # learning rate of the critic 25 | WEIGHT_DECAY = 0.0 # L2 weight decay 26 | 27 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 28 | 29 | class Agent(): 30 | """Interacts with and learns from the environment.""" 31 | 32 | def __init__(self, state_size, action_size, num_agents, random_seed): 33 | """Initialize an Agent object. 34 | 35 | Params 36 | ====== 37 | state_size (int): dimension of each state 38 | action_size (int): dimension of each action 39 | random_seed (int): random seed 40 | """ 41 | self.state_size = state_size 42 | self.action_size = action_size 43 | self.num_agents = num_agents 44 | self.seed = random.seed(random_seed) 45 | 46 | # Actor Network (w/ Target Network) 47 | self.actor_local = Actor(state_size, action_size, random_seed).to(device) 48 | self.actor_target = Actor(state_size, action_size, random_seed).to(device) 49 | self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=LR_ACTOR) 50 | 51 | # Critic Network (w/ Target Network) 52 | self.critic_local = Critic(state_size, action_size, random_seed).to(device) 53 | self.critic_target = Critic(state_size, action_size, random_seed).to(device) 54 | self.critic_optimizer = optim.Adam(self.critic_local.parameters(), lr=LR_CRITIC, weight_decay=WEIGHT_DECAY) 55 | 56 | # Noise process for each agent 57 | self.noise = OUNoise((num_agents, action_size), random_seed) 58 | 59 | # Replay memory 60 | self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, random_seed) 61 | 62 | def step(self, states, actions, rewards, next_states, dones): 63 | """Save experience in replay memory, and use random sample from buffer to learn.""" 64 | # Save experience / reward 65 | for agent in range(self.num_agents): 66 | self.memory.add(states[agent,:], actions[agent,:], rewards[agent], next_states[agent,:], dones[agent]) 67 | 68 | # Learn, if enough samples are available in memory 69 | if len(self.memory) > BATCH_SIZE: 70 | experiences = self.memory.sample() 71 | self.learn(experiences) 72 | 73 | def act(self, state, add_noise=True): 74 | """Returns actions for given state as per current policy.""" 75 | state = torch.from_numpy(state).float().to(device) 76 | acts = np.zeros((self.num_agents, self.action_size)) 77 | self.actor_local.eval() 78 | with torch.no_grad(): 79 | for agent in range(self.num_agents): 80 | acts[agent,:] = self.actor_local(state[agent,:]).cpu().data.numpy() 81 | self.actor_local.train() 82 | if add_noise: 83 | acts += self.noise.sample() 84 | return np.clip(acts, -1, 1) 85 | 86 | def reset(self): 87 | self.noise.reset() 88 | 89 | def learn(self, experiences): 90 | """Update policy and value parameters using given batch of experience tuples. 91 | Q_targets = r + γ * critic_target(next_state, actor_target(next_state)) 92 | where: 93 | actor_target(state) -> action 94 | critic_target(state, action) -> Q-value 95 | 96 | Params 97 | ====== 98 | experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 99 | gamma (float): discount factor 100 | """ 101 | states, actions, rewards, next_states, dones = experiences 102 | 103 | # ---------------------------- update critic ---------------------------- # 104 | # Get predicted next-state actions and Q values from target models 105 | actions_next = self.actor_target(next_states) 106 | Q_targets_next = self.critic_target(next_states, actions_next) 107 | # Compute Q targets for current states (y_i) 108 | Q_targets = rewards + (GAMMA * Q_targets_next * (1 - dones)) 109 | # Compute critic loss 110 | Q_expected = self.critic_local(states, actions) 111 | critic_loss = F.mse_loss(Q_expected, Q_targets) 112 | # Minimize the loss 113 | self.critic_optimizer.zero_grad() 114 | critic_loss.backward() 115 | #torch.nn.utils.clip_grad_norm(self.critic_local.parameters(), 1) 116 | self.critic_optimizer.step() 117 | 118 | # ---------------------------- update actor ---------------------------- # 119 | # Compute actor loss 120 | actions_pred = self.actor_local(states) 121 | actor_loss = -self.critic_local(states, actions_pred).mean() 122 | # Minimize the loss 123 | self.actor_optimizer.zero_grad() 124 | actor_loss.backward() 125 | self.actor_optimizer.step() 126 | 127 | # ----------------------- update target networks ----------------------- # 128 | self.soft_update(self.critic_local, self.critic_target, TAU) 129 | self.soft_update(self.actor_local, self.actor_target, TAU) 130 | 131 | def soft_update(self, local_model, target_model, tau): 132 | """Soft update model parameters. 133 | θ_target = τ*θ_local + (1 - τ)*θ_target 134 | 135 | Params 136 | ====== 137 | local_model: PyTorch model (weights will be copied from) 138 | target_model: PyTorch model (weights will be copied to) 139 | tau (float): interpolation parameter 140 | """ 141 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): 142 | target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data) 143 | 144 | class OUNoise: 145 | """Ornstein-Uhlenbeck process.""" 146 | 147 | def __init__(self, size, seed, mu=0.0, theta=0.15, sigma=0.15, sigma_min = 0.05, sigma_decay=.975): 148 | """Initialize parameters and noise process.""" 149 | self.mu = mu * np.ones(size) 150 | self.theta = theta 151 | self.sigma = sigma 152 | self.sigma_min = sigma_min 153 | self.sigma_decay = sigma_decay 154 | self.seed = random.seed(seed) 155 | self.size = size 156 | self.reset() 157 | 158 | def reset(self): 159 | """Reset the internal state (= noise) to mean (mu).""" 160 | self.state = copy.copy(self.mu) 161 | """Resduce sigma from initial value to min""" 162 | self.sigma = max(self.sigma_min, self.sigma*self.sigma_decay) 163 | 164 | def sample(self): 165 | """Update internal state and return it as a noise sample.""" 166 | x = self.state 167 | dx = self.theta * (self.mu - x) + self.sigma * np.random.standard_normal(self.size) 168 | self.state = x + dx 169 | return self.state -------------------------------------------------------------------------------- /media/ReacherUnityEnvironment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkiwilabs/DDPG-using-PyTorch-and-ML-Agents/e68f85e20bbaccf89b92b072124db87fee9182ae/media/ReacherUnityEnvironment.png -------------------------------------------------------------------------------- /media/exampleTrainingScoresGraph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkiwilabs/DDPG-using-PyTorch-and-ML-Agents/e68f85e20bbaccf89b92b072124db87fee9182ae/media/exampleTrainingScoresGraph.png -------------------------------------------------------------------------------- /media/reachertask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkiwilabs/DDPG-using-PyTorch-and-ML-Agents/e68f85e20bbaccf89b92b072124db87fee9182ae/media/reachertask.gif -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example Neural Network Model for Vector Observation DDPG (Actor-Critic) Agent 3 | DDPG Model for Unity ML-Agents Environments using PyTorch 4 | 5 | Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL) 6 | Code expanded and adapted from code examples provided by Udacity DRL Team, 2018. 7 | """ 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | def hidden_init(layer): 16 | fan_in = layer.weight.data.size()[0] 17 | lim = 1. / np.sqrt(fan_in) 18 | return (-lim, lim) 19 | 20 | class Actor(nn.Module): 21 | """Actor (Policy) Model.""" 22 | 23 | def __init__(self, state_size, action_size, seed, fc1_units=256, fc2_units=128): 24 | """Initialize parameters and build model. 25 | Params 26 | ====== 27 | state_size (int): Dimension of each state 28 | action_size (int): Dimension of each action 29 | seed (int): Random seed 30 | fc1_units (int): Number of nodes in first hidden layer 31 | fc2_units (int): Number of nodes in second hidden layer 32 | """ 33 | super(Actor, self).__init__() 34 | self.seed = torch.manual_seed(seed) 35 | self.fc1 = nn.Linear(state_size, fc1_units) 36 | self.fc2 = nn.Linear(fc1_units, fc2_units) 37 | self.fc3 = nn.Linear(fc2_units, action_size) 38 | self.reset_parameters() 39 | 40 | def reset_parameters(self): 41 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) 42 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 43 | self.fc3.weight.data.uniform_(-3e-3, 3e-3) 44 | 45 | def forward(self, state): 46 | """Build an actor (policy) network that maps states -> actions.""" 47 | x = F.relu(self.fc1(state)) 48 | x = F.relu(self.fc2(x)) 49 | return torch.tanh(self.fc3(x)) 50 | 51 | 52 | class Critic(nn.Module): 53 | """Critic (Value) Model.""" 54 | 55 | def __init__(self, state_size, action_size, seed, fcs1_units=256, fc2_units=128): 56 | """Initialize parameters and build model. 57 | Params 58 | ====== 59 | state_size (int): Dimension of each state 60 | action_size (int): Dimension of each action 61 | seed (int): Random seed 62 | fcs1_units (int): Number of nodes in the first hidden layer 63 | fc2_units (int): Number of nodes in the second hidden layer 64 | """ 65 | super(Critic, self).__init__() 66 | self.seed = torch.manual_seed(seed) 67 | self.fcs1 = nn.Linear(state_size, fcs1_units) 68 | self.fc2 = nn.Linear(fcs1_units+action_size, fc2_units) 69 | self.fc3 = nn.Linear(fc2_units, 1) 70 | self.reset_parameters() 71 | 72 | def reset_parameters(self): 73 | self.fcs1.weight.data.uniform_(*hidden_init(self.fcs1)) 74 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 75 | self.fc3.weight.data.uniform_(-3e-3, 3e-3) 76 | 77 | def forward(self, state, action): 78 | """Build a critic (value) network that maps (state, action) pairs -> Q-values.""" 79 | xs = F.relu(self.fcs1(state)) 80 | x = torch.cat((xs, action), dim=1) 81 | x = F.relu(self.fc2(x)) 82 | return self.fc3(x) 83 | -------------------------------------------------------------------------------- /python/Basics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Unity ML-Agents Toolkit\n", 8 | "## Environment Basics\n", 9 | "This notebook contains a walkthrough of the basic functions of the Python API for the Unity ML-Agents toolkit. For instructions on building a Unity environment, see [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Getting-Started-with-Balance-Ball.md)." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "### 1. Set environment parameters\n", 17 | "\n", 18 | "Be sure to set `env_name` to the name of the Unity environment file you want to launch. Ensure that the environment build is in the `python/` directory." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": { 25 | "collapsed": true 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "env_name = \"3DBall\" # Name of the Unity environment binary to launch\n", 30 | "train_mode = True # Whether to run the environment in training or inference mode" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "### 2. Load dependencies\n", 38 | "\n", 39 | "The following loads the necessary dependencies and checks the Python version (at runtime). ML-Agents Toolkit (v0.3 onwards) requires Python 3." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "import matplotlib.pyplot as plt\n", 51 | "import numpy as np\n", 52 | "import sys\n", 53 | "\n", 54 | "from unityagents import UnityEnvironment\n", 55 | "\n", 56 | "%matplotlib inline\n", 57 | "\n", 58 | "print(\"Python version:\")\n", 59 | "print(sys.version)\n", 60 | "\n", 61 | "# check Python version\n", 62 | "if (sys.version_info[0] < 3):\n", 63 | " raise Exception(\"ERROR: ML-Agents Toolkit (v0.3 onwards) requires Python 3\")" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "### 3. Start the environment\n", 71 | "`UnityEnvironment` launches and begins communication with the environment when instantiated.\n", 72 | "\n", 73 | "Environments contain _brains_ which are responsible for deciding the actions of their associated _agents_. Here we check for the first brain available, and set it as the default brain we will be controlling from Python." 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "collapsed": true 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "env = UnityEnvironment(file_name=env_name)\n", 85 | "\n", 86 | "# Examine environment parameters\n", 87 | "print(str(env))\n", 88 | "\n", 89 | "# Set the default brain to work with\n", 90 | "default_brain = env.brain_names[0]\n", 91 | "brain = env.brains[default_brain]" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "### 4. Examine the observation and state spaces\n", 99 | "We can reset the environment to be provided with an initial set of observations and states for all the agents within the environment. In ML-Agents, _states_ refer to a vector of variables corresponding to relevant aspects of the environment for an agent. Likewise, _observations_ refer to a set of relevant pixel-wise visuals for an agent." 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "collapsed": true 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "# Reset the environment\n", 111 | "env_info = env.reset(train_mode=train_mode)[default_brain]\n", 112 | "\n", 113 | "# Examine the state space for the default brain\n", 114 | "print(\"Agent state looks like: \\n{}\".format(env_info.vector_observations[0]))\n", 115 | "\n", 116 | "# Examine the observation space for the default brain\n", 117 | "for observation in env_info.visual_observations:\n", 118 | " print(\"Agent observations look like:\")\n", 119 | " if observation.shape[3] == 3:\n", 120 | " plt.imshow(observation[0,:,:,:])\n", 121 | " else:\n", 122 | " plt.imshow(observation[0,:,:,0])" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "### 5. Take random actions in the environment\n", 130 | "Once we restart an environment, we can step the environment forward and provide actions to all of the agents within the environment. Here we simply choose random actions based on the `action_space_type` of the default brain. \n", 131 | "\n", 132 | "Once this cell is executed, 10 messages will be printed that detail how much reward will be accumulated for the next 10 episodes. The Unity environment will then pause, waiting for further signals telling it what to do next. Thus, not seeing any animation is expected when running this cell." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "collapsed": true 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "for episode in range(10):\n", 144 | " env_info = env.reset(train_mode=train_mode)[default_brain]\n", 145 | " done = False\n", 146 | " episode_rewards = 0\n", 147 | " while not done:\n", 148 | " if brain.vector_action_space_type == 'continuous':\n", 149 | " env_info = env.step(np.random.randn(len(env_info.agents), \n", 150 | " brain.vector_action_space_size))[default_brain]\n", 151 | " else:\n", 152 | " env_info = env.step(np.random.randint(0, brain.vector_action_space_size, \n", 153 | " size=(len(env_info.agents))))[default_brain]\n", 154 | " episode_rewards += env_info.rewards[0]\n", 155 | " done = env_info.local_done[0]\n", 156 | " print(\"Total reward this episode: {}\".format(episode_rewards))" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "### 6. Close the environment when finished\n", 164 | "When we are finished using an environment, we can close it with the function below." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "collapsed": true 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "env.close()" 176 | ] 177 | } 178 | ], 179 | "metadata": { 180 | "anaconda-cloud": {}, 181 | "kernelspec": { 182 | "display_name": "Python 3", 183 | "language": "python", 184 | "name": "python3" 185 | }, 186 | "language_info": { 187 | "codemirror_mode": { 188 | "name": "ipython", 189 | "version": 3 190 | }, 191 | "file_extension": ".py", 192 | "mimetype": "text/x-python", 193 | "name": "python", 194 | "nbconvert_exporter": "python", 195 | "pygments_lexer": "ipython3", 196 | "version": "3.6.3" 197 | } 198 | }, 199 | "nbformat": 4, 200 | "nbformat_minor": 1 201 | } 202 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | # Dependencies 2 | 3 | This is an amended version of the `python/` folder from the [ML-Agents repository](https://github.com/Unity-Technologies/ml-agents). It has been edited to include a few additional pip packages needed for the Deep Reinforcement Learning Nanodegree program. 4 | -------------------------------------------------------------------------------- /python/communicator_objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent_action_proto_pb2 import * 2 | from .agent_info_proto_pb2 import * 3 | from .brain_parameters_proto_pb2 import * 4 | from .brain_type_proto_pb2 import * 5 | from .command_proto_pb2 import * 6 | from .engine_configuration_proto_pb2 import * 7 | from .environment_parameters_proto_pb2 import * 8 | from .header_pb2 import * 9 | from .resolution_proto_pb2 import * 10 | from .space_type_proto_pb2 import * 11 | from .unity_input_pb2 import * 12 | from .unity_message_pb2 import * 13 | from .unity_output_pb2 import * 14 | from .unity_rl_initialization_input_pb2 import * 15 | from .unity_rl_initialization_output_pb2 import * 16 | from .unity_rl_input_pb2 import * 17 | from .unity_rl_output_pb2 import * 18 | from .unity_to_external_pb2 import * 19 | from .unity_to_external_pb2_grpc import * 20 | -------------------------------------------------------------------------------- /python/communicator_objects/agent_action_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/agent_action_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/agent_action_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n-communicator_objects/agent_action_proto.proto\x12\x14\x63ommunicator_objects\"R\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _AGENTACTIONPROTO = _descriptor.Descriptor( 29 | name='AgentActionProto', 30 | full_name='communicator_objects.AgentActionProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='vector_actions', full_name='communicator_objects.AgentActionProto.vector_actions', index=0, 37 | number=1, type=2, cpp_type=6, label=3, 38 | has_default_value=False, default_value=[], 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='text_actions', full_name='communicator_objects.AgentActionProto.text_actions', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='memories', full_name='communicator_objects.AgentActionProto.memories', index=2, 51 | number=3, type=2, cpp_type=6, label=3, 52 | has_default_value=False, default_value=[], 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto3', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=71, 69 | serialized_end=153, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['AgentActionProto'] = _AGENTACTIONPROTO 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | AgentActionProto = _reflection.GeneratedProtocolMessageType('AgentActionProto', (_message.Message,), dict( 76 | DESCRIPTOR = _AGENTACTIONPROTO, 77 | __module__ = 'communicator_objects.agent_action_proto_pb2' 78 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentActionProto) 79 | )) 80 | _sym_db.RegisterMessage(AgentActionProto) 81 | 82 | 83 | DESCRIPTOR.has_options = True 84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /python/communicator_objects/agent_info_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/agent_info_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/agent_info_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n+communicator_objects/agent_info_proto.proto\x12\x14\x63ommunicator_objects\"\xfd\x01\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x1b\n\x13visual_observations\x18\x02 \x03(\x0c\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x10\n\x08memories\x18\x06 \x03(\x02\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _AGENTINFOPROTO = _descriptor.Descriptor( 29 | name='AgentInfoProto', 30 | full_name='communicator_objects.AgentInfoProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='stacked_vector_observation', full_name='communicator_objects.AgentInfoProto.stacked_vector_observation', index=0, 37 | number=1, type=2, cpp_type=6, label=3, 38 | has_default_value=False, default_value=[], 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='visual_observations', full_name='communicator_objects.AgentInfoProto.visual_observations', index=1, 44 | number=2, type=12, cpp_type=9, label=3, 45 | has_default_value=False, default_value=[], 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='text_observation', full_name='communicator_objects.AgentInfoProto.text_observation', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='stored_vector_actions', full_name='communicator_objects.AgentInfoProto.stored_vector_actions', index=3, 58 | number=4, type=2, cpp_type=6, label=3, 59 | has_default_value=False, default_value=[], 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None, file=DESCRIPTOR), 63 | _descriptor.FieldDescriptor( 64 | name='stored_text_actions', full_name='communicator_objects.AgentInfoProto.stored_text_actions', index=4, 65 | number=5, type=9, cpp_type=9, label=1, 66 | has_default_value=False, default_value=_b("").decode('utf-8'), 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None, file=DESCRIPTOR), 70 | _descriptor.FieldDescriptor( 71 | name='memories', full_name='communicator_objects.AgentInfoProto.memories', index=5, 72 | number=6, type=2, cpp_type=6, label=3, 73 | has_default_value=False, default_value=[], 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None, file=DESCRIPTOR), 77 | _descriptor.FieldDescriptor( 78 | name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=6, 79 | number=7, type=2, cpp_type=6, label=1, 80 | has_default_value=False, default_value=float(0), 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None, file=DESCRIPTOR), 84 | _descriptor.FieldDescriptor( 85 | name='done', full_name='communicator_objects.AgentInfoProto.done', index=7, 86 | number=8, type=8, cpp_type=7, label=1, 87 | has_default_value=False, default_value=False, 88 | message_type=None, enum_type=None, containing_type=None, 89 | is_extension=False, extension_scope=None, 90 | options=None, file=DESCRIPTOR), 91 | _descriptor.FieldDescriptor( 92 | name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=8, 93 | number=9, type=8, cpp_type=7, label=1, 94 | has_default_value=False, default_value=False, 95 | message_type=None, enum_type=None, containing_type=None, 96 | is_extension=False, extension_scope=None, 97 | options=None, file=DESCRIPTOR), 98 | _descriptor.FieldDescriptor( 99 | name='id', full_name='communicator_objects.AgentInfoProto.id', index=9, 100 | number=10, type=5, cpp_type=1, label=1, 101 | has_default_value=False, default_value=0, 102 | message_type=None, enum_type=None, containing_type=None, 103 | is_extension=False, extension_scope=None, 104 | options=None, file=DESCRIPTOR), 105 | ], 106 | extensions=[ 107 | ], 108 | nested_types=[], 109 | enum_types=[ 110 | ], 111 | options=None, 112 | is_extendable=False, 113 | syntax='proto3', 114 | extension_ranges=[], 115 | oneofs=[ 116 | ], 117 | serialized_start=70, 118 | serialized_end=323, 119 | ) 120 | 121 | DESCRIPTOR.message_types_by_name['AgentInfoProto'] = _AGENTINFOPROTO 122 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 123 | 124 | AgentInfoProto = _reflection.GeneratedProtocolMessageType('AgentInfoProto', (_message.Message,), dict( 125 | DESCRIPTOR = _AGENTINFOPROTO, 126 | __module__ = 'communicator_objects.agent_info_proto_pb2' 127 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentInfoProto) 128 | )) 129 | _sym_db.RegisterMessage(AgentInfoProto) 130 | 131 | 132 | DESCRIPTOR.has_options = True 133 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 134 | # @@protoc_insertion_point(module_scope) 135 | -------------------------------------------------------------------------------- /python/communicator_objects/brain_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/brain_parameters_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 17 | from communicator_objects import brain_type_proto_pb2 as communicator__objects_dot_brain__type__proto__pb2 18 | from communicator_objects import space_type_proto_pb2 as communicator__objects_dot_space__type__proto__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='communicator_objects/brain_parameters_proto.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n1communicator_objects/brain_parameters_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto\x1a+communicator_objects/brain_type_proto.proto\x1a+communicator_objects/space_type_proto.proto\"\xc6\x03\n\x14\x42rainParametersProto\x12\x1f\n\x17vector_observation_size\x18\x01 \x01(\x05\x12\'\n\x1fnum_stacked_vector_observations\x18\x02 \x01(\x05\x12\x1a\n\x12vector_action_size\x18\x03 \x01(\x05\x12\x41\n\x12\x63\x61mera_resolutions\x18\x04 \x03(\x0b\x32%.communicator_objects.ResolutionProto\x12\"\n\x1avector_action_descriptions\x18\x05 \x03(\t\x12\x46\n\x18vector_action_space_type\x18\x06 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12K\n\x1dvector_observation_space_type\x18\x07 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12\x12\n\nbrain_name\x18\x08 \x01(\t\x12\x38\n\nbrain_type\x18\t \x01(\x0e\x32$.communicator_objects.BrainTypeProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,communicator__objects_dot_brain__type__proto__pb2.DESCRIPTOR,communicator__objects_dot_space__type__proto__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _BRAINPARAMETERSPROTO = _descriptor.Descriptor( 33 | name='BrainParametersProto', 34 | full_name='communicator_objects.BrainParametersProto', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='vector_observation_size', full_name='communicator_objects.BrainParametersProto.vector_observation_size', index=0, 41 | number=1, type=5, cpp_type=1, label=1, 42 | has_default_value=False, default_value=0, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='num_stacked_vector_observations', full_name='communicator_objects.BrainParametersProto.num_stacked_vector_observations', index=1, 48 | number=2, type=5, cpp_type=1, label=1, 49 | has_default_value=False, default_value=0, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='vector_action_size', full_name='communicator_objects.BrainParametersProto.vector_action_size', index=2, 55 | number=3, type=5, cpp_type=1, label=1, 56 | has_default_value=False, default_value=0, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None, file=DESCRIPTOR), 60 | _descriptor.FieldDescriptor( 61 | name='camera_resolutions', full_name='communicator_objects.BrainParametersProto.camera_resolutions', index=3, 62 | number=4, type=11, cpp_type=10, label=3, 63 | has_default_value=False, default_value=[], 64 | message_type=None, enum_type=None, containing_type=None, 65 | is_extension=False, extension_scope=None, 66 | options=None, file=DESCRIPTOR), 67 | _descriptor.FieldDescriptor( 68 | name='vector_action_descriptions', full_name='communicator_objects.BrainParametersProto.vector_action_descriptions', index=4, 69 | number=5, type=9, cpp_type=9, label=3, 70 | has_default_value=False, default_value=[], 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=None, file=DESCRIPTOR), 74 | _descriptor.FieldDescriptor( 75 | name='vector_action_space_type', full_name='communicator_objects.BrainParametersProto.vector_action_space_type', index=5, 76 | number=6, type=14, cpp_type=8, label=1, 77 | has_default_value=False, default_value=0, 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=None, file=DESCRIPTOR), 81 | _descriptor.FieldDescriptor( 82 | name='vector_observation_space_type', full_name='communicator_objects.BrainParametersProto.vector_observation_space_type', index=6, 83 | number=7, type=14, cpp_type=8, label=1, 84 | has_default_value=False, default_value=0, 85 | message_type=None, enum_type=None, containing_type=None, 86 | is_extension=False, extension_scope=None, 87 | options=None, file=DESCRIPTOR), 88 | _descriptor.FieldDescriptor( 89 | name='brain_name', full_name='communicator_objects.BrainParametersProto.brain_name', index=7, 90 | number=8, type=9, cpp_type=9, label=1, 91 | has_default_value=False, default_value=_b("").decode('utf-8'), 92 | message_type=None, enum_type=None, containing_type=None, 93 | is_extension=False, extension_scope=None, 94 | options=None, file=DESCRIPTOR), 95 | _descriptor.FieldDescriptor( 96 | name='brain_type', full_name='communicator_objects.BrainParametersProto.brain_type', index=8, 97 | number=9, type=14, cpp_type=8, label=1, 98 | has_default_value=False, default_value=0, 99 | message_type=None, enum_type=None, containing_type=None, 100 | is_extension=False, extension_scope=None, 101 | options=None, file=DESCRIPTOR), 102 | ], 103 | extensions=[ 104 | ], 105 | nested_types=[], 106 | enum_types=[ 107 | ], 108 | options=None, 109 | is_extendable=False, 110 | syntax='proto3', 111 | extension_ranges=[], 112 | oneofs=[ 113 | ], 114 | serialized_start=211, 115 | serialized_end=665, 116 | ) 117 | 118 | _BRAINPARAMETERSPROTO.fields_by_name['camera_resolutions'].message_type = communicator__objects_dot_resolution__proto__pb2._RESOLUTIONPROTO 119 | _BRAINPARAMETERSPROTO.fields_by_name['vector_action_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO 120 | _BRAINPARAMETERSPROTO.fields_by_name['vector_observation_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO 121 | _BRAINPARAMETERSPROTO.fields_by_name['brain_type'].enum_type = communicator__objects_dot_brain__type__proto__pb2._BRAINTYPEPROTO 122 | DESCRIPTOR.message_types_by_name['BrainParametersProto'] = _BRAINPARAMETERSPROTO 123 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 124 | 125 | BrainParametersProto = _reflection.GeneratedProtocolMessageType('BrainParametersProto', (_message.Message,), dict( 126 | DESCRIPTOR = _BRAINPARAMETERSPROTO, 127 | __module__ = 'communicator_objects.brain_parameters_proto_pb2' 128 | # @@protoc_insertion_point(class_scope:communicator_objects.BrainParametersProto) 129 | )) 130 | _sym_db.RegisterMessage(BrainParametersProto) 131 | 132 | 133 | DESCRIPTOR.has_options = True 134 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 135 | # @@protoc_insertion_point(module_scope) 136 | -------------------------------------------------------------------------------- /python/communicator_objects/brain_type_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/brain_type_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/brain_type_proto.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n+communicator_objects/brain_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*G\n\x0e\x42rainTypeProto\x12\n\n\x06Player\x10\x00\x12\r\n\tHeuristic\x10\x01\x12\x0c\n\x08\x45xternal\x10\x02\x12\x0c\n\x08Internal\x10\x03\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) 27 | 28 | _BRAINTYPEPROTO = _descriptor.EnumDescriptor( 29 | name='BrainTypeProto', 30 | full_name='communicator_objects.BrainTypeProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='Player', index=0, number=0, 36 | options=None, 37 | type=None), 38 | _descriptor.EnumValueDescriptor( 39 | name='Heuristic', index=1, number=1, 40 | options=None, 41 | type=None), 42 | _descriptor.EnumValueDescriptor( 43 | name='External', index=2, number=2, 44 | options=None, 45 | type=None), 46 | _descriptor.EnumValueDescriptor( 47 | name='Internal', index=3, number=3, 48 | options=None, 49 | type=None), 50 | ], 51 | containing_type=None, 52 | options=None, 53 | serialized_start=114, 54 | serialized_end=185, 55 | ) 56 | _sym_db.RegisterEnumDescriptor(_BRAINTYPEPROTO) 57 | 58 | BrainTypeProto = enum_type_wrapper.EnumTypeWrapper(_BRAINTYPEPROTO) 59 | Player = 0 60 | Heuristic = 1 61 | External = 2 62 | Internal = 3 63 | 64 | 65 | DESCRIPTOR.enum_types_by_name['BrainTypeProto'] = _BRAINTYPEPROTO 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | 69 | DESCRIPTOR.has_options = True 70 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 71 | # @@protoc_insertion_point(module_scope) 72 | -------------------------------------------------------------------------------- /python/communicator_objects/command_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/command_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/command_proto.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n(communicator_objects/command_proto.proto\x12\x14\x63ommunicator_objects*-\n\x0c\x43ommandProto\x12\x08\n\x04STEP\x10\x00\x12\t\n\x05RESET\x10\x01\x12\x08\n\x04QUIT\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | _COMMANDPROTO = _descriptor.EnumDescriptor( 27 | name='CommandProto', 28 | full_name='communicator_objects.CommandProto', 29 | filename=None, 30 | file=DESCRIPTOR, 31 | values=[ 32 | _descriptor.EnumValueDescriptor( 33 | name='STEP', index=0, number=0, 34 | options=None, 35 | type=None), 36 | _descriptor.EnumValueDescriptor( 37 | name='RESET', index=1, number=1, 38 | options=None, 39 | type=None), 40 | _descriptor.EnumValueDescriptor( 41 | name='QUIT', index=2, number=2, 42 | options=None, 43 | type=None), 44 | ], 45 | containing_type=None, 46 | options=None, 47 | serialized_start=66, 48 | serialized_end=111, 49 | ) 50 | _sym_db.RegisterEnumDescriptor(_COMMANDPROTO) 51 | 52 | CommandProto = enum_type_wrapper.EnumTypeWrapper(_COMMANDPROTO) 53 | STEP = 0 54 | RESET = 1 55 | QUIT = 2 56 | 57 | 58 | DESCRIPTOR.enum_types_by_name['CommandProto'] = _COMMANDPROTO 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | 62 | DESCRIPTOR.has_options = True 63 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /python/communicator_objects/engine_configuration_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/engine_configuration_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/engine_configuration_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n5communicator_objects/engine_configuration_proto.proto\x12\x14\x63ommunicator_objects\"\x95\x01\n\x18\x45ngineConfigurationProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x15\n\rquality_level\x18\x03 \x01(\x05\x12\x12\n\ntime_scale\x18\x04 \x01(\x02\x12\x19\n\x11target_frame_rate\x18\x05 \x01(\x05\x12\x14\n\x0cshow_monitor\x18\x06 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _ENGINECONFIGURATIONPROTO = _descriptor.Descriptor( 29 | name='EngineConfigurationProto', 30 | full_name='communicator_objects.EngineConfigurationProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='width', full_name='communicator_objects.EngineConfigurationProto.width', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='communicator_objects.EngineConfigurationProto.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='quality_level', full_name='communicator_objects.EngineConfigurationProto.quality_level', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=False, default_value=0, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='time_scale', full_name='communicator_objects.EngineConfigurationProto.time_scale', index=3, 58 | number=4, type=2, cpp_type=6, label=1, 59 | has_default_value=False, default_value=float(0), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None, file=DESCRIPTOR), 63 | _descriptor.FieldDescriptor( 64 | name='target_frame_rate', full_name='communicator_objects.EngineConfigurationProto.target_frame_rate', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=False, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None, file=DESCRIPTOR), 70 | _descriptor.FieldDescriptor( 71 | name='show_monitor', full_name='communicator_objects.EngineConfigurationProto.show_monitor', index=5, 72 | number=6, type=8, cpp_type=7, label=1, 73 | has_default_value=False, default_value=False, 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None, file=DESCRIPTOR), 77 | ], 78 | extensions=[ 79 | ], 80 | nested_types=[], 81 | enum_types=[ 82 | ], 83 | options=None, 84 | is_extendable=False, 85 | syntax='proto3', 86 | extension_ranges=[], 87 | oneofs=[ 88 | ], 89 | serialized_start=80, 90 | serialized_end=229, 91 | ) 92 | 93 | DESCRIPTOR.message_types_by_name['EngineConfigurationProto'] = _ENGINECONFIGURATIONPROTO 94 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 95 | 96 | EngineConfigurationProto = _reflection.GeneratedProtocolMessageType('EngineConfigurationProto', (_message.Message,), dict( 97 | DESCRIPTOR = _ENGINECONFIGURATIONPROTO, 98 | __module__ = 'communicator_objects.engine_configuration_proto_pb2' 99 | # @@protoc_insertion_point(class_scope:communicator_objects.EngineConfigurationProto) 100 | )) 101 | _sym_db.RegisterMessage(EngineConfigurationProto) 102 | 103 | 104 | DESCRIPTOR.has_options = True 105 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 106 | # @@protoc_insertion_point(module_scope) 107 | -------------------------------------------------------------------------------- /python/communicator_objects/environment_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/environment_parameters_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/environment_parameters_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n7communicator_objects/environment_parameters_proto.proto\x12\x14\x63ommunicator_objects\"\xb5\x01\n\x1a\x45nvironmentParametersProto\x12_\n\x10\x66loat_parameters\x18\x01 \x03(\x0b\x32\x45.communicator_objects.EnvironmentParametersProto.FloatParametersEntry\x1a\x36\n\x14\x46loatParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY = _descriptor.Descriptor( 29 | name='FloatParametersEntry', 30 | full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='key', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.key', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='value', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.value', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=False, default_value=float(0), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=209, 62 | serialized_end=263, 63 | ) 64 | 65 | _ENVIRONMENTPARAMETERSPROTO = _descriptor.Descriptor( 66 | name='EnvironmentParametersProto', 67 | full_name='communicator_objects.EnvironmentParametersProto', 68 | filename=None, 69 | file=DESCRIPTOR, 70 | containing_type=None, 71 | fields=[ 72 | _descriptor.FieldDescriptor( 73 | name='float_parameters', full_name='communicator_objects.EnvironmentParametersProto.float_parameters', index=0, 74 | number=1, type=11, cpp_type=10, label=3, 75 | has_default_value=False, default_value=[], 76 | message_type=None, enum_type=None, containing_type=None, 77 | is_extension=False, extension_scope=None, 78 | options=None, file=DESCRIPTOR), 79 | ], 80 | extensions=[ 81 | ], 82 | nested_types=[_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, ], 83 | enum_types=[ 84 | ], 85 | options=None, 86 | is_extendable=False, 87 | syntax='proto3', 88 | extension_ranges=[], 89 | oneofs=[ 90 | ], 91 | serialized_start=82, 92 | serialized_end=263, 93 | ) 94 | 95 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.containing_type = _ENVIRONMENTPARAMETERSPROTO 96 | _ENVIRONMENTPARAMETERSPROTO.fields_by_name['float_parameters'].message_type = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY 97 | DESCRIPTOR.message_types_by_name['EnvironmentParametersProto'] = _ENVIRONMENTPARAMETERSPROTO 98 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 99 | 100 | EnvironmentParametersProto = _reflection.GeneratedProtocolMessageType('EnvironmentParametersProto', (_message.Message,), dict( 101 | 102 | FloatParametersEntry = _reflection.GeneratedProtocolMessageType('FloatParametersEntry', (_message.Message,), dict( 103 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, 104 | __module__ = 'communicator_objects.environment_parameters_proto_pb2' 105 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto.FloatParametersEntry) 106 | )) 107 | , 108 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO, 109 | __module__ = 'communicator_objects.environment_parameters_proto_pb2' 110 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto) 111 | )) 112 | _sym_db.RegisterMessage(EnvironmentParametersProto) 113 | _sym_db.RegisterMessage(EnvironmentParametersProto.FloatParametersEntry) 114 | 115 | 116 | DESCRIPTOR.has_options = True 117 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 118 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.has_options = True 119 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 120 | # @@protoc_insertion_point(module_scope) 121 | -------------------------------------------------------------------------------- /python/communicator_objects/header_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/header.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/header.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n!communicator_objects/header.proto\x12\x14\x63ommunicator_objects\")\n\x06Header\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _HEADER = _descriptor.Descriptor( 29 | name='Header', 30 | full_name='communicator_objects.Header', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='status', full_name='communicator_objects.Header.status', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='message', full_name='communicator_objects.Header.message', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=None, 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=59, 62 | serialized_end=100, 63 | ) 64 | 65 | DESCRIPTOR.message_types_by_name['Header'] = _HEADER 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | Header = _reflection.GeneratedProtocolMessageType('Header', (_message.Message,), dict( 69 | DESCRIPTOR = _HEADER, 70 | __module__ = 'communicator_objects.header_pb2' 71 | # @@protoc_insertion_point(class_scope:communicator_objects.Header) 72 | )) 73 | _sym_db.RegisterMessage(Header) 74 | 75 | 76 | DESCRIPTOR.has_options = True 77 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 78 | # @@protoc_insertion_point(module_scope) 79 | -------------------------------------------------------------------------------- /python/communicator_objects/resolution_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/resolution_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/resolution_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n+communicator_objects/resolution_proto.proto\x12\x14\x63ommunicator_objects\"D\n\x0fResolutionProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x12\n\ngray_scale\x18\x03 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _RESOLUTIONPROTO = _descriptor.Descriptor( 29 | name='ResolutionProto', 30 | full_name='communicator_objects.ResolutionProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='width', full_name='communicator_objects.ResolutionProto.width', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='communicator_objects.ResolutionProto.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='gray_scale', full_name='communicator_objects.ResolutionProto.gray_scale', index=2, 51 | number=3, type=8, cpp_type=7, label=1, 52 | has_default_value=False, default_value=False, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto3', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=69, 69 | serialized_end=137, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['ResolutionProto'] = _RESOLUTIONPROTO 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | ResolutionProto = _reflection.GeneratedProtocolMessageType('ResolutionProto', (_message.Message,), dict( 76 | DESCRIPTOR = _RESOLUTIONPROTO, 77 | __module__ = 'communicator_objects.resolution_proto_pb2' 78 | # @@protoc_insertion_point(class_scope:communicator_objects.ResolutionProto) 79 | )) 80 | _sym_db.RegisterMessage(ResolutionProto) 81 | 82 | 83 | DESCRIPTOR.has_options = True 84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /python/communicator_objects/space_type_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/space_type_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/space_type_proto.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n+communicator_objects/space_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*.\n\x0eSpaceTypeProto\x12\x0c\n\x08\x64iscrete\x10\x00\x12\x0e\n\ncontinuous\x10\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) 27 | 28 | _SPACETYPEPROTO = _descriptor.EnumDescriptor( 29 | name='SpaceTypeProto', 30 | full_name='communicator_objects.SpaceTypeProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='discrete', index=0, number=0, 36 | options=None, 37 | type=None), 38 | _descriptor.EnumValueDescriptor( 39 | name='continuous', index=1, number=1, 40 | options=None, 41 | type=None), 42 | ], 43 | containing_type=None, 44 | options=None, 45 | serialized_start=114, 46 | serialized_end=160, 47 | ) 48 | _sym_db.RegisterEnumDescriptor(_SPACETYPEPROTO) 49 | 50 | SpaceTypeProto = enum_type_wrapper.EnumTypeWrapper(_SPACETYPEPROTO) 51 | discrete = 0 52 | continuous = 1 53 | 54 | 55 | DESCRIPTOR.enum_types_by_name['SpaceTypeProto'] = _SPACETYPEPROTO 56 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 57 | 58 | 59 | DESCRIPTOR.has_options = True 60 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 61 | # @@protoc_insertion_point(module_scope) 62 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_rl_input_pb2 as communicator__objects_dot_unity__rl__input__pb2 17 | from communicator_objects import unity_rl_initialization_input_pb2 as communicator__objects_dot_unity__rl__initialization__input__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_input.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n&communicator_objects/unity_input.proto\x12\x14\x63ommunicator_objects\x1a)communicator_objects/unity_rl_input.proto\x1a\x38\x63ommunicator_objects/unity_rl_initialization_input.proto\"\xb0\x01\n\nUnityInput\x12\x34\n\x08rl_input\x18\x01 \x01(\x0b\x32\".communicator_objects.UnityRLInput\x12Q\n\x17rl_initialization_input\x18\x02 \x01(\x0b\x32\x30.communicator_objects.UnityRLInitializationInput\x12\x19\n\x11\x63ustom_data_input\x18\x03 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_unity__rl__input__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__input__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYINPUT = _descriptor.Descriptor( 32 | name='UnityInput', 33 | full_name='communicator_objects.UnityInput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='rl_input', full_name='communicator_objects.UnityInput.rl_input', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='rl_initialization_input', full_name='communicator_objects.UnityInput.rl_initialization_input', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='custom_data_input', full_name='communicator_objects.UnityInput.custom_data_input', index=2, 54 | number=3, type=5, cpp_type=1, label=1, 55 | has_default_value=False, default_value=0, 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | ], 60 | extensions=[ 61 | ], 62 | nested_types=[], 63 | enum_types=[ 64 | ], 65 | options=None, 66 | is_extendable=False, 67 | syntax='proto3', 68 | extension_ranges=[], 69 | oneofs=[ 70 | ], 71 | serialized_start=166, 72 | serialized_end=342, 73 | ) 74 | 75 | _UNITYINPUT.fields_by_name['rl_input'].message_type = communicator__objects_dot_unity__rl__input__pb2._UNITYRLINPUT 76 | _UNITYINPUT.fields_by_name['rl_initialization_input'].message_type = communicator__objects_dot_unity__rl__initialization__input__pb2._UNITYRLINITIALIZATIONINPUT 77 | DESCRIPTOR.message_types_by_name['UnityInput'] = _UNITYINPUT 78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 79 | 80 | UnityInput = _reflection.GeneratedProtocolMessageType('UnityInput', (_message.Message,), dict( 81 | DESCRIPTOR = _UNITYINPUT, 82 | __module__ = 'communicator_objects.unity_input_pb2' 83 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityInput) 84 | )) 85 | _sym_db.RegisterMessage(UnityInput) 86 | 87 | 88 | DESCRIPTOR.has_options = True 89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_message_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_message.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_output_pb2 as communicator__objects_dot_unity__output__pb2 17 | from communicator_objects import unity_input_pb2 as communicator__objects_dot_unity__input__pb2 18 | from communicator_objects import header_pb2 as communicator__objects_dot_header__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='communicator_objects/unity_message.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n(communicator_objects/unity_message.proto\x12\x14\x63ommunicator_objects\x1a\'communicator_objects/unity_output.proto\x1a&communicator_objects/unity_input.proto\x1a!communicator_objects/header.proto\"\xac\x01\n\x0cUnityMessage\x12,\n\x06header\x18\x01 \x01(\x0b\x32\x1c.communicator_objects.Header\x12\x37\n\x0cunity_output\x18\x02 \x01(\x0b\x32!.communicator_objects.UnityOutput\x12\x35\n\x0bunity_input\x18\x03 \x01(\x0b\x32 .communicator_objects.UnityInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[communicator__objects_dot_unity__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__input__pb2.DESCRIPTOR,communicator__objects_dot_header__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _UNITYMESSAGE = _descriptor.Descriptor( 33 | name='UnityMessage', 34 | full_name='communicator_objects.UnityMessage', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='header', full_name='communicator_objects.UnityMessage.header', index=0, 41 | number=1, type=11, cpp_type=10, label=1, 42 | has_default_value=False, default_value=None, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='unity_output', full_name='communicator_objects.UnityMessage.unity_output', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='unity_input', full_name='communicator_objects.UnityMessage.unity_input', index=2, 55 | number=3, type=11, cpp_type=10, label=1, 56 | has_default_value=False, default_value=None, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None, file=DESCRIPTOR), 60 | ], 61 | extensions=[ 62 | ], 63 | nested_types=[], 64 | enum_types=[ 65 | ], 66 | options=None, 67 | is_extendable=False, 68 | syntax='proto3', 69 | extension_ranges=[], 70 | oneofs=[ 71 | ], 72 | serialized_start=183, 73 | serialized_end=355, 74 | ) 75 | 76 | _UNITYMESSAGE.fields_by_name['header'].message_type = communicator__objects_dot_header__pb2._HEADER 77 | _UNITYMESSAGE.fields_by_name['unity_output'].message_type = communicator__objects_dot_unity__output__pb2._UNITYOUTPUT 78 | _UNITYMESSAGE.fields_by_name['unity_input'].message_type = communicator__objects_dot_unity__input__pb2._UNITYINPUT 79 | DESCRIPTOR.message_types_by_name['UnityMessage'] = _UNITYMESSAGE 80 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 81 | 82 | UnityMessage = _reflection.GeneratedProtocolMessageType('UnityMessage', (_message.Message,), dict( 83 | DESCRIPTOR = _UNITYMESSAGE, 84 | __module__ = 'communicator_objects.unity_message_pb2' 85 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityMessage) 86 | )) 87 | _sym_db.RegisterMessage(UnityMessage) 88 | 89 | 90 | DESCRIPTOR.has_options = True 91 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 92 | # @@protoc_insertion_point(module_scope) 93 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_rl_output_pb2 as communicator__objects_dot_unity__rl__output__pb2 17 | from communicator_objects import unity_rl_initialization_output_pb2 as communicator__objects_dot_unity__rl__initialization__output__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n\'communicator_objects/unity_output.proto\x12\x14\x63ommunicator_objects\x1a*communicator_objects/unity_rl_output.proto\x1a\x39\x63ommunicator_objects/unity_rl_initialization_output.proto\"\xb6\x01\n\x0bUnityOutput\x12\x36\n\trl_output\x18\x01 \x01(\x0b\x32#.communicator_objects.UnityRLOutput\x12S\n\x18rl_initialization_output\x18\x02 \x01(\x0b\x32\x31.communicator_objects.UnityRLInitializationOutput\x12\x1a\n\x12\x63ustom_data_output\x18\x03 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_unity__rl__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__output__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYOUTPUT = _descriptor.Descriptor( 32 | name='UnityOutput', 33 | full_name='communicator_objects.UnityOutput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='rl_output', full_name='communicator_objects.UnityOutput.rl_output', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='rl_initialization_output', full_name='communicator_objects.UnityOutput.rl_initialization_output', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='custom_data_output', full_name='communicator_objects.UnityOutput.custom_data_output', index=2, 54 | number=3, type=9, cpp_type=9, label=1, 55 | has_default_value=False, default_value=_b("").decode('utf-8'), 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | ], 60 | extensions=[ 61 | ], 62 | nested_types=[], 63 | enum_types=[ 64 | ], 65 | options=None, 66 | is_extendable=False, 67 | syntax='proto3', 68 | extension_ranges=[], 69 | oneofs=[ 70 | ], 71 | serialized_start=169, 72 | serialized_end=351, 73 | ) 74 | 75 | _UNITYOUTPUT.fields_by_name['rl_output'].message_type = communicator__objects_dot_unity__rl__output__pb2._UNITYRLOUTPUT 76 | _UNITYOUTPUT.fields_by_name['rl_initialization_output'].message_type = communicator__objects_dot_unity__rl__initialization__output__pb2._UNITYRLINITIALIZATIONOUTPUT 77 | DESCRIPTOR.message_types_by_name['UnityOutput'] = _UNITYOUTPUT 78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 79 | 80 | UnityOutput = _reflection.GeneratedProtocolMessageType('UnityOutput', (_message.Message,), dict( 81 | DESCRIPTOR = _UNITYOUTPUT, 82 | __module__ = 'communicator_objects.unity_output_pb2' 83 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityOutput) 84 | )) 85 | _sym_db.RegisterMessage(UnityOutput) 86 | 87 | 88 | DESCRIPTOR.has_options = True 89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_rl_initialization_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_initialization_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/unity_rl_initialization_input.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n8communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\"*\n\x1aUnityRLInitializationInput\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _UNITYRLINITIALIZATIONINPUT = _descriptor.Descriptor( 29 | name='UnityRLInitializationInput', 30 | full_name='communicator_objects.UnityRLInitializationInput', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='seed', full_name='communicator_objects.UnityRLInitializationInput.seed', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | ], 43 | extensions=[ 44 | ], 45 | nested_types=[], 46 | enum_types=[ 47 | ], 48 | options=None, 49 | is_extendable=False, 50 | syntax='proto3', 51 | extension_ranges=[], 52 | oneofs=[ 53 | ], 54 | serialized_start=82, 55 | serialized_end=124, 56 | ) 57 | 58 | DESCRIPTOR.message_types_by_name['UnityRLInitializationInput'] = _UNITYRLINITIALIZATIONINPUT 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | UnityRLInitializationInput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationInput', (_message.Message,), dict( 62 | DESCRIPTOR = _UNITYRLINITIALIZATIONINPUT, 63 | __module__ = 'communicator_objects.unity_rl_initialization_input_pb2' 64 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationInput) 65 | )) 66 | _sym_db.RegisterMessage(UnityRLInitializationInput) 67 | 68 | 69 | DESCRIPTOR.has_options = True 70 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 71 | # @@protoc_insertion_point(module_scope) 72 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_rl_initialization_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_initialization_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import brain_parameters_proto_pb2 as communicator__objects_dot_brain__parameters__proto__pb2 17 | from communicator_objects import environment_parameters_proto_pb2 as communicator__objects_dot_environment__parameters__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_rl_initialization_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n9communicator_objects/unity_rl_initialization_output.proto\x12\x14\x63ommunicator_objects\x1a\x31\x63ommunicator_objects/brain_parameters_proto.proto\x1a\x37\x63ommunicator_objects/environment_parameters_proto.proto\"\xe6\x01\n\x1bUnityRLInitializationOutput\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x10\n\x08log_path\x18\x03 \x01(\t\x12\x44\n\x10\x62rain_parameters\x18\x05 \x03(\x0b\x32*.communicator_objects.BrainParametersProto\x12P\n\x16\x65nvironment_parameters\x18\x06 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_brain__parameters__proto__pb2.DESCRIPTOR,communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYRLINITIALIZATIONOUTPUT = _descriptor.Descriptor( 32 | name='UnityRLInitializationOutput', 33 | full_name='communicator_objects.UnityRLInitializationOutput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='name', full_name='communicator_objects.UnityRLInitializationOutput.name', index=0, 40 | number=1, type=9, cpp_type=9, label=1, 41 | has_default_value=False, default_value=_b("").decode('utf-8'), 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='version', full_name='communicator_objects.UnityRLInitializationOutput.version', index=1, 47 | number=2, type=9, cpp_type=9, label=1, 48 | has_default_value=False, default_value=_b("").decode('utf-8'), 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='log_path', full_name='communicator_objects.UnityRLInitializationOutput.log_path', index=2, 54 | number=3, type=9, cpp_type=9, label=1, 55 | has_default_value=False, default_value=_b("").decode('utf-8'), 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | _descriptor.FieldDescriptor( 60 | name='brain_parameters', full_name='communicator_objects.UnityRLInitializationOutput.brain_parameters', index=3, 61 | number=5, type=11, cpp_type=10, label=3, 62 | has_default_value=False, default_value=[], 63 | message_type=None, enum_type=None, containing_type=None, 64 | is_extension=False, extension_scope=None, 65 | options=None, file=DESCRIPTOR), 66 | _descriptor.FieldDescriptor( 67 | name='environment_parameters', full_name='communicator_objects.UnityRLInitializationOutput.environment_parameters', index=4, 68 | number=6, type=11, cpp_type=10, label=1, 69 | has_default_value=False, default_value=None, 70 | message_type=None, enum_type=None, containing_type=None, 71 | is_extension=False, extension_scope=None, 72 | options=None, file=DESCRIPTOR), 73 | ], 74 | extensions=[ 75 | ], 76 | nested_types=[], 77 | enum_types=[ 78 | ], 79 | options=None, 80 | is_extendable=False, 81 | syntax='proto3', 82 | extension_ranges=[], 83 | oneofs=[ 84 | ], 85 | serialized_start=192, 86 | serialized_end=422, 87 | ) 88 | 89 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['brain_parameters'].message_type = communicator__objects_dot_brain__parameters__proto__pb2._BRAINPARAMETERSPROTO 90 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['environment_parameters'].message_type = communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO 91 | DESCRIPTOR.message_types_by_name['UnityRLInitializationOutput'] = _UNITYRLINITIALIZATIONOUTPUT 92 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 93 | 94 | UnityRLInitializationOutput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationOutput', (_message.Message,), dict( 95 | DESCRIPTOR = _UNITYRLINITIALIZATIONOUTPUT, 96 | __module__ = 'communicator_objects.unity_rl_initialization_output_pb2' 97 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationOutput) 98 | )) 99 | _sym_db.RegisterMessage(UnityRLInitializationOutput) 100 | 101 | 102 | DESCRIPTOR.has_options = True 103 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 104 | # @@protoc_insertion_point(module_scope) 105 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_rl_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import agent_action_proto_pb2 as communicator__objects_dot_agent__action__proto__pb2 17 | from communicator_objects import environment_parameters_proto_pb2 as communicator__objects_dot_environment__parameters__proto__pb2 18 | from communicator_objects import command_proto_pb2 as communicator__objects_dot_command__proto__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='communicator_objects/unity_rl_input.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n)communicator_objects/unity_rl_input.proto\x12\x14\x63ommunicator_objects\x1a-communicator_objects/agent_action_proto.proto\x1a\x37\x63ommunicator_objects/environment_parameters_proto.proto\x1a(communicator_objects/command_proto.proto\"\xb4\x03\n\x0cUnityRLInput\x12K\n\ragent_actions\x18\x01 \x03(\x0b\x32\x34.communicator_objects.UnityRLInput.AgentActionsEntry\x12P\n\x16\x65nvironment_parameters\x18\x02 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProto\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x33\n\x07\x63ommand\x18\x04 \x01(\x0e\x32\".communicator_objects.CommandProto\x1aM\n\x14ListAgentActionProto\x12\x35\n\x05value\x18\x01 \x03(\x0b\x32&.communicator_objects.AgentActionProto\x1al\n\x11\x41gentActionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x46\n\x05value\x18\x02 \x01(\x0b\x32\x37.communicator_objects.UnityRLInput.ListAgentActionProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[communicator__objects_dot_agent__action__proto__pb2.DESCRIPTOR,communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,communicator__objects_dot_command__proto__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _UNITYRLINPUT_LISTAGENTACTIONPROTO = _descriptor.Descriptor( 33 | name='ListAgentActionProto', 34 | full_name='communicator_objects.UnityRLInput.ListAgentActionProto', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='value', full_name='communicator_objects.UnityRLInput.ListAgentActionProto.value', index=0, 41 | number=1, type=11, cpp_type=10, label=3, 42 | has_default_value=False, default_value=[], 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | ], 47 | extensions=[ 48 | ], 49 | nested_types=[], 50 | enum_types=[ 51 | ], 52 | options=None, 53 | is_extendable=False, 54 | syntax='proto3', 55 | extension_ranges=[], 56 | oneofs=[ 57 | ], 58 | serialized_start=463, 59 | serialized_end=540, 60 | ) 61 | 62 | _UNITYRLINPUT_AGENTACTIONSENTRY = _descriptor.Descriptor( 63 | name='AgentActionsEntry', 64 | full_name='communicator_objects.UnityRLInput.AgentActionsEntry', 65 | filename=None, 66 | file=DESCRIPTOR, 67 | containing_type=None, 68 | fields=[ 69 | _descriptor.FieldDescriptor( 70 | name='key', full_name='communicator_objects.UnityRLInput.AgentActionsEntry.key', index=0, 71 | number=1, type=9, cpp_type=9, label=1, 72 | has_default_value=False, default_value=_b("").decode('utf-8'), 73 | message_type=None, enum_type=None, containing_type=None, 74 | is_extension=False, extension_scope=None, 75 | options=None, file=DESCRIPTOR), 76 | _descriptor.FieldDescriptor( 77 | name='value', full_name='communicator_objects.UnityRLInput.AgentActionsEntry.value', index=1, 78 | number=2, type=11, cpp_type=10, label=1, 79 | has_default_value=False, default_value=None, 80 | message_type=None, enum_type=None, containing_type=None, 81 | is_extension=False, extension_scope=None, 82 | options=None, file=DESCRIPTOR), 83 | ], 84 | extensions=[ 85 | ], 86 | nested_types=[], 87 | enum_types=[ 88 | ], 89 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 90 | is_extendable=False, 91 | syntax='proto3', 92 | extension_ranges=[], 93 | oneofs=[ 94 | ], 95 | serialized_start=542, 96 | serialized_end=650, 97 | ) 98 | 99 | _UNITYRLINPUT = _descriptor.Descriptor( 100 | name='UnityRLInput', 101 | full_name='communicator_objects.UnityRLInput', 102 | filename=None, 103 | file=DESCRIPTOR, 104 | containing_type=None, 105 | fields=[ 106 | _descriptor.FieldDescriptor( 107 | name='agent_actions', full_name='communicator_objects.UnityRLInput.agent_actions', index=0, 108 | number=1, type=11, cpp_type=10, label=3, 109 | has_default_value=False, default_value=[], 110 | message_type=None, enum_type=None, containing_type=None, 111 | is_extension=False, extension_scope=None, 112 | options=None, file=DESCRIPTOR), 113 | _descriptor.FieldDescriptor( 114 | name='environment_parameters', full_name='communicator_objects.UnityRLInput.environment_parameters', index=1, 115 | number=2, type=11, cpp_type=10, label=1, 116 | has_default_value=False, default_value=None, 117 | message_type=None, enum_type=None, containing_type=None, 118 | is_extension=False, extension_scope=None, 119 | options=None, file=DESCRIPTOR), 120 | _descriptor.FieldDescriptor( 121 | name='is_training', full_name='communicator_objects.UnityRLInput.is_training', index=2, 122 | number=3, type=8, cpp_type=7, label=1, 123 | has_default_value=False, default_value=False, 124 | message_type=None, enum_type=None, containing_type=None, 125 | is_extension=False, extension_scope=None, 126 | options=None, file=DESCRIPTOR), 127 | _descriptor.FieldDescriptor( 128 | name='command', full_name='communicator_objects.UnityRLInput.command', index=3, 129 | number=4, type=14, cpp_type=8, label=1, 130 | has_default_value=False, default_value=0, 131 | message_type=None, enum_type=None, containing_type=None, 132 | is_extension=False, extension_scope=None, 133 | options=None, file=DESCRIPTOR), 134 | ], 135 | extensions=[ 136 | ], 137 | nested_types=[_UNITYRLINPUT_LISTAGENTACTIONPROTO, _UNITYRLINPUT_AGENTACTIONSENTRY, ], 138 | enum_types=[ 139 | ], 140 | options=None, 141 | is_extendable=False, 142 | syntax='proto3', 143 | extension_ranges=[], 144 | oneofs=[ 145 | ], 146 | serialized_start=214, 147 | serialized_end=650, 148 | ) 149 | 150 | _UNITYRLINPUT_LISTAGENTACTIONPROTO.fields_by_name['value'].message_type = communicator__objects_dot_agent__action__proto__pb2._AGENTACTIONPROTO 151 | _UNITYRLINPUT_LISTAGENTACTIONPROTO.containing_type = _UNITYRLINPUT 152 | _UNITYRLINPUT_AGENTACTIONSENTRY.fields_by_name['value'].message_type = _UNITYRLINPUT_LISTAGENTACTIONPROTO 153 | _UNITYRLINPUT_AGENTACTIONSENTRY.containing_type = _UNITYRLINPUT 154 | _UNITYRLINPUT.fields_by_name['agent_actions'].message_type = _UNITYRLINPUT_AGENTACTIONSENTRY 155 | _UNITYRLINPUT.fields_by_name['environment_parameters'].message_type = communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO 156 | _UNITYRLINPUT.fields_by_name['command'].enum_type = communicator__objects_dot_command__proto__pb2._COMMANDPROTO 157 | DESCRIPTOR.message_types_by_name['UnityRLInput'] = _UNITYRLINPUT 158 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 159 | 160 | UnityRLInput = _reflection.GeneratedProtocolMessageType('UnityRLInput', (_message.Message,), dict( 161 | 162 | ListAgentActionProto = _reflection.GeneratedProtocolMessageType('ListAgentActionProto', (_message.Message,), dict( 163 | DESCRIPTOR = _UNITYRLINPUT_LISTAGENTACTIONPROTO, 164 | __module__ = 'communicator_objects.unity_rl_input_pb2' 165 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput.ListAgentActionProto) 166 | )) 167 | , 168 | 169 | AgentActionsEntry = _reflection.GeneratedProtocolMessageType('AgentActionsEntry', (_message.Message,), dict( 170 | DESCRIPTOR = _UNITYRLINPUT_AGENTACTIONSENTRY, 171 | __module__ = 'communicator_objects.unity_rl_input_pb2' 172 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput.AgentActionsEntry) 173 | )) 174 | , 175 | DESCRIPTOR = _UNITYRLINPUT, 176 | __module__ = 'communicator_objects.unity_rl_input_pb2' 177 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput) 178 | )) 179 | _sym_db.RegisterMessage(UnityRLInput) 180 | _sym_db.RegisterMessage(UnityRLInput.ListAgentActionProto) 181 | _sym_db.RegisterMessage(UnityRLInput.AgentActionsEntry) 182 | 183 | 184 | DESCRIPTOR.has_options = True 185 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 186 | _UNITYRLINPUT_AGENTACTIONSENTRY.has_options = True 187 | _UNITYRLINPUT_AGENTACTIONSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 188 | # @@protoc_insertion_point(module_scope) 189 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_rl_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import agent_info_proto_pb2 as communicator__objects_dot_agent__info__proto__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/unity_rl_output.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n*communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/agent_info_proto.proto\"\xa3\x02\n\rUnityRLOutput\x12\x13\n\x0bglobal_done\x18\x01 \x01(\x08\x12G\n\nagentInfos\x18\x02 \x03(\x0b\x32\x33.communicator_objects.UnityRLOutput.AgentInfosEntry\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1ai\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x45\n\x05value\x18\x02 \x01(\x0b\x32\x36.communicator_objects.UnityRLOutput.ListAgentInfoProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | , 25 | dependencies=[communicator__objects_dot_agent__info__proto__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | 30 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO = _descriptor.Descriptor( 31 | name='ListAgentInfoProto', 32 | full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto', 33 | filename=None, 34 | file=DESCRIPTOR, 35 | containing_type=None, 36 | fields=[ 37 | _descriptor.FieldDescriptor( 38 | name='value', full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto.value', index=0, 39 | number=1, type=11, cpp_type=10, label=3, 40 | has_default_value=False, default_value=[], 41 | message_type=None, enum_type=None, containing_type=None, 42 | is_extension=False, extension_scope=None, 43 | options=None, file=DESCRIPTOR), 44 | ], 45 | extensions=[ 46 | ], 47 | nested_types=[], 48 | enum_types=[ 49 | ], 50 | options=None, 51 | is_extendable=False, 52 | syntax='proto3', 53 | extension_ranges=[], 54 | oneofs=[ 55 | ], 56 | serialized_start=225, 57 | serialized_end=298, 58 | ) 59 | 60 | _UNITYRLOUTPUT_AGENTINFOSENTRY = _descriptor.Descriptor( 61 | name='AgentInfosEntry', 62 | full_name='communicator_objects.UnityRLOutput.AgentInfosEntry', 63 | filename=None, 64 | file=DESCRIPTOR, 65 | containing_type=None, 66 | fields=[ 67 | _descriptor.FieldDescriptor( 68 | name='key', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.key', index=0, 69 | number=1, type=9, cpp_type=9, label=1, 70 | has_default_value=False, default_value=_b("").decode('utf-8'), 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=None, file=DESCRIPTOR), 74 | _descriptor.FieldDescriptor( 75 | name='value', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.value', index=1, 76 | number=2, type=11, cpp_type=10, label=1, 77 | has_default_value=False, default_value=None, 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=None, file=DESCRIPTOR), 81 | ], 82 | extensions=[ 83 | ], 84 | nested_types=[], 85 | enum_types=[ 86 | ], 87 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 88 | is_extendable=False, 89 | syntax='proto3', 90 | extension_ranges=[], 91 | oneofs=[ 92 | ], 93 | serialized_start=300, 94 | serialized_end=405, 95 | ) 96 | 97 | _UNITYRLOUTPUT = _descriptor.Descriptor( 98 | name='UnityRLOutput', 99 | full_name='communicator_objects.UnityRLOutput', 100 | filename=None, 101 | file=DESCRIPTOR, 102 | containing_type=None, 103 | fields=[ 104 | _descriptor.FieldDescriptor( 105 | name='global_done', full_name='communicator_objects.UnityRLOutput.global_done', index=0, 106 | number=1, type=8, cpp_type=7, label=1, 107 | has_default_value=False, default_value=False, 108 | message_type=None, enum_type=None, containing_type=None, 109 | is_extension=False, extension_scope=None, 110 | options=None, file=DESCRIPTOR), 111 | _descriptor.FieldDescriptor( 112 | name='agentInfos', full_name='communicator_objects.UnityRLOutput.agentInfos', index=1, 113 | number=2, type=11, cpp_type=10, label=3, 114 | has_default_value=False, default_value=[], 115 | message_type=None, enum_type=None, containing_type=None, 116 | is_extension=False, extension_scope=None, 117 | options=None, file=DESCRIPTOR), 118 | ], 119 | extensions=[ 120 | ], 121 | nested_types=[_UNITYRLOUTPUT_LISTAGENTINFOPROTO, _UNITYRLOUTPUT_AGENTINFOSENTRY, ], 122 | enum_types=[ 123 | ], 124 | options=None, 125 | is_extendable=False, 126 | syntax='proto3', 127 | extension_ranges=[], 128 | oneofs=[ 129 | ], 130 | serialized_start=114, 131 | serialized_end=405, 132 | ) 133 | 134 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.fields_by_name['value'].message_type = communicator__objects_dot_agent__info__proto__pb2._AGENTINFOPROTO 135 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.containing_type = _UNITYRLOUTPUT 136 | _UNITYRLOUTPUT_AGENTINFOSENTRY.fields_by_name['value'].message_type = _UNITYRLOUTPUT_LISTAGENTINFOPROTO 137 | _UNITYRLOUTPUT_AGENTINFOSENTRY.containing_type = _UNITYRLOUTPUT 138 | _UNITYRLOUTPUT.fields_by_name['agentInfos'].message_type = _UNITYRLOUTPUT_AGENTINFOSENTRY 139 | DESCRIPTOR.message_types_by_name['UnityRLOutput'] = _UNITYRLOUTPUT 140 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 141 | 142 | UnityRLOutput = _reflection.GeneratedProtocolMessageType('UnityRLOutput', (_message.Message,), dict( 143 | 144 | ListAgentInfoProto = _reflection.GeneratedProtocolMessageType('ListAgentInfoProto', (_message.Message,), dict( 145 | DESCRIPTOR = _UNITYRLOUTPUT_LISTAGENTINFOPROTO, 146 | __module__ = 'communicator_objects.unity_rl_output_pb2' 147 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.ListAgentInfoProto) 148 | )) 149 | , 150 | 151 | AgentInfosEntry = _reflection.GeneratedProtocolMessageType('AgentInfosEntry', (_message.Message,), dict( 152 | DESCRIPTOR = _UNITYRLOUTPUT_AGENTINFOSENTRY, 153 | __module__ = 'communicator_objects.unity_rl_output_pb2' 154 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.AgentInfosEntry) 155 | )) 156 | , 157 | DESCRIPTOR = _UNITYRLOUTPUT, 158 | __module__ = 'communicator_objects.unity_rl_output_pb2' 159 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput) 160 | )) 161 | _sym_db.RegisterMessage(UnityRLOutput) 162 | _sym_db.RegisterMessage(UnityRLOutput.ListAgentInfoProto) 163 | _sym_db.RegisterMessage(UnityRLOutput.AgentInfosEntry) 164 | 165 | 166 | DESCRIPTOR.has_options = True 167 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 168 | _UNITYRLOUTPUT_AGENTINFOSENTRY.has_options = True 169 | _UNITYRLOUTPUT_AGENTINFOSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 170 | # @@protoc_insertion_point(module_scope) 171 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_to_external_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_to_external.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/unity_to_external.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n,communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a(communicator_objects/unity_message.proto2g\n\x0fUnityToExternal\x12T\n\x08\x45xchange\x12\".communicator_objects.UnityMessage\x1a\".communicator_objects.UnityMessage\"\x00\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | , 25 | dependencies=[communicator__objects_dot_unity__message__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 30 | 31 | 32 | DESCRIPTOR.has_options = True 33 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 34 | 35 | _UNITYTOEXTERNAL = _descriptor.ServiceDescriptor( 36 | name='UnityToExternal', 37 | full_name='communicator_objects.UnityToExternal', 38 | file=DESCRIPTOR, 39 | index=0, 40 | options=None, 41 | serialized_start=112, 42 | serialized_end=215, 43 | methods=[ 44 | _descriptor.MethodDescriptor( 45 | name='Exchange', 46 | full_name='communicator_objects.UnityToExternal.Exchange', 47 | index=0, 48 | containing_service=None, 49 | input_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 50 | output_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 51 | options=None, 52 | ), 53 | ]) 54 | _sym_db.RegisterServiceDescriptor(_UNITYTOEXTERNAL) 55 | 56 | DESCRIPTOR.services_by_name['UnityToExternal'] = _UNITYTOEXTERNAL 57 | 58 | # @@protoc_insertion_point(module_scope) 59 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_to_external_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | import grpc 3 | 4 | from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2 5 | 6 | 7 | class UnityToExternalStub(object): 8 | # missing associated documentation comment in .proto file 9 | pass 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Exchange = channel.unary_unary( 18 | '/communicator_objects.UnityToExternal/Exchange', 19 | request_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 20 | response_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 21 | ) 22 | 23 | 24 | class UnityToExternalServicer(object): 25 | # missing associated documentation comment in .proto file 26 | pass 27 | 28 | def Exchange(self, request, context): 29 | """Sends the academy parameters 30 | """ 31 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 32 | context.set_details('Method not implemented!') 33 | raise NotImplementedError('Method not implemented!') 34 | 35 | 36 | def add_UnityToExternalServicer_to_server(servicer, server): 37 | rpc_method_handlers = { 38 | 'Exchange': grpc.unary_unary_rpc_method_handler( 39 | servicer.Exchange, 40 | request_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 41 | response_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 42 | ), 43 | } 44 | generic_handler = grpc.method_handlers_generic_handler( 45 | 'communicator_objects.UnityToExternal', rpc_method_handlers) 46 | server.add_generic_rpc_handlers((generic_handler,)) 47 | -------------------------------------------------------------------------------- /python/curricula/push.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "reward", 3 | "thresholds" : [0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75], 4 | "min_lesson_length" : 2, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "goal_size" : [3.5, 3.25, 3.0, 2.75, 2.5, 2.25, 2.0, 1.75, 1.5, 1.25, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 9 | "block_size": [1.5, 1.4, 1.3, 1.2, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 10 | "x_variation":[1.5, 1.55, 1.6, 1.65, 1.7, 1.75, 1.8, 1.85, 1.9, 1.95, 2.0, 2.1, 2.2, 2.3, 2.4, 2.5] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /python/curricula/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "reward", 3 | "thresholds" : [10, 20, 50], 4 | "min_lesson_length" : 3, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "param1" : [0.7, 0.5, 0.3, 0.1], 9 | "param2" : [100, 50, 20, 15], 10 | "param3" : [0.2, 0.3, 0.7, 0.9] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /python/curricula/wall.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "progress", 3 | "thresholds" : [0.1, 0.3, 0.5], 4 | "min_lesson_length" : 2, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "big_wall_min_height" : [0.0, 4.0, 6.0, 8.0], 9 | "big_wall_max_height" : [4.0, 7.0, 8.0, 8.0], 10 | "small_wall_height" : [1.5, 2.0, 2.5, 4.0] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /python/learn.py: -------------------------------------------------------------------------------- 1 | # # Unity ML-Agents Toolkit 2 | # ## ML-Agent Learning 3 | 4 | import logging 5 | 6 | import os 7 | from docopt import docopt 8 | 9 | from unitytrainers.trainer_controller import TrainerController 10 | 11 | 12 | if __name__ == '__main__': 13 | print(''' 14 | 15 | ▄▄▄▓▓▓▓ 16 | ╓▓▓▓▓▓▓█▓▓▓▓▓ 17 | ,▄▄▄m▀▀▀' ,▓▓▓▀▓▓▄ ▓▓▓ ▓▓▌ 18 | ▄▓▓▓▀' ▄▓▓▀ ▓▓▓ ▄▄ ▄▄ ,▄▄ ▄▄▄▄ ,▄▄ ▄▓▓▌▄ ▄▄▄ ,▄▄ 19 | ▄▓▓▓▀ ▄▓▓▀ ▐▓▓▌ ▓▓▌ ▐▓▓ ▐▓▓▓▀▀▀▓▓▌ ▓▓▓ ▀▓▓▌▀ ^▓▓▌ ╒▓▓▌ 20 | ▄▓▓▓▓▓▄▄▄▄▄▄▄▄▓▓▓ ▓▀ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▄ ▓▓▌ 21 | ▀▓▓▓▓▀▀▀▀▀▀▀▀▀▀▓▓▄ ▓▓ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▐▓▓ 22 | ^█▓▓▓ ▀▓▓▄ ▐▓▓▌ ▓▓▓▓▄▓▓▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▓▄ ▓▓▓▓` 23 | '▀▓▓▓▄ ^▓▓▓ ▓▓▓ └▀▀▀▀ ▀▀ ^▀▀ `▀▀ `▀▀ '▀▀ ▐▓▓▌ 24 | ▀▀▀▀▓▄▄▄ ▓▓▓▓▓▓, ▓▓▓▓▀ 25 | `▀█▓▓▓▓▓▓▓▓▓▌ 26 | ¬`▀▀▀█▓ 27 | 28 | ''') 29 | 30 | logger = logging.getLogger("unityagents") 31 | _USAGE = ''' 32 | Usage: 33 | learn () [options] 34 | learn [options] 35 | learn --help 36 | 37 | Options: 38 | --curriculum= Curriculum json file for environment [default: None]. 39 | --keep-checkpoints= How many model checkpoints to keep [default: 5]. 40 | --lesson= Start learning from this lesson [default: 0]. 41 | --load Whether to load the model or randomly initialize [default: False]. 42 | --run-id= The sub-directory name for model and summary statistics [default: ppo]. 43 | --save-freq= Frequency at which to save model [default: 50000]. 44 | --seed= Random seed used for training [default: -1]. 45 | --slow Whether to run the game at training speed [default: False]. 46 | --train Whether to train model, or only run inference [default: False]. 47 | --worker-id= Number to add to communication port (5005). Used for multi-environment [default: 0]. 48 | --docker-target-name=
Docker Volume to store curriculum, executable and model files [default: Empty]. 49 | --no-graphics Whether to run the Unity simulator in no-graphics mode [default: False]. 50 | ''' 51 | 52 | options = docopt(_USAGE) 53 | logger.info(options) 54 | # Docker Parameters 55 | if options['--docker-target-name'] == 'Empty': 56 | docker_target_name = '' 57 | else: 58 | docker_target_name = options['--docker-target-name'] 59 | 60 | # General parameters 61 | run_id = options['--run-id'] 62 | seed = int(options['--seed']) 63 | load_model = options['--load'] 64 | train_model = options['--train'] 65 | save_freq = int(options['--save-freq']) 66 | env_path = options[''] 67 | keep_checkpoints = int(options['--keep-checkpoints']) 68 | worker_id = int(options['--worker-id']) 69 | curriculum_file = str(options['--curriculum']) 70 | if curriculum_file == "None": 71 | curriculum_file = None 72 | lesson = int(options['--lesson']) 73 | fast_simulation = not bool(options['--slow']) 74 | no_graphics = options['--no-graphics'] 75 | 76 | # Constants 77 | # Assumption that this yaml is present in same dir as this file 78 | base_path = os.path.dirname(__file__) 79 | TRAINER_CONFIG_PATH = os.path.abspath(os.path.join(base_path, "trainer_config.yaml")) 80 | 81 | tc = TrainerController(env_path, run_id, save_freq, curriculum_file, fast_simulation, load_model, train_model, 82 | worker_id, keep_checkpoints, lesson, seed, docker_target_name, TRAINER_CONFIG_PATH, 83 | no_graphics) 84 | tc.start_learning() 85 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.7.1 2 | Pillow>=4.2.1 3 | matplotlib 4 | numpy>=1.11.0 5 | jupyter 6 | pytest>=3.2.2 7 | docopt 8 | pyyaml 9 | protobuf==3.5.2 10 | grpcio==1.11.0 11 | torch==0.4.0 12 | pandas 13 | scipy 14 | ipykernel 15 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, Command, find_packages 4 | 5 | 6 | with open('requirements.txt') as f: 7 | required = f.read().splitlines() 8 | 9 | setup(name='unityagents', 10 | version='0.4.0', 11 | description='Unity Machine Learning Agents', 12 | license='Apache License 2.0', 13 | author='Unity Technologies', 14 | author_email='ML-Agents@unity3d.com', 15 | url='https://github.com/Unity-Technologies/ml-agents', 16 | packages=find_packages(), 17 | install_requires = required, 18 | long_description= ("Unity Machine Learning Agents allows researchers and developers " 19 | "to transform games and simulations created using the Unity Editor into environments " 20 | "where intelligent agents can be trained using reinforcement learning, evolutionary " 21 | "strategies, or other machine learning methods through a simple to use Python API.") 22 | ) 23 | -------------------------------------------------------------------------------- /python/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from unityagents import * 2 | from unitytrainers import * 3 | -------------------------------------------------------------------------------- /python/tests/mock_communicator.py: -------------------------------------------------------------------------------- 1 | 2 | from unityagents.communicator import Communicator 3 | from communicator_objects import UnityMessage, UnityOutput, UnityInput,\ 4 | ResolutionProto, BrainParametersProto, UnityRLInitializationOutput,\ 5 | AgentInfoProto, UnityRLOutput 6 | 7 | 8 | class MockCommunicator(Communicator): 9 | def __init__(self, discrete_action=False, visual_inputs=0): 10 | """ 11 | Python side of the grpc communication. Python is the client and Unity the server 12 | 13 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 14 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 15 | """ 16 | self.is_discrete = discrete_action 17 | self.steps = 0 18 | self.visual_inputs = visual_inputs 19 | self.has_been_closed = False 20 | 21 | def initialize(self, inputs: UnityInput) -> UnityOutput: 22 | resolutions = [ResolutionProto( 23 | width=30, 24 | height=40, 25 | gray_scale=False) for i in range(self.visual_inputs)] 26 | bp = BrainParametersProto( 27 | vector_observation_size=3, 28 | num_stacked_vector_observations=2, 29 | vector_action_size=2, 30 | camera_resolutions=resolutions, 31 | vector_action_descriptions=["", ""], 32 | vector_action_space_type=int(not self.is_discrete), 33 | vector_observation_space_type=1, 34 | brain_name="RealFakeBrain", 35 | brain_type=2 36 | ) 37 | rl_init = UnityRLInitializationOutput( 38 | name="RealFakeAcademy", 39 | version="API-4", 40 | log_path="", 41 | brain_parameters=[bp] 42 | ) 43 | return UnityOutput( 44 | rl_initialization_output=rl_init 45 | ) 46 | 47 | def exchange(self, inputs: UnityInput) -> UnityOutput: 48 | dict_agent_info = {} 49 | if self.is_discrete: 50 | vector_action = [1] 51 | else: 52 | vector_action = [1, 2] 53 | list_agent_info = [] 54 | for i in range(3): 55 | list_agent_info.append( 56 | AgentInfoProto( 57 | stacked_vector_observation=[1, 2, 3, 1, 2, 3], 58 | reward=1, 59 | stored_vector_actions=vector_action, 60 | stored_text_actions="", 61 | text_observation="", 62 | memories=[], 63 | done=(i == 2), 64 | max_step_reached=False, 65 | id=i 66 | )) 67 | dict_agent_info["RealFakeBrain"] = \ 68 | UnityRLOutput.ListAgentInfoProto(value=list_agent_info) 69 | global_done = False 70 | try: 71 | global_done = (inputs.rl_input.agent_actions["RealFakeBrain"].value[0].vector_actions[0] == -1) 72 | except: 73 | pass 74 | result = UnityRLOutput( 75 | global_done=global_done, 76 | agentInfos=dict_agent_info 77 | ) 78 | return UnityOutput( 79 | rl_output=result 80 | ) 81 | 82 | def close(self): 83 | """ 84 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 85 | """ 86 | self.has_been_closed = True 87 | -------------------------------------------------------------------------------- /python/tests/test_bc.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | import pytest 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from unitytrainers.bc.models import BehavioralCloningModel 8 | from unityagents import UnityEnvironment 9 | from .mock_communicator import MockCommunicator 10 | 11 | 12 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 13 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 14 | def test_cc_bc_model(mock_communicator, mock_launcher): 15 | tf.reset_default_graph() 16 | with tf.Session() as sess: 17 | with tf.variable_scope("FakeGraphScope"): 18 | mock_communicator.return_value = MockCommunicator( 19 | discrete_action=False, visual_inputs=0) 20 | env = UnityEnvironment(' ') 21 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 22 | init = tf.global_variables_initializer() 23 | sess.run(init) 24 | 25 | run_list = [model.sample_action, model.policy] 26 | feed_dict = {model.batch_size: 2, 27 | model.sequence_length: 1, 28 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 29 | [3, 4, 5, 3, 4, 5]])} 30 | sess.run(run_list, feed_dict=feed_dict) 31 | env.close() 32 | 33 | 34 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 35 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 36 | def test_dc_bc_model(mock_communicator, mock_launcher): 37 | tf.reset_default_graph() 38 | with tf.Session() as sess: 39 | with tf.variable_scope("FakeGraphScope"): 40 | mock_communicator.return_value = MockCommunicator( 41 | discrete_action=True, visual_inputs=0) 42 | env = UnityEnvironment(' ') 43 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 44 | init = tf.global_variables_initializer() 45 | sess.run(init) 46 | 47 | run_list = [model.sample_action, model.policy] 48 | feed_dict = {model.batch_size: 2, 49 | model.dropout_rate: 1.0, 50 | model.sequence_length: 1, 51 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 52 | [3, 4, 5, 3, 4, 5]])} 53 | sess.run(run_list, feed_dict=feed_dict) 54 | env.close() 55 | 56 | 57 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 58 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 59 | def test_visual_dc_bc_model(mock_communicator, mock_launcher): 60 | tf.reset_default_graph() 61 | with tf.Session() as sess: 62 | with tf.variable_scope("FakeGraphScope"): 63 | mock_communicator.return_value = MockCommunicator( 64 | discrete_action=True, visual_inputs=2) 65 | env = UnityEnvironment(' ') 66 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 67 | init = tf.global_variables_initializer() 68 | sess.run(init) 69 | 70 | run_list = [model.sample_action, model.policy] 71 | feed_dict = {model.batch_size: 2, 72 | model.dropout_rate: 1.0, 73 | model.sequence_length: 1, 74 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 75 | [3, 4, 5, 3, 4, 5]]), 76 | model.visual_in[0]: np.ones([2, 40, 30, 3]), 77 | model.visual_in[1]: np.ones([2, 40, 30, 3])} 78 | sess.run(run_list, feed_dict=feed_dict) 79 | env.close() 80 | 81 | 82 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 83 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 84 | def test_visual_cc_bc_model(mock_communicator, mock_launcher): 85 | tf.reset_default_graph() 86 | with tf.Session() as sess: 87 | with tf.variable_scope("FakeGraphScope"): 88 | mock_communicator.return_value = MockCommunicator( 89 | discrete_action=False, visual_inputs=2) 90 | env = UnityEnvironment(' ') 91 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 92 | init = tf.global_variables_initializer() 93 | sess.run(init) 94 | 95 | run_list = [model.sample_action, model.policy] 96 | feed_dict = {model.batch_size: 2, 97 | model.sequence_length: 1, 98 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 99 | [3, 4, 5, 3, 4, 5]]), 100 | model.visual_in[0]: np.ones([2, 40, 30, 3]), 101 | model.visual_in[1]: np.ones([2, 40, 30, 3])} 102 | sess.run(run_list, feed_dict=feed_dict) 103 | env.close() 104 | 105 | 106 | if __name__ == '__main__': 107 | pytest.main() 108 | -------------------------------------------------------------------------------- /python/tests/test_unityagents.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unittest.mock as mock 3 | import pytest 4 | import struct 5 | 6 | import numpy as np 7 | 8 | from unityagents import UnityEnvironment, UnityEnvironmentException, UnityActionException, \ 9 | BrainInfo, Curriculum 10 | from .mock_communicator import MockCommunicator 11 | 12 | 13 | dummy_curriculum = json.loads('''{ 14 | "measure" : "reward", 15 | "thresholds" : [10, 20, 50], 16 | "min_lesson_length" : 3, 17 | "signal_smoothing" : true, 18 | "parameters" : 19 | { 20 | "param1" : [0.7, 0.5, 0.3, 0.1], 21 | "param2" : [100, 50, 20, 15], 22 | "param3" : [0.2, 0.3, 0.7, 0.9] 23 | } 24 | }''') 25 | bad_curriculum = json.loads('''{ 26 | "measure" : "reward", 27 | "thresholds" : [10, 20, 50], 28 | "min_lesson_length" : 3, 29 | "signal_smoothing" : false, 30 | "parameters" : 31 | { 32 | "param1" : [0.7, 0.5, 0.3, 0.1], 33 | "param2" : [100, 50, 20], 34 | "param3" : [0.2, 0.3, 0.7, 0.9] 35 | } 36 | }''') 37 | 38 | 39 | def test_handles_bad_filename(): 40 | with pytest.raises(UnityEnvironmentException): 41 | UnityEnvironment(' ') 42 | 43 | 44 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 45 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 46 | def test_initialization(mock_communicator, mock_launcher): 47 | mock_communicator.return_value = MockCommunicator( 48 | discrete_action=False, visual_inputs=0) 49 | env = UnityEnvironment(' ') 50 | with pytest.raises(UnityActionException): 51 | env.step([0]) 52 | assert env.brain_names[0] == 'RealFakeBrain' 53 | env.close() 54 | 55 | 56 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 57 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 58 | def test_reset(mock_communicator, mock_launcher): 59 | mock_communicator.return_value = MockCommunicator( 60 | discrete_action=False, visual_inputs=0) 61 | env = UnityEnvironment(' ') 62 | brain = env.brains['RealFakeBrain'] 63 | brain_info = env.reset() 64 | env.close() 65 | assert not env.global_done 66 | assert isinstance(brain_info, dict) 67 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo) 68 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) 69 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) 70 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations 71 | assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ 72 | len(brain_info['RealFakeBrain'].agents) 73 | assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ 74 | brain.vector_observation_space_size * brain.num_stacked_vector_observations 75 | 76 | 77 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 78 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 79 | def test_step(mock_communicator, mock_launcher): 80 | mock_communicator.return_value = MockCommunicator( 81 | discrete_action=False, visual_inputs=0) 82 | env = UnityEnvironment(' ') 83 | brain = env.brains['RealFakeBrain'] 84 | brain_info = env.reset() 85 | brain_info = env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) 86 | with pytest.raises(UnityActionException): 87 | env.step([0]) 88 | brain_info = env.step([-1] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) 89 | with pytest.raises(UnityActionException): 90 | env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) 91 | env.close() 92 | assert env.global_done 93 | assert isinstance(brain_info, dict) 94 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo) 95 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) 96 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) 97 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations 98 | assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ 99 | len(brain_info['RealFakeBrain'].agents) 100 | assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ 101 | brain.vector_observation_space_size * brain.num_stacked_vector_observations 102 | 103 | print("\n\n\n\n\n\n\n" + str(brain_info['RealFakeBrain'].local_done)) 104 | assert not brain_info['RealFakeBrain'].local_done[0] 105 | assert brain_info['RealFakeBrain'].local_done[2] 106 | 107 | 108 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 109 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 110 | def test_close(mock_communicator, mock_launcher): 111 | comm = MockCommunicator( 112 | discrete_action=False, visual_inputs=0) 113 | mock_communicator.return_value = comm 114 | env = UnityEnvironment(' ') 115 | assert env._loaded 116 | env.close() 117 | assert not env._loaded 118 | assert comm.has_been_closed 119 | 120 | 121 | def test_curriculum(): 122 | open_name = '%s.open' % __name__ 123 | with mock.patch('json.load') as mock_load: 124 | with mock.patch(open_name, create=True) as mock_open: 125 | mock_open.return_value = 0 126 | mock_load.return_value = bad_curriculum 127 | with pytest.raises(UnityEnvironmentException): 128 | Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1, "param3": 1}) 129 | mock_load.return_value = dummy_curriculum 130 | with pytest.raises(UnityEnvironmentException): 131 | Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1}) 132 | curriculum = Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1, "param3": 1}) 133 | assert curriculum.get_lesson_number == 0 134 | curriculum.set_lesson_number(1) 135 | assert curriculum.get_lesson_number == 1 136 | curriculum.increment_lesson(10) 137 | assert curriculum.get_lesson_number == 1 138 | curriculum.increment_lesson(30) 139 | curriculum.increment_lesson(30) 140 | assert curriculum.get_lesson_number == 1 141 | assert curriculum.lesson_length == 3 142 | curriculum.increment_lesson(30) 143 | assert curriculum.get_config() == {'param1': 0.3, 'param2': 20, 'param3': 0.7} 144 | assert curriculum.get_config(0) == {"param1": 0.7, "param2": 100, "param3": 0.2} 145 | assert curriculum.lesson_length == 0 146 | assert curriculum.get_lesson_number == 2 147 | 148 | 149 | if __name__ == '__main__': 150 | pytest.main() 151 | -------------------------------------------------------------------------------- /python/tests/test_unitytrainers.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import unittest.mock as mock 3 | import pytest 4 | 5 | from unitytrainers.trainer_controller import TrainerController 6 | from unitytrainers.buffer import Buffer 7 | from unitytrainers.models import * 8 | from unitytrainers.ppo.trainer import PPOTrainer 9 | from unitytrainers.bc.trainer import BehavioralCloningTrainer 10 | from unityagents import UnityEnvironmentException 11 | from .mock_communicator import MockCommunicator 12 | 13 | dummy_start = '''{ 14 | "AcademyName": "RealFakeAcademy", 15 | "resetParameters": {}, 16 | "brainNames": ["RealFakeBrain"], 17 | "externalBrainNames": ["RealFakeBrain"], 18 | "logPath":"RealFakePath", 19 | "apiNumber":"API-3", 20 | "brainParameters": [{ 21 | "vectorObservationSize": 3, 22 | "numStackedVectorObservations" : 2, 23 | "vectorActionSize": 2, 24 | "memorySize": 0, 25 | "cameraResolutions": [], 26 | "vectorActionDescriptions": ["",""], 27 | "vectorActionSpaceType": 1, 28 | "vectorObservationSpaceType": 1 29 | }] 30 | }'''.encode() 31 | 32 | 33 | dummy_config = yaml.load(''' 34 | default: 35 | trainer: ppo 36 | batch_size: 32 37 | beta: 5.0e-3 38 | buffer_size: 512 39 | epsilon: 0.2 40 | gamma: 0.99 41 | hidden_units: 128 42 | lambd: 0.95 43 | learning_rate: 3.0e-4 44 | max_steps: 5.0e4 45 | normalize: true 46 | num_epoch: 5 47 | num_layers: 2 48 | time_horizon: 64 49 | sequence_length: 64 50 | summary_freq: 1000 51 | use_recurrent: false 52 | memory_size: 8 53 | use_curiosity: false 54 | curiosity_strength: 0.0 55 | curiosity_enc_size: 1 56 | ''') 57 | 58 | dummy_bc_config = yaml.load(''' 59 | default: 60 | trainer: imitation 61 | brain_to_imitate: ExpertBrain 62 | batches_per_epoch: 16 63 | batch_size: 32 64 | beta: 5.0e-3 65 | buffer_size: 512 66 | epsilon: 0.2 67 | gamma: 0.99 68 | hidden_units: 128 69 | lambd: 0.95 70 | learning_rate: 3.0e-4 71 | max_steps: 5.0e4 72 | normalize: true 73 | num_epoch: 5 74 | num_layers: 2 75 | time_horizon: 64 76 | sequence_length: 64 77 | summary_freq: 1000 78 | use_recurrent: false 79 | memory_size: 8 80 | use_curiosity: false 81 | curiosity_strength: 0.0 82 | curiosity_enc_size: 1 83 | ''') 84 | 85 | dummy_bad_config = yaml.load(''' 86 | default: 87 | trainer: incorrect_trainer 88 | brain_to_imitate: ExpertBrain 89 | batches_per_epoch: 16 90 | batch_size: 32 91 | beta: 5.0e-3 92 | buffer_size: 512 93 | epsilon: 0.2 94 | gamma: 0.99 95 | hidden_units: 128 96 | lambd: 0.95 97 | learning_rate: 3.0e-4 98 | max_steps: 5.0e4 99 | normalize: true 100 | num_epoch: 5 101 | num_layers: 2 102 | time_horizon: 64 103 | sequence_length: 64 104 | summary_freq: 1000 105 | use_recurrent: false 106 | memory_size: 8 107 | ''') 108 | 109 | 110 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 111 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 112 | def test_initialization(mock_communicator, mock_launcher): 113 | mock_communicator.return_value = MockCommunicator( 114 | discrete_action=True, visual_inputs=1) 115 | tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, 116 | 1, 1, 1, '', "tests/test_unitytrainers.py", False) 117 | assert(tc.env.brain_names[0] == 'RealFakeBrain') 118 | 119 | 120 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 121 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 122 | def test_load_config(mock_communicator, mock_launcher): 123 | open_name = 'unitytrainers.trainer_controller' + '.open' 124 | with mock.patch('yaml.load') as mock_load: 125 | with mock.patch(open_name, create=True) as _: 126 | mock_load.return_value = dummy_config 127 | mock_communicator.return_value = MockCommunicator( 128 | discrete_action=True, visual_inputs=1) 129 | mock_load.return_value = dummy_config 130 | tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, 131 | 1, 1, 1, '','', False) 132 | config = tc._load_config() 133 | assert(len(config) == 1) 134 | assert(config['default']['trainer'] == "ppo") 135 | 136 | 137 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 138 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 139 | def test_initialize_trainers(mock_communicator, mock_launcher): 140 | open_name = 'unitytrainers.trainer_controller' + '.open' 141 | with mock.patch('yaml.load') as mock_load: 142 | with mock.patch(open_name, create=True) as _: 143 | mock_communicator.return_value = MockCommunicator( 144 | discrete_action=True, visual_inputs=1) 145 | tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, 146 | 1, 1, 1, '', "tests/test_unitytrainers.py", False) 147 | 148 | # Test for PPO trainer 149 | mock_load.return_value = dummy_config 150 | config = tc._load_config() 151 | tf.reset_default_graph() 152 | with tf.Session() as sess: 153 | tc._initialize_trainers(config, sess) 154 | assert(len(tc.trainers) == 1) 155 | assert(isinstance(tc.trainers['RealFakeBrain'], PPOTrainer)) 156 | 157 | # Test for Behavior Cloning Trainer 158 | mock_load.return_value = dummy_bc_config 159 | config = tc._load_config() 160 | tf.reset_default_graph() 161 | with tf.Session() as sess: 162 | tc._initialize_trainers(config, sess) 163 | assert(isinstance(tc.trainers['RealFakeBrain'], BehavioralCloningTrainer)) 164 | 165 | # Test for proper exception when trainer name is incorrect 166 | mock_load.return_value = dummy_bad_config 167 | config = tc._load_config() 168 | tf.reset_default_graph() 169 | with tf.Session() as sess: 170 | with pytest.raises(UnityEnvironmentException): 171 | tc._initialize_trainers(config, sess) 172 | 173 | 174 | def assert_array(a, b): 175 | assert a.shape == b.shape 176 | la = list(a.flatten()) 177 | lb = list(b.flatten()) 178 | for i in range(len(la)): 179 | assert la[i] == lb[i] 180 | 181 | 182 | def test_buffer(): 183 | b = Buffer() 184 | for fake_agent_id in range(4): 185 | for step in range(9): 186 | b[fake_agent_id]['vector_observation'].append( 187 | [100 * fake_agent_id + 10 * step + 1, 188 | 100 * fake_agent_id + 10 * step + 2, 189 | 100 * fake_agent_id + 10 * step + 3] 190 | ) 191 | b[fake_agent_id]['action'].append([100 * fake_agent_id + 10 * step + 4, 192 | 100 * fake_agent_id + 10 * step + 5]) 193 | a = b[1]['vector_observation'].get_batch(batch_size=2, training_length=1, sequential=True) 194 | assert_array(a, np.array([[171, 172, 173], [181, 182, 183]])) 195 | a = b[2]['vector_observation'].get_batch(batch_size=2, training_length=3, sequential=True) 196 | assert_array(a, np.array([ 197 | [[231, 232, 233], [241, 242, 243], [251, 252, 253]], 198 | [[261, 262, 263], [271, 272, 273], [281, 282, 283]] 199 | ])) 200 | a = b[2]['vector_observation'].get_batch(batch_size=2, training_length=3, sequential=False) 201 | assert_array(a, np.array([ 202 | [[251, 252, 253], [261, 262, 263], [271, 272, 273]], 203 | [[261, 262, 263], [271, 272, 273], [281, 282, 283]] 204 | ])) 205 | b[4].reset_agent() 206 | assert len(b[4]) == 0 207 | b.append_update_buffer(3, 208 | batch_size=None, training_length=2) 209 | b.append_update_buffer(2, 210 | batch_size=None, training_length=2) 211 | assert len(b.update_buffer['action']) == 10 212 | assert np.array(b.update_buffer['action']).shape == (10, 2, 2) 213 | 214 | 215 | if __name__ == '__main__': 216 | pytest.main() 217 | -------------------------------------------------------------------------------- /python/trainer_config.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | trainer: ppo 3 | batch_size: 1024 4 | beta: 5.0e-3 5 | buffer_size: 10240 6 | epsilon: 0.2 7 | gamma: 0.99 8 | hidden_units: 128 9 | lambd: 0.95 10 | learning_rate: 3.0e-4 11 | max_steps: 5.0e4 12 | memory_size: 256 13 | normalize: false 14 | num_epoch: 3 15 | num_layers: 2 16 | time_horizon: 64 17 | sequence_length: 64 18 | summary_freq: 1000 19 | use_recurrent: false 20 | use_curiosity: false 21 | curiosity_strength: 0.01 22 | curiosity_enc_size: 128 23 | 24 | BananaBrain: 25 | normalize: false 26 | batch_size: 1024 27 | beta: 5.0e-3 28 | buffer_size: 10240 29 | 30 | BouncerBrain: 31 | normalize: true 32 | max_steps: 5.0e5 33 | num_layers: 2 34 | hidden_units: 64 35 | 36 | PushBlockBrain: 37 | max_steps: 5.0e4 38 | batch_size: 128 39 | buffer_size: 2048 40 | beta: 1.0e-2 41 | hidden_units: 256 42 | summary_freq: 2000 43 | time_horizon: 64 44 | num_layers: 2 45 | 46 | SmallWallBrain: 47 | max_steps: 2.0e5 48 | batch_size: 128 49 | buffer_size: 2048 50 | beta: 5.0e-3 51 | hidden_units: 256 52 | summary_freq: 2000 53 | time_horizon: 128 54 | num_layers: 2 55 | normalize: false 56 | 57 | BigWallBrain: 58 | max_steps: 2.0e5 59 | batch_size: 128 60 | buffer_size: 2048 61 | beta: 5.0e-3 62 | hidden_units: 256 63 | summary_freq: 2000 64 | time_horizon: 128 65 | num_layers: 2 66 | normalize: false 67 | 68 | StrikerBrain: 69 | max_steps: 1.0e5 70 | batch_size: 128 71 | buffer_size: 2048 72 | beta: 5.0e-3 73 | hidden_units: 256 74 | summary_freq: 2000 75 | time_horizon: 128 76 | num_layers: 2 77 | normalize: false 78 | 79 | GoalieBrain: 80 | max_steps: 1.0e5 81 | batch_size: 128 82 | buffer_size: 2048 83 | beta: 5.0e-3 84 | hidden_units: 256 85 | summary_freq: 2000 86 | time_horizon: 128 87 | num_layers: 2 88 | normalize: false 89 | 90 | PyramidBrain: 91 | use_curiosity: true 92 | summary_freq: 2000 93 | curiosity_strength: 0.01 94 | curiosity_enc_size: 256 95 | time_horizon: 128 96 | batch_size: 128 97 | buffer_size: 2048 98 | hidden_units: 512 99 | num_layers: 2 100 | beta: 1.0e-2 101 | max_steps: 2.0e5 102 | num_epoch: 3 103 | 104 | VisualPyramidBrain: 105 | use_curiosity: true 106 | time_horizon: 128 107 | batch_size: 32 108 | buffer_size: 1024 109 | hidden_units: 256 110 | num_layers: 2 111 | beta: 1.0e-2 112 | max_steps: 5.0e5 113 | num_epoch: 3 114 | 115 | Ball3DBrain: 116 | normalize: true 117 | batch_size: 64 118 | buffer_size: 12000 119 | summary_freq: 1000 120 | time_horizon: 1000 121 | lambd: 0.99 122 | gamma: 0.995 123 | beta: 0.001 124 | 125 | Ball3DHardBrain: 126 | normalize: true 127 | batch_size: 1200 128 | buffer_size: 12000 129 | summary_freq: 1000 130 | time_horizon: 1000 131 | gamma: 0.995 132 | beta: 0.001 133 | 134 | TennisBrain: 135 | normalize: true 136 | 137 | CrawlerBrain: 138 | normalize: true 139 | num_epoch: 3 140 | time_horizon: 1000 141 | batch_size: 2024 142 | buffer_size: 20240 143 | gamma: 0.995 144 | max_steps: 1e6 145 | summary_freq: 3000 146 | num_layers: 3 147 | hidden_units: 512 148 | 149 | WalkerBrain: 150 | normalize: true 151 | num_epoch: 3 152 | time_horizon: 1000 153 | batch_size: 2048 154 | buffer_size: 20480 155 | gamma: 0.995 156 | max_steps: 2e6 157 | summary_freq: 3000 158 | num_layers: 3 159 | hidden_units: 512 160 | 161 | ReacherBrain: 162 | normalize: true 163 | num_epoch: 3 164 | time_horizon: 1000 165 | batch_size: 2024 166 | buffer_size: 20240 167 | gamma: 0.995 168 | max_steps: 1e6 169 | summary_freq: 3000 170 | 171 | HallwayBrain: 172 | use_recurrent: true 173 | sequence_length: 64 174 | num_layers: 2 175 | hidden_units: 128 176 | memory_size: 256 177 | beta: 1.0e-2 178 | gamma: 0.99 179 | num_epoch: 3 180 | buffer_size: 1024 181 | batch_size: 128 182 | max_steps: 5.0e5 183 | summary_freq: 1000 184 | time_horizon: 64 185 | 186 | GridWorldBrain: 187 | batch_size: 32 188 | normalize: false 189 | num_layers: 1 190 | hidden_units: 256 191 | beta: 5.0e-3 192 | gamma: 0.9 193 | buffer_size: 256 194 | max_steps: 5.0e5 195 | summary_freq: 2000 196 | time_horizon: 5 197 | 198 | BasicBrain: 199 | batch_size: 32 200 | normalize: false 201 | num_layers: 1 202 | hidden_units: 20 203 | beta: 5.0e-3 204 | gamma: 0.9 205 | buffer_size: 256 206 | max_steps: 5.0e5 207 | summary_freq: 2000 208 | time_horizon: 3 209 | 210 | StudentBrain: 211 | trainer: imitation 212 | max_steps: 10000 213 | summary_freq: 1000 214 | brain_to_imitate: TeacherBrain 215 | batch_size: 16 216 | batches_per_epoch: 5 217 | num_layers: 4 218 | hidden_units: 64 219 | sequence_length: 16 220 | buffer_size: 128 221 | 222 | StudentRecurrentBrain: 223 | trainer: imitation 224 | max_steps: 10000 225 | summary_freq: 1000 226 | brain_to_imitate: TeacherBrain 227 | batch_size: 16 228 | batches_per_epoch: 5 229 | num_layers: 4 230 | hidden_units: 64 231 | use_recurrent: true 232 | sequence_length: 32 233 | buffer_size: 128 234 | -------------------------------------------------------------------------------- /python/unityagents/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import * 2 | from .brain import * 3 | from .exception import * 4 | from .curriculum import * 5 | -------------------------------------------------------------------------------- /python/unityagents/brain.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | class BrainInfo: 5 | def __init__(self, visual_observation, vector_observation, text_observations, memory=None, 6 | reward=None, agents=None, local_done=None, 7 | vector_action=None, text_action=None, max_reached=None): 8 | """ 9 | Describes experience at current step of all agents linked to a brain. 10 | """ 11 | self.visual_observations = visual_observation 12 | self.vector_observations = vector_observation 13 | self.text_observations = text_observations 14 | self.memories = memory 15 | self.rewards = reward 16 | self.local_done = local_done 17 | self.max_reached = max_reached 18 | self.agents = agents 19 | self.previous_vector_actions = vector_action 20 | self.previous_text_actions = text_action 21 | 22 | 23 | AllBrainInfo = Dict[str, BrainInfo] 24 | 25 | 26 | class BrainParameters: 27 | def __init__(self, brain_name, brain_param): 28 | """ 29 | Contains all brain-specific parameters. 30 | :param brain_name: Name of brain. 31 | :param brain_param: Dictionary of brain parameters. 32 | """ 33 | self.brain_name = brain_name 34 | self.vector_observation_space_size = brain_param["vectorObservationSize"] 35 | self.num_stacked_vector_observations = brain_param["numStackedVectorObservations"] 36 | self.number_visual_observations = len(brain_param["cameraResolutions"]) 37 | self.camera_resolutions = brain_param["cameraResolutions"] 38 | self.vector_action_space_size = brain_param["vectorActionSize"] 39 | self.vector_action_descriptions = brain_param["vectorActionDescriptions"] 40 | self.vector_action_space_type = ["discrete", "continuous"][brain_param["vectorActionSpaceType"]] 41 | self.vector_observation_space_type = ["discrete", "continuous"][brain_param["vectorObservationSpaceType"]] 42 | 43 | def __str__(self): 44 | return '''Unity brain name: {0} 45 | Number of Visual Observations (per agent): {1} 46 | Vector Observation space type: {2} 47 | Vector Observation space size (per agent): {3} 48 | Number of stacked Vector Observation: {4} 49 | Vector Action space type: {5} 50 | Vector Action space size (per agent): {6} 51 | Vector Action descriptions: {7}'''.format(self.brain_name, 52 | str(self.number_visual_observations), 53 | self.vector_observation_space_type, 54 | str(self.vector_observation_space_size), 55 | str(self.num_stacked_vector_observations), 56 | self.vector_action_space_type, 57 | str(self.vector_action_space_size), 58 | ', '.join(self.vector_action_descriptions)) 59 | -------------------------------------------------------------------------------- /python/unityagents/communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from communicator_objects import UnityOutput, UnityInput 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | logger = logging.getLogger("unityagents") 7 | 8 | 9 | class Communicator(object): 10 | def __init__(self, worker_id=0, 11 | base_port=5005): 12 | """ 13 | Python side of the communication. Must be used in pair with the right Unity Communicator equivalent. 14 | 15 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 16 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 17 | """ 18 | 19 | def initialize(self, inputs: UnityInput) -> UnityOutput: 20 | """ 21 | Used to exchange initialization parameters between Python and the Environment 22 | :param inputs: The initialization input that will be sent to the environment. 23 | :return: UnityOutput: The initialization output sent by Unity 24 | """ 25 | 26 | def exchange(self, inputs: UnityInput) -> UnityOutput: 27 | """ 28 | Used to send an input and receive an output from the Environment 29 | :param inputs: The UnityInput that needs to be sent the Environment 30 | :return: The UnityOutputs generated by the Environment 31 | """ 32 | 33 | def close(self): 34 | """ 35 | Sends a shutdown signal to the unity environment, and closes the connection. 36 | """ 37 | 38 | -------------------------------------------------------------------------------- /python/unityagents/curriculum.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .exception import UnityEnvironmentException 4 | 5 | import logging 6 | 7 | logger = logging.getLogger("unityagents") 8 | 9 | 10 | class Curriculum(object): 11 | def __init__(self, location, default_reset_parameters): 12 | """ 13 | Initializes a Curriculum object. 14 | :param location: Path to JSON defining curriculum. 15 | :param default_reset_parameters: Set of reset parameters for environment. 16 | """ 17 | self.lesson_length = 0 18 | self.max_lesson_number = 0 19 | self.measure_type = None 20 | if location is None: 21 | self.data = None 22 | else: 23 | try: 24 | with open(location) as data_file: 25 | self.data = json.load(data_file) 26 | except IOError: 27 | raise UnityEnvironmentException( 28 | "The file {0} could not be found.".format(location)) 29 | except UnicodeDecodeError: 30 | raise UnityEnvironmentException("There was an error decoding {}".format(location)) 31 | self.smoothing_value = 0 32 | for key in ['parameters', 'measure', 'thresholds', 33 | 'min_lesson_length', 'signal_smoothing']: 34 | if key not in self.data: 35 | raise UnityEnvironmentException("{0} does not contain a " 36 | "{1} field.".format(location, key)) 37 | parameters = self.data['parameters'] 38 | self.measure_type = self.data['measure'] 39 | self.max_lesson_number = len(self.data['thresholds']) 40 | for key in parameters: 41 | if key not in default_reset_parameters: 42 | raise UnityEnvironmentException( 43 | "The parameter {0} in Curriculum {1} is not present in " 44 | "the Environment".format(key, location)) 45 | for key in parameters: 46 | if len(parameters[key]) != self.max_lesson_number + 1: 47 | raise UnityEnvironmentException( 48 | "The parameter {0} in Curriculum {1} must have {2} values " 49 | "but {3} were found".format(key, location, 50 | self.max_lesson_number + 1, len(parameters[key]))) 51 | self.set_lesson_number(0) 52 | 53 | @property 54 | def measure(self): 55 | return self.measure_type 56 | 57 | @property 58 | def get_lesson_number(self): 59 | return self.lesson_number 60 | 61 | def set_lesson_number(self, value): 62 | self.lesson_length = 0 63 | self.lesson_number = max(0, min(value, self.max_lesson_number)) 64 | 65 | def increment_lesson(self, progress): 66 | """ 67 | Increments the lesson number depending on the progree given. 68 | :param progress: Measure of progress (either reward or percentage steps completed). 69 | """ 70 | if self.data is None or progress is None: 71 | return 72 | if self.data["signal_smoothing"]: 73 | progress = self.smoothing_value * 0.25 + 0.75 * progress 74 | self.smoothing_value = progress 75 | self.lesson_length += 1 76 | if self.lesson_number < self.max_lesson_number: 77 | if ((progress > self.data['thresholds'][self.lesson_number]) and 78 | (self.lesson_length > self.data['min_lesson_length'])): 79 | self.lesson_length = 0 80 | self.lesson_number += 1 81 | config = {} 82 | parameters = self.data["parameters"] 83 | for key in parameters: 84 | config[key] = parameters[key][self.lesson_number] 85 | logger.info("\nLesson changed. Now in Lesson {0} : \t{1}" 86 | .format(self.lesson_number, 87 | ', '.join([str(x) + ' -> ' + str(config[x]) for x in config]))) 88 | 89 | def get_config(self, lesson=None): 90 | """ 91 | Returns reset parameters which correspond to the lesson. 92 | :param lesson: The lesson you want to get the config of. If None, the current lesson is returned. 93 | :return: The configuration of the reset parameters. 94 | """ 95 | if self.data is None: 96 | return {} 97 | if lesson is None: 98 | lesson = self.lesson_number 99 | lesson = max(0, min(lesson, self.max_lesson_number)) 100 | config = {} 101 | parameters = self.data["parameters"] 102 | for key in parameters: 103 | config[key] = parameters[key][lesson] 104 | return config 105 | -------------------------------------------------------------------------------- /python/unityagents/exception.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger("unityagents") 3 | 4 | class UnityException(Exception): 5 | """ 6 | Any error related to ml-agents environment. 7 | """ 8 | pass 9 | 10 | class UnityEnvironmentException(UnityException): 11 | """ 12 | Related to errors starting and closing environment. 13 | """ 14 | pass 15 | 16 | 17 | class UnityActionException(UnityException): 18 | """ 19 | Related to errors with sending actions. 20 | """ 21 | pass 22 | 23 | class UnityTimeOutException(UnityException): 24 | """ 25 | Related to errors with communication timeouts. 26 | """ 27 | def __init__(self, message, log_file_path = None): 28 | if log_file_path is not None: 29 | try: 30 | with open(log_file_path, "r") as f: 31 | printing = False 32 | unity_error = '\n' 33 | for l in f: 34 | l=l.strip() 35 | if (l == 'Exception') or (l=='Error'): 36 | printing = True 37 | unity_error += '----------------------\n' 38 | if (l == ''): 39 | printing = False 40 | if printing: 41 | unity_error += l + '\n' 42 | logger.info(unity_error) 43 | logger.error("An error might have occured in the environment. " 44 | "You can check the logfile for more information at {}".format(log_file_path)) 45 | except: 46 | logger.error("An error might have occured in the environment. " 47 | "No unity-environment.log file could be found.") 48 | super(UnityTimeOutException, self).__init__(message) 49 | 50 | -------------------------------------------------------------------------------- /python/unityagents/rpc_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import grpc 3 | 4 | from multiprocessing import Pipe 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | from .communicator import Communicator 8 | from communicator_objects import UnityToExternalServicer, add_UnityToExternalServicer_to_server 9 | from communicator_objects import UnityMessage, UnityInput, UnityOutput 10 | from .exception import UnityTimeOutException 11 | 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger("unityagents") 15 | 16 | 17 | class UnityToExternalServicerImplementation(UnityToExternalServicer): 18 | parent_conn, child_conn = Pipe() 19 | 20 | def Initialize(self, request, context): 21 | self.child_conn.send(request) 22 | return self.child_conn.recv() 23 | 24 | def Exchange(self, request, context): 25 | self.child_conn.send(request) 26 | return self.child_conn.recv() 27 | 28 | 29 | class RpcCommunicator(Communicator): 30 | def __init__(self, worker_id=0, 31 | base_port=5005): 32 | """ 33 | Python side of the grpc communication. Python is the server and Unity the client 34 | 35 | 36 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 37 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 38 | """ 39 | self.port = base_port + worker_id 40 | self.worker_id = worker_id 41 | self.server = None 42 | self.unity_to_external = None 43 | self.is_open = False 44 | 45 | def initialize(self, inputs: UnityInput) -> UnityOutput: 46 | try: 47 | # Establish communication grpc 48 | self.server = grpc.server(ThreadPoolExecutor(max_workers=10)) 49 | self.unity_to_external = UnityToExternalServicerImplementation() 50 | add_UnityToExternalServicer_to_server(self.unity_to_external, self.server) 51 | self.server.add_insecure_port('[::]:'+str(self.port)) 52 | self.server.start() 53 | except : 54 | raise UnityTimeOutException( 55 | "Couldn't start socket communication because worker number {} is still in use. " 56 | "You may need to manually close a previously opened environment " 57 | "or use a different worker number.".format(str(self.worker_id))) 58 | if not self.unity_to_external.parent_conn.poll(30): 59 | raise UnityTimeOutException( 60 | "The Unity environment took too long to respond. Make sure that :\n" 61 | "\t The environment does not need user interaction to launch\n" 62 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" 63 | "\t The environment and the Python interface have compatible versions.") 64 | aca_param = self.unity_to_external.parent_conn.recv().unity_output 65 | self.is_open = True 66 | message = UnityMessage() 67 | message.header.status = 200 68 | message.unity_input.CopyFrom(inputs) 69 | self.unity_to_external.parent_conn.send(message) 70 | self.unity_to_external.parent_conn.recv() 71 | return aca_param 72 | 73 | def exchange(self, inputs: UnityInput) -> UnityOutput: 74 | message = UnityMessage() 75 | message.header.status = 200 76 | message.unity_input.CopyFrom(inputs) 77 | self.unity_to_external.parent_conn.send(message) 78 | output = self.unity_to_external.parent_conn.recv() 79 | if output.header.status != 200: 80 | return None 81 | return output.unity_output 82 | 83 | def close(self): 84 | """ 85 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 86 | """ 87 | if self.is_open: 88 | message_input = UnityMessage() 89 | message_input.header.status = 400 90 | self.unity_to_external.parent_conn.send(message_input) 91 | self.unity_to_external.parent_conn.close() 92 | self.server.stop(False) 93 | self.is_open = False 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /python/unityagents/socket_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import socket 3 | import struct 4 | 5 | from .communicator import Communicator 6 | from communicator_objects import UnityMessage, UnityOutput, UnityInput 7 | from .exception import UnityTimeOutException 8 | 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | logger = logging.getLogger("unityagents") 12 | 13 | 14 | class SocketCommunicator(Communicator): 15 | def __init__(self, worker_id=0, 16 | base_port=5005): 17 | """ 18 | Python side of the socket communication 19 | 20 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 21 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 22 | """ 23 | 24 | self.port = base_port + worker_id 25 | self._buffer_size = 12000 26 | self.worker_id = worker_id 27 | self._socket = None 28 | self._conn = None 29 | 30 | def initialize(self, inputs: UnityInput) -> UnityOutput: 31 | try: 32 | # Establish communication socket 33 | self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 34 | self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 35 | self._socket.bind(("localhost", self.port)) 36 | except: 37 | raise UnityTimeOutException("Couldn't start socket communication because worker number {} is still in use. " 38 | "You may need to manually close a previously opened environment " 39 | "or use a different worker number.".format(str(self.worker_id))) 40 | try: 41 | self._socket.settimeout(30) 42 | self._socket.listen(1) 43 | self._conn, _ = self._socket.accept() 44 | self._conn.settimeout(30) 45 | except : 46 | raise UnityTimeOutException( 47 | "The Unity environment took too long to respond. Make sure that :\n" 48 | "\t The environment does not need user interaction to launch\n" 49 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" 50 | "\t The environment and the Python interface have compatible versions.") 51 | message = UnityMessage() 52 | message.header.status = 200 53 | message.unity_input.CopyFrom(inputs) 54 | self._communicator_send(message.SerializeToString()) 55 | initialization_output = UnityMessage() 56 | initialization_output.ParseFromString(self._communicator_receive()) 57 | return initialization_output.unity_output 58 | 59 | def _communicator_receive(self): 60 | try: 61 | s = self._conn.recv(self._buffer_size) 62 | message_length = struct.unpack("I", bytearray(s[:4]))[0] 63 | s = s[4:] 64 | while len(s) != message_length: 65 | s += self._conn.recv(self._buffer_size) 66 | except socket.timeout as e: 67 | raise UnityTimeOutException("The environment took too long to respond.") 68 | return s 69 | 70 | def _communicator_send(self, message): 71 | self._conn.send(struct.pack("I", len(message)) + message) 72 | 73 | def exchange(self, inputs: UnityInput) -> UnityOutput: 74 | message = UnityMessage() 75 | message.header.status = 200 76 | message.unity_input.CopyFrom(inputs) 77 | self._communicator_send(message.SerializeToString()) 78 | outputs = UnityMessage() 79 | outputs.ParseFromString(self._communicator_receive()) 80 | if outputs.header.status != 200: 81 | return None 82 | return outputs.unity_output 83 | 84 | def close(self): 85 | """ 86 | Sends a shutdown signal to the unity environment, and closes the socket connection. 87 | """ 88 | if self._socket is not None and self._conn is not None: 89 | message_input = UnityMessage() 90 | message_input.header.status = 400 91 | self._communicator_send(message_input.SerializeToString()) 92 | if self._socket is not None: 93 | self._socket.close() 94 | self._socket = None 95 | if self._socket is not None: 96 | self._conn.close() 97 | self._conn = None 98 | 99 | -------------------------------------------------------------------------------- /python/unitytrainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .buffer import * 2 | from .models import * 3 | from .trainer_controller import * 4 | from .bc.models import * 5 | from .bc.trainer import * 6 | from .ppo.models import * 7 | from .ppo.trainer import * 8 | -------------------------------------------------------------------------------- /python/unitytrainers/bc/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .trainer import * 3 | -------------------------------------------------------------------------------- /python/unitytrainers/bc/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as c_layers 3 | from unitytrainers.models import LearningModel 4 | 5 | 6 | class BehavioralCloningModel(LearningModel): 7 | def __init__(self, brain, h_size=128, lr=1e-4, n_layers=2, m_size=128, 8 | normalize=False, use_recurrent=False): 9 | LearningModel.__init__(self, m_size, normalize, use_recurrent, brain) 10 | 11 | num_streams = 1 12 | hidden_streams = self.create_observation_streams(num_streams, h_size, n_layers) 13 | hidden = hidden_streams[0] 14 | self.dropout_rate = tf.placeholder(dtype=tf.float32, shape=[], name="dropout_rate") 15 | hidden_reg = tf.layers.dropout(hidden, self.dropout_rate) 16 | if self.use_recurrent: 17 | tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32) 18 | self.memory_in = tf.placeholder(shape=[None, self.m_size], dtype=tf.float32, name='recurrent_in') 19 | hidden_reg, self.memory_out = self.create_recurrent_encoder(hidden_reg, self.memory_in, 20 | self.sequence_length) 21 | self.memory_out = tf.identity(self.memory_out, name='recurrent_out') 22 | self.policy = tf.layers.dense(hidden_reg, self.a_size, activation=None, use_bias=False, name='pre_action', 23 | kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01)) 24 | 25 | if brain.vector_action_space_type == "discrete": 26 | self.action_probs = tf.nn.softmax(self.policy) 27 | self.sample_action_float = tf.multinomial(self.policy, 1) 28 | self.sample_action_float = tf.identity(self.sample_action_float, name="action") 29 | self.sample_action = tf.cast(self.sample_action_float, tf.int32) 30 | self.true_action = tf.placeholder(shape=[None], dtype=tf.int32, name="teacher_action") 31 | self.action_oh = tf.one_hot(self.true_action, self.a_size) 32 | self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh) 33 | self.action_percent = tf.reduce_mean(tf.cast( 34 | tf.equal(tf.cast(tf.argmax(self.action_probs, axis=1), tf.int32), self.sample_action), tf.float32)) 35 | else: 36 | self.clipped_sample_action = tf.clip_by_value(self.policy, -1, 1) 37 | self.sample_action = tf.identity(self.clipped_sample_action, name="action") 38 | self.true_action = tf.placeholder(shape=[None, self.a_size], dtype=tf.float32, name="teacher_action") 39 | self.clipped_true_action = tf.clip_by_value(self.true_action, -1, 1) 40 | self.loss = tf.reduce_sum(tf.squared_difference(self.clipped_true_action, self.sample_action)) 41 | 42 | optimizer = tf.train.AdamOptimizer(learning_rate=lr) 43 | self.update = optimizer.minimize(self.loss) 44 | -------------------------------------------------------------------------------- /python/unitytrainers/buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from unityagents.exception import UnityException 4 | 5 | 6 | class BufferException(UnityException): 7 | """ 8 | Related to errors with the Buffer. 9 | """ 10 | pass 11 | 12 | 13 | class Buffer(dict): 14 | """ 15 | Buffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id. 16 | Buffer also contains an update_buffer that corresponds to the buffer used when updating the model. 17 | """ 18 | 19 | class AgentBuffer(dict): 20 | """ 21 | AgentBuffer contains a dictionary of AgentBufferFields. Each agent has his own AgentBuffer. 22 | The keys correspond to the name of the field. Example: state, action 23 | """ 24 | 25 | class AgentBufferField(list): 26 | """ 27 | AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his 28 | AgentBufferField with the append method. 29 | """ 30 | 31 | def __str__(self): 32 | return str(np.array(self).shape) 33 | 34 | def extend(self, data): 35 | """ 36 | Ads a list of np.arrays to the end of the list of np.arrays. 37 | :param data: The np.array list to append. 38 | """ 39 | self += list(np.array(data)) 40 | 41 | def set(self, data): 42 | """ 43 | Sets the list of np.array to the input data 44 | :param data: The np.array list to be set. 45 | """ 46 | self[:] = [] 47 | self[:] = list(np.array(data)) 48 | 49 | def get_batch(self, batch_size=None, training_length=1, sequential=True): 50 | """ 51 | Retrieve the last batch_size elements of length training_length 52 | from the list of np.array 53 | :param batch_size: The number of elements to retrieve. If None: 54 | All elements will be retrieved. 55 | :param training_length: The length of the sequence to be retrieved. If 56 | None: only takes one element. 57 | :param sequential: If true and training_length is not None: the elements 58 | will not repeat in the sequence. [a,b,c,d,e] with training_length = 2 and 59 | sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives 60 | [[a,b],[b,c],[c,d],[d,e]] 61 | """ 62 | if training_length == 1: 63 | # When the training length is 1, the method returns a list of elements, 64 | # not a list of sequences of elements. 65 | if batch_size is None: 66 | # If batch_size is None : All the elements of the AgentBufferField are returned. 67 | return np.array(self) 68 | else: 69 | # return the batch_size last elements 70 | if batch_size > len(self): 71 | raise BufferException("Batch size requested is too large") 72 | return np.array(self[-batch_size:]) 73 | else: 74 | # The training_length is not None, the method returns a list of SEQUENCES of elements 75 | if not sequential: 76 | # The sequences will have overlapping elements 77 | if batch_size is None: 78 | # retrieve the maximum number of elements 79 | batch_size = len(self) - training_length + 1 80 | # The number of sequences of length training_length taken from a list of len(self) elements 81 | # with overlapping is equal to batch_size 82 | if (len(self) - training_length + 1) < batch_size: 83 | raise BufferException("The batch size and training length requested for get_batch where" 84 | " too large given the current number of data points.") 85 | tmp_list = [] 86 | for end in range(len(self) - batch_size + 1, len(self) + 1): 87 | tmp_list += [np.array(self[end - training_length:end])] 88 | return np.array(tmp_list) 89 | if sequential: 90 | # The sequences will not have overlapping elements (this involves padding) 91 | leftover = len(self) % training_length 92 | # leftover is the number of elements in the first sequence (this sequence might need 0 padding) 93 | if batch_size is None: 94 | # retrieve the maximum number of elements 95 | batch_size = len(self) // training_length + 1 * (leftover != 0) 96 | # The maximum number of sequences taken from a list of length len(self) without overlapping 97 | # with padding is equal to batch_size 98 | if batch_size > (len(self) // training_length + 1 * (leftover != 0)): 99 | raise BufferException("The batch size and training length requested for get_batch where" 100 | " too large given the current number of data points.") 101 | tmp_list = [] 102 | padding = np.array(self[-1]) * 0 103 | # The padding is made with zeros and its shape is given by the shape of the last element 104 | for end in range(len(self), len(self) % training_length, -training_length)[:batch_size]: 105 | tmp_list += [np.array(self[end - training_length:end])] 106 | if (leftover != 0) and (len(tmp_list) < batch_size): 107 | tmp_list += [np.array([padding] * (training_length - leftover) + self[:leftover])] 108 | tmp_list.reverse() 109 | return np.array(tmp_list) 110 | 111 | def reset_field(self): 112 | """ 113 | Resets the AgentBufferField 114 | """ 115 | self[:] = [] 116 | 117 | def __init__(self): 118 | self.last_brain_info = None 119 | self.last_take_action_outputs = None 120 | super(Buffer.AgentBuffer, self).__init__() 121 | 122 | def __str__(self): 123 | return ", ".join(["'{0}' : {1}".format(k, str(self[k])) for k in self.keys()]) 124 | 125 | def reset_agent(self): 126 | """ 127 | Resets the AgentBuffer 128 | """ 129 | for k in self.keys(): 130 | self[k].reset_field() 131 | self.last_brain_info = None 132 | self.last_take_action_outputs = None 133 | 134 | def __getitem__(self, key): 135 | if key not in self.keys(): 136 | self[key] = self.AgentBufferField() 137 | return super(Buffer.AgentBuffer, self).__getitem__(key) 138 | 139 | def check_length(self, key_list): 140 | """ 141 | Some methods will require that some fields have the same length. 142 | check_length will return true if the fields in key_list 143 | have the same length. 144 | :param key_list: The fields which length will be compared 145 | """ 146 | if len(key_list) < 2: 147 | return True 148 | l = None 149 | for key in key_list: 150 | if key not in self.keys(): 151 | return False 152 | if (l is not None) and (l != len(self[key])): 153 | return False 154 | l = len(self[key]) 155 | return True 156 | 157 | def shuffle(self, key_list=None): 158 | """ 159 | Shuffles the fields in key_list in a consistent way: The reordering will 160 | be the same across fields. 161 | :param key_list: The fields that must be shuffled. 162 | """ 163 | if key_list is None: 164 | key_list = list(self.keys()) 165 | if not self.check_length(key_list): 166 | raise BufferException("Unable to shuffle if the fields are not of same length") 167 | s = np.arange(len(self[key_list[0]])) 168 | np.random.shuffle(s) 169 | for key in key_list: 170 | self[key][:] = [self[key][i] for i in s] 171 | 172 | def __init__(self): 173 | self.update_buffer = self.AgentBuffer() 174 | super(Buffer, self).__init__() 175 | 176 | def __str__(self): 177 | return "update buffer :\n\t{0}\nlocal_buffers :\n{1}".format(str(self.update_buffer), 178 | '\n'.join( 179 | ['\tagent {0} :{1}'.format(k, str(self[k])) for 180 | k in self.keys()])) 181 | 182 | def __getitem__(self, key): 183 | if key not in self.keys(): 184 | self[key] = self.AgentBuffer() 185 | return super(Buffer, self).__getitem__(key) 186 | 187 | def reset_update_buffer(self): 188 | """ 189 | Resets the update buffer 190 | """ 191 | self.update_buffer.reset_agent() 192 | 193 | def reset_all(self): 194 | """ 195 | Resets all the local local_buffers 196 | """ 197 | agent_ids = list(self.keys()) 198 | for k in agent_ids: 199 | self[k].reset_agent() 200 | 201 | def append_update_buffer(self, agent_id, key_list=None, batch_size=None, training_length=None): 202 | """ 203 | Appends the buffer of an agent to the update buffer. 204 | :param agent_id: The id of the agent which data will be appended 205 | :param key_list: The fields that must be added. If None: all fields will be appended. 206 | :param batch_size: The number of elements that must be appended. If None: All of them will be. 207 | :param training_length: The length of the samples that must be appended. If None: only takes one element. 208 | """ 209 | if key_list is None: 210 | key_list = self[agent_id].keys() 211 | if not self[agent_id].check_length(key_list): 212 | raise BufferException("The length of the fields {0} for agent {1} where not of same length" 213 | .format(key_list, agent_id)) 214 | for field_key in key_list: 215 | self.update_buffer[field_key].extend( 216 | self[agent_id][field_key].get_batch(batch_size=batch_size, training_length=training_length) 217 | ) 218 | 219 | def append_all_agent_batch_to_update_buffer(self, key_list=None, batch_size=None, training_length=None): 220 | """ 221 | Appends the buffer of all agents to the update buffer. 222 | :param key_list: The fields that must be added. If None: all fields will be appended. 223 | :param batch_size: The number of elements that must be appended. If None: All of them will be. 224 | :param training_length: The length of the samples that must be appended. If None: only takes one element. 225 | """ 226 | for agent_id in self.keys(): 227 | self.append_update_buffer(agent_id, key_list, batch_size, training_length) 228 | -------------------------------------------------------------------------------- /python/unitytrainers/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .trainer import * 3 | -------------------------------------------------------------------------------- /python/unitytrainers/ppo/models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import tensorflow as tf 4 | from unitytrainers.models import LearningModel 5 | 6 | logger = logging.getLogger("unityagents") 7 | 8 | 9 | class PPOModel(LearningModel): 10 | def __init__(self, brain, lr=1e-4, h_size=128, epsilon=0.2, beta=1e-3, max_step=5e6, 11 | normalize=False, use_recurrent=False, num_layers=2, m_size=None, use_curiosity=False, 12 | curiosity_strength=0.01, curiosity_enc_size=128): 13 | """ 14 | Takes a Unity environment and model-specific hyper-parameters and returns the 15 | appropriate PPO agent model for the environment. 16 | :param brain: BrainInfo used to generate specific network graph. 17 | :param lr: Learning rate. 18 | :param h_size: Size of hidden layers 19 | :param epsilon: Value for policy-divergence threshold. 20 | :param beta: Strength of entropy regularization. 21 | :return: a sub-class of PPOAgent tailored to the environment. 22 | :param max_step: Total number of training steps. 23 | :param normalize: Whether to normalize vector observation input. 24 | :param use_recurrent: Whether to use an LSTM layer in the network. 25 | :param num_layers Number of hidden layers between encoded input and policy & value layers 26 | :param m_size: Size of brain memory. 27 | """ 28 | LearningModel.__init__(self, m_size, normalize, use_recurrent, brain) 29 | self.use_curiosity = use_curiosity 30 | if num_layers < 1: 31 | num_layers = 1 32 | self.last_reward, self.new_reward, self.update_reward = self.create_reward_encoder() 33 | if brain.vector_action_space_type == "continuous": 34 | self.create_cc_actor_critic(h_size, num_layers) 35 | self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy 36 | else: 37 | self.create_dc_actor_critic(h_size, num_layers) 38 | if self.use_curiosity: 39 | self.curiosity_enc_size = curiosity_enc_size 40 | self.curiosity_strength = curiosity_strength 41 | encoded_state, encoded_next_state = self.create_curiosity_encoders() 42 | self.create_inverse_model(encoded_state, encoded_next_state) 43 | self.create_forward_model(encoded_state, encoded_next_state) 44 | self.create_ppo_optimizer(self.probs, self.old_probs, self.value, 45 | self.entropy, beta, epsilon, lr, max_step) 46 | 47 | @staticmethod 48 | def create_reward_encoder(): 49 | """Creates TF ops to track and increment recent average cumulative reward.""" 50 | last_reward = tf.Variable(0, name="last_reward", trainable=False, dtype=tf.float32) 51 | new_reward = tf.placeholder(shape=[], dtype=tf.float32, name='new_reward') 52 | update_reward = tf.assign(last_reward, new_reward) 53 | return last_reward, new_reward, update_reward 54 | 55 | def create_curiosity_encoders(self): 56 | """ 57 | Creates state encoders for current and future observations. 58 | Used for implementation of Curiosity-driven Exploration by Self-supervised Prediction 59 | See https://arxiv.org/abs/1705.05363 for more details. 60 | :return: current and future state encoder tensors. 61 | """ 62 | encoded_state_list = [] 63 | encoded_next_state_list = [] 64 | 65 | if self.v_size > 0: 66 | self.next_visual_in = [] 67 | visual_encoders = [] 68 | next_visual_encoders = [] 69 | for i in range(self.v_size): 70 | # Create input ops for next (t+1) visual observations. 71 | next_visual_input = self.create_visual_input(self.brain.camera_resolutions[i], 72 | name="next_visual_observation_" + str(i)) 73 | self.next_visual_in.append(next_visual_input) 74 | 75 | # Create the encoder ops for current and next visual input. Not that these encoders are siamese. 76 | encoded_visual = self.create_visual_observation_encoder(self.visual_in[i], self.curiosity_enc_size, 77 | self.swish, 1, "stream_{}_visual_obs_encoder" 78 | .format(i), False) 79 | 80 | encoded_next_visual = self.create_visual_observation_encoder(self.next_visual_in[i], 81 | self.curiosity_enc_size, 82 | self.swish, 1, 83 | "stream_{}_visual_obs_encoder".format(i), 84 | True) 85 | visual_encoders.append(encoded_visual) 86 | next_visual_encoders.append(encoded_next_visual) 87 | 88 | hidden_visual = tf.concat(visual_encoders, axis=1) 89 | hidden_next_visual = tf.concat(next_visual_encoders, axis=1) 90 | encoded_state_list.append(hidden_visual) 91 | encoded_next_state_list.append(hidden_next_visual) 92 | 93 | if self.o_size > 0: 94 | 95 | # Create the encoder ops for current and next vector input. Not that these encoders are siamese. 96 | if self.brain.vector_observation_space_type == "continuous": 97 | # Create input op for next (t+1) vector observation. 98 | self.next_vector_in = tf.placeholder(shape=[None, self.o_size], dtype=tf.float32, 99 | name='next_vector_observation') 100 | 101 | encoded_vector_obs = self.create_continuous_observation_encoder(self.vector_in, 102 | self.curiosity_enc_size, 103 | self.swish, 2, "vector_obs_encoder", 104 | False) 105 | encoded_next_vector_obs = self.create_continuous_observation_encoder(self.next_vector_in, 106 | self.curiosity_enc_size, 107 | self.swish, 2, 108 | "vector_obs_encoder", 109 | True) 110 | else: 111 | self.next_vector_in = tf.placeholder(shape=[None, 1], dtype=tf.int32, 112 | name='next_vector_observation') 113 | 114 | encoded_vector_obs = self.create_discrete_observation_encoder(self.vector_in, self.o_size, 115 | self.curiosity_enc_size, 116 | self.swish, 2, "vector_obs_encoder", 117 | False) 118 | encoded_next_vector_obs = self.create_discrete_observation_encoder(self.next_vector_in, self.o_size, 119 | self.curiosity_enc_size, 120 | self.swish, 2, "vector_obs_encoder", 121 | True) 122 | encoded_state_list.append(encoded_vector_obs) 123 | encoded_next_state_list.append(encoded_next_vector_obs) 124 | 125 | encoded_state = tf.concat(encoded_state_list, axis=1) 126 | encoded_next_state = tf.concat(encoded_next_state_list, axis=1) 127 | return encoded_state, encoded_next_state 128 | 129 | def create_inverse_model(self, encoded_state, encoded_next_state): 130 | """ 131 | Creates inverse model TensorFlow ops for Curiosity module. 132 | Predicts action taken given current and future encoded states. 133 | :param encoded_state: Tensor corresponding to encoded current state. 134 | :param encoded_next_state: Tensor corresponding to encoded next state. 135 | """ 136 | combined_input = tf.concat([encoded_state, encoded_next_state], axis=1) 137 | hidden = tf.layers.dense(combined_input, 256, activation=self.swish) 138 | if self.brain.vector_action_space_type == "continuous": 139 | pred_action = tf.layers.dense(hidden, self.a_size, activation=None) 140 | squared_difference = tf.reduce_sum(tf.squared_difference(pred_action, self.selected_actions), axis=1) 141 | self.inverse_loss = tf.reduce_mean(tf.dynamic_partition(squared_difference, self.mask, 2)[1]) 142 | else: 143 | pred_action = tf.layers.dense(hidden, self.a_size, activation=tf.nn.softmax) 144 | cross_entropy = tf.reduce_sum(-tf.log(pred_action + 1e-10) * self.selected_actions, axis=1) 145 | self.inverse_loss = tf.reduce_mean(tf.dynamic_partition(cross_entropy, self.mask, 2)[1]) 146 | 147 | def create_forward_model(self, encoded_state, encoded_next_state): 148 | """ 149 | Creates forward model TensorFlow ops for Curiosity module. 150 | Predicts encoded future state based on encoded current state and given action. 151 | :param encoded_state: Tensor corresponding to encoded current state. 152 | :param encoded_next_state: Tensor corresponding to encoded next state. 153 | """ 154 | combined_input = tf.concat([encoded_state, self.selected_actions], axis=1) 155 | hidden = tf.layers.dense(combined_input, 256, activation=self.swish) 156 | # We compare against the concatenation of all observation streams, hence `self.v_size + int(self.o_size > 0)`. 157 | pred_next_state = tf.layers.dense(hidden, self.curiosity_enc_size * (self.v_size + int(self.o_size > 0)), 158 | activation=None) 159 | 160 | squared_difference = 0.5 * tf.reduce_sum(tf.squared_difference(pred_next_state, encoded_next_state), axis=1) 161 | self.intrinsic_reward = tf.clip_by_value(self.curiosity_strength * squared_difference, 0, 1) 162 | self.forward_loss = tf.reduce_mean(tf.dynamic_partition(squared_difference, self.mask, 2)[1]) 163 | 164 | def create_ppo_optimizer(self, probs, old_probs, value, entropy, beta, epsilon, lr, max_step): 165 | """ 166 | Creates training-specific Tensorflow ops for PPO models. 167 | :param probs: Current policy probabilities 168 | :param old_probs: Past policy probabilities 169 | :param value: Current value estimate 170 | :param beta: Entropy regularization strength 171 | :param entropy: Current policy entropy 172 | :param epsilon: Value for policy-divergence threshold 173 | :param lr: Learning rate 174 | :param max_step: Total number of training steps. 175 | """ 176 | self.returns_holder = tf.placeholder(shape=[None], dtype=tf.float32, name='discounted_rewards') 177 | self.advantage = tf.placeholder(shape=[None, 1], dtype=tf.float32, name='advantages') 178 | self.learning_rate = tf.train.polynomial_decay(lr, self.global_step, max_step, 1e-10, power=1.0) 179 | 180 | self.old_value = tf.placeholder(shape=[None], dtype=tf.float32, name='old_value_estimates') 181 | 182 | decay_epsilon = tf.train.polynomial_decay(epsilon, self.global_step, max_step, 0.1, power=1.0) 183 | decay_beta = tf.train.polynomial_decay(beta, self.global_step, max_step, 1e-5, power=1.0) 184 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) 185 | 186 | clipped_value_estimate = self.old_value + tf.clip_by_value(tf.reduce_sum(value, axis=1) - self.old_value, 187 | - decay_epsilon, decay_epsilon) 188 | 189 | v_opt_a = tf.squared_difference(self.returns_holder, tf.reduce_sum(value, axis=1)) 190 | v_opt_b = tf.squared_difference(self.returns_holder, clipped_value_estimate) 191 | self.value_loss = tf.reduce_mean(tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.mask, 2)[1]) 192 | 193 | # Here we calculate PPO policy loss. In continuous control this is done independently for each action gaussian 194 | # and then averaged together. This provides significantly better performance than treating the probability 195 | # as an average of probabilities, or as a joint probability. 196 | r_theta = probs / (old_probs + 1e-10) 197 | p_opt_a = r_theta * self.advantage 198 | p_opt_b = tf.clip_by_value(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * self.advantage 199 | self.policy_loss = -tf.reduce_mean(tf.dynamic_partition(tf.minimum(p_opt_a, p_opt_b), self.mask, 2)[1]) 200 | 201 | self.loss = self.policy_loss + 0.5 * self.value_loss - decay_beta * tf.reduce_mean( 202 | tf.dynamic_partition(entropy, self.mask, 2)[1]) 203 | 204 | if self.use_curiosity: 205 | self.loss += 10 * (0.2 * self.forward_loss + 0.8 * self.inverse_loss) 206 | self.update_batch = optimizer.minimize(self.loss) 207 | -------------------------------------------------------------------------------- /python/unitytrainers/trainer.py: -------------------------------------------------------------------------------- 1 | # # Unity ML-Agents Toolkit 2 | import logging 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | from unityagents import UnityException, AllBrainInfo 8 | 9 | logger = logging.getLogger("unityagents") 10 | 11 | 12 | class UnityTrainerException(UnityException): 13 | """ 14 | Related to errors with the Trainer. 15 | """ 16 | pass 17 | 18 | 19 | class Trainer(object): 20 | """This class is the abstract class for the unitytrainers""" 21 | 22 | def __init__(self, sess, env, brain_name, trainer_parameters, training): 23 | """ 24 | Responsible for collecting experiences and training a neural network model. 25 | :param sess: Tensorflow session. 26 | :param env: The UnityEnvironment. 27 | :param trainer_parameters: The parameters for the trainer (dictionary). 28 | :param training: Whether the trainer is set for training. 29 | """ 30 | self.brain_name = brain_name 31 | self.brain = env.brains[self.brain_name] 32 | self.trainer_parameters = trainer_parameters 33 | self.is_training = training 34 | self.sess = sess 35 | self.stats = {} 36 | self.summary_writer = None 37 | 38 | def __str__(self): 39 | return '''Empty Trainer''' 40 | 41 | @property 42 | def parameters(self): 43 | """ 44 | Returns the trainer parameters of the trainer. 45 | """ 46 | raise UnityTrainerException("The parameters property was not implemented.") 47 | 48 | @property 49 | def graph_scope(self): 50 | """ 51 | Returns the graph scope of the trainer. 52 | """ 53 | raise UnityTrainerException("The graph_scope property was not implemented.") 54 | 55 | @property 56 | def get_max_steps(self): 57 | """ 58 | Returns the maximum number of steps. Is used to know when the trainer should be stopped. 59 | :return: The maximum number of steps of the trainer 60 | """ 61 | raise UnityTrainerException("The get_max_steps property was not implemented.") 62 | 63 | @property 64 | def get_step(self): 65 | """ 66 | Returns the number of steps the trainer has performed 67 | :return: the step count of the trainer 68 | """ 69 | raise UnityTrainerException("The get_step property was not implemented.") 70 | 71 | @property 72 | def get_last_reward(self): 73 | """ 74 | Returns the last reward the trainer has had 75 | :return: the new last reward 76 | """ 77 | raise UnityTrainerException("The get_last_reward property was not implemented.") 78 | 79 | def increment_step_and_update_last_reward(self): 80 | """ 81 | Increment the step count of the trainer and updates the last reward 82 | """ 83 | raise UnityTrainerException("The increment_step_and_update_last_reward method was not implemented.") 84 | 85 | def take_action(self, all_brain_info: AllBrainInfo): 86 | """ 87 | Decides actions given state/observation information, and takes them in environment. 88 | :param all_brain_info: A dictionary of brain names and BrainInfo from environment. 89 | :return: a tuple containing action, memories, values and an object 90 | to be passed to add experiences 91 | """ 92 | raise UnityTrainerException("The take_action method was not implemented.") 93 | 94 | def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take_action_outputs): 95 | """ 96 | Adds experiences to each agent's experience history. 97 | :param curr_info: Current AllBrainInfo. 98 | :param next_info: Next AllBrainInfo. 99 | :param take_action_outputs: The outputs of the take action method. 100 | """ 101 | raise UnityTrainerException("The add_experiences method was not implemented.") 102 | 103 | def process_experiences(self, current_info: AllBrainInfo, next_info: AllBrainInfo): 104 | """ 105 | Checks agent histories for processing condition, and processes them as necessary. 106 | Processing involves calculating value and advantage targets for model updating step. 107 | :param current_info: Dictionary of all current-step brains and corresponding BrainInfo. 108 | :param next_info: Dictionary of all next-step brains and corresponding BrainInfo. 109 | """ 110 | raise UnityTrainerException("The process_experiences method was not implemented.") 111 | 112 | def end_episode(self): 113 | """ 114 | A signal that the Episode has ended. The buffer must be reset. 115 | Get only called when the academy resets. 116 | """ 117 | raise UnityTrainerException("The end_episode method was not implemented.") 118 | 119 | def is_ready_update(self): 120 | """ 121 | Returns whether or not the trainer has enough elements to run update model 122 | :return: A boolean corresponding to wether or not update_model() can be run 123 | """ 124 | raise UnityTrainerException("The is_ready_update method was not implemented.") 125 | 126 | def update_model(self): 127 | """ 128 | Uses training_buffer to update model. 129 | """ 130 | raise UnityTrainerException("The update_model method was not implemented.") 131 | 132 | def write_summary(self, lesson_number): 133 | """ 134 | Saves training statistics to Tensorboard. 135 | :param lesson_number: The lesson the trainer is at. 136 | """ 137 | if (self.get_step % self.trainer_parameters['summary_freq'] == 0 and self.get_step != 0 and 138 | self.is_training and self.get_step <= self.get_max_steps): 139 | if len(self.stats['cumulative_reward']) > 0: 140 | mean_reward = np.mean(self.stats['cumulative_reward']) 141 | logger.info(" {}: Step: {}. Mean Reward: {:0.3f}. Std of Reward: {:0.3f}." 142 | .format(self.brain_name, self.get_step, 143 | mean_reward, np.std(self.stats['cumulative_reward']))) 144 | else: 145 | logger.info(" {}: Step: {}. No episode was completed since last summary." 146 | .format(self.brain_name, self.get_step)) 147 | summary = tf.Summary() 148 | for key in self.stats: 149 | if len(self.stats[key]) > 0: 150 | stat_mean = float(np.mean(self.stats[key])) 151 | summary.value.add(tag='Info/{}'.format(key), simple_value=stat_mean) 152 | self.stats[key] = [] 153 | summary.value.add(tag='Info/Lesson', simple_value=lesson_number) 154 | self.summary_writer.add_summary(summary, self.get_step) 155 | self.summary_writer.flush() 156 | 157 | def write_tensorboard_text(self, key, input_dict): 158 | """ 159 | Saves text to Tensorboard. 160 | Note: Only works on tensorflow r1.2 or above. 161 | :param key: The name of the text. 162 | :param input_dict: A dictionary that will be displayed in a table on Tensorboard. 163 | """ 164 | try: 165 | s_op = tf.summary.text(key, tf.convert_to_tensor(([[str(x), str(input_dict[x])] for x in input_dict]))) 166 | s = self.sess.run(s_op) 167 | self.summary_writer.add_summary(s, self.get_step) 168 | except: 169 | logger.info("Cannot write text summary for Tensorboard. Tensorflow version must be r1.2 or above.") 170 | pass 171 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL) 3 | Code expanded and adapted from code examples provided by Udacity DRL Team, 2018. 4 | """ 5 | 6 | # Import Required Packages 7 | import torch 8 | import numpy as np 9 | import random 10 | from collections import namedtuple, deque 11 | 12 | # Determine if CPU or GPU computation should be used 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | 15 | class ReplayBuffer: 16 | """Fixed-size buffer to store experience tuples.""" 17 | 18 | def __init__(self, action_size, buffer_size, batch_size, seed): 19 | """Initialize a ReplayBuffer object. 20 | Params 21 | ====== 22 | buffer_size (int): maximum size of buffer 23 | batch_size (int): size of each training batch 24 | """ 25 | self.action_size = action_size 26 | self.memory = deque(maxlen=buffer_size) # internal memory (deque) 27 | self.batch_size = batch_size 28 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 29 | self.seed = random.seed(seed) 30 | 31 | def add(self, state, action, reward, next_state, done): 32 | """Add a new experience to memory.""" 33 | e = self.experience(state, action, reward, next_state, done) 34 | self.memory.append(e) 35 | 36 | def sample(self): 37 | """Randomly sample a batch of experiences from memory.""" 38 | experiences = random.sample(self.memory, k=self.batch_size) 39 | 40 | states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) 41 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).float().to(device) 42 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) 43 | next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) 44 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) 45 | 46 | return (states, actions, rewards, next_states, dones) 47 | 48 | def __len__(self): 49 | """Return the current size of internal memory.""" 50 | return len(self.memory) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Test DDPG Model for Unity ML-Agents Environments using PyTorch 4 | 5 | The example uses a modified version of the Unity ML-Agents Reacher Example Environment. 6 | The environment includes In this environment, a double-jointed arm can move to target locations. 7 | A reward of +0.1 is provided for each step that the agent's hand is in the goal location. 8 | Thus, the goal of your agent is to maintain its position at the target location for as many 9 | time steps as possible. 10 | 11 | Example Developed By: 12 | Michael Richardson, 2018 13 | Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL) 14 | Code Expanded and Adapted from Code provided by Udacity DRL Team, 2018. 15 | """ 16 | 17 | ################################### 18 | # Import Required Packages 19 | import torch 20 | import time 21 | import random 22 | import numpy as np 23 | from ddpg_agent import Agent 24 | from unityagents import UnityEnvironment 25 | 26 | """ 27 | ################################### 28 | STEP 1: Set the Test Parameters 29 | ====== 30 | num_episodes (int): number of test episodes 31 | """ 32 | num_episodes=5 33 | 34 | 35 | """ 36 | ################################### 37 | STEP 2: Start the Unity Environment 38 | # Use the corresponding call depending on your operating system 39 | """ 40 | env = UnityEnvironment(file_name="Reacher.app") 41 | # - **Mac**: "Banana_Mac/Reacher.app" 42 | # - **Windows** (x86): "Reacher_Windows_x86/Reacher.exe" 43 | # - **Windows** (x86_64): "Reacher_Windows_x86_64/Reacher.exe" 44 | # - **Linux** (x86): "Reacher_Linux/Reacher.x86" 45 | # - **Linux** (x86_64): "Reacher_Linux/Reacher.x86_64" 46 | # - **Linux** (x86, headless): "Reacher_Linux_NoVis/Reacher.x86" 47 | # - **Linux** (x86_64, headless): "Reacher_Linux_NoVis/Reacher.x86_64" 48 | 49 | """ 50 | ####################################### 51 | STEP 3: Get The Unity Environment Brian 52 | Unity ML-Agent applications or Environments contain "BRAINS" which are responsible for deciding 53 | the actions an agent or set of agents should take given a current set of environment (state) 54 | observations. The Reacher environment has a single Brian, thus, we just need to access the first brain 55 | available (i.e., the default brain). We then set the default brain as the brain that will be controlled. 56 | """ 57 | # Get the default brain 58 | brain_name = env.brain_names[0] 59 | 60 | # Assign the default brain as the brain to be controlled 61 | brain = env.brains[brain_name] 62 | 63 | 64 | """ 65 | ############################################# 66 | STEP 4: Determine the size of the Action and State Spaces and the Number of Agents 67 | 68 | The observation space consists of 33 variables corresponding to 69 | position, rotation, velocity, and angular velocities of the arm. 70 | Each action is a vector with four numbers, corresponding to torque 71 | applicable to two joints. Every entry in the action vector should 72 | be a number between -1 and 1. 73 | 74 | The reacher environment can contain multiple agents in the environment to increase training time. 75 | To use multiple (active) training agents we need to know how many there are. 76 | """ 77 | 78 | # Set the number of actions or action size 79 | action_size = brain.vector_action_space_size 80 | 81 | # Set the size of state observations or state size 82 | state_size = brain.vector_observation_space_size 83 | 84 | # Get number of agents in Environment 85 | env_info = env.reset(train_mode=True)[brain_name] 86 | num_agents = len(env_info.agents) 87 | print('\nNumber of Agents: ', num_agents) 88 | 89 | 90 | """ 91 | ################################### 92 | STEP 5: Initialize a DDPG Agent from the Agent Class in dqn_agent.py 93 | A DDPG agent initialized with the following parameters. 94 | ====== 95 | state_size (int): dimension of each state (required) 96 | action_size (int): dimension of each action (required) 97 | num_agents (int): number of agents in the unity environment 98 | seed (int): random seed for initializing training point (default = 0) 99 | 100 | Here we initialize an agent using the Unity environments state and action size and number of Agents 101 | determined above. 102 | """ 103 | #Initialize Agent 104 | agent = Agent(state_size=state_size, action_size=action_size, num_agents=num_agents, random_seed=0) 105 | 106 | # Load trained model weights 107 | agent.actor_local.load_state_dict(torch.load('ddpgActor_Model.pth')) 108 | agent.critic_local.load_state_dict(torch.load('ddpgCritic_Model.pth')) 109 | 110 | """ 111 | ################################### 112 | STEP 6: Play Banana for specified number of Episodes 113 | """ 114 | # loop from num_episodes 115 | for i_episode in range(1, num_episodes+1): 116 | 117 | # reset the unity environment at the beginning of each episode 118 | # set train mode to false 119 | env_info = env.reset(train_mode=False)[brain_name] 120 | 121 | # get initial state of the unity environment 122 | states = env_info.vector_observations 123 | 124 | # reset the training agent for new episode 125 | agent.reset() 126 | 127 | # set the initial episode scores to zero for each unity agent. 128 | scores = np.zeros(num_agents) 129 | 130 | # Run the episode loop; 131 | # At each loop step take an action as a function of the current state observations 132 | # If environment episode is done, exit loop... 133 | # Otherwise repeat until done == true 134 | while True: 135 | # determine actions for the unity agents from current sate 136 | actions = agent.act(states) 137 | 138 | # send the actions to the unity agents in the environment and receive resultant environment information 139 | env_info = env.step(actions)[brain_name] 140 | 141 | next_states = env_info.vector_observations # get the next states for each unity agent in the environment 142 | rewards = env_info.rewards # get the rewards for each unity agent in the environment 143 | dones = env_info.local_done # see if episode has finished for each unity agent in the environment 144 | 145 | # set new states to current states for determining next actions 146 | states = next_states 147 | 148 | # Update episode score for each unity agent 149 | scores += rewards 150 | 151 | # If any unity agent indicates that the episode is done, 152 | # then exit episode loop, to begin new episode 153 | if np.any(dones): 154 | break 155 | 156 | # Print current average score 157 | print('\nEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores), end="")) 158 | 159 | 160 | """ 161 | ################################### 162 | STEP 7: Everything is Finished -> Close the Environment. 163 | """ 164 | env.close() 165 | 166 | # END :) ############# 167 | 168 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | DDPG (Actor-Critic) RL Example for Unity ML-Agents Environments using PyTorch 4 | Includes examples of the following DDPG training algorithms: 5 | 6 | The example uses a modified version of the Unity ML-Agents Reacher Example Environment. 7 | The environment includes In this environment, a double-jointed arm can move to target locations. 8 | A reward of +0.1 is provided for each step that the agent's hand is in the goal location. 9 | Thus, the goal of your agent is to maintain its position at the target location for as many 10 | time steps as possible. 11 | 12 | Example Developed By: 13 | Michael Richardson, 2018 14 | Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL) 15 | Code Expanded and Adapted from Code provided by Udacity DRL Team, 2018. 16 | """ 17 | 18 | ################################### 19 | # Import Required Packages 20 | import torch 21 | import random 22 | import numpy as np 23 | from collections import deque 24 | from ddpg_agent import Agent 25 | from unityagents import UnityEnvironment 26 | 27 | """ 28 | ################################### 29 | STEP 1: Set the Training Parameters 30 | ====== 31 | num_episodes (int): maximum number of training episodes 32 | episode_scores (float): list to record the scores obtained from each episode 33 | scores_average_window (int): the window size employed for calculating the average score (e.g. 100) 34 | solved_score (float): the average score required for the environment to be considered solved 35 | """ 36 | num_episodes=500 37 | episode_scores = [] 38 | scores_average_window = 100 39 | solved_score = 30 40 | 41 | """ 42 | ################################### 43 | STEP 2: Start the Unity Environment 44 | # Use the corresponding call depending on your operating system 45 | """ 46 | env = UnityEnvironment(file_name="Reacher.app") 47 | # - **Mac**: "Banana_Mac/Reacher.app" 48 | # - **Windows** (x86): "Reacher_Windows_x86/Reacher.exe" 49 | # - **Windows** (x86_64): "Reacher_Windows_x86_64/Reacher.exe" 50 | # - **Linux** (x86): "Reacher_Linux/Reacher.x86" 51 | # - **Linux** (x86_64): "Reacher_Linux/Reacher.x86_64" 52 | # - **Linux** (x86, headless): "Reacher_Linux_NoVis/Reacher.x86" 53 | # - **Linux** (x86_64, headless): "Reacher_Linux_NoVis/Reacher.x86_64" 54 | 55 | """ 56 | ####################################### 57 | STEP 3: Get The Unity Environment Brian 58 | Unity ML-Agent applications or Environments contain "BRAINS" which are responsible for deciding 59 | the actions an agent or set of agents should take given a current set of environment (state) 60 | observations. The Reacher environment has a single Brian, thus, we just need to access the first brain 61 | available (i.e., the default brain). We then set the default brain as the brain that will be controlled. 62 | """ 63 | # Get the default brain 64 | brain_name = env.brain_names[0] 65 | 66 | # Assign the default brain as the brain to be controlled 67 | brain = env.brains[brain_name] 68 | 69 | 70 | """ 71 | ############################################# 72 | STEP 4: Determine the size of the Action and State Spaces and the Number of Agents 73 | 74 | The observation space consists of 33 variables corresponding to 75 | position, rotation, velocity, and angular velocities of the arm. 76 | Each action is a vector with four numbers, corresponding to torque 77 | applicable to two joints. Every entry in the action vector should 78 | be a number between -1 and 1. 79 | 80 | The reacher environment can contain multiple agents in the environment to increase training time. 81 | To use multiple (active) training agents we need to know how many there are. 82 | """ 83 | 84 | # Set the number of actions or action size 85 | action_size = brain.vector_action_space_size 86 | 87 | # Set the size of state observations or state size 88 | state_size = brain.vector_observation_space_size 89 | 90 | # Get number of agents in Environment 91 | env_info = env.reset(train_mode=True)[brain_name] 92 | num_agents = len(env_info.agents) 93 | print('\nNumber of Agents: ', num_agents) 94 | 95 | 96 | """ 97 | ################################### 98 | STEP 5: Create a DDPG Agent from the Agent Class in ddpg_agent.py 99 | A DDPG agent initialized with the following parameters. 100 | ====== 101 | state_size (int): dimension of each state (required) 102 | action_size (int): dimension of each action (required) 103 | num_agents (int): number of agents in the unity environment 104 | seed (int): random seed for initializing training point (default = 0) 105 | 106 | Here we initialize an agent using the Unity environments state and action size and number of Agents 107 | determined above. 108 | """ 109 | agent = Agent(state_size=state_size, action_size=action_size, num_agents=num_agents, random_seed=0) 110 | 111 | 112 | """ 113 | ################################### 114 | STEP 6: Run the DDPG Training Sequence 115 | The DDPG Training Process involves the agent learning from repeated episodes of behaviour 116 | to map states to actions the maximize rewards received via environmental interaction. 117 | 118 | The agent training process involves the following: 119 | (1) Reset the environment at the beginning of each episode. 120 | (2) Obtain (observe) current state, s, of the environment at time t 121 | (3) Perform an action, a(t), in the environment given s(t) 122 | (4) Observe the result of the action in terms of the reward received and 123 | the state of the environment at time t+1 (i.e., s(t+1)) 124 | (5) Update agent memory and learn from experience (i.e, agent.step) 125 | (6) Update episode score (total reward received) and set s(t) -> s(t+1). 126 | (7) If episode is done, break and repeat from (1), otherwise repeat from (3). 127 | 128 | Below we also exit the training process early if the environment is solved. 129 | That is, if the average score for the previous 100 episodes is greater than solved_score. 130 | """ 131 | 132 | # loop from num_episodes 133 | for i_episode in range(1, num_episodes+1): 134 | 135 | # reset the unity environment at the beginning of each episode 136 | env_info = env.reset(train_mode=True)[brain_name] 137 | 138 | # get initial state of the unity environment 139 | states = env_info.vector_observations 140 | 141 | # reset the training agent for new episode 142 | agent.reset() 143 | 144 | # set the initial episode score to zero. 145 | agent_scores = np.zeros(num_agents) 146 | 147 | # Run the episode training loop; 148 | # At each loop step take an action as a function of the current state observations 149 | # Based on the resultant environmental state (next_state) and reward received update the Agents Actor and Critic networks 150 | # If environment episode is done, exit loop... 151 | # Otherwise repeat until done == true 152 | while True: 153 | # determine actions for the unity agents from current sate 154 | actions = agent.act(states) 155 | 156 | # send the actions to the unity agents in the environment and receive resultant environment information 157 | env_info = env.step(actions)[brain_name] 158 | 159 | next_states = env_info.vector_observations # get the next states for each unity agent in the environment 160 | rewards = env_info.rewards # get the rewards for each unity agent in the environment 161 | dones = env_info.local_done # see if episode has finished for each unity agent in the environment 162 | 163 | #Send (S, A, R, S') info to the training agent for replay buffer (memory) and network updates 164 | agent.step(states, actions, rewards, next_states, dones) 165 | 166 | # set new states to current states for determining next actions 167 | states = next_states 168 | 169 | # Update episode score for each unity agent 170 | agent_scores += rewards 171 | 172 | # If any unity agent indicates that the episode is done, 173 | # then exit episode loop, to begin new episode 174 | if np.any(dones): 175 | break 176 | 177 | # Add episode score to Scores and... 178 | # Calculate mean score over last 100 episodes 179 | # Mean score is calculated over current episodes until i_episode > 100 180 | episode_scores.append(np.mean(agent_scores)) 181 | average_score = np.mean(episode_scores[i_episode-min(i_episode,scores_average_window):i_episode+1]) 182 | 183 | #Print current and average score 184 | print('\nEpisode {}\tEpisode Score: {:.3f}\tAverage Score: {:.3f}'.format(i_episode, episode_scores[i_episode-1], average_score), end="") 185 | 186 | # Save trained Actor and Critic network weights after each episode 187 | an_filename = "ddpgActor_Model.pth" 188 | torch.save(agent.actor_local.state_dict(), an_filename) 189 | cn_filename = "ddpgCritic_Model.pth" 190 | torch.save(agent.critic_local.state_dict(), cn_filename) 191 | 192 | # Check to see if the task is solved (i.e,. avearge_score > solved_score over 100 episodes). 193 | # If yes, save the network weights and scores and end training. 194 | if i_episode > 100 and average_score >= solved_score: 195 | print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.3f}'.format(i_episode, average_score)) 196 | 197 | # Save the recorded Scores data 198 | scores_filename = "ddpgAgent_Scores.csv" 199 | np.savetxt(scores_filename, episode_scores, delimiter=",") 200 | break 201 | 202 | 203 | """ 204 | ################################### 205 | STEP 7: Everything is Finished -> Close the Environment. 206 | """ 207 | env.close() 208 | 209 | # END :) ############# 210 | 211 | --------------------------------------------------------------------------------