├── LICENSE ├── README.md ├── Report.pdf ├── dqnAgent_Trained_Model.pth ├── dqn_agent.py ├── media ├── bananacollection.gif └── exampleTrainingScoresGraph.jpg ├── model.py ├── python ├── Basics.ipynb ├── README.md ├── communicator_objects │ ├── __init__.py │ ├── agent_action_proto_pb2.py │ ├── agent_info_proto_pb2.py │ ├── brain_parameters_proto_pb2.py │ ├── brain_type_proto_pb2.py │ ├── command_proto_pb2.py │ ├── engine_configuration_proto_pb2.py │ ├── environment_parameters_proto_pb2.py │ ├── header_pb2.py │ ├── resolution_proto_pb2.py │ ├── space_type_proto_pb2.py │ ├── unity_input_pb2.py │ ├── unity_message_pb2.py │ ├── unity_output_pb2.py │ ├── unity_rl_initialization_input_pb2.py │ ├── unity_rl_initialization_output_pb2.py │ ├── unity_rl_input_pb2.py │ ├── unity_rl_output_pb2.py │ ├── unity_to_external_pb2.py │ └── unity_to_external_pb2_grpc.py ├── curricula │ ├── push.json │ ├── test.json │ └── wall.json ├── learn.py ├── requirements.txt ├── setup.py ├── tests │ ├── __init__.py │ ├── mock_communicator.py │ ├── test_bc.py │ ├── test_ppo.py │ ├── test_unityagents.py │ └── test_unitytrainers.py ├── trainer_config.yaml ├── unityagents │ ├── __init__.py │ ├── brain.py │ ├── communicator.py │ ├── curriculum.py │ ├── environment.py │ ├── exception.py │ ├── rpc_communicator.py │ └── socket_communicator.py └── unitytrainers │ ├── __init__.py │ ├── bc │ ├── __init__.py │ ├── models.py │ └── trainer.py │ ├── buffer.py │ ├── models.py │ ├── ppo │ ├── __init__.py │ ├── models.py │ └── trainer.py │ ├── trainer.py │ └── trainer_controller.py ├── replay_memory.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Deep Q-Network (DQN) Reinforcement Learning using PyTorch and Unity ML-Agents 2 | A simple example of how to implement vector based DQN using PyTorch and a ML-Agents environment. 3 | Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL) 4 | 5 | The repository includes the following DQN related files: 6 | - dqn_agent.py -> dqn-agent implementation 7 | - replay_memory.py -> dqn-agent's replay buffer implementation 8 | - model.py -> example PyTorch neural network for vector based DQN learning 9 | - train.py -> initializes and implements the training processes for a DQN-agent. 10 | - test.py -> testes a trained DQN-agent 11 | 12 | Code expanded and adapted from code examples provided by Udacity DRL Team, 2018. 13 | 14 | The repository also includes Mac/Linux/Windows versions of a simple Unity environment, *Banana*, for testing. 15 | This Unity application and testing environment was developed using ML-Agents Beta v0.4. The version of the Banana environment employed for this project was developed for the Udacity Deep Reinforcement Nanodegree course. For more information about this course visit: https://www.udacity.com/course/deep-reinforcement-learning-nanodegree--nd893 16 | 17 | The files in the python/. directory are the ML-Agents toolkit files and dependencies required to run the Banana environment. 18 | For more information about the Unity ML-Agents Toolkit visit: https://github.com/Unity-Technologies/ml-agents 19 | 20 | For further details about DQN see: Mnih, V., Kavukcuoglu, K., Silver, D., Rusu, A. A., Veness, J., Bellemare, M. G., ... & Petersen, S. (2015). Human-level control through deep reinforcement learning. Nature, 518(7540), 529. 21 | 22 | ## Example Unity Environment - Banana's 23 | The example uses a modified version of the Unity ML-Agents Banana Collection example Environment. 24 | The environment includes a single agent, who can turn left or right and move forward or backward. 25 | The agent's task is to collect yellow bananas (reward of +1) that are scattered around a square 26 | game area, while avoiding purple bananas (reward of -1). For the version of Bananas employed here, 27 | the environment is considered solved when the average score over the last 100 episodes > 13. 28 | 29 | ![Trained DQN-Agent Collecting Yellows Banana's](media/bananacollection.gif) 30 | 31 | ### Action Space 32 | At each time step, the agent can perform four possible actions: 33 | - `0` - walk forward 34 | - `1` - walk backward 35 | - `2` - turn left 36 | - `3` - turn right 37 | 38 | ### State Spaces 39 | The agent is trained from vector input data (not pixel input data) 40 | The state space has `37` dimensions and contains the agent's velocity, along with ray-based perception of objects around agent's forward direction. A reward of `+1` is provided for collecting a yellow banana, and a reward of `-1` is provided for collecting a purple banana. 41 | 42 | ## Installation and Dependencies 43 | 1. Anaconda Python 3.6: Download and installation instructions here: https://www.anaconda.com/download/ 44 | 45 | 2. Create (and activate) a new conda (virtual) environment with Python 3.6. 46 | - Linux or Mac: 47 | 48 | `conda create --name yourenvnamehere python=3.6` 49 | 50 | `source activate yourenvnamehere` 51 | 52 | - Windows: 53 | 54 | `conda create --name yourenvnamehere python=3.6` 55 | 56 | `activate yourenvnamehere` 57 | 58 | 3. Download and save this GitHub repository. 59 | 60 | 4. To install required dependencies (torch, ML-Agents trainers (v.4), etc...) 61 | - Naviagte to where you downloaded and saved this GitHub repository (e.g., *yourpath/thisgithubrepository*) 62 | - Change to the '.python/' subdirectory and run from the command line: 63 | 64 | `pip3 install .` 65 | 66 | ## Download the Unity Environment 67 | For this example project, you will not need to install Unity - this is because you can use a version of the Banana's unity environment that is already built (compiled) as a standalone application. 68 | 69 | Download the relevant environment zip file from one of the links below. You need only to download the environment that matches your operating system: 70 | 71 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Linux.zip) 72 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana.app.zip) 73 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Windows_x86.zip) 74 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Windows_x86_64.zip) 75 | 76 | After you have downloaded the relevant zip file, navigate to where you downloaded and saved this GitHub repository and place the file in the main folder of the repository, then unzip (or decompress) the file. 77 | 78 | NOTE: The Banana environment is similar to, but not identical to the Banana Collector environment on the Unity ML-Agents GitHub page. 79 | 80 | ## Training 81 | - activate the conda environment you created above 82 | - change the directory to the 'yourpath/thisgithubrepository' directory. 83 | - open `train.py`, find STEP 2 (lines 55 to 65) and set the relevant version of Banana to match your operating system. 84 | - run the following command: 85 | 86 | `python train.py` 87 | 88 | - training will complete once the agent reaches *solved_score* in `train.py`. 89 | - after training a *dqnAgent_Trained_Model_datetime.path* file will be saved with the trained model weights 90 | - a *dqnAgent_scores_datetime.csv* file will also be saved with the scores received during training. You can use this file to plot or assess training performance (see below figure). 91 | - It is recommended that you train multiple agents and test different hyperparameter settings in `train.py` and `dqn_agent.py`. 92 | - For more information about the DQN training algorithm and the training hyperparameters see the included `Report.pdf` file. 93 | 94 | ![Example of agent performance (score) as a function of training episodes](media/exampleTrainingScoresGraph.jpg) 95 | 96 | 97 | ## Testing 98 | - activate the conda environment you created above 99 | - change the directory to the 'yourpath/thisgithubrepository' directory. 100 | - run the following command: 101 | 102 | `python test.py` 103 | 104 | - An example model weights file is included in the repository (*dqnAgent_Trained_Model.pth*). 105 | - A different model weights file can be tested by changing the model file name defined in `test.py` on line 109. 106 | -------------------------------------------------------------------------------- /Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkiwilabs/DQN-using-PyTorch-and-ML-Agents/8bd47f7c845bbbab2cbb34d717dd08c8b7a50aab/Report.pdf -------------------------------------------------------------------------------- /dqnAgent_Trained_Model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkiwilabs/DQN-using-PyTorch-and-ML-Agents/8bd47f7c845bbbab2cbb34d717dd08c8b7a50aab/dqnAgent_Trained_Model.pth -------------------------------------------------------------------------------- /dqn_agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | DQN Agent for Vector Observation Learning 3 | 4 | Example Developed By: 5 | Michael Richardson, 2018 6 | Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL) 7 | Code expanded and adapted from code examples provided by Udacity DRL Team, 2018. 8 | """ 9 | 10 | # Import Required Packages 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | 15 | import numpy as np 16 | import random 17 | from collections import namedtuple, deque 18 | 19 | from model import QNetwork 20 | from replay_memory import ReplayBuffer 21 | 22 | # Determine if CPU or GPU computation should be used 23 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 24 | 25 | 26 | """ 27 | ################################################## 28 | Agent Class 29 | Defines DQN Agent Methods 30 | Agent interacts with and learns from an environment. 31 | """ 32 | class Agent(): 33 | 34 | """ 35 | Initialize Agent, inclduing: 36 | DQN Hyperparameters 37 | Local and Targat State-Action Policy Networks 38 | Replay Memory Buffer from Replay Buffer Class (define below) 39 | """ 40 | def __init__(self, state_size, action_size, dqn_type='DQN', replay_memory_size=1e5, batch_size=64, gamma=0.99, 41 | learning_rate=1e-3, target_tau=2e-3, update_rate=4, seed=0): 42 | 43 | """ 44 | DQN Agent Parameters 45 | ====== 46 | state_size (int): dimension of each state 47 | action_size (int): dimension of each action 48 | dqn_type (string): can be either 'DQN' for vanillia dqn learning (default) or 'DDQN' for double-DQN. 49 | replay_memory size (int): size of the replay memory buffer (typically 5e4 to 5e6) 50 | batch_size (int): size of the memory batch used for model updates (typically 32, 64 or 128) 51 | gamma (float): paramete for setting the discoun ted value of future rewards (typically .95 to .995) 52 | learning_rate (float): specifies the rate of model learing (typically 1e-4 to 1e-3)) 53 | seed (int): random seed for initializing training point. 54 | """ 55 | self.dqn_type = dqn_type 56 | self.state_size = state_size 57 | self.action_size = action_size 58 | self.buffer_size = int(replay_memory_size) 59 | self.batch_size = batch_size 60 | self.gamma = gamma 61 | self.learn_rate = learning_rate 62 | self.tau = target_tau 63 | self.update_rate = update_rate 64 | self.seed = random.seed(seed) 65 | 66 | """ 67 | # DQN Agent Q-Network 68 | # For DQN training, two nerual network models are employed; 69 | # (a) A network that is updated every (step % update_rate == 0) 70 | # (b) A target network, with weights updated to equal the network at a slower (target_tau) rate. 71 | # The slower modulation of the target network weights operates to stablize learning. 72 | """ 73 | self.network = QNetwork(state_size, action_size, seed).to(device) 74 | self.target_network = QNetwork(state_size, action_size, seed).to(device) 75 | self.optimizer = optim.Adam(self.network.parameters(), lr=self.learn_rate) 76 | 77 | # Replay memory 78 | self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, seed) 79 | 80 | # Initialize time step (for updating every UPDATE_EVERY steps) 81 | self.t_step = 0 82 | 83 | 84 | ######################################################## 85 | # STEP() method 86 | # 87 | def step(self, state, action, reward, next_state, done): 88 | # Save experience in replay memory 89 | self.memory.add(state, action, reward, next_state, done) 90 | 91 | # Learn every UPDATE_EVERY time steps. 92 | self.t_step = (self.t_step + 1) % self.update_rate 93 | if self.t_step == 0: 94 | # If enough samples are available in memory, get random subset and learn 95 | if len(self.memory) > self.batch_size: 96 | experiences = self.memory.sample() 97 | self.learn(experiences, self.gamma) 98 | 99 | 100 | ######################################################## 101 | # ACT() method 102 | # 103 | def act(self, state, eps=0.0): 104 | """Returns actions for given state as per current policy. 105 | 106 | Params 107 | ====== 108 | state (array_like): current state 109 | eps (float): epsilon, for epsilon-greedy action selection 110 | """ 111 | state = torch.from_numpy(state).float().unsqueeze(0).to(device) 112 | self.network.eval() 113 | with torch.no_grad(): 114 | action_values = self.network(state) 115 | self.network.train() 116 | 117 | # Epsilon-greedy action selection 118 | if random.random() > eps: 119 | return np.argmax(action_values.cpu().data.numpy()) 120 | else: 121 | return random.choice(np.arange(self.action_size)) 122 | 123 | 124 | ######################################################## 125 | # LEARN() method 126 | # Update value parameters using given batch of experience tuples. 127 | def learn(self, experiences, gamma, DQN=True): 128 | 129 | """ 130 | Params 131 | ====== 132 | experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples 133 | gamma (float): discount factor 134 | """ 135 | 136 | states, actions, rewards, next_states, dones = experiences 137 | 138 | # Get Q values from current observations (s, a) using model nextwork 139 | Qsa = self.network(states).gather(1, actions) 140 | 141 | 142 | if (self.dqn_type == 'DDQN'): 143 | #Double DQN 144 | #************************ 145 | Qsa_prime_actions = self.network(next_states).detach().max(1)[1].unsqueeze(1) 146 | Qsa_prime_targets = self.target_network(next_states)[Qsa_prime_actions].unsqueeze(1) 147 | 148 | else: 149 | #Regular (Vanilla) DQN 150 | #************************ 151 | # Get max Q values for (s',a') from target model 152 | Qsa_prime_target_values = self.target_network(next_states).detach() 153 | Qsa_prime_targets = Qsa_prime_target_values.max(1)[0].unsqueeze(1) 154 | 155 | 156 | # Compute Q targets for current states 157 | Qsa_targets = rewards + (gamma * Qsa_prime_targets * (1 - dones)) 158 | 159 | # Compute loss (error) 160 | loss = F.mse_loss(Qsa, Qsa_targets) 161 | 162 | # Minimize the loss 163 | self.optimizer.zero_grad() 164 | loss.backward() 165 | self.optimizer.step() 166 | 167 | # ------------------- update target network ------------------- # 168 | self.soft_update(self.network, self.target_network, self.tau) 169 | 170 | 171 | ######################################################## 172 | """ 173 | Soft update model parameters. 174 | θ_target = τ*θ_local + (1 - τ)*θ_target 175 | """ 176 | def soft_update(self, local_model, target_model, tau): 177 | """ 178 | Params 179 | ====== 180 | local_model (PyTorch model): weights will be copied from 181 | target_model (PyTorch model): weights will be copied to 182 | tau (float): interpolation parameter 183 | """ 184 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): 185 | target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data) 186 | -------------------------------------------------------------------------------- /media/bananacollection.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkiwilabs/DQN-using-PyTorch-and-ML-Agents/8bd47f7c845bbbab2cbb34d717dd08c8b7a50aab/media/bananacollection.gif -------------------------------------------------------------------------------- /media/exampleTrainingScoresGraph.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkiwilabs/DQN-using-PyTorch-and-ML-Agents/8bd47f7c845bbbab2cbb34d717dd08c8b7a50aab/media/exampleTrainingScoresGraph.jpg -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example Neural Network Model for Vector Observation DQN Agent 3 | DQN Model for Unity ML-Agents Environments using PyTorch 4 | 5 | Example Developed By: 6 | Michael Richardson, 2018 7 | Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL) 8 | Code expanded and adapted from code examples provided by Udacity DRL Team, 2018. 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | class QNetwork(nn.Module): 16 | 17 | """ 18 | ################################################# 19 | Initialize neural network model 20 | Initialize parameters and build model. 21 | """ 22 | def __init__(self, state_size, action_size, seed, fc1_units=128, fc2_units=128): 23 | """ 24 | Params 25 | ====== 26 | state_size (int): Dimension of each state 27 | action_size (int): Dimension of each action 28 | seed (int): Random seed 29 | fc1_units (int): Number of nodes in first hidden layer 30 | fc2_units (int): Number of nodes in second hidden layer 31 | """ 32 | super(QNetwork, self).__init__() 33 | self.seed = torch.manual_seed(seed) 34 | self.fc1 = nn.Linear(state_size, fc1_units) 35 | self.fc2 = nn.Linear(fc1_units, fc2_units) 36 | self.fc3 = nn.Linear(fc2_units, action_size) 37 | 38 | 39 | """ 40 | ################################################### 41 | Build a network that maps state -> action values. 42 | """ 43 | def forward(self, state): 44 | 45 | x = F.relu(self.fc1(state)) 46 | x = F.relu(self.fc2(x)) 47 | return self.fc3(x) 48 | -------------------------------------------------------------------------------- /python/Basics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Unity ML-Agents Toolkit\n", 8 | "## Environment Basics\n", 9 | "This notebook contains a walkthrough of the basic functions of the Python API for the Unity ML-Agents toolkit. For instructions on building a Unity environment, see [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Getting-Started-with-Balance-Ball.md)." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "### 1. Set environment parameters\n", 17 | "\n", 18 | "Be sure to set `env_name` to the name of the Unity environment file you want to launch. Ensure that the environment build is in the `python/` directory." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": { 25 | "collapsed": true 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "env_name = \"3DBall\" # Name of the Unity environment binary to launch\n", 30 | "train_mode = True # Whether to run the environment in training or inference mode" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "### 2. Load dependencies\n", 38 | "\n", 39 | "The following loads the necessary dependencies and checks the Python version (at runtime). ML-Agents Toolkit (v0.3 onwards) requires Python 3." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "import matplotlib.pyplot as plt\n", 51 | "import numpy as np\n", 52 | "import sys\n", 53 | "\n", 54 | "from unityagents import UnityEnvironment\n", 55 | "\n", 56 | "%matplotlib inline\n", 57 | "\n", 58 | "print(\"Python version:\")\n", 59 | "print(sys.version)\n", 60 | "\n", 61 | "# check Python version\n", 62 | "if (sys.version_info[0] < 3):\n", 63 | " raise Exception(\"ERROR: ML-Agents Toolkit (v0.3 onwards) requires Python 3\")" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "### 3. Start the environment\n", 71 | "`UnityEnvironment` launches and begins communication with the environment when instantiated.\n", 72 | "\n", 73 | "Environments contain _brains_ which are responsible for deciding the actions of their associated _agents_. Here we check for the first brain available, and set it as the default brain we will be controlling from Python." 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "collapsed": true 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "env = UnityEnvironment(file_name=env_name)\n", 85 | "\n", 86 | "# Examine environment parameters\n", 87 | "print(str(env))\n", 88 | "\n", 89 | "# Set the default brain to work with\n", 90 | "default_brain = env.brain_names[0]\n", 91 | "brain = env.brains[default_brain]" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "### 4. Examine the observation and state spaces\n", 99 | "We can reset the environment to be provided with an initial set of observations and states for all the agents within the environment. In ML-Agents, _states_ refer to a vector of variables corresponding to relevant aspects of the environment for an agent. Likewise, _observations_ refer to a set of relevant pixel-wise visuals for an agent." 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "collapsed": true 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "# Reset the environment\n", 111 | "env_info = env.reset(train_mode=train_mode)[default_brain]\n", 112 | "\n", 113 | "# Examine the state space for the default brain\n", 114 | "print(\"Agent state looks like: \\n{}\".format(env_info.vector_observations[0]))\n", 115 | "\n", 116 | "# Examine the observation space for the default brain\n", 117 | "for observation in env_info.visual_observations:\n", 118 | " print(\"Agent observations look like:\")\n", 119 | " if observation.shape[3] == 3:\n", 120 | " plt.imshow(observation[0,:,:,:])\n", 121 | " else:\n", 122 | " plt.imshow(observation[0,:,:,0])" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "### 5. Take random actions in the environment\n", 130 | "Once we restart an environment, we can step the environment forward and provide actions to all of the agents within the environment. Here we simply choose random actions based on the `action_space_type` of the default brain. \n", 131 | "\n", 132 | "Once this cell is executed, 10 messages will be printed that detail how much reward will be accumulated for the next 10 episodes. The Unity environment will then pause, waiting for further signals telling it what to do next. Thus, not seeing any animation is expected when running this cell." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "collapsed": true 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "for episode in range(10):\n", 144 | " env_info = env.reset(train_mode=train_mode)[default_brain]\n", 145 | " done = False\n", 146 | " episode_rewards = 0\n", 147 | " while not done:\n", 148 | " if brain.vector_action_space_type == 'continuous':\n", 149 | " env_info = env.step(np.random.randn(len(env_info.agents), \n", 150 | " brain.vector_action_space_size))[default_brain]\n", 151 | " else:\n", 152 | " env_info = env.step(np.random.randint(0, brain.vector_action_space_size, \n", 153 | " size=(len(env_info.agents))))[default_brain]\n", 154 | " episode_rewards += env_info.rewards[0]\n", 155 | " done = env_info.local_done[0]\n", 156 | " print(\"Total reward this episode: {}\".format(episode_rewards))" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "### 6. Close the environment when finished\n", 164 | "When we are finished using an environment, we can close it with the function below." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "collapsed": true 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "env.close()" 176 | ] 177 | } 178 | ], 179 | "metadata": { 180 | "anaconda-cloud": {}, 181 | "kernelspec": { 182 | "display_name": "Python 3", 183 | "language": "python", 184 | "name": "python3" 185 | }, 186 | "language_info": { 187 | "codemirror_mode": { 188 | "name": "ipython", 189 | "version": 3 190 | }, 191 | "file_extension": ".py", 192 | "mimetype": "text/x-python", 193 | "name": "python", 194 | "nbconvert_exporter": "python", 195 | "pygments_lexer": "ipython3", 196 | "version": "3.6.3" 197 | } 198 | }, 199 | "nbformat": 4, 200 | "nbformat_minor": 1 201 | } 202 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | # Dependencies 2 | 3 | This is an amended version of the `python/` folder from the [ML-Agents repository](https://github.com/Unity-Technologies/ml-agents). It has been edited to include a few additional pip packages needed for the Deep Reinforcement Learning Nanodegree program. 4 | -------------------------------------------------------------------------------- /python/communicator_objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent_action_proto_pb2 import * 2 | from .agent_info_proto_pb2 import * 3 | from .brain_parameters_proto_pb2 import * 4 | from .brain_type_proto_pb2 import * 5 | from .command_proto_pb2 import * 6 | from .engine_configuration_proto_pb2 import * 7 | from .environment_parameters_proto_pb2 import * 8 | from .header_pb2 import * 9 | from .resolution_proto_pb2 import * 10 | from .space_type_proto_pb2 import * 11 | from .unity_input_pb2 import * 12 | from .unity_message_pb2 import * 13 | from .unity_output_pb2 import * 14 | from .unity_rl_initialization_input_pb2 import * 15 | from .unity_rl_initialization_output_pb2 import * 16 | from .unity_rl_input_pb2 import * 17 | from .unity_rl_output_pb2 import * 18 | from .unity_to_external_pb2 import * 19 | from .unity_to_external_pb2_grpc import * 20 | -------------------------------------------------------------------------------- /python/communicator_objects/agent_action_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/agent_action_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/agent_action_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n-communicator_objects/agent_action_proto.proto\x12\x14\x63ommunicator_objects\"R\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _AGENTACTIONPROTO = _descriptor.Descriptor( 29 | name='AgentActionProto', 30 | full_name='communicator_objects.AgentActionProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='vector_actions', full_name='communicator_objects.AgentActionProto.vector_actions', index=0, 37 | number=1, type=2, cpp_type=6, label=3, 38 | has_default_value=False, default_value=[], 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='text_actions', full_name='communicator_objects.AgentActionProto.text_actions', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='memories', full_name='communicator_objects.AgentActionProto.memories', index=2, 51 | number=3, type=2, cpp_type=6, label=3, 52 | has_default_value=False, default_value=[], 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto3', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=71, 69 | serialized_end=153, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['AgentActionProto'] = _AGENTACTIONPROTO 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | AgentActionProto = _reflection.GeneratedProtocolMessageType('AgentActionProto', (_message.Message,), dict( 76 | DESCRIPTOR = _AGENTACTIONPROTO, 77 | __module__ = 'communicator_objects.agent_action_proto_pb2' 78 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentActionProto) 79 | )) 80 | _sym_db.RegisterMessage(AgentActionProto) 81 | 82 | 83 | DESCRIPTOR.has_options = True 84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /python/communicator_objects/agent_info_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/agent_info_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/agent_info_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n+communicator_objects/agent_info_proto.proto\x12\x14\x63ommunicator_objects\"\xfd\x01\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x1b\n\x13visual_observations\x18\x02 \x03(\x0c\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x10\n\x08memories\x18\x06 \x03(\x02\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _AGENTINFOPROTO = _descriptor.Descriptor( 29 | name='AgentInfoProto', 30 | full_name='communicator_objects.AgentInfoProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='stacked_vector_observation', full_name='communicator_objects.AgentInfoProto.stacked_vector_observation', index=0, 37 | number=1, type=2, cpp_type=6, label=3, 38 | has_default_value=False, default_value=[], 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='visual_observations', full_name='communicator_objects.AgentInfoProto.visual_observations', index=1, 44 | number=2, type=12, cpp_type=9, label=3, 45 | has_default_value=False, default_value=[], 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='text_observation', full_name='communicator_objects.AgentInfoProto.text_observation', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='stored_vector_actions', full_name='communicator_objects.AgentInfoProto.stored_vector_actions', index=3, 58 | number=4, type=2, cpp_type=6, label=3, 59 | has_default_value=False, default_value=[], 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None, file=DESCRIPTOR), 63 | _descriptor.FieldDescriptor( 64 | name='stored_text_actions', full_name='communicator_objects.AgentInfoProto.stored_text_actions', index=4, 65 | number=5, type=9, cpp_type=9, label=1, 66 | has_default_value=False, default_value=_b("").decode('utf-8'), 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None, file=DESCRIPTOR), 70 | _descriptor.FieldDescriptor( 71 | name='memories', full_name='communicator_objects.AgentInfoProto.memories', index=5, 72 | number=6, type=2, cpp_type=6, label=3, 73 | has_default_value=False, default_value=[], 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None, file=DESCRIPTOR), 77 | _descriptor.FieldDescriptor( 78 | name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=6, 79 | number=7, type=2, cpp_type=6, label=1, 80 | has_default_value=False, default_value=float(0), 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None, file=DESCRIPTOR), 84 | _descriptor.FieldDescriptor( 85 | name='done', full_name='communicator_objects.AgentInfoProto.done', index=7, 86 | number=8, type=8, cpp_type=7, label=1, 87 | has_default_value=False, default_value=False, 88 | message_type=None, enum_type=None, containing_type=None, 89 | is_extension=False, extension_scope=None, 90 | options=None, file=DESCRIPTOR), 91 | _descriptor.FieldDescriptor( 92 | name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=8, 93 | number=9, type=8, cpp_type=7, label=1, 94 | has_default_value=False, default_value=False, 95 | message_type=None, enum_type=None, containing_type=None, 96 | is_extension=False, extension_scope=None, 97 | options=None, file=DESCRIPTOR), 98 | _descriptor.FieldDescriptor( 99 | name='id', full_name='communicator_objects.AgentInfoProto.id', index=9, 100 | number=10, type=5, cpp_type=1, label=1, 101 | has_default_value=False, default_value=0, 102 | message_type=None, enum_type=None, containing_type=None, 103 | is_extension=False, extension_scope=None, 104 | options=None, file=DESCRIPTOR), 105 | ], 106 | extensions=[ 107 | ], 108 | nested_types=[], 109 | enum_types=[ 110 | ], 111 | options=None, 112 | is_extendable=False, 113 | syntax='proto3', 114 | extension_ranges=[], 115 | oneofs=[ 116 | ], 117 | serialized_start=70, 118 | serialized_end=323, 119 | ) 120 | 121 | DESCRIPTOR.message_types_by_name['AgentInfoProto'] = _AGENTINFOPROTO 122 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 123 | 124 | AgentInfoProto = _reflection.GeneratedProtocolMessageType('AgentInfoProto', (_message.Message,), dict( 125 | DESCRIPTOR = _AGENTINFOPROTO, 126 | __module__ = 'communicator_objects.agent_info_proto_pb2' 127 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentInfoProto) 128 | )) 129 | _sym_db.RegisterMessage(AgentInfoProto) 130 | 131 | 132 | DESCRIPTOR.has_options = True 133 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 134 | # @@protoc_insertion_point(module_scope) 135 | -------------------------------------------------------------------------------- /python/communicator_objects/brain_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/brain_parameters_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 17 | from communicator_objects import brain_type_proto_pb2 as communicator__objects_dot_brain__type__proto__pb2 18 | from communicator_objects import space_type_proto_pb2 as communicator__objects_dot_space__type__proto__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='communicator_objects/brain_parameters_proto.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n1communicator_objects/brain_parameters_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto\x1a+communicator_objects/brain_type_proto.proto\x1a+communicator_objects/space_type_proto.proto\"\xc6\x03\n\x14\x42rainParametersProto\x12\x1f\n\x17vector_observation_size\x18\x01 \x01(\x05\x12\'\n\x1fnum_stacked_vector_observations\x18\x02 \x01(\x05\x12\x1a\n\x12vector_action_size\x18\x03 \x01(\x05\x12\x41\n\x12\x63\x61mera_resolutions\x18\x04 \x03(\x0b\x32%.communicator_objects.ResolutionProto\x12\"\n\x1avector_action_descriptions\x18\x05 \x03(\t\x12\x46\n\x18vector_action_space_type\x18\x06 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12K\n\x1dvector_observation_space_type\x18\x07 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12\x12\n\nbrain_name\x18\x08 \x01(\t\x12\x38\n\nbrain_type\x18\t \x01(\x0e\x32$.communicator_objects.BrainTypeProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,communicator__objects_dot_brain__type__proto__pb2.DESCRIPTOR,communicator__objects_dot_space__type__proto__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _BRAINPARAMETERSPROTO = _descriptor.Descriptor( 33 | name='BrainParametersProto', 34 | full_name='communicator_objects.BrainParametersProto', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='vector_observation_size', full_name='communicator_objects.BrainParametersProto.vector_observation_size', index=0, 41 | number=1, type=5, cpp_type=1, label=1, 42 | has_default_value=False, default_value=0, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='num_stacked_vector_observations', full_name='communicator_objects.BrainParametersProto.num_stacked_vector_observations', index=1, 48 | number=2, type=5, cpp_type=1, label=1, 49 | has_default_value=False, default_value=0, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='vector_action_size', full_name='communicator_objects.BrainParametersProto.vector_action_size', index=2, 55 | number=3, type=5, cpp_type=1, label=1, 56 | has_default_value=False, default_value=0, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None, file=DESCRIPTOR), 60 | _descriptor.FieldDescriptor( 61 | name='camera_resolutions', full_name='communicator_objects.BrainParametersProto.camera_resolutions', index=3, 62 | number=4, type=11, cpp_type=10, label=3, 63 | has_default_value=False, default_value=[], 64 | message_type=None, enum_type=None, containing_type=None, 65 | is_extension=False, extension_scope=None, 66 | options=None, file=DESCRIPTOR), 67 | _descriptor.FieldDescriptor( 68 | name='vector_action_descriptions', full_name='communicator_objects.BrainParametersProto.vector_action_descriptions', index=4, 69 | number=5, type=9, cpp_type=9, label=3, 70 | has_default_value=False, default_value=[], 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=None, file=DESCRIPTOR), 74 | _descriptor.FieldDescriptor( 75 | name='vector_action_space_type', full_name='communicator_objects.BrainParametersProto.vector_action_space_type', index=5, 76 | number=6, type=14, cpp_type=8, label=1, 77 | has_default_value=False, default_value=0, 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=None, file=DESCRIPTOR), 81 | _descriptor.FieldDescriptor( 82 | name='vector_observation_space_type', full_name='communicator_objects.BrainParametersProto.vector_observation_space_type', index=6, 83 | number=7, type=14, cpp_type=8, label=1, 84 | has_default_value=False, default_value=0, 85 | message_type=None, enum_type=None, containing_type=None, 86 | is_extension=False, extension_scope=None, 87 | options=None, file=DESCRIPTOR), 88 | _descriptor.FieldDescriptor( 89 | name='brain_name', full_name='communicator_objects.BrainParametersProto.brain_name', index=7, 90 | number=8, type=9, cpp_type=9, label=1, 91 | has_default_value=False, default_value=_b("").decode('utf-8'), 92 | message_type=None, enum_type=None, containing_type=None, 93 | is_extension=False, extension_scope=None, 94 | options=None, file=DESCRIPTOR), 95 | _descriptor.FieldDescriptor( 96 | name='brain_type', full_name='communicator_objects.BrainParametersProto.brain_type', index=8, 97 | number=9, type=14, cpp_type=8, label=1, 98 | has_default_value=False, default_value=0, 99 | message_type=None, enum_type=None, containing_type=None, 100 | is_extension=False, extension_scope=None, 101 | options=None, file=DESCRIPTOR), 102 | ], 103 | extensions=[ 104 | ], 105 | nested_types=[], 106 | enum_types=[ 107 | ], 108 | options=None, 109 | is_extendable=False, 110 | syntax='proto3', 111 | extension_ranges=[], 112 | oneofs=[ 113 | ], 114 | serialized_start=211, 115 | serialized_end=665, 116 | ) 117 | 118 | _BRAINPARAMETERSPROTO.fields_by_name['camera_resolutions'].message_type = communicator__objects_dot_resolution__proto__pb2._RESOLUTIONPROTO 119 | _BRAINPARAMETERSPROTO.fields_by_name['vector_action_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO 120 | _BRAINPARAMETERSPROTO.fields_by_name['vector_observation_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO 121 | _BRAINPARAMETERSPROTO.fields_by_name['brain_type'].enum_type = communicator__objects_dot_brain__type__proto__pb2._BRAINTYPEPROTO 122 | DESCRIPTOR.message_types_by_name['BrainParametersProto'] = _BRAINPARAMETERSPROTO 123 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 124 | 125 | BrainParametersProto = _reflection.GeneratedProtocolMessageType('BrainParametersProto', (_message.Message,), dict( 126 | DESCRIPTOR = _BRAINPARAMETERSPROTO, 127 | __module__ = 'communicator_objects.brain_parameters_proto_pb2' 128 | # @@protoc_insertion_point(class_scope:communicator_objects.BrainParametersProto) 129 | )) 130 | _sym_db.RegisterMessage(BrainParametersProto) 131 | 132 | 133 | DESCRIPTOR.has_options = True 134 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 135 | # @@protoc_insertion_point(module_scope) 136 | -------------------------------------------------------------------------------- /python/communicator_objects/brain_type_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/brain_type_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/brain_type_proto.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n+communicator_objects/brain_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*G\n\x0e\x42rainTypeProto\x12\n\n\x06Player\x10\x00\x12\r\n\tHeuristic\x10\x01\x12\x0c\n\x08\x45xternal\x10\x02\x12\x0c\n\x08Internal\x10\x03\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) 27 | 28 | _BRAINTYPEPROTO = _descriptor.EnumDescriptor( 29 | name='BrainTypeProto', 30 | full_name='communicator_objects.BrainTypeProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='Player', index=0, number=0, 36 | options=None, 37 | type=None), 38 | _descriptor.EnumValueDescriptor( 39 | name='Heuristic', index=1, number=1, 40 | options=None, 41 | type=None), 42 | _descriptor.EnumValueDescriptor( 43 | name='External', index=2, number=2, 44 | options=None, 45 | type=None), 46 | _descriptor.EnumValueDescriptor( 47 | name='Internal', index=3, number=3, 48 | options=None, 49 | type=None), 50 | ], 51 | containing_type=None, 52 | options=None, 53 | serialized_start=114, 54 | serialized_end=185, 55 | ) 56 | _sym_db.RegisterEnumDescriptor(_BRAINTYPEPROTO) 57 | 58 | BrainTypeProto = enum_type_wrapper.EnumTypeWrapper(_BRAINTYPEPROTO) 59 | Player = 0 60 | Heuristic = 1 61 | External = 2 62 | Internal = 3 63 | 64 | 65 | DESCRIPTOR.enum_types_by_name['BrainTypeProto'] = _BRAINTYPEPROTO 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | 69 | DESCRIPTOR.has_options = True 70 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 71 | # @@protoc_insertion_point(module_scope) 72 | -------------------------------------------------------------------------------- /python/communicator_objects/command_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/command_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/command_proto.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n(communicator_objects/command_proto.proto\x12\x14\x63ommunicator_objects*-\n\x0c\x43ommandProto\x12\x08\n\x04STEP\x10\x00\x12\t\n\x05RESET\x10\x01\x12\x08\n\x04QUIT\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | _COMMANDPROTO = _descriptor.EnumDescriptor( 27 | name='CommandProto', 28 | full_name='communicator_objects.CommandProto', 29 | filename=None, 30 | file=DESCRIPTOR, 31 | values=[ 32 | _descriptor.EnumValueDescriptor( 33 | name='STEP', index=0, number=0, 34 | options=None, 35 | type=None), 36 | _descriptor.EnumValueDescriptor( 37 | name='RESET', index=1, number=1, 38 | options=None, 39 | type=None), 40 | _descriptor.EnumValueDescriptor( 41 | name='QUIT', index=2, number=2, 42 | options=None, 43 | type=None), 44 | ], 45 | containing_type=None, 46 | options=None, 47 | serialized_start=66, 48 | serialized_end=111, 49 | ) 50 | _sym_db.RegisterEnumDescriptor(_COMMANDPROTO) 51 | 52 | CommandProto = enum_type_wrapper.EnumTypeWrapper(_COMMANDPROTO) 53 | STEP = 0 54 | RESET = 1 55 | QUIT = 2 56 | 57 | 58 | DESCRIPTOR.enum_types_by_name['CommandProto'] = _COMMANDPROTO 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | 62 | DESCRIPTOR.has_options = True 63 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /python/communicator_objects/engine_configuration_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/engine_configuration_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/engine_configuration_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n5communicator_objects/engine_configuration_proto.proto\x12\x14\x63ommunicator_objects\"\x95\x01\n\x18\x45ngineConfigurationProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x15\n\rquality_level\x18\x03 \x01(\x05\x12\x12\n\ntime_scale\x18\x04 \x01(\x02\x12\x19\n\x11target_frame_rate\x18\x05 \x01(\x05\x12\x14\n\x0cshow_monitor\x18\x06 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _ENGINECONFIGURATIONPROTO = _descriptor.Descriptor( 29 | name='EngineConfigurationProto', 30 | full_name='communicator_objects.EngineConfigurationProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='width', full_name='communicator_objects.EngineConfigurationProto.width', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='communicator_objects.EngineConfigurationProto.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='quality_level', full_name='communicator_objects.EngineConfigurationProto.quality_level', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=False, default_value=0, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='time_scale', full_name='communicator_objects.EngineConfigurationProto.time_scale', index=3, 58 | number=4, type=2, cpp_type=6, label=1, 59 | has_default_value=False, default_value=float(0), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None, file=DESCRIPTOR), 63 | _descriptor.FieldDescriptor( 64 | name='target_frame_rate', full_name='communicator_objects.EngineConfigurationProto.target_frame_rate', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=False, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None, file=DESCRIPTOR), 70 | _descriptor.FieldDescriptor( 71 | name='show_monitor', full_name='communicator_objects.EngineConfigurationProto.show_monitor', index=5, 72 | number=6, type=8, cpp_type=7, label=1, 73 | has_default_value=False, default_value=False, 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None, file=DESCRIPTOR), 77 | ], 78 | extensions=[ 79 | ], 80 | nested_types=[], 81 | enum_types=[ 82 | ], 83 | options=None, 84 | is_extendable=False, 85 | syntax='proto3', 86 | extension_ranges=[], 87 | oneofs=[ 88 | ], 89 | serialized_start=80, 90 | serialized_end=229, 91 | ) 92 | 93 | DESCRIPTOR.message_types_by_name['EngineConfigurationProto'] = _ENGINECONFIGURATIONPROTO 94 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 95 | 96 | EngineConfigurationProto = _reflection.GeneratedProtocolMessageType('EngineConfigurationProto', (_message.Message,), dict( 97 | DESCRIPTOR = _ENGINECONFIGURATIONPROTO, 98 | __module__ = 'communicator_objects.engine_configuration_proto_pb2' 99 | # @@protoc_insertion_point(class_scope:communicator_objects.EngineConfigurationProto) 100 | )) 101 | _sym_db.RegisterMessage(EngineConfigurationProto) 102 | 103 | 104 | DESCRIPTOR.has_options = True 105 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 106 | # @@protoc_insertion_point(module_scope) 107 | -------------------------------------------------------------------------------- /python/communicator_objects/environment_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/environment_parameters_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/environment_parameters_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n7communicator_objects/environment_parameters_proto.proto\x12\x14\x63ommunicator_objects\"\xb5\x01\n\x1a\x45nvironmentParametersProto\x12_\n\x10\x66loat_parameters\x18\x01 \x03(\x0b\x32\x45.communicator_objects.EnvironmentParametersProto.FloatParametersEntry\x1a\x36\n\x14\x46loatParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY = _descriptor.Descriptor( 29 | name='FloatParametersEntry', 30 | full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='key', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.key', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='value', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.value', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=False, default_value=float(0), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=209, 62 | serialized_end=263, 63 | ) 64 | 65 | _ENVIRONMENTPARAMETERSPROTO = _descriptor.Descriptor( 66 | name='EnvironmentParametersProto', 67 | full_name='communicator_objects.EnvironmentParametersProto', 68 | filename=None, 69 | file=DESCRIPTOR, 70 | containing_type=None, 71 | fields=[ 72 | _descriptor.FieldDescriptor( 73 | name='float_parameters', full_name='communicator_objects.EnvironmentParametersProto.float_parameters', index=0, 74 | number=1, type=11, cpp_type=10, label=3, 75 | has_default_value=False, default_value=[], 76 | message_type=None, enum_type=None, containing_type=None, 77 | is_extension=False, extension_scope=None, 78 | options=None, file=DESCRIPTOR), 79 | ], 80 | extensions=[ 81 | ], 82 | nested_types=[_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, ], 83 | enum_types=[ 84 | ], 85 | options=None, 86 | is_extendable=False, 87 | syntax='proto3', 88 | extension_ranges=[], 89 | oneofs=[ 90 | ], 91 | serialized_start=82, 92 | serialized_end=263, 93 | ) 94 | 95 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.containing_type = _ENVIRONMENTPARAMETERSPROTO 96 | _ENVIRONMENTPARAMETERSPROTO.fields_by_name['float_parameters'].message_type = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY 97 | DESCRIPTOR.message_types_by_name['EnvironmentParametersProto'] = _ENVIRONMENTPARAMETERSPROTO 98 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 99 | 100 | EnvironmentParametersProto = _reflection.GeneratedProtocolMessageType('EnvironmentParametersProto', (_message.Message,), dict( 101 | 102 | FloatParametersEntry = _reflection.GeneratedProtocolMessageType('FloatParametersEntry', (_message.Message,), dict( 103 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, 104 | __module__ = 'communicator_objects.environment_parameters_proto_pb2' 105 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto.FloatParametersEntry) 106 | )) 107 | , 108 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO, 109 | __module__ = 'communicator_objects.environment_parameters_proto_pb2' 110 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto) 111 | )) 112 | _sym_db.RegisterMessage(EnvironmentParametersProto) 113 | _sym_db.RegisterMessage(EnvironmentParametersProto.FloatParametersEntry) 114 | 115 | 116 | DESCRIPTOR.has_options = True 117 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 118 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.has_options = True 119 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 120 | # @@protoc_insertion_point(module_scope) 121 | -------------------------------------------------------------------------------- /python/communicator_objects/header_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/header.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/header.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n!communicator_objects/header.proto\x12\x14\x63ommunicator_objects\")\n\x06Header\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _HEADER = _descriptor.Descriptor( 29 | name='Header', 30 | full_name='communicator_objects.Header', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='status', full_name='communicator_objects.Header.status', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='message', full_name='communicator_objects.Header.message', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=None, 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=59, 62 | serialized_end=100, 63 | ) 64 | 65 | DESCRIPTOR.message_types_by_name['Header'] = _HEADER 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | Header = _reflection.GeneratedProtocolMessageType('Header', (_message.Message,), dict( 69 | DESCRIPTOR = _HEADER, 70 | __module__ = 'communicator_objects.header_pb2' 71 | # @@protoc_insertion_point(class_scope:communicator_objects.Header) 72 | )) 73 | _sym_db.RegisterMessage(Header) 74 | 75 | 76 | DESCRIPTOR.has_options = True 77 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 78 | # @@protoc_insertion_point(module_scope) 79 | -------------------------------------------------------------------------------- /python/communicator_objects/resolution_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/resolution_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/resolution_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n+communicator_objects/resolution_proto.proto\x12\x14\x63ommunicator_objects\"D\n\x0fResolutionProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x12\n\ngray_scale\x18\x03 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _RESOLUTIONPROTO = _descriptor.Descriptor( 29 | name='ResolutionProto', 30 | full_name='communicator_objects.ResolutionProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='width', full_name='communicator_objects.ResolutionProto.width', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='communicator_objects.ResolutionProto.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='gray_scale', full_name='communicator_objects.ResolutionProto.gray_scale', index=2, 51 | number=3, type=8, cpp_type=7, label=1, 52 | has_default_value=False, default_value=False, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto3', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=69, 69 | serialized_end=137, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['ResolutionProto'] = _RESOLUTIONPROTO 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | ResolutionProto = _reflection.GeneratedProtocolMessageType('ResolutionProto', (_message.Message,), dict( 76 | DESCRIPTOR = _RESOLUTIONPROTO, 77 | __module__ = 'communicator_objects.resolution_proto_pb2' 78 | # @@protoc_insertion_point(class_scope:communicator_objects.ResolutionProto) 79 | )) 80 | _sym_db.RegisterMessage(ResolutionProto) 81 | 82 | 83 | DESCRIPTOR.has_options = True 84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /python/communicator_objects/space_type_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/space_type_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/space_type_proto.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n+communicator_objects/space_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*.\n\x0eSpaceTypeProto\x12\x0c\n\x08\x64iscrete\x10\x00\x12\x0e\n\ncontinuous\x10\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) 27 | 28 | _SPACETYPEPROTO = _descriptor.EnumDescriptor( 29 | name='SpaceTypeProto', 30 | full_name='communicator_objects.SpaceTypeProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='discrete', index=0, number=0, 36 | options=None, 37 | type=None), 38 | _descriptor.EnumValueDescriptor( 39 | name='continuous', index=1, number=1, 40 | options=None, 41 | type=None), 42 | ], 43 | containing_type=None, 44 | options=None, 45 | serialized_start=114, 46 | serialized_end=160, 47 | ) 48 | _sym_db.RegisterEnumDescriptor(_SPACETYPEPROTO) 49 | 50 | SpaceTypeProto = enum_type_wrapper.EnumTypeWrapper(_SPACETYPEPROTO) 51 | discrete = 0 52 | continuous = 1 53 | 54 | 55 | DESCRIPTOR.enum_types_by_name['SpaceTypeProto'] = _SPACETYPEPROTO 56 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 57 | 58 | 59 | DESCRIPTOR.has_options = True 60 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 61 | # @@protoc_insertion_point(module_scope) 62 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_rl_input_pb2 as communicator__objects_dot_unity__rl__input__pb2 17 | from communicator_objects import unity_rl_initialization_input_pb2 as communicator__objects_dot_unity__rl__initialization__input__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_input.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n&communicator_objects/unity_input.proto\x12\x14\x63ommunicator_objects\x1a)communicator_objects/unity_rl_input.proto\x1a\x38\x63ommunicator_objects/unity_rl_initialization_input.proto\"\xb0\x01\n\nUnityInput\x12\x34\n\x08rl_input\x18\x01 \x01(\x0b\x32\".communicator_objects.UnityRLInput\x12Q\n\x17rl_initialization_input\x18\x02 \x01(\x0b\x32\x30.communicator_objects.UnityRLInitializationInput\x12\x19\n\x11\x63ustom_data_input\x18\x03 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_unity__rl__input__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__input__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYINPUT = _descriptor.Descriptor( 32 | name='UnityInput', 33 | full_name='communicator_objects.UnityInput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='rl_input', full_name='communicator_objects.UnityInput.rl_input', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='rl_initialization_input', full_name='communicator_objects.UnityInput.rl_initialization_input', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='custom_data_input', full_name='communicator_objects.UnityInput.custom_data_input', index=2, 54 | number=3, type=5, cpp_type=1, label=1, 55 | has_default_value=False, default_value=0, 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | ], 60 | extensions=[ 61 | ], 62 | nested_types=[], 63 | enum_types=[ 64 | ], 65 | options=None, 66 | is_extendable=False, 67 | syntax='proto3', 68 | extension_ranges=[], 69 | oneofs=[ 70 | ], 71 | serialized_start=166, 72 | serialized_end=342, 73 | ) 74 | 75 | _UNITYINPUT.fields_by_name['rl_input'].message_type = communicator__objects_dot_unity__rl__input__pb2._UNITYRLINPUT 76 | _UNITYINPUT.fields_by_name['rl_initialization_input'].message_type = communicator__objects_dot_unity__rl__initialization__input__pb2._UNITYRLINITIALIZATIONINPUT 77 | DESCRIPTOR.message_types_by_name['UnityInput'] = _UNITYINPUT 78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 79 | 80 | UnityInput = _reflection.GeneratedProtocolMessageType('UnityInput', (_message.Message,), dict( 81 | DESCRIPTOR = _UNITYINPUT, 82 | __module__ = 'communicator_objects.unity_input_pb2' 83 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityInput) 84 | )) 85 | _sym_db.RegisterMessage(UnityInput) 86 | 87 | 88 | DESCRIPTOR.has_options = True 89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_message_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_message.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_output_pb2 as communicator__objects_dot_unity__output__pb2 17 | from communicator_objects import unity_input_pb2 as communicator__objects_dot_unity__input__pb2 18 | from communicator_objects import header_pb2 as communicator__objects_dot_header__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='communicator_objects/unity_message.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n(communicator_objects/unity_message.proto\x12\x14\x63ommunicator_objects\x1a\'communicator_objects/unity_output.proto\x1a&communicator_objects/unity_input.proto\x1a!communicator_objects/header.proto\"\xac\x01\n\x0cUnityMessage\x12,\n\x06header\x18\x01 \x01(\x0b\x32\x1c.communicator_objects.Header\x12\x37\n\x0cunity_output\x18\x02 \x01(\x0b\x32!.communicator_objects.UnityOutput\x12\x35\n\x0bunity_input\x18\x03 \x01(\x0b\x32 .communicator_objects.UnityInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[communicator__objects_dot_unity__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__input__pb2.DESCRIPTOR,communicator__objects_dot_header__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _UNITYMESSAGE = _descriptor.Descriptor( 33 | name='UnityMessage', 34 | full_name='communicator_objects.UnityMessage', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='header', full_name='communicator_objects.UnityMessage.header', index=0, 41 | number=1, type=11, cpp_type=10, label=1, 42 | has_default_value=False, default_value=None, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='unity_output', full_name='communicator_objects.UnityMessage.unity_output', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='unity_input', full_name='communicator_objects.UnityMessage.unity_input', index=2, 55 | number=3, type=11, cpp_type=10, label=1, 56 | has_default_value=False, default_value=None, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None, file=DESCRIPTOR), 60 | ], 61 | extensions=[ 62 | ], 63 | nested_types=[], 64 | enum_types=[ 65 | ], 66 | options=None, 67 | is_extendable=False, 68 | syntax='proto3', 69 | extension_ranges=[], 70 | oneofs=[ 71 | ], 72 | serialized_start=183, 73 | serialized_end=355, 74 | ) 75 | 76 | _UNITYMESSAGE.fields_by_name['header'].message_type = communicator__objects_dot_header__pb2._HEADER 77 | _UNITYMESSAGE.fields_by_name['unity_output'].message_type = communicator__objects_dot_unity__output__pb2._UNITYOUTPUT 78 | _UNITYMESSAGE.fields_by_name['unity_input'].message_type = communicator__objects_dot_unity__input__pb2._UNITYINPUT 79 | DESCRIPTOR.message_types_by_name['UnityMessage'] = _UNITYMESSAGE 80 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 81 | 82 | UnityMessage = _reflection.GeneratedProtocolMessageType('UnityMessage', (_message.Message,), dict( 83 | DESCRIPTOR = _UNITYMESSAGE, 84 | __module__ = 'communicator_objects.unity_message_pb2' 85 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityMessage) 86 | )) 87 | _sym_db.RegisterMessage(UnityMessage) 88 | 89 | 90 | DESCRIPTOR.has_options = True 91 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 92 | # @@protoc_insertion_point(module_scope) 93 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_rl_output_pb2 as communicator__objects_dot_unity__rl__output__pb2 17 | from communicator_objects import unity_rl_initialization_output_pb2 as communicator__objects_dot_unity__rl__initialization__output__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n\'communicator_objects/unity_output.proto\x12\x14\x63ommunicator_objects\x1a*communicator_objects/unity_rl_output.proto\x1a\x39\x63ommunicator_objects/unity_rl_initialization_output.proto\"\xb6\x01\n\x0bUnityOutput\x12\x36\n\trl_output\x18\x01 \x01(\x0b\x32#.communicator_objects.UnityRLOutput\x12S\n\x18rl_initialization_output\x18\x02 \x01(\x0b\x32\x31.communicator_objects.UnityRLInitializationOutput\x12\x1a\n\x12\x63ustom_data_output\x18\x03 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_unity__rl__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__output__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYOUTPUT = _descriptor.Descriptor( 32 | name='UnityOutput', 33 | full_name='communicator_objects.UnityOutput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='rl_output', full_name='communicator_objects.UnityOutput.rl_output', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='rl_initialization_output', full_name='communicator_objects.UnityOutput.rl_initialization_output', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='custom_data_output', full_name='communicator_objects.UnityOutput.custom_data_output', index=2, 54 | number=3, type=9, cpp_type=9, label=1, 55 | has_default_value=False, default_value=_b("").decode('utf-8'), 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | ], 60 | extensions=[ 61 | ], 62 | nested_types=[], 63 | enum_types=[ 64 | ], 65 | options=None, 66 | is_extendable=False, 67 | syntax='proto3', 68 | extension_ranges=[], 69 | oneofs=[ 70 | ], 71 | serialized_start=169, 72 | serialized_end=351, 73 | ) 74 | 75 | _UNITYOUTPUT.fields_by_name['rl_output'].message_type = communicator__objects_dot_unity__rl__output__pb2._UNITYRLOUTPUT 76 | _UNITYOUTPUT.fields_by_name['rl_initialization_output'].message_type = communicator__objects_dot_unity__rl__initialization__output__pb2._UNITYRLINITIALIZATIONOUTPUT 77 | DESCRIPTOR.message_types_by_name['UnityOutput'] = _UNITYOUTPUT 78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 79 | 80 | UnityOutput = _reflection.GeneratedProtocolMessageType('UnityOutput', (_message.Message,), dict( 81 | DESCRIPTOR = _UNITYOUTPUT, 82 | __module__ = 'communicator_objects.unity_output_pb2' 83 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityOutput) 84 | )) 85 | _sym_db.RegisterMessage(UnityOutput) 86 | 87 | 88 | DESCRIPTOR.has_options = True 89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_rl_initialization_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_initialization_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/unity_rl_initialization_input.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n8communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\"*\n\x1aUnityRLInitializationInput\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _UNITYRLINITIALIZATIONINPUT = _descriptor.Descriptor( 29 | name='UnityRLInitializationInput', 30 | full_name='communicator_objects.UnityRLInitializationInput', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='seed', full_name='communicator_objects.UnityRLInitializationInput.seed', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | ], 43 | extensions=[ 44 | ], 45 | nested_types=[], 46 | enum_types=[ 47 | ], 48 | options=None, 49 | is_extendable=False, 50 | syntax='proto3', 51 | extension_ranges=[], 52 | oneofs=[ 53 | ], 54 | serialized_start=82, 55 | serialized_end=124, 56 | ) 57 | 58 | DESCRIPTOR.message_types_by_name['UnityRLInitializationInput'] = _UNITYRLINITIALIZATIONINPUT 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | UnityRLInitializationInput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationInput', (_message.Message,), dict( 62 | DESCRIPTOR = _UNITYRLINITIALIZATIONINPUT, 63 | __module__ = 'communicator_objects.unity_rl_initialization_input_pb2' 64 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationInput) 65 | )) 66 | _sym_db.RegisterMessage(UnityRLInitializationInput) 67 | 68 | 69 | DESCRIPTOR.has_options = True 70 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 71 | # @@protoc_insertion_point(module_scope) 72 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_rl_initialization_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_initialization_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import brain_parameters_proto_pb2 as communicator__objects_dot_brain__parameters__proto__pb2 17 | from communicator_objects import environment_parameters_proto_pb2 as communicator__objects_dot_environment__parameters__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_rl_initialization_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n9communicator_objects/unity_rl_initialization_output.proto\x12\x14\x63ommunicator_objects\x1a\x31\x63ommunicator_objects/brain_parameters_proto.proto\x1a\x37\x63ommunicator_objects/environment_parameters_proto.proto\"\xe6\x01\n\x1bUnityRLInitializationOutput\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x10\n\x08log_path\x18\x03 \x01(\t\x12\x44\n\x10\x62rain_parameters\x18\x05 \x03(\x0b\x32*.communicator_objects.BrainParametersProto\x12P\n\x16\x65nvironment_parameters\x18\x06 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_brain__parameters__proto__pb2.DESCRIPTOR,communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYRLINITIALIZATIONOUTPUT = _descriptor.Descriptor( 32 | name='UnityRLInitializationOutput', 33 | full_name='communicator_objects.UnityRLInitializationOutput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='name', full_name='communicator_objects.UnityRLInitializationOutput.name', index=0, 40 | number=1, type=9, cpp_type=9, label=1, 41 | has_default_value=False, default_value=_b("").decode('utf-8'), 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='version', full_name='communicator_objects.UnityRLInitializationOutput.version', index=1, 47 | number=2, type=9, cpp_type=9, label=1, 48 | has_default_value=False, default_value=_b("").decode('utf-8'), 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='log_path', full_name='communicator_objects.UnityRLInitializationOutput.log_path', index=2, 54 | number=3, type=9, cpp_type=9, label=1, 55 | has_default_value=False, default_value=_b("").decode('utf-8'), 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | _descriptor.FieldDescriptor( 60 | name='brain_parameters', full_name='communicator_objects.UnityRLInitializationOutput.brain_parameters', index=3, 61 | number=5, type=11, cpp_type=10, label=3, 62 | has_default_value=False, default_value=[], 63 | message_type=None, enum_type=None, containing_type=None, 64 | is_extension=False, extension_scope=None, 65 | options=None, file=DESCRIPTOR), 66 | _descriptor.FieldDescriptor( 67 | name='environment_parameters', full_name='communicator_objects.UnityRLInitializationOutput.environment_parameters', index=4, 68 | number=6, type=11, cpp_type=10, label=1, 69 | has_default_value=False, default_value=None, 70 | message_type=None, enum_type=None, containing_type=None, 71 | is_extension=False, extension_scope=None, 72 | options=None, file=DESCRIPTOR), 73 | ], 74 | extensions=[ 75 | ], 76 | nested_types=[], 77 | enum_types=[ 78 | ], 79 | options=None, 80 | is_extendable=False, 81 | syntax='proto3', 82 | extension_ranges=[], 83 | oneofs=[ 84 | ], 85 | serialized_start=192, 86 | serialized_end=422, 87 | ) 88 | 89 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['brain_parameters'].message_type = communicator__objects_dot_brain__parameters__proto__pb2._BRAINPARAMETERSPROTO 90 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['environment_parameters'].message_type = communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO 91 | DESCRIPTOR.message_types_by_name['UnityRLInitializationOutput'] = _UNITYRLINITIALIZATIONOUTPUT 92 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 93 | 94 | UnityRLInitializationOutput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationOutput', (_message.Message,), dict( 95 | DESCRIPTOR = _UNITYRLINITIALIZATIONOUTPUT, 96 | __module__ = 'communicator_objects.unity_rl_initialization_output_pb2' 97 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationOutput) 98 | )) 99 | _sym_db.RegisterMessage(UnityRLInitializationOutput) 100 | 101 | 102 | DESCRIPTOR.has_options = True 103 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 104 | # @@protoc_insertion_point(module_scope) 105 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_rl_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import agent_action_proto_pb2 as communicator__objects_dot_agent__action__proto__pb2 17 | from communicator_objects import environment_parameters_proto_pb2 as communicator__objects_dot_environment__parameters__proto__pb2 18 | from communicator_objects import command_proto_pb2 as communicator__objects_dot_command__proto__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='communicator_objects/unity_rl_input.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n)communicator_objects/unity_rl_input.proto\x12\x14\x63ommunicator_objects\x1a-communicator_objects/agent_action_proto.proto\x1a\x37\x63ommunicator_objects/environment_parameters_proto.proto\x1a(communicator_objects/command_proto.proto\"\xb4\x03\n\x0cUnityRLInput\x12K\n\ragent_actions\x18\x01 \x03(\x0b\x32\x34.communicator_objects.UnityRLInput.AgentActionsEntry\x12P\n\x16\x65nvironment_parameters\x18\x02 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProto\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x33\n\x07\x63ommand\x18\x04 \x01(\x0e\x32\".communicator_objects.CommandProto\x1aM\n\x14ListAgentActionProto\x12\x35\n\x05value\x18\x01 \x03(\x0b\x32&.communicator_objects.AgentActionProto\x1al\n\x11\x41gentActionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x46\n\x05value\x18\x02 \x01(\x0b\x32\x37.communicator_objects.UnityRLInput.ListAgentActionProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[communicator__objects_dot_agent__action__proto__pb2.DESCRIPTOR,communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,communicator__objects_dot_command__proto__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _UNITYRLINPUT_LISTAGENTACTIONPROTO = _descriptor.Descriptor( 33 | name='ListAgentActionProto', 34 | full_name='communicator_objects.UnityRLInput.ListAgentActionProto', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='value', full_name='communicator_objects.UnityRLInput.ListAgentActionProto.value', index=0, 41 | number=1, type=11, cpp_type=10, label=3, 42 | has_default_value=False, default_value=[], 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | ], 47 | extensions=[ 48 | ], 49 | nested_types=[], 50 | enum_types=[ 51 | ], 52 | options=None, 53 | is_extendable=False, 54 | syntax='proto3', 55 | extension_ranges=[], 56 | oneofs=[ 57 | ], 58 | serialized_start=463, 59 | serialized_end=540, 60 | ) 61 | 62 | _UNITYRLINPUT_AGENTACTIONSENTRY = _descriptor.Descriptor( 63 | name='AgentActionsEntry', 64 | full_name='communicator_objects.UnityRLInput.AgentActionsEntry', 65 | filename=None, 66 | file=DESCRIPTOR, 67 | containing_type=None, 68 | fields=[ 69 | _descriptor.FieldDescriptor( 70 | name='key', full_name='communicator_objects.UnityRLInput.AgentActionsEntry.key', index=0, 71 | number=1, type=9, cpp_type=9, label=1, 72 | has_default_value=False, default_value=_b("").decode('utf-8'), 73 | message_type=None, enum_type=None, containing_type=None, 74 | is_extension=False, extension_scope=None, 75 | options=None, file=DESCRIPTOR), 76 | _descriptor.FieldDescriptor( 77 | name='value', full_name='communicator_objects.UnityRLInput.AgentActionsEntry.value', index=1, 78 | number=2, type=11, cpp_type=10, label=1, 79 | has_default_value=False, default_value=None, 80 | message_type=None, enum_type=None, containing_type=None, 81 | is_extension=False, extension_scope=None, 82 | options=None, file=DESCRIPTOR), 83 | ], 84 | extensions=[ 85 | ], 86 | nested_types=[], 87 | enum_types=[ 88 | ], 89 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 90 | is_extendable=False, 91 | syntax='proto3', 92 | extension_ranges=[], 93 | oneofs=[ 94 | ], 95 | serialized_start=542, 96 | serialized_end=650, 97 | ) 98 | 99 | _UNITYRLINPUT = _descriptor.Descriptor( 100 | name='UnityRLInput', 101 | full_name='communicator_objects.UnityRLInput', 102 | filename=None, 103 | file=DESCRIPTOR, 104 | containing_type=None, 105 | fields=[ 106 | _descriptor.FieldDescriptor( 107 | name='agent_actions', full_name='communicator_objects.UnityRLInput.agent_actions', index=0, 108 | number=1, type=11, cpp_type=10, label=3, 109 | has_default_value=False, default_value=[], 110 | message_type=None, enum_type=None, containing_type=None, 111 | is_extension=False, extension_scope=None, 112 | options=None, file=DESCRIPTOR), 113 | _descriptor.FieldDescriptor( 114 | name='environment_parameters', full_name='communicator_objects.UnityRLInput.environment_parameters', index=1, 115 | number=2, type=11, cpp_type=10, label=1, 116 | has_default_value=False, default_value=None, 117 | message_type=None, enum_type=None, containing_type=None, 118 | is_extension=False, extension_scope=None, 119 | options=None, file=DESCRIPTOR), 120 | _descriptor.FieldDescriptor( 121 | name='is_training', full_name='communicator_objects.UnityRLInput.is_training', index=2, 122 | number=3, type=8, cpp_type=7, label=1, 123 | has_default_value=False, default_value=False, 124 | message_type=None, enum_type=None, containing_type=None, 125 | is_extension=False, extension_scope=None, 126 | options=None, file=DESCRIPTOR), 127 | _descriptor.FieldDescriptor( 128 | name='command', full_name='communicator_objects.UnityRLInput.command', index=3, 129 | number=4, type=14, cpp_type=8, label=1, 130 | has_default_value=False, default_value=0, 131 | message_type=None, enum_type=None, containing_type=None, 132 | is_extension=False, extension_scope=None, 133 | options=None, file=DESCRIPTOR), 134 | ], 135 | extensions=[ 136 | ], 137 | nested_types=[_UNITYRLINPUT_LISTAGENTACTIONPROTO, _UNITYRLINPUT_AGENTACTIONSENTRY, ], 138 | enum_types=[ 139 | ], 140 | options=None, 141 | is_extendable=False, 142 | syntax='proto3', 143 | extension_ranges=[], 144 | oneofs=[ 145 | ], 146 | serialized_start=214, 147 | serialized_end=650, 148 | ) 149 | 150 | _UNITYRLINPUT_LISTAGENTACTIONPROTO.fields_by_name['value'].message_type = communicator__objects_dot_agent__action__proto__pb2._AGENTACTIONPROTO 151 | _UNITYRLINPUT_LISTAGENTACTIONPROTO.containing_type = _UNITYRLINPUT 152 | _UNITYRLINPUT_AGENTACTIONSENTRY.fields_by_name['value'].message_type = _UNITYRLINPUT_LISTAGENTACTIONPROTO 153 | _UNITYRLINPUT_AGENTACTIONSENTRY.containing_type = _UNITYRLINPUT 154 | _UNITYRLINPUT.fields_by_name['agent_actions'].message_type = _UNITYRLINPUT_AGENTACTIONSENTRY 155 | _UNITYRLINPUT.fields_by_name['environment_parameters'].message_type = communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO 156 | _UNITYRLINPUT.fields_by_name['command'].enum_type = communicator__objects_dot_command__proto__pb2._COMMANDPROTO 157 | DESCRIPTOR.message_types_by_name['UnityRLInput'] = _UNITYRLINPUT 158 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 159 | 160 | UnityRLInput = _reflection.GeneratedProtocolMessageType('UnityRLInput', (_message.Message,), dict( 161 | 162 | ListAgentActionProto = _reflection.GeneratedProtocolMessageType('ListAgentActionProto', (_message.Message,), dict( 163 | DESCRIPTOR = _UNITYRLINPUT_LISTAGENTACTIONPROTO, 164 | __module__ = 'communicator_objects.unity_rl_input_pb2' 165 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput.ListAgentActionProto) 166 | )) 167 | , 168 | 169 | AgentActionsEntry = _reflection.GeneratedProtocolMessageType('AgentActionsEntry', (_message.Message,), dict( 170 | DESCRIPTOR = _UNITYRLINPUT_AGENTACTIONSENTRY, 171 | __module__ = 'communicator_objects.unity_rl_input_pb2' 172 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput.AgentActionsEntry) 173 | )) 174 | , 175 | DESCRIPTOR = _UNITYRLINPUT, 176 | __module__ = 'communicator_objects.unity_rl_input_pb2' 177 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput) 178 | )) 179 | _sym_db.RegisterMessage(UnityRLInput) 180 | _sym_db.RegisterMessage(UnityRLInput.ListAgentActionProto) 181 | _sym_db.RegisterMessage(UnityRLInput.AgentActionsEntry) 182 | 183 | 184 | DESCRIPTOR.has_options = True 185 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 186 | _UNITYRLINPUT_AGENTACTIONSENTRY.has_options = True 187 | _UNITYRLINPUT_AGENTACTIONSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 188 | # @@protoc_insertion_point(module_scope) 189 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_rl_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import agent_info_proto_pb2 as communicator__objects_dot_agent__info__proto__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/unity_rl_output.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n*communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/agent_info_proto.proto\"\xa3\x02\n\rUnityRLOutput\x12\x13\n\x0bglobal_done\x18\x01 \x01(\x08\x12G\n\nagentInfos\x18\x02 \x03(\x0b\x32\x33.communicator_objects.UnityRLOutput.AgentInfosEntry\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1ai\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x45\n\x05value\x18\x02 \x01(\x0b\x32\x36.communicator_objects.UnityRLOutput.ListAgentInfoProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | , 25 | dependencies=[communicator__objects_dot_agent__info__proto__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | 30 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO = _descriptor.Descriptor( 31 | name='ListAgentInfoProto', 32 | full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto', 33 | filename=None, 34 | file=DESCRIPTOR, 35 | containing_type=None, 36 | fields=[ 37 | _descriptor.FieldDescriptor( 38 | name='value', full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto.value', index=0, 39 | number=1, type=11, cpp_type=10, label=3, 40 | has_default_value=False, default_value=[], 41 | message_type=None, enum_type=None, containing_type=None, 42 | is_extension=False, extension_scope=None, 43 | options=None, file=DESCRIPTOR), 44 | ], 45 | extensions=[ 46 | ], 47 | nested_types=[], 48 | enum_types=[ 49 | ], 50 | options=None, 51 | is_extendable=False, 52 | syntax='proto3', 53 | extension_ranges=[], 54 | oneofs=[ 55 | ], 56 | serialized_start=225, 57 | serialized_end=298, 58 | ) 59 | 60 | _UNITYRLOUTPUT_AGENTINFOSENTRY = _descriptor.Descriptor( 61 | name='AgentInfosEntry', 62 | full_name='communicator_objects.UnityRLOutput.AgentInfosEntry', 63 | filename=None, 64 | file=DESCRIPTOR, 65 | containing_type=None, 66 | fields=[ 67 | _descriptor.FieldDescriptor( 68 | name='key', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.key', index=0, 69 | number=1, type=9, cpp_type=9, label=1, 70 | has_default_value=False, default_value=_b("").decode('utf-8'), 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=None, file=DESCRIPTOR), 74 | _descriptor.FieldDescriptor( 75 | name='value', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.value', index=1, 76 | number=2, type=11, cpp_type=10, label=1, 77 | has_default_value=False, default_value=None, 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=None, file=DESCRIPTOR), 81 | ], 82 | extensions=[ 83 | ], 84 | nested_types=[], 85 | enum_types=[ 86 | ], 87 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 88 | is_extendable=False, 89 | syntax='proto3', 90 | extension_ranges=[], 91 | oneofs=[ 92 | ], 93 | serialized_start=300, 94 | serialized_end=405, 95 | ) 96 | 97 | _UNITYRLOUTPUT = _descriptor.Descriptor( 98 | name='UnityRLOutput', 99 | full_name='communicator_objects.UnityRLOutput', 100 | filename=None, 101 | file=DESCRIPTOR, 102 | containing_type=None, 103 | fields=[ 104 | _descriptor.FieldDescriptor( 105 | name='global_done', full_name='communicator_objects.UnityRLOutput.global_done', index=0, 106 | number=1, type=8, cpp_type=7, label=1, 107 | has_default_value=False, default_value=False, 108 | message_type=None, enum_type=None, containing_type=None, 109 | is_extension=False, extension_scope=None, 110 | options=None, file=DESCRIPTOR), 111 | _descriptor.FieldDescriptor( 112 | name='agentInfos', full_name='communicator_objects.UnityRLOutput.agentInfos', index=1, 113 | number=2, type=11, cpp_type=10, label=3, 114 | has_default_value=False, default_value=[], 115 | message_type=None, enum_type=None, containing_type=None, 116 | is_extension=False, extension_scope=None, 117 | options=None, file=DESCRIPTOR), 118 | ], 119 | extensions=[ 120 | ], 121 | nested_types=[_UNITYRLOUTPUT_LISTAGENTINFOPROTO, _UNITYRLOUTPUT_AGENTINFOSENTRY, ], 122 | enum_types=[ 123 | ], 124 | options=None, 125 | is_extendable=False, 126 | syntax='proto3', 127 | extension_ranges=[], 128 | oneofs=[ 129 | ], 130 | serialized_start=114, 131 | serialized_end=405, 132 | ) 133 | 134 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.fields_by_name['value'].message_type = communicator__objects_dot_agent__info__proto__pb2._AGENTINFOPROTO 135 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.containing_type = _UNITYRLOUTPUT 136 | _UNITYRLOUTPUT_AGENTINFOSENTRY.fields_by_name['value'].message_type = _UNITYRLOUTPUT_LISTAGENTINFOPROTO 137 | _UNITYRLOUTPUT_AGENTINFOSENTRY.containing_type = _UNITYRLOUTPUT 138 | _UNITYRLOUTPUT.fields_by_name['agentInfos'].message_type = _UNITYRLOUTPUT_AGENTINFOSENTRY 139 | DESCRIPTOR.message_types_by_name['UnityRLOutput'] = _UNITYRLOUTPUT 140 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 141 | 142 | UnityRLOutput = _reflection.GeneratedProtocolMessageType('UnityRLOutput', (_message.Message,), dict( 143 | 144 | ListAgentInfoProto = _reflection.GeneratedProtocolMessageType('ListAgentInfoProto', (_message.Message,), dict( 145 | DESCRIPTOR = _UNITYRLOUTPUT_LISTAGENTINFOPROTO, 146 | __module__ = 'communicator_objects.unity_rl_output_pb2' 147 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.ListAgentInfoProto) 148 | )) 149 | , 150 | 151 | AgentInfosEntry = _reflection.GeneratedProtocolMessageType('AgentInfosEntry', (_message.Message,), dict( 152 | DESCRIPTOR = _UNITYRLOUTPUT_AGENTINFOSENTRY, 153 | __module__ = 'communicator_objects.unity_rl_output_pb2' 154 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.AgentInfosEntry) 155 | )) 156 | , 157 | DESCRIPTOR = _UNITYRLOUTPUT, 158 | __module__ = 'communicator_objects.unity_rl_output_pb2' 159 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput) 160 | )) 161 | _sym_db.RegisterMessage(UnityRLOutput) 162 | _sym_db.RegisterMessage(UnityRLOutput.ListAgentInfoProto) 163 | _sym_db.RegisterMessage(UnityRLOutput.AgentInfosEntry) 164 | 165 | 166 | DESCRIPTOR.has_options = True 167 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 168 | _UNITYRLOUTPUT_AGENTINFOSENTRY.has_options = True 169 | _UNITYRLOUTPUT_AGENTINFOSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 170 | # @@protoc_insertion_point(module_scope) 171 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_to_external_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_to_external.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/unity_to_external.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n,communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a(communicator_objects/unity_message.proto2g\n\x0fUnityToExternal\x12T\n\x08\x45xchange\x12\".communicator_objects.UnityMessage\x1a\".communicator_objects.UnityMessage\"\x00\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | , 25 | dependencies=[communicator__objects_dot_unity__message__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 30 | 31 | 32 | DESCRIPTOR.has_options = True 33 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 34 | 35 | _UNITYTOEXTERNAL = _descriptor.ServiceDescriptor( 36 | name='UnityToExternal', 37 | full_name='communicator_objects.UnityToExternal', 38 | file=DESCRIPTOR, 39 | index=0, 40 | options=None, 41 | serialized_start=112, 42 | serialized_end=215, 43 | methods=[ 44 | _descriptor.MethodDescriptor( 45 | name='Exchange', 46 | full_name='communicator_objects.UnityToExternal.Exchange', 47 | index=0, 48 | containing_service=None, 49 | input_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 50 | output_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 51 | options=None, 52 | ), 53 | ]) 54 | _sym_db.RegisterServiceDescriptor(_UNITYTOEXTERNAL) 55 | 56 | DESCRIPTOR.services_by_name['UnityToExternal'] = _UNITYTOEXTERNAL 57 | 58 | # @@protoc_insertion_point(module_scope) 59 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_to_external_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | import grpc 3 | 4 | from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2 5 | 6 | 7 | class UnityToExternalStub(object): 8 | # missing associated documentation comment in .proto file 9 | pass 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Exchange = channel.unary_unary( 18 | '/communicator_objects.UnityToExternal/Exchange', 19 | request_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 20 | response_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 21 | ) 22 | 23 | 24 | class UnityToExternalServicer(object): 25 | # missing associated documentation comment in .proto file 26 | pass 27 | 28 | def Exchange(self, request, context): 29 | """Sends the academy parameters 30 | """ 31 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 32 | context.set_details('Method not implemented!') 33 | raise NotImplementedError('Method not implemented!') 34 | 35 | 36 | def add_UnityToExternalServicer_to_server(servicer, server): 37 | rpc_method_handlers = { 38 | 'Exchange': grpc.unary_unary_rpc_method_handler( 39 | servicer.Exchange, 40 | request_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 41 | response_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 42 | ), 43 | } 44 | generic_handler = grpc.method_handlers_generic_handler( 45 | 'communicator_objects.UnityToExternal', rpc_method_handlers) 46 | server.add_generic_rpc_handlers((generic_handler,)) 47 | -------------------------------------------------------------------------------- /python/curricula/push.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "reward", 3 | "thresholds" : [0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75], 4 | "min_lesson_length" : 2, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "goal_size" : [3.5, 3.25, 3.0, 2.75, 2.5, 2.25, 2.0, 1.75, 1.5, 1.25, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 9 | "block_size": [1.5, 1.4, 1.3, 1.2, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 10 | "x_variation":[1.5, 1.55, 1.6, 1.65, 1.7, 1.75, 1.8, 1.85, 1.9, 1.95, 2.0, 2.1, 2.2, 2.3, 2.4, 2.5] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /python/curricula/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "reward", 3 | "thresholds" : [10, 20, 50], 4 | "min_lesson_length" : 3, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "param1" : [0.7, 0.5, 0.3, 0.1], 9 | "param2" : [100, 50, 20, 15], 10 | "param3" : [0.2, 0.3, 0.7, 0.9] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /python/curricula/wall.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "progress", 3 | "thresholds" : [0.1, 0.3, 0.5], 4 | "min_lesson_length" : 2, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "big_wall_min_height" : [0.0, 4.0, 6.0, 8.0], 9 | "big_wall_max_height" : [4.0, 7.0, 8.0, 8.0], 10 | "small_wall_height" : [1.5, 2.0, 2.5, 4.0] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /python/learn.py: -------------------------------------------------------------------------------- 1 | # # Unity ML-Agents Toolkit 2 | # ## ML-Agent Learning 3 | 4 | import logging 5 | 6 | import os 7 | from docopt import docopt 8 | 9 | from unitytrainers.trainer_controller import TrainerController 10 | 11 | 12 | if __name__ == '__main__': 13 | print(''' 14 | 15 | ▄▄▄▓▓▓▓ 16 | ╓▓▓▓▓▓▓█▓▓▓▓▓ 17 | ,▄▄▄m▀▀▀' ,▓▓▓▀▓▓▄ ▓▓▓ ▓▓▌ 18 | ▄▓▓▓▀' ▄▓▓▀ ▓▓▓ ▄▄ ▄▄ ,▄▄ ▄▄▄▄ ,▄▄ ▄▓▓▌▄ ▄▄▄ ,▄▄ 19 | ▄▓▓▓▀ ▄▓▓▀ ▐▓▓▌ ▓▓▌ ▐▓▓ ▐▓▓▓▀▀▀▓▓▌ ▓▓▓ ▀▓▓▌▀ ^▓▓▌ ╒▓▓▌ 20 | ▄▓▓▓▓▓▄▄▄▄▄▄▄▄▓▓▓ ▓▀ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▄ ▓▓▌ 21 | ▀▓▓▓▓▀▀▀▀▀▀▀▀▀▀▓▓▄ ▓▓ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▐▓▓ 22 | ^█▓▓▓ ▀▓▓▄ ▐▓▓▌ ▓▓▓▓▄▓▓▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▓▄ ▓▓▓▓` 23 | '▀▓▓▓▄ ^▓▓▓ ▓▓▓ └▀▀▀▀ ▀▀ ^▀▀ `▀▀ `▀▀ '▀▀ ▐▓▓▌ 24 | ▀▀▀▀▓▄▄▄ ▓▓▓▓▓▓, ▓▓▓▓▀ 25 | `▀█▓▓▓▓▓▓▓▓▓▌ 26 | ¬`▀▀▀█▓ 27 | 28 | ''') 29 | 30 | logger = logging.getLogger("unityagents") 31 | _USAGE = ''' 32 | Usage: 33 | learn () [options] 34 | learn [options] 35 | learn --help 36 | 37 | Options: 38 | --curriculum= Curriculum json file for environment [default: None]. 39 | --keep-checkpoints= How many model checkpoints to keep [default: 5]. 40 | --lesson= Start learning from this lesson [default: 0]. 41 | --load Whether to load the model or randomly initialize [default: False]. 42 | --run-id= The sub-directory name for model and summary statistics [default: ppo]. 43 | --save-freq= Frequency at which to save model [default: 50000]. 44 | --seed= Random seed used for training [default: -1]. 45 | --slow Whether to run the game at training speed [default: False]. 46 | --train Whether to train model, or only run inference [default: False]. 47 | --worker-id= Number to add to communication port (5005). Used for multi-environment [default: 0]. 48 | --docker-target-name=
Docker Volume to store curriculum, executable and model files [default: Empty]. 49 | --no-graphics Whether to run the Unity simulator in no-graphics mode [default: False]. 50 | ''' 51 | 52 | options = docopt(_USAGE) 53 | logger.info(options) 54 | # Docker Parameters 55 | if options['--docker-target-name'] == 'Empty': 56 | docker_target_name = '' 57 | else: 58 | docker_target_name = options['--docker-target-name'] 59 | 60 | # General parameters 61 | run_id = options['--run-id'] 62 | seed = int(options['--seed']) 63 | load_model = options['--load'] 64 | train_model = options['--train'] 65 | save_freq = int(options['--save-freq']) 66 | env_path = options[''] 67 | keep_checkpoints = int(options['--keep-checkpoints']) 68 | worker_id = int(options['--worker-id']) 69 | curriculum_file = str(options['--curriculum']) 70 | if curriculum_file == "None": 71 | curriculum_file = None 72 | lesson = int(options['--lesson']) 73 | fast_simulation = not bool(options['--slow']) 74 | no_graphics = options['--no-graphics'] 75 | 76 | # Constants 77 | # Assumption that this yaml is present in same dir as this file 78 | base_path = os.path.dirname(__file__) 79 | TRAINER_CONFIG_PATH = os.path.abspath(os.path.join(base_path, "trainer_config.yaml")) 80 | 81 | tc = TrainerController(env_path, run_id, save_freq, curriculum_file, fast_simulation, load_model, train_model, 82 | worker_id, keep_checkpoints, lesson, seed, docker_target_name, TRAINER_CONFIG_PATH, 83 | no_graphics) 84 | tc.start_learning() 85 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.7.1 2 | Pillow>=4.2.1 3 | matplotlib 4 | numpy>=1.11.0 5 | jupyter 6 | pytest>=3.2.2 7 | docopt 8 | pyyaml 9 | protobuf==3.5.2 10 | grpcio==1.11.0 11 | torch==0.4.0 12 | pandas 13 | scipy 14 | ipykernel 15 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, Command, find_packages 4 | 5 | 6 | with open('requirements.txt') as f: 7 | required = f.read().splitlines() 8 | 9 | setup(name='unityagents', 10 | version='0.4.0', 11 | description='Unity Machine Learning Agents', 12 | license='Apache License 2.0', 13 | author='Unity Technologies', 14 | author_email='ML-Agents@unity3d.com', 15 | url='https://github.com/Unity-Technologies/ml-agents', 16 | packages=find_packages(), 17 | install_requires = required, 18 | long_description= ("Unity Machine Learning Agents allows researchers and developers " 19 | "to transform games and simulations created using the Unity Editor into environments " 20 | "where intelligent agents can be trained using reinforcement learning, evolutionary " 21 | "strategies, or other machine learning methods through a simple to use Python API.") 22 | ) 23 | -------------------------------------------------------------------------------- /python/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from unityagents import * 2 | from unitytrainers import * 3 | -------------------------------------------------------------------------------- /python/tests/mock_communicator.py: -------------------------------------------------------------------------------- 1 | 2 | from unityagents.communicator import Communicator 3 | from communicator_objects import UnityMessage, UnityOutput, UnityInput,\ 4 | ResolutionProto, BrainParametersProto, UnityRLInitializationOutput,\ 5 | AgentInfoProto, UnityRLOutput 6 | 7 | 8 | class MockCommunicator(Communicator): 9 | def __init__(self, discrete_action=False, visual_inputs=0): 10 | """ 11 | Python side of the grpc communication. Python is the client and Unity the server 12 | 13 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 14 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 15 | """ 16 | self.is_discrete = discrete_action 17 | self.steps = 0 18 | self.visual_inputs = visual_inputs 19 | self.has_been_closed = False 20 | 21 | def initialize(self, inputs: UnityInput) -> UnityOutput: 22 | resolutions = [ResolutionProto( 23 | width=30, 24 | height=40, 25 | gray_scale=False) for i in range(self.visual_inputs)] 26 | bp = BrainParametersProto( 27 | vector_observation_size=3, 28 | num_stacked_vector_observations=2, 29 | vector_action_size=2, 30 | camera_resolutions=resolutions, 31 | vector_action_descriptions=["", ""], 32 | vector_action_space_type=int(not self.is_discrete), 33 | vector_observation_space_type=1, 34 | brain_name="RealFakeBrain", 35 | brain_type=2 36 | ) 37 | rl_init = UnityRLInitializationOutput( 38 | name="RealFakeAcademy", 39 | version="API-4", 40 | log_path="", 41 | brain_parameters=[bp] 42 | ) 43 | return UnityOutput( 44 | rl_initialization_output=rl_init 45 | ) 46 | 47 | def exchange(self, inputs: UnityInput) -> UnityOutput: 48 | dict_agent_info = {} 49 | if self.is_discrete: 50 | vector_action = [1] 51 | else: 52 | vector_action = [1, 2] 53 | list_agent_info = [] 54 | for i in range(3): 55 | list_agent_info.append( 56 | AgentInfoProto( 57 | stacked_vector_observation=[1, 2, 3, 1, 2, 3], 58 | reward=1, 59 | stored_vector_actions=vector_action, 60 | stored_text_actions="", 61 | text_observation="", 62 | memories=[], 63 | done=(i == 2), 64 | max_step_reached=False, 65 | id=i 66 | )) 67 | dict_agent_info["RealFakeBrain"] = \ 68 | UnityRLOutput.ListAgentInfoProto(value=list_agent_info) 69 | global_done = False 70 | try: 71 | global_done = (inputs.rl_input.agent_actions["RealFakeBrain"].value[0].vector_actions[0] == -1) 72 | except: 73 | pass 74 | result = UnityRLOutput( 75 | global_done=global_done, 76 | agentInfos=dict_agent_info 77 | ) 78 | return UnityOutput( 79 | rl_output=result 80 | ) 81 | 82 | def close(self): 83 | """ 84 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 85 | """ 86 | self.has_been_closed = True 87 | -------------------------------------------------------------------------------- /python/tests/test_bc.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | import pytest 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from unitytrainers.bc.models import BehavioralCloningModel 8 | from unityagents import UnityEnvironment 9 | from .mock_communicator import MockCommunicator 10 | 11 | 12 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 13 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 14 | def test_cc_bc_model(mock_communicator, mock_launcher): 15 | tf.reset_default_graph() 16 | with tf.Session() as sess: 17 | with tf.variable_scope("FakeGraphScope"): 18 | mock_communicator.return_value = MockCommunicator( 19 | discrete_action=False, visual_inputs=0) 20 | env = UnityEnvironment(' ') 21 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 22 | init = tf.global_variables_initializer() 23 | sess.run(init) 24 | 25 | run_list = [model.sample_action, model.policy] 26 | feed_dict = {model.batch_size: 2, 27 | model.sequence_length: 1, 28 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 29 | [3, 4, 5, 3, 4, 5]])} 30 | sess.run(run_list, feed_dict=feed_dict) 31 | env.close() 32 | 33 | 34 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 35 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 36 | def test_dc_bc_model(mock_communicator, mock_launcher): 37 | tf.reset_default_graph() 38 | with tf.Session() as sess: 39 | with tf.variable_scope("FakeGraphScope"): 40 | mock_communicator.return_value = MockCommunicator( 41 | discrete_action=True, visual_inputs=0) 42 | env = UnityEnvironment(' ') 43 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 44 | init = tf.global_variables_initializer() 45 | sess.run(init) 46 | 47 | run_list = [model.sample_action, model.policy] 48 | feed_dict = {model.batch_size: 2, 49 | model.dropout_rate: 1.0, 50 | model.sequence_length: 1, 51 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 52 | [3, 4, 5, 3, 4, 5]])} 53 | sess.run(run_list, feed_dict=feed_dict) 54 | env.close() 55 | 56 | 57 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 58 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 59 | def test_visual_dc_bc_model(mock_communicator, mock_launcher): 60 | tf.reset_default_graph() 61 | with tf.Session() as sess: 62 | with tf.variable_scope("FakeGraphScope"): 63 | mock_communicator.return_value = MockCommunicator( 64 | discrete_action=True, visual_inputs=2) 65 | env = UnityEnvironment(' ') 66 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 67 | init = tf.global_variables_initializer() 68 | sess.run(init) 69 | 70 | run_list = [model.sample_action, model.policy] 71 | feed_dict = {model.batch_size: 2, 72 | model.dropout_rate: 1.0, 73 | model.sequence_length: 1, 74 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 75 | [3, 4, 5, 3, 4, 5]]), 76 | model.visual_in[0]: np.ones([2, 40, 30, 3]), 77 | model.visual_in[1]: np.ones([2, 40, 30, 3])} 78 | sess.run(run_list, feed_dict=feed_dict) 79 | env.close() 80 | 81 | 82 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 83 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 84 | def test_visual_cc_bc_model(mock_communicator, mock_launcher): 85 | tf.reset_default_graph() 86 | with tf.Session() as sess: 87 | with tf.variable_scope("FakeGraphScope"): 88 | mock_communicator.return_value = MockCommunicator( 89 | discrete_action=False, visual_inputs=2) 90 | env = UnityEnvironment(' ') 91 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 92 | init = tf.global_variables_initializer() 93 | sess.run(init) 94 | 95 | run_list = [model.sample_action, model.policy] 96 | feed_dict = {model.batch_size: 2, 97 | model.sequence_length: 1, 98 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 99 | [3, 4, 5, 3, 4, 5]]), 100 | model.visual_in[0]: np.ones([2, 40, 30, 3]), 101 | model.visual_in[1]: np.ones([2, 40, 30, 3])} 102 | sess.run(run_list, feed_dict=feed_dict) 103 | env.close() 104 | 105 | 106 | if __name__ == '__main__': 107 | pytest.main() 108 | -------------------------------------------------------------------------------- /python/tests/test_unityagents.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unittest.mock as mock 3 | import pytest 4 | import struct 5 | 6 | import numpy as np 7 | 8 | from unityagents import UnityEnvironment, UnityEnvironmentException, UnityActionException, \ 9 | BrainInfo, Curriculum 10 | from .mock_communicator import MockCommunicator 11 | 12 | 13 | dummy_curriculum = json.loads('''{ 14 | "measure" : "reward", 15 | "thresholds" : [10, 20, 50], 16 | "min_lesson_length" : 3, 17 | "signal_smoothing" : true, 18 | "parameters" : 19 | { 20 | "param1" : [0.7, 0.5, 0.3, 0.1], 21 | "param2" : [100, 50, 20, 15], 22 | "param3" : [0.2, 0.3, 0.7, 0.9] 23 | } 24 | }''') 25 | bad_curriculum = json.loads('''{ 26 | "measure" : "reward", 27 | "thresholds" : [10, 20, 50], 28 | "min_lesson_length" : 3, 29 | "signal_smoothing" : false, 30 | "parameters" : 31 | { 32 | "param1" : [0.7, 0.5, 0.3, 0.1], 33 | "param2" : [100, 50, 20], 34 | "param3" : [0.2, 0.3, 0.7, 0.9] 35 | } 36 | }''') 37 | 38 | 39 | def test_handles_bad_filename(): 40 | with pytest.raises(UnityEnvironmentException): 41 | UnityEnvironment(' ') 42 | 43 | 44 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 45 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 46 | def test_initialization(mock_communicator, mock_launcher): 47 | mock_communicator.return_value = MockCommunicator( 48 | discrete_action=False, visual_inputs=0) 49 | env = UnityEnvironment(' ') 50 | with pytest.raises(UnityActionException): 51 | env.step([0]) 52 | assert env.brain_names[0] == 'RealFakeBrain' 53 | env.close() 54 | 55 | 56 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 57 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 58 | def test_reset(mock_communicator, mock_launcher): 59 | mock_communicator.return_value = MockCommunicator( 60 | discrete_action=False, visual_inputs=0) 61 | env = UnityEnvironment(' ') 62 | brain = env.brains['RealFakeBrain'] 63 | brain_info = env.reset() 64 | env.close() 65 | assert not env.global_done 66 | assert isinstance(brain_info, dict) 67 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo) 68 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) 69 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) 70 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations 71 | assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ 72 | len(brain_info['RealFakeBrain'].agents) 73 | assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ 74 | brain.vector_observation_space_size * brain.num_stacked_vector_observations 75 | 76 | 77 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 78 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 79 | def test_step(mock_communicator, mock_launcher): 80 | mock_communicator.return_value = MockCommunicator( 81 | discrete_action=False, visual_inputs=0) 82 | env = UnityEnvironment(' ') 83 | brain = env.brains['RealFakeBrain'] 84 | brain_info = env.reset() 85 | brain_info = env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) 86 | with pytest.raises(UnityActionException): 87 | env.step([0]) 88 | brain_info = env.step([-1] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) 89 | with pytest.raises(UnityActionException): 90 | env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) 91 | env.close() 92 | assert env.global_done 93 | assert isinstance(brain_info, dict) 94 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo) 95 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) 96 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) 97 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations 98 | assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ 99 | len(brain_info['RealFakeBrain'].agents) 100 | assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ 101 | brain.vector_observation_space_size * brain.num_stacked_vector_observations 102 | 103 | print("\n\n\n\n\n\n\n" + str(brain_info['RealFakeBrain'].local_done)) 104 | assert not brain_info['RealFakeBrain'].local_done[0] 105 | assert brain_info['RealFakeBrain'].local_done[2] 106 | 107 | 108 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 109 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 110 | def test_close(mock_communicator, mock_launcher): 111 | comm = MockCommunicator( 112 | discrete_action=False, visual_inputs=0) 113 | mock_communicator.return_value = comm 114 | env = UnityEnvironment(' ') 115 | assert env._loaded 116 | env.close() 117 | assert not env._loaded 118 | assert comm.has_been_closed 119 | 120 | 121 | def test_curriculum(): 122 | open_name = '%s.open' % __name__ 123 | with mock.patch('json.load') as mock_load: 124 | with mock.patch(open_name, create=True) as mock_open: 125 | mock_open.return_value = 0 126 | mock_load.return_value = bad_curriculum 127 | with pytest.raises(UnityEnvironmentException): 128 | Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1, "param3": 1}) 129 | mock_load.return_value = dummy_curriculum 130 | with pytest.raises(UnityEnvironmentException): 131 | Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1}) 132 | curriculum = Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1, "param3": 1}) 133 | assert curriculum.get_lesson_number == 0 134 | curriculum.set_lesson_number(1) 135 | assert curriculum.get_lesson_number == 1 136 | curriculum.increment_lesson(10) 137 | assert curriculum.get_lesson_number == 1 138 | curriculum.increment_lesson(30) 139 | curriculum.increment_lesson(30) 140 | assert curriculum.get_lesson_number == 1 141 | assert curriculum.lesson_length == 3 142 | curriculum.increment_lesson(30) 143 | assert curriculum.get_config() == {'param1': 0.3, 'param2': 20, 'param3': 0.7} 144 | assert curriculum.get_config(0) == {"param1": 0.7, "param2": 100, "param3": 0.2} 145 | assert curriculum.lesson_length == 0 146 | assert curriculum.get_lesson_number == 2 147 | 148 | 149 | if __name__ == '__main__': 150 | pytest.main() 151 | -------------------------------------------------------------------------------- /python/tests/test_unitytrainers.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import unittest.mock as mock 3 | import pytest 4 | 5 | from unitytrainers.trainer_controller import TrainerController 6 | from unitytrainers.buffer import Buffer 7 | from unitytrainers.models import * 8 | from unitytrainers.ppo.trainer import PPOTrainer 9 | from unitytrainers.bc.trainer import BehavioralCloningTrainer 10 | from unityagents import UnityEnvironmentException 11 | from .mock_communicator import MockCommunicator 12 | 13 | dummy_start = '''{ 14 | "AcademyName": "RealFakeAcademy", 15 | "resetParameters": {}, 16 | "brainNames": ["RealFakeBrain"], 17 | "externalBrainNames": ["RealFakeBrain"], 18 | "logPath":"RealFakePath", 19 | "apiNumber":"API-3", 20 | "brainParameters": [{ 21 | "vectorObservationSize": 3, 22 | "numStackedVectorObservations" : 2, 23 | "vectorActionSize": 2, 24 | "memorySize": 0, 25 | "cameraResolutions": [], 26 | "vectorActionDescriptions": ["",""], 27 | "vectorActionSpaceType": 1, 28 | "vectorObservationSpaceType": 1 29 | }] 30 | }'''.encode() 31 | 32 | 33 | dummy_config = yaml.load(''' 34 | default: 35 | trainer: ppo 36 | batch_size: 32 37 | beta: 5.0e-3 38 | buffer_size: 512 39 | epsilon: 0.2 40 | gamma: 0.99 41 | hidden_units: 128 42 | lambd: 0.95 43 | learning_rate: 3.0e-4 44 | max_steps: 5.0e4 45 | normalize: true 46 | num_epoch: 5 47 | num_layers: 2 48 | time_horizon: 64 49 | sequence_length: 64 50 | summary_freq: 1000 51 | use_recurrent: false 52 | memory_size: 8 53 | use_curiosity: false 54 | curiosity_strength: 0.0 55 | curiosity_enc_size: 1 56 | ''') 57 | 58 | dummy_bc_config = yaml.load(''' 59 | default: 60 | trainer: imitation 61 | brain_to_imitate: ExpertBrain 62 | batches_per_epoch: 16 63 | batch_size: 32 64 | beta: 5.0e-3 65 | buffer_size: 512 66 | epsilon: 0.2 67 | gamma: 0.99 68 | hidden_units: 128 69 | lambd: 0.95 70 | learning_rate: 3.0e-4 71 | max_steps: 5.0e4 72 | normalize: true 73 | num_epoch: 5 74 | num_layers: 2 75 | time_horizon: 64 76 | sequence_length: 64 77 | summary_freq: 1000 78 | use_recurrent: false 79 | memory_size: 8 80 | use_curiosity: false 81 | curiosity_strength: 0.0 82 | curiosity_enc_size: 1 83 | ''') 84 | 85 | dummy_bad_config = yaml.load(''' 86 | default: 87 | trainer: incorrect_trainer 88 | brain_to_imitate: ExpertBrain 89 | batches_per_epoch: 16 90 | batch_size: 32 91 | beta: 5.0e-3 92 | buffer_size: 512 93 | epsilon: 0.2 94 | gamma: 0.99 95 | hidden_units: 128 96 | lambd: 0.95 97 | learning_rate: 3.0e-4 98 | max_steps: 5.0e4 99 | normalize: true 100 | num_epoch: 5 101 | num_layers: 2 102 | time_horizon: 64 103 | sequence_length: 64 104 | summary_freq: 1000 105 | use_recurrent: false 106 | memory_size: 8 107 | ''') 108 | 109 | 110 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 111 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 112 | def test_initialization(mock_communicator, mock_launcher): 113 | mock_communicator.return_value = MockCommunicator( 114 | discrete_action=True, visual_inputs=1) 115 | tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, 116 | 1, 1, 1, '', "tests/test_unitytrainers.py", False) 117 | assert(tc.env.brain_names[0] == 'RealFakeBrain') 118 | 119 | 120 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 121 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 122 | def test_load_config(mock_communicator, mock_launcher): 123 | open_name = 'unitytrainers.trainer_controller' + '.open' 124 | with mock.patch('yaml.load') as mock_load: 125 | with mock.patch(open_name, create=True) as _: 126 | mock_load.return_value = dummy_config 127 | mock_communicator.return_value = MockCommunicator( 128 | discrete_action=True, visual_inputs=1) 129 | mock_load.return_value = dummy_config 130 | tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, 131 | 1, 1, 1, '','', False) 132 | config = tc._load_config() 133 | assert(len(config) == 1) 134 | assert(config['default']['trainer'] == "ppo") 135 | 136 | 137 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 138 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 139 | def test_initialize_trainers(mock_communicator, mock_launcher): 140 | open_name = 'unitytrainers.trainer_controller' + '.open' 141 | with mock.patch('yaml.load') as mock_load: 142 | with mock.patch(open_name, create=True) as _: 143 | mock_communicator.return_value = MockCommunicator( 144 | discrete_action=True, visual_inputs=1) 145 | tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, 146 | 1, 1, 1, '', "tests/test_unitytrainers.py", False) 147 | 148 | # Test for PPO trainer 149 | mock_load.return_value = dummy_config 150 | config = tc._load_config() 151 | tf.reset_default_graph() 152 | with tf.Session() as sess: 153 | tc._initialize_trainers(config, sess) 154 | assert(len(tc.trainers) == 1) 155 | assert(isinstance(tc.trainers['RealFakeBrain'], PPOTrainer)) 156 | 157 | # Test for Behavior Cloning Trainer 158 | mock_load.return_value = dummy_bc_config 159 | config = tc._load_config() 160 | tf.reset_default_graph() 161 | with tf.Session() as sess: 162 | tc._initialize_trainers(config, sess) 163 | assert(isinstance(tc.trainers['RealFakeBrain'], BehavioralCloningTrainer)) 164 | 165 | # Test for proper exception when trainer name is incorrect 166 | mock_load.return_value = dummy_bad_config 167 | config = tc._load_config() 168 | tf.reset_default_graph() 169 | with tf.Session() as sess: 170 | with pytest.raises(UnityEnvironmentException): 171 | tc._initialize_trainers(config, sess) 172 | 173 | 174 | def assert_array(a, b): 175 | assert a.shape == b.shape 176 | la = list(a.flatten()) 177 | lb = list(b.flatten()) 178 | for i in range(len(la)): 179 | assert la[i] == lb[i] 180 | 181 | 182 | def test_buffer(): 183 | b = Buffer() 184 | for fake_agent_id in range(4): 185 | for step in range(9): 186 | b[fake_agent_id]['vector_observation'].append( 187 | [100 * fake_agent_id + 10 * step + 1, 188 | 100 * fake_agent_id + 10 * step + 2, 189 | 100 * fake_agent_id + 10 * step + 3] 190 | ) 191 | b[fake_agent_id]['action'].append([100 * fake_agent_id + 10 * step + 4, 192 | 100 * fake_agent_id + 10 * step + 5]) 193 | a = b[1]['vector_observation'].get_batch(batch_size=2, training_length=1, sequential=True) 194 | assert_array(a, np.array([[171, 172, 173], [181, 182, 183]])) 195 | a = b[2]['vector_observation'].get_batch(batch_size=2, training_length=3, sequential=True) 196 | assert_array(a, np.array([ 197 | [[231, 232, 233], [241, 242, 243], [251, 252, 253]], 198 | [[261, 262, 263], [271, 272, 273], [281, 282, 283]] 199 | ])) 200 | a = b[2]['vector_observation'].get_batch(batch_size=2, training_length=3, sequential=False) 201 | assert_array(a, np.array([ 202 | [[251, 252, 253], [261, 262, 263], [271, 272, 273]], 203 | [[261, 262, 263], [271, 272, 273], [281, 282, 283]] 204 | ])) 205 | b[4].reset_agent() 206 | assert len(b[4]) == 0 207 | b.append_update_buffer(3, 208 | batch_size=None, training_length=2) 209 | b.append_update_buffer(2, 210 | batch_size=None, training_length=2) 211 | assert len(b.update_buffer['action']) == 10 212 | assert np.array(b.update_buffer['action']).shape == (10, 2, 2) 213 | 214 | 215 | if __name__ == '__main__': 216 | pytest.main() 217 | -------------------------------------------------------------------------------- /python/trainer_config.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | trainer: ppo 3 | batch_size: 1024 4 | beta: 5.0e-3 5 | buffer_size: 10240 6 | epsilon: 0.2 7 | gamma: 0.99 8 | hidden_units: 128 9 | lambd: 0.95 10 | learning_rate: 3.0e-4 11 | max_steps: 5.0e4 12 | memory_size: 256 13 | normalize: false 14 | num_epoch: 3 15 | num_layers: 2 16 | time_horizon: 64 17 | sequence_length: 64 18 | summary_freq: 1000 19 | use_recurrent: false 20 | use_curiosity: false 21 | curiosity_strength: 0.01 22 | curiosity_enc_size: 128 23 | 24 | BananaBrain: 25 | normalize: false 26 | batch_size: 1024 27 | beta: 5.0e-3 28 | buffer_size: 10240 29 | 30 | BouncerBrain: 31 | normalize: true 32 | max_steps: 5.0e5 33 | num_layers: 2 34 | hidden_units: 64 35 | 36 | PushBlockBrain: 37 | max_steps: 5.0e4 38 | batch_size: 128 39 | buffer_size: 2048 40 | beta: 1.0e-2 41 | hidden_units: 256 42 | summary_freq: 2000 43 | time_horizon: 64 44 | num_layers: 2 45 | 46 | SmallWallBrain: 47 | max_steps: 2.0e5 48 | batch_size: 128 49 | buffer_size: 2048 50 | beta: 5.0e-3 51 | hidden_units: 256 52 | summary_freq: 2000 53 | time_horizon: 128 54 | num_layers: 2 55 | normalize: false 56 | 57 | BigWallBrain: 58 | max_steps: 2.0e5 59 | batch_size: 128 60 | buffer_size: 2048 61 | beta: 5.0e-3 62 | hidden_units: 256 63 | summary_freq: 2000 64 | time_horizon: 128 65 | num_layers: 2 66 | normalize: false 67 | 68 | StrikerBrain: 69 | max_steps: 1.0e5 70 | batch_size: 128 71 | buffer_size: 2048 72 | beta: 5.0e-3 73 | hidden_units: 256 74 | summary_freq: 2000 75 | time_horizon: 128 76 | num_layers: 2 77 | normalize: false 78 | 79 | GoalieBrain: 80 | max_steps: 1.0e5 81 | batch_size: 128 82 | buffer_size: 2048 83 | beta: 5.0e-3 84 | hidden_units: 256 85 | summary_freq: 2000 86 | time_horizon: 128 87 | num_layers: 2 88 | normalize: false 89 | 90 | PyramidBrain: 91 | use_curiosity: true 92 | summary_freq: 2000 93 | curiosity_strength: 0.01 94 | curiosity_enc_size: 256 95 | time_horizon: 128 96 | batch_size: 128 97 | buffer_size: 2048 98 | hidden_units: 512 99 | num_layers: 2 100 | beta: 1.0e-2 101 | max_steps: 2.0e5 102 | num_epoch: 3 103 | 104 | VisualPyramidBrain: 105 | use_curiosity: true 106 | time_horizon: 128 107 | batch_size: 32 108 | buffer_size: 1024 109 | hidden_units: 256 110 | num_layers: 2 111 | beta: 1.0e-2 112 | max_steps: 5.0e5 113 | num_epoch: 3 114 | 115 | Ball3DBrain: 116 | normalize: true 117 | batch_size: 64 118 | buffer_size: 12000 119 | summary_freq: 1000 120 | time_horizon: 1000 121 | lambd: 0.99 122 | gamma: 0.995 123 | beta: 0.001 124 | 125 | Ball3DHardBrain: 126 | normalize: true 127 | batch_size: 1200 128 | buffer_size: 12000 129 | summary_freq: 1000 130 | time_horizon: 1000 131 | gamma: 0.995 132 | beta: 0.001 133 | 134 | TennisBrain: 135 | normalize: true 136 | 137 | CrawlerBrain: 138 | normalize: true 139 | num_epoch: 3 140 | time_horizon: 1000 141 | batch_size: 2024 142 | buffer_size: 20240 143 | gamma: 0.995 144 | max_steps: 1e6 145 | summary_freq: 3000 146 | num_layers: 3 147 | hidden_units: 512 148 | 149 | WalkerBrain: 150 | normalize: true 151 | num_epoch: 3 152 | time_horizon: 1000 153 | batch_size: 2048 154 | buffer_size: 20480 155 | gamma: 0.995 156 | max_steps: 2e6 157 | summary_freq: 3000 158 | num_layers: 3 159 | hidden_units: 512 160 | 161 | ReacherBrain: 162 | normalize: true 163 | num_epoch: 3 164 | time_horizon: 1000 165 | batch_size: 2024 166 | buffer_size: 20240 167 | gamma: 0.995 168 | max_steps: 1e6 169 | summary_freq: 3000 170 | 171 | HallwayBrain: 172 | use_recurrent: true 173 | sequence_length: 64 174 | num_layers: 2 175 | hidden_units: 128 176 | memory_size: 256 177 | beta: 1.0e-2 178 | gamma: 0.99 179 | num_epoch: 3 180 | buffer_size: 1024 181 | batch_size: 128 182 | max_steps: 5.0e5 183 | summary_freq: 1000 184 | time_horizon: 64 185 | 186 | GridWorldBrain: 187 | batch_size: 32 188 | normalize: false 189 | num_layers: 1 190 | hidden_units: 256 191 | beta: 5.0e-3 192 | gamma: 0.9 193 | buffer_size: 256 194 | max_steps: 5.0e5 195 | summary_freq: 2000 196 | time_horizon: 5 197 | 198 | BasicBrain: 199 | batch_size: 32 200 | normalize: false 201 | num_layers: 1 202 | hidden_units: 20 203 | beta: 5.0e-3 204 | gamma: 0.9 205 | buffer_size: 256 206 | max_steps: 5.0e5 207 | summary_freq: 2000 208 | time_horizon: 3 209 | 210 | StudentBrain: 211 | trainer: imitation 212 | max_steps: 10000 213 | summary_freq: 1000 214 | brain_to_imitate: TeacherBrain 215 | batch_size: 16 216 | batches_per_epoch: 5 217 | num_layers: 4 218 | hidden_units: 64 219 | sequence_length: 16 220 | buffer_size: 128 221 | 222 | StudentRecurrentBrain: 223 | trainer: imitation 224 | max_steps: 10000 225 | summary_freq: 1000 226 | brain_to_imitate: TeacherBrain 227 | batch_size: 16 228 | batches_per_epoch: 5 229 | num_layers: 4 230 | hidden_units: 64 231 | use_recurrent: true 232 | sequence_length: 32 233 | buffer_size: 128 234 | -------------------------------------------------------------------------------- /python/unityagents/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import * 2 | from .brain import * 3 | from .exception import * 4 | from .curriculum import * 5 | -------------------------------------------------------------------------------- /python/unityagents/brain.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | class BrainInfo: 5 | def __init__(self, visual_observation, vector_observation, text_observations, memory=None, 6 | reward=None, agents=None, local_done=None, 7 | vector_action=None, text_action=None, max_reached=None): 8 | """ 9 | Describes experience at current step of all agents linked to a brain. 10 | """ 11 | self.visual_observations = visual_observation 12 | self.vector_observations = vector_observation 13 | self.text_observations = text_observations 14 | self.memories = memory 15 | self.rewards = reward 16 | self.local_done = local_done 17 | self.max_reached = max_reached 18 | self.agents = agents 19 | self.previous_vector_actions = vector_action 20 | self.previous_text_actions = text_action 21 | 22 | 23 | AllBrainInfo = Dict[str, BrainInfo] 24 | 25 | 26 | class BrainParameters: 27 | def __init__(self, brain_name, brain_param): 28 | """ 29 | Contains all brain-specific parameters. 30 | :param brain_name: Name of brain. 31 | :param brain_param: Dictionary of brain parameters. 32 | """ 33 | self.brain_name = brain_name 34 | self.vector_observation_space_size = brain_param["vectorObservationSize"] 35 | self.num_stacked_vector_observations = brain_param["numStackedVectorObservations"] 36 | self.number_visual_observations = len(brain_param["cameraResolutions"]) 37 | self.camera_resolutions = brain_param["cameraResolutions"] 38 | self.vector_action_space_size = brain_param["vectorActionSize"] 39 | self.vector_action_descriptions = brain_param["vectorActionDescriptions"] 40 | self.vector_action_space_type = ["discrete", "continuous"][brain_param["vectorActionSpaceType"]] 41 | self.vector_observation_space_type = ["discrete", "continuous"][brain_param["vectorObservationSpaceType"]] 42 | 43 | def __str__(self): 44 | return '''Unity brain name: {0} 45 | Number of Visual Observations (per agent): {1} 46 | Vector Observation space type: {2} 47 | Vector Observation space size (per agent): {3} 48 | Number of stacked Vector Observation: {4} 49 | Vector Action space type: {5} 50 | Vector Action space size (per agent): {6} 51 | Vector Action descriptions: {7}'''.format(self.brain_name, 52 | str(self.number_visual_observations), 53 | self.vector_observation_space_type, 54 | str(self.vector_observation_space_size), 55 | str(self.num_stacked_vector_observations), 56 | self.vector_action_space_type, 57 | str(self.vector_action_space_size), 58 | ', '.join(self.vector_action_descriptions)) 59 | -------------------------------------------------------------------------------- /python/unityagents/communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from communicator_objects import UnityOutput, UnityInput 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | logger = logging.getLogger("unityagents") 7 | 8 | 9 | class Communicator(object): 10 | def __init__(self, worker_id=0, 11 | base_port=5005): 12 | """ 13 | Python side of the communication. Must be used in pair with the right Unity Communicator equivalent. 14 | 15 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 16 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 17 | """ 18 | 19 | def initialize(self, inputs: UnityInput) -> UnityOutput: 20 | """ 21 | Used to exchange initialization parameters between Python and the Environment 22 | :param inputs: The initialization input that will be sent to the environment. 23 | :return: UnityOutput: The initialization output sent by Unity 24 | """ 25 | 26 | def exchange(self, inputs: UnityInput) -> UnityOutput: 27 | """ 28 | Used to send an input and receive an output from the Environment 29 | :param inputs: The UnityInput that needs to be sent the Environment 30 | :return: The UnityOutputs generated by the Environment 31 | """ 32 | 33 | def close(self): 34 | """ 35 | Sends a shutdown signal to the unity environment, and closes the connection. 36 | """ 37 | 38 | -------------------------------------------------------------------------------- /python/unityagents/curriculum.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .exception import UnityEnvironmentException 4 | 5 | import logging 6 | 7 | logger = logging.getLogger("unityagents") 8 | 9 | 10 | class Curriculum(object): 11 | def __init__(self, location, default_reset_parameters): 12 | """ 13 | Initializes a Curriculum object. 14 | :param location: Path to JSON defining curriculum. 15 | :param default_reset_parameters: Set of reset parameters for environment. 16 | """ 17 | self.lesson_length = 0 18 | self.max_lesson_number = 0 19 | self.measure_type = None 20 | if location is None: 21 | self.data = None 22 | else: 23 | try: 24 | with open(location) as data_file: 25 | self.data = json.load(data_file) 26 | except IOError: 27 | raise UnityEnvironmentException( 28 | "The file {0} could not be found.".format(location)) 29 | except UnicodeDecodeError: 30 | raise UnityEnvironmentException("There was an error decoding {}".format(location)) 31 | self.smoothing_value = 0 32 | for key in ['parameters', 'measure', 'thresholds', 33 | 'min_lesson_length', 'signal_smoothing']: 34 | if key not in self.data: 35 | raise UnityEnvironmentException("{0} does not contain a " 36 | "{1} field.".format(location, key)) 37 | parameters = self.data['parameters'] 38 | self.measure_type = self.data['measure'] 39 | self.max_lesson_number = len(self.data['thresholds']) 40 | for key in parameters: 41 | if key not in default_reset_parameters: 42 | raise UnityEnvironmentException( 43 | "The parameter {0} in Curriculum {1} is not present in " 44 | "the Environment".format(key, location)) 45 | for key in parameters: 46 | if len(parameters[key]) != self.max_lesson_number + 1: 47 | raise UnityEnvironmentException( 48 | "The parameter {0} in Curriculum {1} must have {2} values " 49 | "but {3} were found".format(key, location, 50 | self.max_lesson_number + 1, len(parameters[key]))) 51 | self.set_lesson_number(0) 52 | 53 | @property 54 | def measure(self): 55 | return self.measure_type 56 | 57 | @property 58 | def get_lesson_number(self): 59 | return self.lesson_number 60 | 61 | def set_lesson_number(self, value): 62 | self.lesson_length = 0 63 | self.lesson_number = max(0, min(value, self.max_lesson_number)) 64 | 65 | def increment_lesson(self, progress): 66 | """ 67 | Increments the lesson number depending on the progree given. 68 | :param progress: Measure of progress (either reward or percentage steps completed). 69 | """ 70 | if self.data is None or progress is None: 71 | return 72 | if self.data["signal_smoothing"]: 73 | progress = self.smoothing_value * 0.25 + 0.75 * progress 74 | self.smoothing_value = progress 75 | self.lesson_length += 1 76 | if self.lesson_number < self.max_lesson_number: 77 | if ((progress > self.data['thresholds'][self.lesson_number]) and 78 | (self.lesson_length > self.data['min_lesson_length'])): 79 | self.lesson_length = 0 80 | self.lesson_number += 1 81 | config = {} 82 | parameters = self.data["parameters"] 83 | for key in parameters: 84 | config[key] = parameters[key][self.lesson_number] 85 | logger.info("\nLesson changed. Now in Lesson {0} : \t{1}" 86 | .format(self.lesson_number, 87 | ', '.join([str(x) + ' -> ' + str(config[x]) for x in config]))) 88 | 89 | def get_config(self, lesson=None): 90 | """ 91 | Returns reset parameters which correspond to the lesson. 92 | :param lesson: The lesson you want to get the config of. If None, the current lesson is returned. 93 | :return: The configuration of the reset parameters. 94 | """ 95 | if self.data is None: 96 | return {} 97 | if lesson is None: 98 | lesson = self.lesson_number 99 | lesson = max(0, min(lesson, self.max_lesson_number)) 100 | config = {} 101 | parameters = self.data["parameters"] 102 | for key in parameters: 103 | config[key] = parameters[key][lesson] 104 | return config 105 | -------------------------------------------------------------------------------- /python/unityagents/exception.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger("unityagents") 3 | 4 | class UnityException(Exception): 5 | """ 6 | Any error related to ml-agents environment. 7 | """ 8 | pass 9 | 10 | class UnityEnvironmentException(UnityException): 11 | """ 12 | Related to errors starting and closing environment. 13 | """ 14 | pass 15 | 16 | 17 | class UnityActionException(UnityException): 18 | """ 19 | Related to errors with sending actions. 20 | """ 21 | pass 22 | 23 | class UnityTimeOutException(UnityException): 24 | """ 25 | Related to errors with communication timeouts. 26 | """ 27 | def __init__(self, message, log_file_path = None): 28 | if log_file_path is not None: 29 | try: 30 | with open(log_file_path, "r") as f: 31 | printing = False 32 | unity_error = '\n' 33 | for l in f: 34 | l=l.strip() 35 | if (l == 'Exception') or (l=='Error'): 36 | printing = True 37 | unity_error += '----------------------\n' 38 | if (l == ''): 39 | printing = False 40 | if printing: 41 | unity_error += l + '\n' 42 | logger.info(unity_error) 43 | logger.error("An error might have occured in the environment. " 44 | "You can check the logfile for more information at {}".format(log_file_path)) 45 | except: 46 | logger.error("An error might have occured in the environment. " 47 | "No unity-environment.log file could be found.") 48 | super(UnityTimeOutException, self).__init__(message) 49 | 50 | -------------------------------------------------------------------------------- /python/unityagents/rpc_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import grpc 3 | 4 | from multiprocessing import Pipe 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | from .communicator import Communicator 8 | from communicator_objects import UnityToExternalServicer, add_UnityToExternalServicer_to_server 9 | from communicator_objects import UnityMessage, UnityInput, UnityOutput 10 | from .exception import UnityTimeOutException 11 | 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger("unityagents") 15 | 16 | 17 | class UnityToExternalServicerImplementation(UnityToExternalServicer): 18 | parent_conn, child_conn = Pipe() 19 | 20 | def Initialize(self, request, context): 21 | self.child_conn.send(request) 22 | return self.child_conn.recv() 23 | 24 | def Exchange(self, request, context): 25 | self.child_conn.send(request) 26 | return self.child_conn.recv() 27 | 28 | 29 | class RpcCommunicator(Communicator): 30 | def __init__(self, worker_id=0, 31 | base_port=5005): 32 | """ 33 | Python side of the grpc communication. Python is the server and Unity the client 34 | 35 | 36 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 37 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 38 | """ 39 | self.port = base_port + worker_id 40 | self.worker_id = worker_id 41 | self.server = None 42 | self.unity_to_external = None 43 | self.is_open = False 44 | 45 | def initialize(self, inputs: UnityInput) -> UnityOutput: 46 | try: 47 | # Establish communication grpc 48 | self.server = grpc.server(ThreadPoolExecutor(max_workers=10)) 49 | self.unity_to_external = UnityToExternalServicerImplementation() 50 | add_UnityToExternalServicer_to_server(self.unity_to_external, self.server) 51 | self.server.add_insecure_port('[::]:'+str(self.port)) 52 | self.server.start() 53 | except : 54 | raise UnityTimeOutException( 55 | "Couldn't start socket communication because worker number {} is still in use. " 56 | "You may need to manually close a previously opened environment " 57 | "or use a different worker number.".format(str(self.worker_id))) 58 | if not self.unity_to_external.parent_conn.poll(30): 59 | raise UnityTimeOutException( 60 | "The Unity environment took too long to respond. Make sure that :\n" 61 | "\t The environment does not need user interaction to launch\n" 62 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" 63 | "\t The environment and the Python interface have compatible versions.") 64 | aca_param = self.unity_to_external.parent_conn.recv().unity_output 65 | self.is_open = True 66 | message = UnityMessage() 67 | message.header.status = 200 68 | message.unity_input.CopyFrom(inputs) 69 | self.unity_to_external.parent_conn.send(message) 70 | self.unity_to_external.parent_conn.recv() 71 | return aca_param 72 | 73 | def exchange(self, inputs: UnityInput) -> UnityOutput: 74 | message = UnityMessage() 75 | message.header.status = 200 76 | message.unity_input.CopyFrom(inputs) 77 | self.unity_to_external.parent_conn.send(message) 78 | output = self.unity_to_external.parent_conn.recv() 79 | if output.header.status != 200: 80 | return None 81 | return output.unity_output 82 | 83 | def close(self): 84 | """ 85 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 86 | """ 87 | if self.is_open: 88 | message_input = UnityMessage() 89 | message_input.header.status = 400 90 | self.unity_to_external.parent_conn.send(message_input) 91 | self.unity_to_external.parent_conn.close() 92 | self.server.stop(False) 93 | self.is_open = False 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /python/unityagents/socket_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import socket 3 | import struct 4 | 5 | from .communicator import Communicator 6 | from communicator_objects import UnityMessage, UnityOutput, UnityInput 7 | from .exception import UnityTimeOutException 8 | 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | logger = logging.getLogger("unityagents") 12 | 13 | 14 | class SocketCommunicator(Communicator): 15 | def __init__(self, worker_id=0, 16 | base_port=5005): 17 | """ 18 | Python side of the socket communication 19 | 20 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 21 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 22 | """ 23 | 24 | self.port = base_port + worker_id 25 | self._buffer_size = 12000 26 | self.worker_id = worker_id 27 | self._socket = None 28 | self._conn = None 29 | 30 | def initialize(self, inputs: UnityInput) -> UnityOutput: 31 | try: 32 | # Establish communication socket 33 | self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 34 | self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 35 | self._socket.bind(("localhost", self.port)) 36 | except: 37 | raise UnityTimeOutException("Couldn't start socket communication because worker number {} is still in use. " 38 | "You may need to manually close a previously opened environment " 39 | "or use a different worker number.".format(str(self.worker_id))) 40 | try: 41 | self._socket.settimeout(30) 42 | self._socket.listen(1) 43 | self._conn, _ = self._socket.accept() 44 | self._conn.settimeout(30) 45 | except : 46 | raise UnityTimeOutException( 47 | "The Unity environment took too long to respond. Make sure that :\n" 48 | "\t The environment does not need user interaction to launch\n" 49 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" 50 | "\t The environment and the Python interface have compatible versions.") 51 | message = UnityMessage() 52 | message.header.status = 200 53 | message.unity_input.CopyFrom(inputs) 54 | self._communicator_send(message.SerializeToString()) 55 | initialization_output = UnityMessage() 56 | initialization_output.ParseFromString(self._communicator_receive()) 57 | return initialization_output.unity_output 58 | 59 | def _communicator_receive(self): 60 | try: 61 | s = self._conn.recv(self._buffer_size) 62 | message_length = struct.unpack("I", bytearray(s[:4]))[0] 63 | s = s[4:] 64 | while len(s) != message_length: 65 | s += self._conn.recv(self._buffer_size) 66 | except socket.timeout as e: 67 | raise UnityTimeOutException("The environment took too long to respond.") 68 | return s 69 | 70 | def _communicator_send(self, message): 71 | self._conn.send(struct.pack("I", len(message)) + message) 72 | 73 | def exchange(self, inputs: UnityInput) -> UnityOutput: 74 | message = UnityMessage() 75 | message.header.status = 200 76 | message.unity_input.CopyFrom(inputs) 77 | self._communicator_send(message.SerializeToString()) 78 | outputs = UnityMessage() 79 | outputs.ParseFromString(self._communicator_receive()) 80 | if outputs.header.status != 200: 81 | return None 82 | return outputs.unity_output 83 | 84 | def close(self): 85 | """ 86 | Sends a shutdown signal to the unity environment, and closes the socket connection. 87 | """ 88 | if self._socket is not None and self._conn is not None: 89 | message_input = UnityMessage() 90 | message_input.header.status = 400 91 | self._communicator_send(message_input.SerializeToString()) 92 | if self._socket is not None: 93 | self._socket.close() 94 | self._socket = None 95 | if self._socket is not None: 96 | self._conn.close() 97 | self._conn = None 98 | 99 | -------------------------------------------------------------------------------- /python/unitytrainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .buffer import * 2 | from .models import * 3 | from .trainer_controller import * 4 | from .bc.models import * 5 | from .bc.trainer import * 6 | from .ppo.models import * 7 | from .ppo.trainer import * 8 | -------------------------------------------------------------------------------- /python/unitytrainers/bc/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .trainer import * 3 | -------------------------------------------------------------------------------- /python/unitytrainers/bc/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as c_layers 3 | from unitytrainers.models import LearningModel 4 | 5 | 6 | class BehavioralCloningModel(LearningModel): 7 | def __init__(self, brain, h_size=128, lr=1e-4, n_layers=2, m_size=128, 8 | normalize=False, use_recurrent=False): 9 | LearningModel.__init__(self, m_size, normalize, use_recurrent, brain) 10 | 11 | num_streams = 1 12 | hidden_streams = self.create_observation_streams(num_streams, h_size, n_layers) 13 | hidden = hidden_streams[0] 14 | self.dropout_rate = tf.placeholder(dtype=tf.float32, shape=[], name="dropout_rate") 15 | hidden_reg = tf.layers.dropout(hidden, self.dropout_rate) 16 | if self.use_recurrent: 17 | tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32) 18 | self.memory_in = tf.placeholder(shape=[None, self.m_size], dtype=tf.float32, name='recurrent_in') 19 | hidden_reg, self.memory_out = self.create_recurrent_encoder(hidden_reg, self.memory_in, 20 | self.sequence_length) 21 | self.memory_out = tf.identity(self.memory_out, name='recurrent_out') 22 | self.policy = tf.layers.dense(hidden_reg, self.a_size, activation=None, use_bias=False, name='pre_action', 23 | kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01)) 24 | 25 | if brain.vector_action_space_type == "discrete": 26 | self.action_probs = tf.nn.softmax(self.policy) 27 | self.sample_action_float = tf.multinomial(self.policy, 1) 28 | self.sample_action_float = tf.identity(self.sample_action_float, name="action") 29 | self.sample_action = tf.cast(self.sample_action_float, tf.int32) 30 | self.true_action = tf.placeholder(shape=[None], dtype=tf.int32, name="teacher_action") 31 | self.action_oh = tf.one_hot(self.true_action, self.a_size) 32 | self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh) 33 | self.action_percent = tf.reduce_mean(tf.cast( 34 | tf.equal(tf.cast(tf.argmax(self.action_probs, axis=1), tf.int32), self.sample_action), tf.float32)) 35 | else: 36 | self.clipped_sample_action = tf.clip_by_value(self.policy, -1, 1) 37 | self.sample_action = tf.identity(self.clipped_sample_action, name="action") 38 | self.true_action = tf.placeholder(shape=[None, self.a_size], dtype=tf.float32, name="teacher_action") 39 | self.clipped_true_action = tf.clip_by_value(self.true_action, -1, 1) 40 | self.loss = tf.reduce_sum(tf.squared_difference(self.clipped_true_action, self.sample_action)) 41 | 42 | optimizer = tf.train.AdamOptimizer(learning_rate=lr) 43 | self.update = optimizer.minimize(self.loss) 44 | -------------------------------------------------------------------------------- /python/unitytrainers/buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from unityagents.exception import UnityException 4 | 5 | 6 | class BufferException(UnityException): 7 | """ 8 | Related to errors with the Buffer. 9 | """ 10 | pass 11 | 12 | 13 | class Buffer(dict): 14 | """ 15 | Buffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id. 16 | Buffer also contains an update_buffer that corresponds to the buffer used when updating the model. 17 | """ 18 | 19 | class AgentBuffer(dict): 20 | """ 21 | AgentBuffer contains a dictionary of AgentBufferFields. Each agent has his own AgentBuffer. 22 | The keys correspond to the name of the field. Example: state, action 23 | """ 24 | 25 | class AgentBufferField(list): 26 | """ 27 | AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his 28 | AgentBufferField with the append method. 29 | """ 30 | 31 | def __str__(self): 32 | return str(np.array(self).shape) 33 | 34 | def extend(self, data): 35 | """ 36 | Ads a list of np.arrays to the end of the list of np.arrays. 37 | :param data: The np.array list to append. 38 | """ 39 | self += list(np.array(data)) 40 | 41 | def set(self, data): 42 | """ 43 | Sets the list of np.array to the input data 44 | :param data: The np.array list to be set. 45 | """ 46 | self[:] = [] 47 | self[:] = list(np.array(data)) 48 | 49 | def get_batch(self, batch_size=None, training_length=1, sequential=True): 50 | """ 51 | Retrieve the last batch_size elements of length training_length 52 | from the list of np.array 53 | :param batch_size: The number of elements to retrieve. If None: 54 | All elements will be retrieved. 55 | :param training_length: The length of the sequence to be retrieved. If 56 | None: only takes one element. 57 | :param sequential: If true and training_length is not None: the elements 58 | will not repeat in the sequence. [a,b,c,d,e] with training_length = 2 and 59 | sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives 60 | [[a,b],[b,c],[c,d],[d,e]] 61 | """ 62 | if training_length == 1: 63 | # When the training length is 1, the method returns a list of elements, 64 | # not a list of sequences of elements. 65 | if batch_size is None: 66 | # If batch_size is None : All the elements of the AgentBufferField are returned. 67 | return np.array(self) 68 | else: 69 | # return the batch_size last elements 70 | if batch_size > len(self): 71 | raise BufferException("Batch size requested is too large") 72 | return np.array(self[-batch_size:]) 73 | else: 74 | # The training_length is not None, the method returns a list of SEQUENCES of elements 75 | if not sequential: 76 | # The sequences will have overlapping elements 77 | if batch_size is None: 78 | # retrieve the maximum number of elements 79 | batch_size = len(self) - training_length + 1 80 | # The number of sequences of length training_length taken from a list of len(self) elements 81 | # with overlapping is equal to batch_size 82 | if (len(self) - training_length + 1) < batch_size: 83 | raise BufferException("The batch size and training length requested for get_batch where" 84 | " too large given the current number of data points.") 85 | tmp_list = [] 86 | for end in range(len(self) - batch_size + 1, len(self) + 1): 87 | tmp_list += [np.array(self[end - training_length:end])] 88 | return np.array(tmp_list) 89 | if sequential: 90 | # The sequences will not have overlapping elements (this involves padding) 91 | leftover = len(self) % training_length 92 | # leftover is the number of elements in the first sequence (this sequence might need 0 padding) 93 | if batch_size is None: 94 | # retrieve the maximum number of elements 95 | batch_size = len(self) // training_length + 1 * (leftover != 0) 96 | # The maximum number of sequences taken from a list of length len(self) without overlapping 97 | # with padding is equal to batch_size 98 | if batch_size > (len(self) // training_length + 1 * (leftover != 0)): 99 | raise BufferException("The batch size and training length requested for get_batch where" 100 | " too large given the current number of data points.") 101 | tmp_list = [] 102 | padding = np.array(self[-1]) * 0 103 | # The padding is made with zeros and its shape is given by the shape of the last element 104 | for end in range(len(self), len(self) % training_length, -training_length)[:batch_size]: 105 | tmp_list += [np.array(self[end - training_length:end])] 106 | if (leftover != 0) and (len(tmp_list) < batch_size): 107 | tmp_list += [np.array([padding] * (training_length - leftover) + self[:leftover])] 108 | tmp_list.reverse() 109 | return np.array(tmp_list) 110 | 111 | def reset_field(self): 112 | """ 113 | Resets the AgentBufferField 114 | """ 115 | self[:] = [] 116 | 117 | def __init__(self): 118 | self.last_brain_info = None 119 | self.last_take_action_outputs = None 120 | super(Buffer.AgentBuffer, self).__init__() 121 | 122 | def __str__(self): 123 | return ", ".join(["'{0}' : {1}".format(k, str(self[k])) for k in self.keys()]) 124 | 125 | def reset_agent(self): 126 | """ 127 | Resets the AgentBuffer 128 | """ 129 | for k in self.keys(): 130 | self[k].reset_field() 131 | self.last_brain_info = None 132 | self.last_take_action_outputs = None 133 | 134 | def __getitem__(self, key): 135 | if key not in self.keys(): 136 | self[key] = self.AgentBufferField() 137 | return super(Buffer.AgentBuffer, self).__getitem__(key) 138 | 139 | def check_length(self, key_list): 140 | """ 141 | Some methods will require that some fields have the same length. 142 | check_length will return true if the fields in key_list 143 | have the same length. 144 | :param key_list: The fields which length will be compared 145 | """ 146 | if len(key_list) < 2: 147 | return True 148 | l = None 149 | for key in key_list: 150 | if key not in self.keys(): 151 | return False 152 | if (l is not None) and (l != len(self[key])): 153 | return False 154 | l = len(self[key]) 155 | return True 156 | 157 | def shuffle(self, key_list=None): 158 | """ 159 | Shuffles the fields in key_list in a consistent way: The reordering will 160 | be the same across fields. 161 | :param key_list: The fields that must be shuffled. 162 | """ 163 | if key_list is None: 164 | key_list = list(self.keys()) 165 | if not self.check_length(key_list): 166 | raise BufferException("Unable to shuffle if the fields are not of same length") 167 | s = np.arange(len(self[key_list[0]])) 168 | np.random.shuffle(s) 169 | for key in key_list: 170 | self[key][:] = [self[key][i] for i in s] 171 | 172 | def __init__(self): 173 | self.update_buffer = self.AgentBuffer() 174 | super(Buffer, self).__init__() 175 | 176 | def __str__(self): 177 | return "update buffer :\n\t{0}\nlocal_buffers :\n{1}".format(str(self.update_buffer), 178 | '\n'.join( 179 | ['\tagent {0} :{1}'.format(k, str(self[k])) for 180 | k in self.keys()])) 181 | 182 | def __getitem__(self, key): 183 | if key not in self.keys(): 184 | self[key] = self.AgentBuffer() 185 | return super(Buffer, self).__getitem__(key) 186 | 187 | def reset_update_buffer(self): 188 | """ 189 | Resets the update buffer 190 | """ 191 | self.update_buffer.reset_agent() 192 | 193 | def reset_all(self): 194 | """ 195 | Resets all the local local_buffers 196 | """ 197 | agent_ids = list(self.keys()) 198 | for k in agent_ids: 199 | self[k].reset_agent() 200 | 201 | def append_update_buffer(self, agent_id, key_list=None, batch_size=None, training_length=None): 202 | """ 203 | Appends the buffer of an agent to the update buffer. 204 | :param agent_id: The id of the agent which data will be appended 205 | :param key_list: The fields that must be added. If None: all fields will be appended. 206 | :param batch_size: The number of elements that must be appended. If None: All of them will be. 207 | :param training_length: The length of the samples that must be appended. If None: only takes one element. 208 | """ 209 | if key_list is None: 210 | key_list = self[agent_id].keys() 211 | if not self[agent_id].check_length(key_list): 212 | raise BufferException("The length of the fields {0} for agent {1} where not of same length" 213 | .format(key_list, agent_id)) 214 | for field_key in key_list: 215 | self.update_buffer[field_key].extend( 216 | self[agent_id][field_key].get_batch(batch_size=batch_size, training_length=training_length) 217 | ) 218 | 219 | def append_all_agent_batch_to_update_buffer(self, key_list=None, batch_size=None, training_length=None): 220 | """ 221 | Appends the buffer of all agents to the update buffer. 222 | :param key_list: The fields that must be added. If None: all fields will be appended. 223 | :param batch_size: The number of elements that must be appended. If None: All of them will be. 224 | :param training_length: The length of the samples that must be appended. If None: only takes one element. 225 | """ 226 | for agent_id in self.keys(): 227 | self.append_update_buffer(agent_id, key_list, batch_size, training_length) 228 | -------------------------------------------------------------------------------- /python/unitytrainers/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .trainer import * 3 | -------------------------------------------------------------------------------- /python/unitytrainers/ppo/models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import tensorflow as tf 4 | from unitytrainers.models import LearningModel 5 | 6 | logger = logging.getLogger("unityagents") 7 | 8 | 9 | class PPOModel(LearningModel): 10 | def __init__(self, brain, lr=1e-4, h_size=128, epsilon=0.2, beta=1e-3, max_step=5e6, 11 | normalize=False, use_recurrent=False, num_layers=2, m_size=None, use_curiosity=False, 12 | curiosity_strength=0.01, curiosity_enc_size=128): 13 | """ 14 | Takes a Unity environment and model-specific hyper-parameters and returns the 15 | appropriate PPO agent model for the environment. 16 | :param brain: BrainInfo used to generate specific network graph. 17 | :param lr: Learning rate. 18 | :param h_size: Size of hidden layers 19 | :param epsilon: Value for policy-divergence threshold. 20 | :param beta: Strength of entropy regularization. 21 | :return: a sub-class of PPOAgent tailored to the environment. 22 | :param max_step: Total number of training steps. 23 | :param normalize: Whether to normalize vector observation input. 24 | :param use_recurrent: Whether to use an LSTM layer in the network. 25 | :param num_layers Number of hidden layers between encoded input and policy & value layers 26 | :param m_size: Size of brain memory. 27 | """ 28 | LearningModel.__init__(self, m_size, normalize, use_recurrent, brain) 29 | self.use_curiosity = use_curiosity 30 | if num_layers < 1: 31 | num_layers = 1 32 | self.last_reward, self.new_reward, self.update_reward = self.create_reward_encoder() 33 | if brain.vector_action_space_type == "continuous": 34 | self.create_cc_actor_critic(h_size, num_layers) 35 | self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy 36 | else: 37 | self.create_dc_actor_critic(h_size, num_layers) 38 | if self.use_curiosity: 39 | self.curiosity_enc_size = curiosity_enc_size 40 | self.curiosity_strength = curiosity_strength 41 | encoded_state, encoded_next_state = self.create_curiosity_encoders() 42 | self.create_inverse_model(encoded_state, encoded_next_state) 43 | self.create_forward_model(encoded_state, encoded_next_state) 44 | self.create_ppo_optimizer(self.probs, self.old_probs, self.value, 45 | self.entropy, beta, epsilon, lr, max_step) 46 | 47 | @staticmethod 48 | def create_reward_encoder(): 49 | """Creates TF ops to track and increment recent average cumulative reward.""" 50 | last_reward = tf.Variable(0, name="last_reward", trainable=False, dtype=tf.float32) 51 | new_reward = tf.placeholder(shape=[], dtype=tf.float32, name='new_reward') 52 | update_reward = tf.assign(last_reward, new_reward) 53 | return last_reward, new_reward, update_reward 54 | 55 | def create_curiosity_encoders(self): 56 | """ 57 | Creates state encoders for current and future observations. 58 | Used for implementation of Curiosity-driven Exploration by Self-supervised Prediction 59 | See https://arxiv.org/abs/1705.05363 for more details. 60 | :return: current and future state encoder tensors. 61 | """ 62 | encoded_state_list = [] 63 | encoded_next_state_list = [] 64 | 65 | if self.v_size > 0: 66 | self.next_visual_in = [] 67 | visual_encoders = [] 68 | next_visual_encoders = [] 69 | for i in range(self.v_size): 70 | # Create input ops for next (t+1) visual observations. 71 | next_visual_input = self.create_visual_input(self.brain.camera_resolutions[i], 72 | name="next_visual_observation_" + str(i)) 73 | self.next_visual_in.append(next_visual_input) 74 | 75 | # Create the encoder ops for current and next visual input. Not that these encoders are siamese. 76 | encoded_visual = self.create_visual_observation_encoder(self.visual_in[i], self.curiosity_enc_size, 77 | self.swish, 1, "stream_{}_visual_obs_encoder" 78 | .format(i), False) 79 | 80 | encoded_next_visual = self.create_visual_observation_encoder(self.next_visual_in[i], 81 | self.curiosity_enc_size, 82 | self.swish, 1, 83 | "stream_{}_visual_obs_encoder".format(i), 84 | True) 85 | visual_encoders.append(encoded_visual) 86 | next_visual_encoders.append(encoded_next_visual) 87 | 88 | hidden_visual = tf.concat(visual_encoders, axis=1) 89 | hidden_next_visual = tf.concat(next_visual_encoders, axis=1) 90 | encoded_state_list.append(hidden_visual) 91 | encoded_next_state_list.append(hidden_next_visual) 92 | 93 | if self.o_size > 0: 94 | 95 | # Create the encoder ops for current and next vector input. Not that these encoders are siamese. 96 | if self.brain.vector_observation_space_type == "continuous": 97 | # Create input op for next (t+1) vector observation. 98 | self.next_vector_in = tf.placeholder(shape=[None, self.o_size], dtype=tf.float32, 99 | name='next_vector_observation') 100 | 101 | encoded_vector_obs = self.create_continuous_observation_encoder(self.vector_in, 102 | self.curiosity_enc_size, 103 | self.swish, 2, "vector_obs_encoder", 104 | False) 105 | encoded_next_vector_obs = self.create_continuous_observation_encoder(self.next_vector_in, 106 | self.curiosity_enc_size, 107 | self.swish, 2, 108 | "vector_obs_encoder", 109 | True) 110 | else: 111 | self.next_vector_in = tf.placeholder(shape=[None, 1], dtype=tf.int32, 112 | name='next_vector_observation') 113 | 114 | encoded_vector_obs = self.create_discrete_observation_encoder(self.vector_in, self.o_size, 115 | self.curiosity_enc_size, 116 | self.swish, 2, "vector_obs_encoder", 117 | False) 118 | encoded_next_vector_obs = self.create_discrete_observation_encoder(self.next_vector_in, self.o_size, 119 | self.curiosity_enc_size, 120 | self.swish, 2, "vector_obs_encoder", 121 | True) 122 | encoded_state_list.append(encoded_vector_obs) 123 | encoded_next_state_list.append(encoded_next_vector_obs) 124 | 125 | encoded_state = tf.concat(encoded_state_list, axis=1) 126 | encoded_next_state = tf.concat(encoded_next_state_list, axis=1) 127 | return encoded_state, encoded_next_state 128 | 129 | def create_inverse_model(self, encoded_state, encoded_next_state): 130 | """ 131 | Creates inverse model TensorFlow ops for Curiosity module. 132 | Predicts action taken given current and future encoded states. 133 | :param encoded_state: Tensor corresponding to encoded current state. 134 | :param encoded_next_state: Tensor corresponding to encoded next state. 135 | """ 136 | combined_input = tf.concat([encoded_state, encoded_next_state], axis=1) 137 | hidden = tf.layers.dense(combined_input, 256, activation=self.swish) 138 | if self.brain.vector_action_space_type == "continuous": 139 | pred_action = tf.layers.dense(hidden, self.a_size, activation=None) 140 | squared_difference = tf.reduce_sum(tf.squared_difference(pred_action, self.selected_actions), axis=1) 141 | self.inverse_loss = tf.reduce_mean(tf.dynamic_partition(squared_difference, self.mask, 2)[1]) 142 | else: 143 | pred_action = tf.layers.dense(hidden, self.a_size, activation=tf.nn.softmax) 144 | cross_entropy = tf.reduce_sum(-tf.log(pred_action + 1e-10) * self.selected_actions, axis=1) 145 | self.inverse_loss = tf.reduce_mean(tf.dynamic_partition(cross_entropy, self.mask, 2)[1]) 146 | 147 | def create_forward_model(self, encoded_state, encoded_next_state): 148 | """ 149 | Creates forward model TensorFlow ops for Curiosity module. 150 | Predicts encoded future state based on encoded current state and given action. 151 | :param encoded_state: Tensor corresponding to encoded current state. 152 | :param encoded_next_state: Tensor corresponding to encoded next state. 153 | """ 154 | combined_input = tf.concat([encoded_state, self.selected_actions], axis=1) 155 | hidden = tf.layers.dense(combined_input, 256, activation=self.swish) 156 | # We compare against the concatenation of all observation streams, hence `self.v_size + int(self.o_size > 0)`. 157 | pred_next_state = tf.layers.dense(hidden, self.curiosity_enc_size * (self.v_size + int(self.o_size > 0)), 158 | activation=None) 159 | 160 | squared_difference = 0.5 * tf.reduce_sum(tf.squared_difference(pred_next_state, encoded_next_state), axis=1) 161 | self.intrinsic_reward = tf.clip_by_value(self.curiosity_strength * squared_difference, 0, 1) 162 | self.forward_loss = tf.reduce_mean(tf.dynamic_partition(squared_difference, self.mask, 2)[1]) 163 | 164 | def create_ppo_optimizer(self, probs, old_probs, value, entropy, beta, epsilon, lr, max_step): 165 | """ 166 | Creates training-specific Tensorflow ops for PPO models. 167 | :param probs: Current policy probabilities 168 | :param old_probs: Past policy probabilities 169 | :param value: Current value estimate 170 | :param beta: Entropy regularization strength 171 | :param entropy: Current policy entropy 172 | :param epsilon: Value for policy-divergence threshold 173 | :param lr: Learning rate 174 | :param max_step: Total number of training steps. 175 | """ 176 | self.returns_holder = tf.placeholder(shape=[None], dtype=tf.float32, name='discounted_rewards') 177 | self.advantage = tf.placeholder(shape=[None, 1], dtype=tf.float32, name='advantages') 178 | self.learning_rate = tf.train.polynomial_decay(lr, self.global_step, max_step, 1e-10, power=1.0) 179 | 180 | self.old_value = tf.placeholder(shape=[None], dtype=tf.float32, name='old_value_estimates') 181 | 182 | decay_epsilon = tf.train.polynomial_decay(epsilon, self.global_step, max_step, 0.1, power=1.0) 183 | decay_beta = tf.train.polynomial_decay(beta, self.global_step, max_step, 1e-5, power=1.0) 184 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) 185 | 186 | clipped_value_estimate = self.old_value + tf.clip_by_value(tf.reduce_sum(value, axis=1) - self.old_value, 187 | - decay_epsilon, decay_epsilon) 188 | 189 | v_opt_a = tf.squared_difference(self.returns_holder, tf.reduce_sum(value, axis=1)) 190 | v_opt_b = tf.squared_difference(self.returns_holder, clipped_value_estimate) 191 | self.value_loss = tf.reduce_mean(tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.mask, 2)[1]) 192 | 193 | # Here we calculate PPO policy loss. In continuous control this is done independently for each action gaussian 194 | # and then averaged together. This provides significantly better performance than treating the probability 195 | # as an average of probabilities, or as a joint probability. 196 | r_theta = probs / (old_probs + 1e-10) 197 | p_opt_a = r_theta * self.advantage 198 | p_opt_b = tf.clip_by_value(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * self.advantage 199 | self.policy_loss = -tf.reduce_mean(tf.dynamic_partition(tf.minimum(p_opt_a, p_opt_b), self.mask, 2)[1]) 200 | 201 | self.loss = self.policy_loss + 0.5 * self.value_loss - decay_beta * tf.reduce_mean( 202 | tf.dynamic_partition(entropy, self.mask, 2)[1]) 203 | 204 | if self.use_curiosity: 205 | self.loss += 10 * (0.2 * self.forward_loss + 0.8 * self.inverse_loss) 206 | self.update_batch = optimizer.minimize(self.loss) 207 | -------------------------------------------------------------------------------- /python/unitytrainers/trainer.py: -------------------------------------------------------------------------------- 1 | # # Unity ML-Agents Toolkit 2 | import logging 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | from unityagents import UnityException, AllBrainInfo 8 | 9 | logger = logging.getLogger("unityagents") 10 | 11 | 12 | class UnityTrainerException(UnityException): 13 | """ 14 | Related to errors with the Trainer. 15 | """ 16 | pass 17 | 18 | 19 | class Trainer(object): 20 | """This class is the abstract class for the unitytrainers""" 21 | 22 | def __init__(self, sess, env, brain_name, trainer_parameters, training): 23 | """ 24 | Responsible for collecting experiences and training a neural network model. 25 | :param sess: Tensorflow session. 26 | :param env: The UnityEnvironment. 27 | :param trainer_parameters: The parameters for the trainer (dictionary). 28 | :param training: Whether the trainer is set for training. 29 | """ 30 | self.brain_name = brain_name 31 | self.brain = env.brains[self.brain_name] 32 | self.trainer_parameters = trainer_parameters 33 | self.is_training = training 34 | self.sess = sess 35 | self.stats = {} 36 | self.summary_writer = None 37 | 38 | def __str__(self): 39 | return '''Empty Trainer''' 40 | 41 | @property 42 | def parameters(self): 43 | """ 44 | Returns the trainer parameters of the trainer. 45 | """ 46 | raise UnityTrainerException("The parameters property was not implemented.") 47 | 48 | @property 49 | def graph_scope(self): 50 | """ 51 | Returns the graph scope of the trainer. 52 | """ 53 | raise UnityTrainerException("The graph_scope property was not implemented.") 54 | 55 | @property 56 | def get_max_steps(self): 57 | """ 58 | Returns the maximum number of steps. Is used to know when the trainer should be stopped. 59 | :return: The maximum number of steps of the trainer 60 | """ 61 | raise UnityTrainerException("The get_max_steps property was not implemented.") 62 | 63 | @property 64 | def get_step(self): 65 | """ 66 | Returns the number of steps the trainer has performed 67 | :return: the step count of the trainer 68 | """ 69 | raise UnityTrainerException("The get_step property was not implemented.") 70 | 71 | @property 72 | def get_last_reward(self): 73 | """ 74 | Returns the last reward the trainer has had 75 | :return: the new last reward 76 | """ 77 | raise UnityTrainerException("The get_last_reward property was not implemented.") 78 | 79 | def increment_step_and_update_last_reward(self): 80 | """ 81 | Increment the step count of the trainer and updates the last reward 82 | """ 83 | raise UnityTrainerException("The increment_step_and_update_last_reward method was not implemented.") 84 | 85 | def take_action(self, all_brain_info: AllBrainInfo): 86 | """ 87 | Decides actions given state/observation information, and takes them in environment. 88 | :param all_brain_info: A dictionary of brain names and BrainInfo from environment. 89 | :return: a tuple containing action, memories, values and an object 90 | to be passed to add experiences 91 | """ 92 | raise UnityTrainerException("The take_action method was not implemented.") 93 | 94 | def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take_action_outputs): 95 | """ 96 | Adds experiences to each agent's experience history. 97 | :param curr_info: Current AllBrainInfo. 98 | :param next_info: Next AllBrainInfo. 99 | :param take_action_outputs: The outputs of the take action method. 100 | """ 101 | raise UnityTrainerException("The add_experiences method was not implemented.") 102 | 103 | def process_experiences(self, current_info: AllBrainInfo, next_info: AllBrainInfo): 104 | """ 105 | Checks agent histories for processing condition, and processes them as necessary. 106 | Processing involves calculating value and advantage targets for model updating step. 107 | :param current_info: Dictionary of all current-step brains and corresponding BrainInfo. 108 | :param next_info: Dictionary of all next-step brains and corresponding BrainInfo. 109 | """ 110 | raise UnityTrainerException("The process_experiences method was not implemented.") 111 | 112 | def end_episode(self): 113 | """ 114 | A signal that the Episode has ended. The buffer must be reset. 115 | Get only called when the academy resets. 116 | """ 117 | raise UnityTrainerException("The end_episode method was not implemented.") 118 | 119 | def is_ready_update(self): 120 | """ 121 | Returns whether or not the trainer has enough elements to run update model 122 | :return: A boolean corresponding to wether or not update_model() can be run 123 | """ 124 | raise UnityTrainerException("The is_ready_update method was not implemented.") 125 | 126 | def update_model(self): 127 | """ 128 | Uses training_buffer to update model. 129 | """ 130 | raise UnityTrainerException("The update_model method was not implemented.") 131 | 132 | def write_summary(self, lesson_number): 133 | """ 134 | Saves training statistics to Tensorboard. 135 | :param lesson_number: The lesson the trainer is at. 136 | """ 137 | if (self.get_step % self.trainer_parameters['summary_freq'] == 0 and self.get_step != 0 and 138 | self.is_training and self.get_step <= self.get_max_steps): 139 | if len(self.stats['cumulative_reward']) > 0: 140 | mean_reward = np.mean(self.stats['cumulative_reward']) 141 | logger.info(" {}: Step: {}. Mean Reward: {:0.3f}. Std of Reward: {:0.3f}." 142 | .format(self.brain_name, self.get_step, 143 | mean_reward, np.std(self.stats['cumulative_reward']))) 144 | else: 145 | logger.info(" {}: Step: {}. No episode was completed since last summary." 146 | .format(self.brain_name, self.get_step)) 147 | summary = tf.Summary() 148 | for key in self.stats: 149 | if len(self.stats[key]) > 0: 150 | stat_mean = float(np.mean(self.stats[key])) 151 | summary.value.add(tag='Info/{}'.format(key), simple_value=stat_mean) 152 | self.stats[key] = [] 153 | summary.value.add(tag='Info/Lesson', simple_value=lesson_number) 154 | self.summary_writer.add_summary(summary, self.get_step) 155 | self.summary_writer.flush() 156 | 157 | def write_tensorboard_text(self, key, input_dict): 158 | """ 159 | Saves text to Tensorboard. 160 | Note: Only works on tensorflow r1.2 or above. 161 | :param key: The name of the text. 162 | :param input_dict: A dictionary that will be displayed in a table on Tensorboard. 163 | """ 164 | try: 165 | s_op = tf.summary.text(key, tf.convert_to_tensor(([[str(x), str(input_dict[x])] for x in input_dict]))) 166 | s = self.sess.run(s_op) 167 | self.summary_writer.add_summary(s, self.get_step) 168 | except: 169 | logger.info("Cannot write text summary for Tensorboard. Tensorflow version must be r1.2 or above.") 170 | pass 171 | -------------------------------------------------------------------------------- /replay_memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Replay Memory Class for DQN Agent for Vector Observation Learning 3 | 4 | Example Developed By: 5 | Michael Richardson, 2018 6 | Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL) 7 | Code expanded and adapted from code examples provided by Udacity DRL Team, 2018. 8 | """ 9 | 10 | # Import Required Packages 11 | import torch 12 | import numpy as np 13 | import random 14 | from collections import namedtuple, deque 15 | from model import QNetwork 16 | 17 | # Determine if CPU or GPU computation should be used 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | """ 22 | ################################################## 23 | ReplayBuffer Class 24 | Defines a Replay Memeory Buffer for a DQN or DDQN agent 25 | The buffer holds memories of: [sate, action reward, next sate, done] tuples 26 | Random batches of replay memories are sampled for learning. 27 | """ 28 | class ReplayBuffer: 29 | """Fixed-size buffer to store experience tuples.""" 30 | 31 | def __init__(self, action_size, buffer_size, batch_size, seed): 32 | """Initialize a ReplayBuffer object. 33 | 34 | Params 35 | ====== 36 | action_size (int): dimension of each action 37 | buffer_size (int): maximum size of buffer 38 | batch_size (int): size of each training batch 39 | seed (int): random seed 40 | """ 41 | self.action_size = action_size 42 | self.memory = deque(maxlen=buffer_size) 43 | self.batch_size = batch_size 44 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 45 | self.seed = random.seed(seed) 46 | 47 | def add(self, state, action, reward, next_state, done): 48 | """Add a new experience to memory.""" 49 | e = self.experience(state, action, reward, next_state, done) 50 | self.memory.append(e) 51 | 52 | def sample(self): 53 | """Randomly sample a batch of experiences from memory.""" 54 | experiences = random.sample(self.memory, k=self.batch_size) 55 | 56 | states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) 57 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device) 58 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) 59 | next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) 60 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) 61 | 62 | return (states, actions, rewards, next_states, dones) 63 | 64 | def __len__(self): 65 | """Return the current size of internal memory.""" 66 | return len(self.memory) 67 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Test DQN Model for Unity ML-Agents Environments using PyTorch 4 | 5 | This example tests a trained DQN NN model on a modified version of the Unity ML-Agents Banana Collection Example Environment. 6 | The environment includes a single agent, who can turn left or right and move forward or backward. 7 | The agent's task is to collect yellow bananas (reward of +1) that are scattered around an square 8 | game area, while avoiding purple bananas (reward of -1). For the version of Bananas employed here, 9 | the environment is considered solved when the average score over the last 100 episodes > 13. 10 | 11 | Example Developed By: 12 | Michael Richardson, 2018 13 | Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL) 14 | Code Expanded and Adapted from Code provided by Udacity DRL Team, 2018. 15 | """ 16 | 17 | ################################### 18 | # Import Required Packages 19 | import torch 20 | import time 21 | import random 22 | import numpy as np 23 | from dqn_agent import Agent 24 | from unityagents import UnityEnvironment 25 | 26 | """ 27 | ################################### 28 | STEP 1: Set the Test Parameters 29 | ====== 30 | num_episodes (int): number of test episodes 31 | """ 32 | num_episodes=10 33 | 34 | 35 | """ 36 | ################################### 37 | STEP 2: Start the Unity Environment 38 | # Use the corresponding call depending on your operating system 39 | """ 40 | env = UnityEnvironment(file_name="Banana.app") 41 | # - **Mac**: "Banana.app" 42 | # - **Windows** (x86): "Banana_Windows_x86/Banana.exe" 43 | # - **Windows** (x86_64): "Banana_Windows_x86_64/Banana.exe" 44 | # - **Linux** (x86): "Banana_Linux/Banana.x86" 45 | # - **Linux** (x86_64): "Banana_Linux/Banana.x86_64" 46 | # - **Linux** (x86, headless): "Banana_Linux_NoVis/Banana.x86" 47 | # - **Linux** (x86_64, headless): "Banana_Linux_NoVis/Banana.x86_64" 48 | 49 | """ 50 | ####################################### 51 | STEP 3: Get The Unity Environment Brian 52 | Unity ML-Agent applications or Environments contain "BRAINS" which are responsible for deciding 53 | the actions an agent or set of agents should take given a current set of environment (state) 54 | observations. The Banana environment has a single Brian, thus, we just need to access the first brain 55 | available (i.e., the default brain). We then set the default brain as the brain that will be controlled. 56 | """ 57 | # Get the default brain 58 | brain_name = env.brain_names[0] 59 | 60 | # Assign the default brain as the brain to be controlled 61 | brain = env.brains[brain_name] 62 | 63 | 64 | """ 65 | ############################################# 66 | STEP 4: Determine the size of the Action and State Spaces 67 | # 68 | # The simulation contains a single agent that navigates a large environment. 69 | # At each time step, it can perform four possible actions: 70 | # - `0` - walk forward 71 | # - `1` - walk backward 72 | # - `2` - turn left 73 | # - `3` - turn right 74 | # 75 | # The state space has `37` dimensions and contains the agent's velocity, 76 | # along with ray-based perception of objects around agent's forward direction. 77 | # A reward of `+1` is provided for collecting a yellow banana, and a reward of 78 | # `-1` is provided for collecting a purple banana. 79 | """ 80 | 81 | # Set the number of actions or action size 82 | action_size = brain.vector_action_space_size 83 | 84 | # Set the size of state observations or state size 85 | state_size = brain.vector_observation_space_size 86 | 87 | 88 | """ 89 | ################################### 90 | STEP 5: Initialize a DQN Agent from the Agent Class in dqn_agent.py 91 | A DQN agent initialized with the following state, action and DQN hyperparameters. 92 | DQN Agent Parameters 93 | ====== 94 | state_size (int): dimension of each state (required) 95 | action_size (int): dimension of each action (required) 96 | 97 | The DQN agent specifies a local and target neural network for training. 98 | The network is defined in model.py. The input is a real (float) value vector of observations. 99 | (NOTE: not appropriate for pixel data). It is a dense, fully connected neural network, 100 | with 2 x 128 node hidden layers. The network can be modified by changing model.py. 101 | 102 | Here we initialize an agent using the Unity environments state and action size determined above 103 | We also load the model parameters from training 104 | """ 105 | #Initialize Agent 106 | agent = Agent(state_size=state_size, action_size=action_size, seed=0) 107 | 108 | # Load trained model weights 109 | agent.network.load_state_dict(torch.load('dqnAgent_Trained_Model.pth')) 110 | 111 | """ 112 | ################################### 113 | STEP 6: Play Banana for specified number of Episodes 114 | """ 115 | # loop from num_episodes 116 | for i_episode in range(1, num_episodes+1): 117 | 118 | # reset the unity environment at the beginning of each episode 119 | # set train mode to false 120 | env_info = env.reset(train_mode=False)[brain_name] 121 | 122 | # get initial state of the unity environment 123 | state = env_info.vector_observations[0] 124 | 125 | # set the initial episode score to zero. 126 | score = 0 127 | 128 | # Run the episode loop; 129 | # At each loop step take an action as a function of the current state observations 130 | # If environment episode is done, exit loop... 131 | # Otherwise repeat until done == true 132 | while True: 133 | # determine epsilon-greedy action from current sate 134 | action = agent.act(state) 135 | 136 | # send the action to the environment and receive resultant environment information 137 | env_info = env.step(action)[brain_name] 138 | 139 | next_state = env_info.vector_observations[0] # get the next state 140 | reward = env_info.rewards[0] # get the reward 141 | done = env_info.local_done[0] # see if episode has finished 142 | 143 | # set new state to current state for determining next action 144 | state = next_state 145 | 146 | # Update episode score 147 | score += reward 148 | 149 | # If unity indicates that episode is done, 150 | # then exit episode loop, to begin new episode 151 | if done: 152 | break 153 | 154 | # (Over-) Print current average score 155 | print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, score), end="") 156 | 157 | 158 | """ 159 | ################################### 160 | STEP 7: Everything is Finished -> Close the Environment. 161 | """ 162 | env.close() 163 | 164 | # END :) ############# 165 | 166 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | DQN for Unity ML-Agents Environments using PyTorch 4 | Includes examples of the following DQN training algorithms: 5 | -> Vanilla DNQ, 6 | -> Double-DQN (DDQN) 7 | 8 | The example uses a modified version of the Unity ML-Agents Banana Collection Example Environment. 9 | The environment includes a single agent, who can turn left or right and move forward or backward. 10 | The agent's task is to collect yellow bananas (reward of +1) that are scattered around an square 11 | game area, while avoiding purple bananas (reward of -1). For the version of Bananas employed here, 12 | the environment is considered solved when the average score over the last 100 episodes > 13. 13 | 14 | Example Developed By: 15 | Michael Richardson, 2018 16 | Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL) 17 | Code Expanded and Adapted from Code provided by Udacity DRL Team, 2018. 18 | """ 19 | 20 | ################################### 21 | # Import Required Packages 22 | import torch 23 | import time 24 | import random 25 | import numpy as np 26 | from collections import deque 27 | from dqn_agent import Agent 28 | from unityagents import UnityEnvironment 29 | 30 | """ 31 | ################################### 32 | STEP 1: Set the Training Parameters 33 | ====== 34 | num_episodes (int): maximum number of training episodes 35 | epsilon (float): starting value of epsilon, for epsilon-greedy action selection 36 | epsilon_min (float): minimum value of epsilon 37 | epsilon_decay (float): multiplicative factor (per episode) for decreasing epsilon 38 | scores (float): list to record the scores obtained from each episode 39 | scores_average_window (int): the window size employed for calculating the average score (e.g. 100) 40 | solved_score (float): the average score required for the environment to be considered solved 41 | (here we set the solved_score a little higher than 13 [i.e., 14] to ensure robust learning). 42 | """ 43 | num_episodes=2000 44 | epsilon=1.0 45 | epsilon_min=0.05 46 | epsilon_decay=0.99 47 | scores = [] 48 | scores_average_window = 100 49 | solved_score = 14 50 | 51 | 52 | """ 53 | ################################### 54 | STEP 2: Start the Unity Environment 55 | # Use the corresponding call depending on your operating system 56 | """ 57 | env = UnityEnvironment(file_name="Banana.app") 58 | # - **Mac**: "Banana.app" 59 | # - **Windows** (x86): "Banana_Windows_x86/Banana.exe" 60 | # - **Windows** (x86_64): "Banana_Windows_x86_64/Banana.exe" 61 | # - **Linux** (x86): "Banana_Linux/Banana.x86" 62 | # - **Linux** (x86_64): "Banana_Linux/Banana.x86_64" 63 | # - **Linux** (x86, headless): "Banana_Linux_NoVis/Banana.x86" 64 | # - **Linux** (x86_64, headless): "Banana_Linux_NoVis/Banana.x86_64" 65 | 66 | """ 67 | ####################################### 68 | STEP 3: Get The Unity Environment Brian 69 | Unity ML-Agent applications or Environments contain "BRAINS" which are responsible for deciding 70 | the actions an agent or set of agents should take given a current set of environment (state) 71 | observations. The Banana environment has a single Brian, thus, we just need to access the first brain 72 | available (i.e., the default brain). We then set the default brain as the brain that will be controlled. 73 | """ 74 | # Get the default brain 75 | brain_name = env.brain_names[0] 76 | 77 | # Assign the default brain as the brain to be controlled 78 | brain = env.brains[brain_name] 79 | 80 | 81 | """ 82 | ############################################# 83 | STEP 4: Determine the size of the Action and State Spaces 84 | # 85 | # The simulation contains a single agent that navigates a large environment. 86 | # At each time step, it can perform four possible actions: 87 | # - `0` - walk forward 88 | # - `1` - walk backward 89 | # - `2` - turn left 90 | # - `3` - turn right 91 | # 92 | # The state space has `37` dimensions and contains the agent's velocity, 93 | # along with ray-based perception of objects around agent's forward direction. 94 | # A reward of `+1` is provided for collecting a yellow banana, and a reward of 95 | # `-1` is provided for collecting a purple banana. 96 | """ 97 | 98 | # Set the number of actions or action size 99 | action_size = brain.vector_action_space_size 100 | 101 | # Set the size of state observations or state size 102 | state_size = brain.vector_observation_space_size 103 | 104 | 105 | """ 106 | ################################### 107 | STEP 5: Create a DQN Agent from the Agent Class in dqn_agent.py 108 | A DQN agent initialized with the following state, action and DQN hyperparameters. 109 | DQN Agent Parameters 110 | ====== 111 | state_size (int): dimension of each state (required) 112 | action_size (int): dimension of each action (required) 113 | dqn_type (string): can be either 'DQN' for vanillia dqn learning (default) or 'DDQN' for double-DQN. 114 | replay_memory size (int): size of the replay memory buffer (default = 1e5) 115 | batch_size (int): size of the memory batch used for model updates (default = 64) 116 | gamma (float): parameter for setting the discounted value of future rewards (default = 0.99) 117 | learning_rate (float): specifies the rate of model learning (default = 5e-4) 118 | seed (int): random seed for initializing training point (default = 0) 119 | 120 | The DQN agent specifies a local and target neural network for training. 121 | The network is defined in model.py. The input is a real (float) value vector of observations. 122 | (NOTE: not appropriate for pixel data). It is a dense, fully connected neural network, 123 | with 2 x 128 node hidden layers. The network can be modified by changing model.py. 124 | 125 | Here we initialize an agent using the Unity environments state and action size determined above 126 | and the default DQN hyperparameter settings. 127 | """ 128 | agent = Agent(state_size=state_size, action_size=action_size, dqn_type='DQN') 129 | 130 | 131 | """ 132 | ################################### 133 | STEP 6: Run the DQN Training Sequence 134 | The DQN RL Training Process involves the agent learning from repeated episodes of behaviour 135 | to map states to actions the maximize rewards received via environmental interaction. 136 | The artificial neural network is expected to converge on or approximate the optimal function 137 | that maps states to actions. 138 | 139 | The agent training process involves the following: 140 | (1) Reset the environment at the beginning of each episode. 141 | (2) Obtain (observe) current state, s, of the environment at time t 142 | (3) Use an epsilon-greedy policy to perform an action, a(t), in the environment 143 | given s(t), where the greedy action policy is specified by the neural network. 144 | (4) Observe the result of the action in terms of the reward received and 145 | the state of the environment at time t+1 (i.e., s(t+1)) 146 | (5) Calculate the error between the actual and expected Q value for s(t),a(t),r(t) and s(t+1) 147 | to update the neural network weights. 148 | (6) Update episode score (total reward received) and set s(t) -> s(t+1). 149 | (7) If episode is done, break and repeat from (1), otherwise repeat from (3). 150 | 151 | Below we also exit the training process early if the environment is solved. 152 | That is, if the average score for the previous 100 episodes is greater than solved_score. 153 | """ 154 | 155 | # loop from num_episodes 156 | for i_episode in range(1, num_episodes+1): 157 | 158 | # reset the unity environment at the beginning of each episode 159 | env_info = env.reset(train_mode=True)[brain_name] 160 | 161 | # get initial state of the unity environment 162 | state = env_info.vector_observations[0] 163 | 164 | # set the initial episode score to zero. 165 | score = 0 166 | 167 | # Run the episode training loop; 168 | # At each loop step take an epsilon-greedy action as a function of the current state observations 169 | # Based on the resultant environmental state (next_state) and reward received update the Agent network 170 | # If environment episode is done, exit loop... 171 | # Otherwise repeat until done == true 172 | while True: 173 | # determine epsilon-greedy action from current sate 174 | action = agent.act(state, epsilon) 175 | 176 | # send the action to the environment and receive resultant environment information 177 | env_info = env.step(action)[brain_name] 178 | 179 | next_state = env_info.vector_observations[0] # get the next state 180 | reward = env_info.rewards[0] # get the reward 181 | done = env_info.local_done[0] # see if episode has finished 182 | 183 | #Send (S, A, R, S') info to the DQN agent for a neural network update 184 | agent.step(state, action, reward, next_state, done) 185 | 186 | # set new state to current state for determining next action 187 | state = next_state 188 | 189 | # Update episode score 190 | score += reward 191 | 192 | # If unity indicates that episode is done, 193 | # then exit episode loop, to begin new episode 194 | if done: 195 | break 196 | 197 | # Add episode score to Scores and... 198 | # Calculate mean score over last 100 episodes 199 | # Mean score is calculated over current episodes until i_episode > 100 200 | scores.append(score) 201 | average_score = np.mean(scores[i_episode-min(i_episode,scores_average_window):i_episode+1]) 202 | 203 | # Decrease epsilon for epsilon-greedy policy by decay rate 204 | # Use max method to make sure epsilon doesn't decrease below epsilon_min 205 | epsilon = max(epsilon_min, epsilon_decay*epsilon) 206 | 207 | # (Over-) Print current average score 208 | print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, average_score), end="") 209 | 210 | # Print average score every scores_average_window episodes 211 | if i_episode % scores_average_window == 0: 212 | print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, average_score)) 213 | 214 | # Check to see if the task is solved (i.e,. avearge_score > solved_score). 215 | # If yes, save the network weights and scores and end training. 216 | if average_score >= solved_score: 217 | print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode, average_score)) 218 | 219 | # Save trained neural network weights 220 | timestr = time.strftime("%Y%m%d-%H%M%S") 221 | nn_filename = "dqnAgent_Trained_Model_" + timestr + ".pth" 222 | torch.save(agent.network.state_dict(), nn_filename) 223 | 224 | # Save the recorded Scores data 225 | scores_filename = "dqnAgent_scores_" + timestr + ".csv" 226 | np.savetxt(scores_filename, scores, delimiter=",") 227 | break 228 | 229 | 230 | """ 231 | ################################### 232 | STEP 7: Everything is Finished -> Close the Environment. 233 | """ 234 | env.close() 235 | 236 | # END :) ############# 237 | 238 | --------------------------------------------------------------------------------