├── .github └── workflows │ └── manual.yml ├── .gitignore ├── CODEOWNERS ├── LICENSE ├── README.md ├── cheatsheet ├── LICENSE.txt ├── README.md ├── cheatsheet.pdf ├── cheatsheet.tex └── udacity-logo.png ├── cross-entropy ├── CEM.ipynb ├── README.md └── checkpoint.pth ├── ddpg-bipedal ├── DDPG.ipynb ├── README.md ├── ddpg_agent.py └── model.py ├── ddpg-pendulum ├── DDPG.ipynb ├── README.md ├── checkpoint_actor.pth ├── checkpoint_critic.pth ├── ddpg_agent.py └── model.py ├── discretization ├── Discretization.ipynb ├── Discretization_Solution.ipynb └── README.md ├── dqn ├── README.md ├── exercise │ ├── Deep_Q_Network.ipynb │ ├── dqn_agent.py │ └── model.py └── solution │ ├── Deep_Q_Network_Solution.ipynb │ ├── checkpoint.pth │ ├── dqn_agent.py │ └── model.py ├── dynamic-programming ├── Dynamic_Programming.ipynb ├── Dynamic_Programming_Solution.ipynb ├── README.md ├── check_test.py ├── frozenlake.py └── plot_utils.py ├── finance ├── DRL.ipynb ├── ddpg_agent.py ├── model.py ├── syntheticChrissAlmgren.py ├── text_images │ ├── 4.jpeg │ ├── Actor-Critic.png │ ├── RL.png │ ├── nvidia.png │ └── udacity.png └── utils.py ├── hill-climbing ├── Hill_Climbing.ipynb └── README.md ├── lab-taxi ├── README.md ├── agent.py ├── main.py └── monitor.py ├── monte-carlo ├── Monte_Carlo.ipynb ├── Monte_Carlo_Solution.ipynb ├── README.md ├── images │ └── optimal.png └── plot_utils.py ├── p1_navigation ├── Navigation.ipynb ├── Navigation_Pixels.ipynb └── README.md ├── p2_continuous-control ├── Continuous_Control.ipynb ├── Crawler.ipynb └── README.md ├── p3_collab-compet ├── README.md ├── Soccer.ipynb └── Tennis.ipynb ├── python ├── Basics.ipynb ├── README.md ├── communicator_objects │ ├── __init__.py │ ├── agent_action_proto_pb2.py │ ├── agent_info_proto_pb2.py │ ├── brain_parameters_proto_pb2.py │ ├── brain_type_proto_pb2.py │ ├── command_proto_pb2.py │ ├── engine_configuration_proto_pb2.py │ ├── environment_parameters_proto_pb2.py │ ├── header_pb2.py │ ├── resolution_proto_pb2.py │ ├── space_type_proto_pb2.py │ ├── unity_input_pb2.py │ ├── unity_message_pb2.py │ ├── unity_output_pb2.py │ ├── unity_rl_initialization_input_pb2.py │ ├── unity_rl_initialization_output_pb2.py │ ├── unity_rl_input_pb2.py │ ├── unity_rl_output_pb2.py │ ├── unity_to_external_pb2.py │ └── unity_to_external_pb2_grpc.py ├── curricula │ ├── push.json │ ├── test.json │ └── wall.json ├── learn.py ├── requirements.txt ├── setup.py ├── tests │ ├── __init__.py │ ├── mock_communicator.py │ ├── test_bc.py │ ├── test_ppo.py │ ├── test_unityagents.py │ └── test_unitytrainers.py ├── trainer_config.yaml ├── unityagents │ ├── __init__.py │ ├── brain.py │ ├── communicator.py │ ├── curriculum.py │ ├── environment.py │ ├── exception.py │ ├── rpc_communicator.py │ └── socket_communicator.py └── unitytrainers │ ├── __init__.py │ ├── bc │ ├── __init__.py │ ├── models.py │ └── trainer.py │ ├── buffer.py │ ├── models.py │ ├── ppo │ ├── __init__.py │ ├── models.py │ └── trainer.py │ ├── trainer.py │ └── trainer_controller.py ├── reinforce ├── README.md └── REINFORCE.ipynb ├── temporal-difference ├── README.md ├── Temporal_Difference.ipynb ├── Temporal_Difference_Solution.ipynb ├── check_test.py └── plot_utils.py └── tile-coding ├── README.md ├── Tile_Coding.ipynb └── Tile_Coding_Solution.ipynb /.github/workflows/manual.yml: -------------------------------------------------------------------------------- 1 | # Workflow to ensure whenever a Github PR is submitted, 2 | # a JIRA ticket gets created automatically. 3 | name: Manual Workflow 4 | 5 | # Controls when the action will run. 6 | on: 7 | # Triggers the workflow on pull request events but only for the master branch 8 | pull_request_target: 9 | types: [opened, reopened] 10 | 11 | # Allows you to run this workflow manually from the Actions tab 12 | workflow_dispatch: 13 | 14 | jobs: 15 | test-transition-issue: 16 | name: Convert Github Issue to Jira Issue 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@master 21 | 22 | - name: Login 23 | uses: atlassian/gajira-login@master 24 | env: 25 | JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} 26 | JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} 27 | JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} 28 | 29 | - name: Create NEW JIRA ticket 30 | id: create 31 | uses: atlassian/gajira-create@master 32 | with: 33 | project: CONUPDATE 34 | issuetype: Task 35 | summary: | 36 | Github PR [Assign the ND component] | Repo: ${{ github.repository }} | PR# ${{github.event.number}} 37 | description: | 38 | Repo link: https://github.com/${{ github.repository }} 39 | PR no. ${{ github.event.pull_request.number }} 40 | PR title: ${{ github.event.pull_request.title }} 41 | PR description: ${{ github.event.pull_request.description }} 42 | In addition, please resolve other issues, if any. 43 | fields: '{"components": [{"name":"Github PR"}], "customfield_16449":"https://classroom.udacity.com/", "customfield_16450":"Resolve the PR", "labels": ["github"], "priority":{"id": "4"}}' 44 | 45 | - name: Log created issue 46 | run: echo "Issue ${{ steps.create.outputs.issue }} was created" 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac OS 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | .venv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @udacity/active-public-content -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Udacity 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cheatsheet/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Udacity, Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /cheatsheet/README.md: -------------------------------------------------------------------------------- 1 | # Cheatsheet 2 | 3 | You are encouraged to use the [PDF file](https://github.com/udacity/deep-reinforcement-learning/tree/master/cheatsheet/cheatsheet.pdf) in the repository to guide your study of RL. -------------------------------------------------------------------------------- /cheatsheet/cheatsheet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/cheatsheet/cheatsheet.pdf -------------------------------------------------------------------------------- /cheatsheet/udacity-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/cheatsheet/udacity-logo.png -------------------------------------------------------------------------------- /cross-entropy/README.md: -------------------------------------------------------------------------------- 1 | [//]: # (Image References) 2 | 3 | [image1]: https://user-images.githubusercontent.com/10624937/42135605-ba0e5f2c-7d12-11e8-9578-86d74e0976f8.gif "Trained Agent" 4 | 5 | # Cross-Entropy Method 6 | 7 | ### Instructions 8 | 9 | Open `CEM.ipynb` to see an implementation of the cross-entropy method with OpenAI Gym's MountainCarContinuous environment. 10 | 11 | Try to change the parameters in the notebook, to see if you can get the agent to train faster! 12 | 13 | ### Results 14 | 15 | ![Trained Agent][image1] -------------------------------------------------------------------------------- /cross-entropy/checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/cross-entropy/checkpoint.pth -------------------------------------------------------------------------------- /ddpg-bipedal/README.md: -------------------------------------------------------------------------------- 1 | [//]: # (Image References) 2 | 3 | [image1]: https://user-images.githubusercontent.com/10624937/42135608-be87357e-7d12-11e8-8eca-e6d5fabdba6b.gif "Trained Agent" 4 | 5 | 6 | # Actor-Critic Methods 7 | 8 | ### Instructions 9 | 10 | Open `DDPG.ipynb` to see an implementation of DDPG with OpenAI Gym's BipedalWalker environment. 11 | 12 | ### Results 13 | 14 | ![Trained Agent][image1] 15 | -------------------------------------------------------------------------------- /ddpg-bipedal/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def hidden_init(layer): 8 | fan_in = layer.weight.data.size()[0] 9 | lim = 1. / np.sqrt(fan_in) 10 | return (-lim, lim) 11 | 12 | class Actor(nn.Module): 13 | """Actor (Policy) Model.""" 14 | 15 | def __init__(self, state_size, action_size, seed, fc_units=256): 16 | """Initialize parameters and build model. 17 | Params 18 | ====== 19 | state_size (int): Dimension of each state 20 | action_size (int): Dimension of each action 21 | seed (int): Random seed 22 | fc1_units (int): Number of nodes in first hidden layer 23 | fc2_units (int): Number of nodes in second hidden layer 24 | """ 25 | super(Actor, self).__init__() 26 | self.seed = torch.manual_seed(seed) 27 | self.fc1 = nn.Linear(state_size, fc_units) 28 | self.fc2 = nn.Linear(fc_units, action_size) 29 | self.reset_parameters() 30 | 31 | def reset_parameters(self): 32 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) 33 | self.fc2.weight.data.uniform_(-3e-3, 3e-3) 34 | 35 | def forward(self, state): 36 | """Build an actor (policy) network that maps states -> actions.""" 37 | x = F.relu(self.fc1(state)) 38 | return F.tanh(self.fc2(x)) 39 | 40 | 41 | class Critic(nn.Module): 42 | """Critic (Value) Model.""" 43 | 44 | def __init__(self, state_size, action_size, seed, fcs1_units=256, fc2_units=256, fc3_units=128): 45 | """Initialize parameters and build model. 46 | Params 47 | ====== 48 | state_size (int): Dimension of each state 49 | action_size (int): Dimension of each action 50 | seed (int): Random seed 51 | fcs1_units (int): Number of nodes in the first hidden layer 52 | fc2_units (int): Number of nodes in the second hidden layer 53 | """ 54 | super(Critic, self).__init__() 55 | self.seed = torch.manual_seed(seed) 56 | self.fcs1 = nn.Linear(state_size, fcs1_units) 57 | self.fc2 = nn.Linear(fcs1_units+action_size, fc2_units) 58 | self.fc3 = nn.Linear(fc2_units, fc3_units) 59 | self.fc4 = nn.Linear(fc3_units, 1) 60 | self.reset_parameters() 61 | 62 | def reset_parameters(self): 63 | self.fcs1.weight.data.uniform_(*hidden_init(self.fcs1)) 64 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 65 | self.fc3.weight.data.uniform_(*hidden_init(self.fc3)) 66 | self.fc4.weight.data.uniform_(-3e-3, 3e-3) 67 | 68 | def forward(self, state, action): 69 | """Build a critic (value) network that maps (state, action) pairs -> Q-values.""" 70 | xs = F.leaky_relu(self.fcs1(state)) 71 | x = torch.cat((xs, action), dim=1) 72 | x = F.leaky_relu(self.fc2(x)) 73 | x = F.leaky_relu(self.fc3(x)) 74 | return self.fc4(x) 75 | -------------------------------------------------------------------------------- /ddpg-pendulum/README.md: -------------------------------------------------------------------------------- 1 | [//]: # (Image References) 2 | 3 | [image1]: https://user-images.githubusercontent.com/10624937/42135610-c37e0292-7d12-11e8-8228-4d3585f8c026.gif "Trained Agent" 4 | 5 | # Actor-Critic Methods 6 | 7 | ### Instructions 8 | 9 | Open `DDPG.ipynb` to see an implementation of DDPG with OpenAI Gym's Pendulum environment. 10 | 11 | ### Results 12 | 13 | ![Trained Agent][image1] -------------------------------------------------------------------------------- /ddpg-pendulum/checkpoint_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/ddpg-pendulum/checkpoint_actor.pth -------------------------------------------------------------------------------- /ddpg-pendulum/checkpoint_critic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/ddpg-pendulum/checkpoint_critic.pth -------------------------------------------------------------------------------- /ddpg-pendulum/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def hidden_init(layer): 8 | fan_in = layer.weight.data.size()[0] 9 | lim = 1. / np.sqrt(fan_in) 10 | return (-lim, lim) 11 | 12 | class Actor(nn.Module): 13 | """Actor (Policy) Model.""" 14 | 15 | def __init__(self, state_size, action_size, seed, fc1_units=400, fc2_units=300): 16 | """Initialize parameters and build model. 17 | Params 18 | ====== 19 | state_size (int): Dimension of each state 20 | action_size (int): Dimension of each action 21 | seed (int): Random seed 22 | fc1_units (int): Number of nodes in first hidden layer 23 | fc2_units (int): Number of nodes in second hidden layer 24 | """ 25 | super(Actor, self).__init__() 26 | self.seed = torch.manual_seed(seed) 27 | self.fc1 = nn.Linear(state_size, fc1_units) 28 | self.fc2 = nn.Linear(fc1_units, fc2_units) 29 | self.fc3 = nn.Linear(fc2_units, action_size) 30 | self.reset_parameters() 31 | 32 | def reset_parameters(self): 33 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) 34 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 35 | self.fc3.weight.data.uniform_(-3e-3, 3e-3) 36 | 37 | def forward(self, state): 38 | """Build an actor (policy) network that maps states -> actions.""" 39 | x = F.relu(self.fc1(state)) 40 | x = F.relu(self.fc2(x)) 41 | return F.tanh(self.fc3(x)) 42 | 43 | 44 | class Critic(nn.Module): 45 | """Critic (Value) Model.""" 46 | 47 | def __init__(self, state_size, action_size, seed, fcs1_units=400, fc2_units=300): 48 | """Initialize parameters and build model. 49 | Params 50 | ====== 51 | state_size (int): Dimension of each state 52 | action_size (int): Dimension of each action 53 | seed (int): Random seed 54 | fcs1_units (int): Number of nodes in the first hidden layer 55 | fc2_units (int): Number of nodes in the second hidden layer 56 | """ 57 | super(Critic, self).__init__() 58 | self.seed = torch.manual_seed(seed) 59 | self.fcs1 = nn.Linear(state_size, fcs1_units) 60 | self.fc2 = nn.Linear(fcs1_units+action_size, fc2_units) 61 | self.fc3 = nn.Linear(fc2_units, 1) 62 | self.reset_parameters() 63 | 64 | def reset_parameters(self): 65 | self.fcs1.weight.data.uniform_(*hidden_init(self.fcs1)) 66 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 67 | self.fc3.weight.data.uniform_(-3e-3, 3e-3) 68 | 69 | def forward(self, state, action): 70 | """Build a critic (value) network that maps (state, action) pairs -> Q-values.""" 71 | xs = F.relu(self.fcs1(state)) 72 | x = torch.cat((xs, action), dim=1) 73 | x = F.relu(self.fc2(x)) 74 | return self.fc3(x) 75 | -------------------------------------------------------------------------------- /discretization/README.md: -------------------------------------------------------------------------------- 1 | [//]: # (Image References) 2 | 3 | [image1]: https://user-images.githubusercontent.com/10624937/42135605-ba0e5f2c-7d12-11e8-9578-86d74e0976f8.gif "Trained Agent" 4 | 5 | # Discretization 6 | 7 | ### Instructions 8 | 9 | Follow the instructions in `Discretization.ipynb` to learn how to discretize continuous state spaces, to use tabular solution methods to solve complex tasks. The corresponding solutions can be found in `Discretization_Solution.ipynb`. 10 | 11 | ### Results 12 | 13 | ![Trained Agent][image1] 14 | 15 | ### Resources 16 | 17 | To learn about more advanced discretization approaches, refer to the following: 18 | 19 | - Uther, W., and Veloso, M., 1998. [Tree Based Discretization for Continuous State Space Reinforcement Learning](http://www.cs.cmu.edu/~mmv/papers/will-aaai98.pdf). In _Proceedings of AAAI, 1998_, pp. 769-774. 20 | - Munos, R. and Moore, A., 2002. [Variable Resolution Discretization in Optimal Control](https://link.springer.com/content/pdf/10.1023%2FA%3A1017992615625.pdf). In _Machine Learning_, 49(2), pp. 291-323. 21 | -------------------------------------------------------------------------------- /dqn/README.md: -------------------------------------------------------------------------------- 1 | [//]: # (Image References) 2 | 3 | [image1]: https://user-images.githubusercontent.com/10624937/42135612-cbff24aa-7d12-11e8-9b6c-2b41e64b3bb0.gif "Trained Agent" 4 | 5 | # Deep Q-Network (DQN) 6 | 7 | ### Instructions 8 | 9 | In this exercise, you will implement Deep Q-Learning to solve OpenAI Gym's LunarLander environment. To begin, navigate to the `exercise/` folder, and follow the instructions in `Deep_Q_Network.ipynb`. 10 | 11 | (_Alternatively, if you'd prefer to explore a complete implementation, enter the `solution/` folder, and run the code in `Deep_Q_Network_Solution.ipynb`._) 12 | 13 | After you are able to get the code working, try to change the parameters in the notebook, to see if you can get the agent to train faster! You may also like to implement prioritized experience replay, or use it as a starting point to implement a Double DQN or Dueling DQN! 14 | 15 | ### Results 16 | 17 | ![Trained Agent][image1] 18 | 19 | ### Resources 20 | 21 | - [Human-Level Control through Deep Reinforcement Learning](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) 22 | - [Deep Reinforcement Learning with Double Q-Learning](https://arxiv.org/abs/1509.06461) 23 | - [Dueling Network Architectures for Deep Reinforcement Learning](https://arxiv.org/abs/1511.06581) 24 | - [Prioritized Experience Replay](https://arxiv.org/abs/1511.05952) 25 | -------------------------------------------------------------------------------- /dqn/exercise/dqn_agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from collections import namedtuple, deque 4 | 5 | from model import QNetwork 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | 11 | BUFFER_SIZE = int(1e5) # replay buffer size 12 | BATCH_SIZE = 64 # minibatch size 13 | GAMMA = 0.99 # discount factor 14 | TAU = 1e-3 # for soft update of target parameters 15 | LR = 5e-4 # learning rate 16 | UPDATE_EVERY = 4 # how often to update the network 17 | 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | 20 | class Agent(): 21 | """Interacts with and learns from the environment.""" 22 | 23 | def __init__(self, state_size, action_size, seed): 24 | """Initialize an Agent object. 25 | 26 | Params 27 | ====== 28 | state_size (int): dimension of each state 29 | action_size (int): dimension of each action 30 | seed (int): random seed 31 | """ 32 | self.state_size = state_size 33 | self.action_size = action_size 34 | self.seed = random.seed(seed) 35 | 36 | # Q-Network 37 | self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) 38 | self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device) 39 | self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR) 40 | 41 | # Replay memory 42 | self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) 43 | # Initialize time step (for updating every UPDATE_EVERY steps) 44 | self.t_step = 0 45 | 46 | def step(self, state, action, reward, next_state, done): 47 | # Save experience in replay memory 48 | self.memory.add(state, action, reward, next_state, done) 49 | 50 | # Learn every UPDATE_EVERY time steps. 51 | self.t_step = (self.t_step + 1) % UPDATE_EVERY 52 | if self.t_step == 0: 53 | # If enough samples are available in memory, get random subset and learn 54 | if len(self.memory) > BATCH_SIZE: 55 | experiences = self.memory.sample() 56 | self.learn(experiences, GAMMA) 57 | 58 | def act(self, state, eps=0.): 59 | """Returns actions for given state as per current policy. 60 | 61 | Params 62 | ====== 63 | state (array_like): current state 64 | eps (float): epsilon, for epsilon-greedy action selection 65 | """ 66 | state = torch.from_numpy(state).float().unsqueeze(0).to(device) 67 | self.qnetwork_local.eval() 68 | with torch.no_grad(): 69 | action_values = self.qnetwork_local(state) 70 | self.qnetwork_local.train() 71 | 72 | # Epsilon-greedy action selection 73 | if random.random() > eps: 74 | return np.argmax(action_values.cpu().data.numpy()) 75 | else: 76 | return random.choice(np.arange(self.action_size)) 77 | 78 | def learn(self, experiences, gamma): 79 | """Update value parameters using given batch of experience tuples. 80 | 81 | Params 82 | ====== 83 | experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 84 | gamma (float): discount factor 85 | """ 86 | states, actions, rewards, next_states, dones = experiences 87 | 88 | ## TODO: compute and minimize the loss 89 | "*** YOUR CODE HERE ***" 90 | 91 | # ------------------- update target network ------------------- # 92 | self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) 93 | 94 | def soft_update(self, local_model, target_model, tau): 95 | """Soft update model parameters. 96 | θ_target = τ*θ_local + (1 - τ)*θ_target 97 | 98 | Params 99 | ====== 100 | local_model (PyTorch model): weights will be copied from 101 | target_model (PyTorch model): weights will be copied to 102 | tau (float): interpolation parameter 103 | """ 104 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): 105 | target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data) 106 | 107 | 108 | class ReplayBuffer: 109 | """Fixed-size buffer to store experience tuples.""" 110 | 111 | def __init__(self, action_size, buffer_size, batch_size, seed): 112 | """Initialize a ReplayBuffer object. 113 | 114 | Params 115 | ====== 116 | action_size (int): dimension of each action 117 | buffer_size (int): maximum size of buffer 118 | batch_size (int): size of each training batch 119 | seed (int): random seed 120 | """ 121 | self.action_size = action_size 122 | self.memory = deque(maxlen=buffer_size) 123 | self.batch_size = batch_size 124 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 125 | self.seed = random.seed(seed) 126 | 127 | def add(self, state, action, reward, next_state, done): 128 | """Add a new experience to memory.""" 129 | e = self.experience(state, action, reward, next_state, done) 130 | self.memory.append(e) 131 | 132 | def sample(self): 133 | """Randomly sample a batch of experiences from memory.""" 134 | experiences = random.sample(self.memory, k=self.batch_size) 135 | 136 | states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) 137 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device) 138 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) 139 | next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) 140 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) 141 | 142 | return (states, actions, rewards, next_states, dones) 143 | 144 | def __len__(self): 145 | """Return the current size of internal memory.""" 146 | return len(self.memory) -------------------------------------------------------------------------------- /dqn/exercise/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class QNetwork(nn.Module): 6 | """Actor (Policy) Model.""" 7 | 8 | def __init__(self, state_size, action_size, seed): 9 | """Initialize parameters and build model. 10 | Params 11 | ====== 12 | state_size (int): Dimension of each state 13 | action_size (int): Dimension of each action 14 | seed (int): Random seed 15 | """ 16 | super(QNetwork, self).__init__() 17 | self.seed = torch.manual_seed(seed) 18 | "*** YOUR CODE HERE ***" 19 | 20 | def forward(self, state): 21 | """Build a network that maps state -> action values.""" 22 | pass 23 | -------------------------------------------------------------------------------- /dqn/solution/checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/dqn/solution/checkpoint.pth -------------------------------------------------------------------------------- /dqn/solution/dqn_agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from collections import namedtuple, deque 4 | 5 | from model import QNetwork 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | 11 | BUFFER_SIZE = int(1e5) # replay buffer size 12 | BATCH_SIZE = 64 # minibatch size 13 | GAMMA = 0.99 # discount factor 14 | TAU = 1e-3 # for soft update of target parameters 15 | LR = 5e-4 # learning rate 16 | UPDATE_EVERY = 4 # how often to update the network 17 | 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | 20 | class Agent(): 21 | """Interacts with and learns from the environment.""" 22 | 23 | def __init__(self, state_size, action_size, seed): 24 | """Initialize an Agent object. 25 | 26 | Params 27 | ====== 28 | state_size (int): dimension of each state 29 | action_size (int): dimension of each action 30 | seed (int): random seed 31 | """ 32 | self.state_size = state_size 33 | self.action_size = action_size 34 | self.seed = random.seed(seed) 35 | 36 | # Q-Network 37 | self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) 38 | self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device) 39 | self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR) 40 | 41 | # Replay memory 42 | self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) 43 | # Initialize time step (for updating every UPDATE_EVERY steps) 44 | self.t_step = 0 45 | 46 | def step(self, state, action, reward, next_state, done): 47 | # Save experience in replay memory 48 | self.memory.add(state, action, reward, next_state, done) 49 | 50 | # Learn every UPDATE_EVERY time steps. 51 | self.t_step = (self.t_step + 1) % UPDATE_EVERY 52 | if self.t_step == 0: 53 | # If enough samples are available in memory, get random subset and learn 54 | if len(self.memory) > BATCH_SIZE: 55 | experiences = self.memory.sample() 56 | self.learn(experiences, GAMMA) 57 | 58 | def act(self, state, eps=0.): 59 | """Returns actions for given state as per current policy. 60 | 61 | Params 62 | ====== 63 | state (array_like): current state 64 | eps (float): epsilon, for epsilon-greedy action selection 65 | """ 66 | state = torch.from_numpy(state).float().unsqueeze(0).to(device) 67 | self.qnetwork_local.eval() 68 | with torch.no_grad(): 69 | action_values = self.qnetwork_local(state) 70 | self.qnetwork_local.train() 71 | 72 | # Epsilon-greedy action selection 73 | if random.random() > eps: 74 | return np.argmax(action_values.cpu().data.numpy()) 75 | else: 76 | return random.choice(np.arange(self.action_size)) 77 | 78 | def learn(self, experiences, gamma): 79 | """Update value parameters using given batch of experience tuples. 80 | 81 | Params 82 | ====== 83 | experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 84 | gamma (float): discount factor 85 | """ 86 | states, actions, rewards, next_states, dones = experiences 87 | 88 | # Get max predicted Q values (for next states) from target model 89 | Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1) 90 | # Compute Q targets for current states 91 | Q_targets = rewards + (gamma * Q_targets_next * (1 - dones)) 92 | 93 | # Get expected Q values from local model 94 | Q_expected = self.qnetwork_local(states).gather(1, actions) 95 | 96 | # Compute loss 97 | loss = F.mse_loss(Q_expected, Q_targets) 98 | # Minimize the loss 99 | self.optimizer.zero_grad() 100 | loss.backward() 101 | self.optimizer.step() 102 | 103 | # ------------------- update target network ------------------- # 104 | self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) 105 | 106 | def soft_update(self, local_model, target_model, tau): 107 | """Soft update model parameters. 108 | θ_target = τ*θ_local + (1 - τ)*θ_target 109 | 110 | Params 111 | ====== 112 | local_model (PyTorch model): weights will be copied from 113 | target_model (PyTorch model): weights will be copied to 114 | tau (float): interpolation parameter 115 | """ 116 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): 117 | target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data) 118 | 119 | 120 | class ReplayBuffer: 121 | """Fixed-size buffer to store experience tuples.""" 122 | 123 | def __init__(self, action_size, buffer_size, batch_size, seed): 124 | """Initialize a ReplayBuffer object. 125 | 126 | Params 127 | ====== 128 | action_size (int): dimension of each action 129 | buffer_size (int): maximum size of buffer 130 | batch_size (int): size of each training batch 131 | seed (int): random seed 132 | """ 133 | self.action_size = action_size 134 | self.memory = deque(maxlen=buffer_size) 135 | self.batch_size = batch_size 136 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 137 | self.seed = random.seed(seed) 138 | 139 | def add(self, state, action, reward, next_state, done): 140 | """Add a new experience to memory.""" 141 | e = self.experience(state, action, reward, next_state, done) 142 | self.memory.append(e) 143 | 144 | def sample(self): 145 | """Randomly sample a batch of experiences from memory.""" 146 | experiences = random.sample(self.memory, k=self.batch_size) 147 | 148 | states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) 149 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device) 150 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) 151 | next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) 152 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) 153 | 154 | return (states, actions, rewards, next_states, dones) 155 | 156 | def __len__(self): 157 | """Return the current size of internal memory.""" 158 | return len(self.memory) -------------------------------------------------------------------------------- /dqn/solution/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class QNetwork(nn.Module): 6 | """Actor (Policy) Model.""" 7 | 8 | def __init__(self, state_size, action_size, seed, fc1_units=64, fc2_units=64): 9 | """Initialize parameters and build model. 10 | Params 11 | ====== 12 | state_size (int): Dimension of each state 13 | action_size (int): Dimension of each action 14 | seed (int): Random seed 15 | fc1_units (int): Number of nodes in first hidden layer 16 | fc2_units (int): Number of nodes in second hidden layer 17 | """ 18 | super(QNetwork, self).__init__() 19 | self.seed = torch.manual_seed(seed) 20 | self.fc1 = nn.Linear(state_size, fc1_units) 21 | self.fc2 = nn.Linear(fc1_units, fc2_units) 22 | self.fc3 = nn.Linear(fc2_units, action_size) 23 | 24 | def forward(self, state): 25 | """Build a network that maps state -> action values.""" 26 | x = F.relu(self.fc1(state)) 27 | x = F.relu(self.fc2(x)) 28 | return self.fc3(x) 29 | -------------------------------------------------------------------------------- /dynamic-programming/README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Programming 2 | 3 | ### Instructions 4 | 5 | Follow the instructions in `Dynamic_Programming.ipynb` to write your own implementations of many dynamic programming algorithms! The corresponding solutions can be found in `Dynamic_Programming_Solution.ipynb`. 6 | -------------------------------------------------------------------------------- /dynamic-programming/check_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import copy 3 | from IPython.display import Markdown, display 4 | import numpy as np 5 | from frozenlake import FrozenLakeEnv 6 | 7 | def printmd(string): 8 | display(Markdown(string)) 9 | 10 | def policy_evaluation_soln(env, policy, gamma=1, theta=1e-8): 11 | V = np.zeros(env.nS) 12 | while True: 13 | delta = 0 14 | for s in range(env.nS): 15 | Vs = 0 16 | for a, action_prob in enumerate(policy[s]): 17 | for prob, next_state, reward, done in env.P[s][a]: 18 | Vs += action_prob * prob * (reward + gamma * V[next_state]) 19 | delta = max(delta, np.abs(V[s]-Vs)) 20 | V[s] = Vs 21 | if delta < theta: 22 | break 23 | return V 24 | 25 | def q_from_v_soln(env, V, s, gamma=1): 26 | q = np.zeros(env.nA) 27 | for a in range(env.nA): 28 | for prob, next_state, reward, done in env.P[s][a]: 29 | q[a] += prob * (reward + gamma * V[next_state]) 30 | return q 31 | 32 | def policy_improvement_soln(env, V, gamma=1): 33 | policy = np.zeros([env.nS, env.nA]) / env.nA 34 | for s in range(env.nS): 35 | q = q_from_v_soln(env, V, s, gamma) 36 | best_a = np.argwhere(q==np.max(q)).flatten() 37 | policy[s] = np.sum([np.eye(env.nA)[i] for i in best_a], axis=0)/len(best_a) 38 | return policy 39 | 40 | def policy_iteration_soln(env, gamma=1, theta=1e-8): 41 | policy = np.ones([env.nS, env.nA]) / env.nA 42 | while True: 43 | V = policy_evaluation_soln(env, policy, gamma, theta) 44 | new_policy = policy_improvement_soln(env, V) 45 | if (new_policy == policy).all(): 46 | break; 47 | policy = copy.copy(new_policy) 48 | return policy, V 49 | 50 | env = FrozenLakeEnv() 51 | random_policy = np.ones([env.nS, env.nA]) / env.nA 52 | 53 | class Tests(unittest.TestCase): 54 | 55 | def policy_evaluation_check(self, policy_evaluation): 56 | soln = policy_evaluation_soln(env, random_policy) 57 | to_check = policy_evaluation(env, random_policy) 58 | np.testing.assert_array_almost_equal(soln, to_check) 59 | 60 | def q_from_v_check(self, q_from_v): 61 | V = policy_evaluation_soln(env, random_policy) 62 | soln = np.zeros([env.nS, env.nA]) 63 | to_check = np.zeros([env.nS, env.nA]) 64 | for s in range(env.nS): 65 | soln[s] = q_from_v_soln(env, V, s) 66 | to_check[s] = q_from_v(env, V, s) 67 | np.testing.assert_array_almost_equal(soln, to_check) 68 | 69 | def policy_improvement_check(self, policy_improvement): 70 | V = policy_evaluation_soln(env, random_policy) 71 | new_policy = policy_improvement(env, V) 72 | new_V = policy_evaluation_soln(env, new_policy) 73 | self.assertTrue(np.all(new_V >= V)) 74 | 75 | def policy_iteration_check(self, policy_iteration): 76 | policy_soln, _ = policy_iteration_soln(env) 77 | policy_to_check, _ = policy_iteration(env) 78 | soln = policy_evaluation_soln(env, policy_soln) 79 | to_check = policy_evaluation_soln(env, policy_to_check) 80 | np.testing.assert_array_almost_equal(soln, to_check) 81 | 82 | def truncated_policy_iteration_check(self, truncated_policy_iteration): 83 | self.policy_iteration_check(truncated_policy_iteration) 84 | 85 | def value_iteration_check(self, value_iteration): 86 | self.policy_iteration_check(value_iteration) 87 | 88 | check = Tests() 89 | 90 | def run_check(check_name, func): 91 | try: 92 | getattr(check, check_name)(func) 93 | except check.failureException as e: 94 | printmd('**PLEASE TRY AGAIN**') 95 | return 96 | printmd('**PASSED**') -------------------------------------------------------------------------------- /dynamic-programming/frozenlake.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | from six import StringIO, b 4 | 5 | from gym import utils 6 | from gym.envs.toy_text import discrete 7 | 8 | LEFT = 0 9 | DOWN = 1 10 | RIGHT = 2 11 | UP = 3 12 | 13 | MAPS = { 14 | "4x4": [ 15 | "SFFF", 16 | "FHFH", 17 | "FFFH", 18 | "HFFG" 19 | ], 20 | "8x8": [ 21 | "SFFFFFFF", 22 | "FFFFFFFF", 23 | "FFFHFFFF", 24 | "FFFFFHFF", 25 | "FFFHFFFF", 26 | "FHHFFFHF", 27 | "FHFFHFHF", 28 | "FFFHFFFG" 29 | ], 30 | } 31 | 32 | class FrozenLakeEnv(discrete.DiscreteEnv): 33 | """ 34 | Winter is here. You and your friends were tossing around a frisbee at the park 35 | when you made a wild throw that left the frisbee out in the middle of the lake. 36 | The water is mostly frozen, but there are a few holes where the ice has melted. 37 | If you step into one of those holes, you'll fall into the freezing water. 38 | At this time, there's an international frisbee shortage, so it's absolutely imperative that 39 | you navigate across the lake and retrieve the disc. 40 | However, the ice is slippery, so you won't always move in the direction you intend. 41 | The surface is described using a grid like the following 42 | 43 | SFFF 44 | FHFH 45 | FFFH 46 | HFFG 47 | 48 | S : starting point, safe 49 | F : frozen surface, safe 50 | H : hole, fall to your doom 51 | G : goal, where the frisbee is located 52 | 53 | The episode ends when you reach the goal or fall in a hole. 54 | You receive a reward of 1 if you reach the goal, and zero otherwise. 55 | 56 | """ 57 | 58 | metadata = {'render.modes': ['human', 'ansi']} 59 | 60 | def __init__(self, desc=None, map_name="4x4",is_slippery=True): 61 | if desc is None and map_name is None: 62 | raise ValueError('Must provide either desc or map_name') 63 | elif desc is None: 64 | desc = MAPS[map_name] 65 | self.desc = desc = np.asarray(desc,dtype='c') 66 | self.nrow, self.ncol = nrow, ncol = desc.shape 67 | 68 | nA = 4 69 | nS = nrow * ncol 70 | 71 | isd = np.array(desc == b'S').astype('float64').ravel() 72 | isd /= isd.sum() 73 | 74 | P = {s : {a : [] for a in range(nA)} for s in range(nS)} 75 | 76 | def to_s(row, col): 77 | return row*ncol + col 78 | def inc(row, col, a): 79 | if a==0: # left 80 | col = max(col-1,0) 81 | elif a==1: # down 82 | row = min(row+1,nrow-1) 83 | elif a==2: # right 84 | col = min(col+1,ncol-1) 85 | elif a==3: # up 86 | row = max(row-1,0) 87 | return (row, col) 88 | 89 | for row in range(nrow): 90 | for col in range(ncol): 91 | s = to_s(row, col) 92 | for a in range(4): 93 | li = P[s][a] 94 | letter = desc[row, col] 95 | if letter in b'GH': 96 | li.append((1.0, s, 0, True)) 97 | else: 98 | if is_slippery: 99 | for b in [(a-1)%4, a, (a+1)%4]: 100 | newrow, newcol = inc(row, col, b) 101 | newstate = to_s(newrow, newcol) 102 | newletter = desc[newrow, newcol] 103 | done = bytes(newletter) in b'GH' 104 | rew = float(newletter == b'G') 105 | li.append((1.0/3.0, newstate, rew, done)) 106 | else: 107 | newrow, newcol = inc(row, col, a) 108 | newstate = to_s(newrow, newcol) 109 | newletter = desc[newrow, newcol] 110 | done = bytes(newletter) in b'GH' 111 | rew = float(newletter == b'G') 112 | li.append((1.0, newstate, rew, done)) 113 | 114 | # obtain one-step dynamics for dynamic programming setting 115 | self.P = P 116 | 117 | super(FrozenLakeEnv, self).__init__(nS, nA, P, isd) 118 | 119 | def _render(self, mode='human', close=False): 120 | if close: 121 | return 122 | outfile = StringIO() if mode == 'ansi' else sys.stdout 123 | 124 | row, col = self.s // self.ncol, self.s % self.ncol 125 | desc = self.desc.tolist() 126 | desc = [[c.decode('utf-8') for c in line] for line in desc] 127 | desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True) 128 | if self.lastaction is not None: 129 | outfile.write(" ({})\n".format(["Left","Down","Right","Up"][self.lastaction])) 130 | else: 131 | outfile.write("\n") 132 | outfile.write("\n".join(''.join(line) for line in desc)+"\n") 133 | 134 | if mode != 'human': 135 | return outfile 136 | -------------------------------------------------------------------------------- /dynamic-programming/plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def plot_values(V): 5 | # reshape value function 6 | V_sq = np.reshape(V, (4,4)) 7 | 8 | # plot the state-value function 9 | fig = plt.figure(figsize=(6, 6)) 10 | ax = fig.add_subplot(111) 11 | im = ax.imshow(V_sq, cmap='cool') 12 | for (j,i),label in np.ndenumerate(V_sq): 13 | ax.text(i, j, np.round(label, 5), ha='center', va='center', fontsize=14) 14 | plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False) 15 | plt.title('State-Value Function') 16 | plt.show() 17 | -------------------------------------------------------------------------------- /finance/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def hidden_init(layer): 8 | fan_in = layer.weight.data.size()[0] 9 | lim = 1. / np.sqrt(fan_in) 10 | return (-lim, lim) 11 | 12 | class Actor(nn.Module): 13 | """Actor (Policy) Model.""" 14 | 15 | def __init__(self, state_size, action_size, seed, fc1_units=24, fc2_units=48): 16 | """Initialize parameters and build model. 17 | Params 18 | ====== 19 | state_size (int): Dimension of each state 20 | action_size (int): Dimension of each action 21 | seed (int): Random seed 22 | fc1_units (int): Number of nodes in first hidden layer 23 | fc2_units (int): Number of nodes in second hidden layer 24 | """ 25 | super(Actor, self).__init__() 26 | self.seed = torch.manual_seed(seed) 27 | self.fc1 = nn.Linear(state_size, fc1_units) 28 | self.fc2 = nn.Linear(fc1_units, fc2_units) 29 | self.fc3 = nn.Linear(fc2_units, action_size) 30 | self.reset_parameters() 31 | 32 | def reset_parameters(self): 33 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) 34 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 35 | self.fc3.weight.data.uniform_(-3e-3, 3e-3) 36 | 37 | def forward(self, state): 38 | """Build an actor (policy) network that maps states -> actions.""" 39 | x = F.relu(self.fc1(state)) 40 | x = F.relu(self.fc2(x)) 41 | return F.tanh(self.fc3(x)) 42 | 43 | 44 | class Critic(nn.Module): 45 | """Critic (Value) Model.""" 46 | 47 | def __init__(self, state_size, action_size, seed, fcs1_units=24, fc2_units=48): 48 | """Initialize parameters and build model. 49 | Params 50 | ====== 51 | state_size (int): Dimension of each state 52 | action_size (int): Dimension of each action 53 | seed (int): Random seed 54 | fcs1_units (int): Number of nodes in the first hidden layer 55 | fc2_units (int): Number of nodes in the second hidden layer 56 | """ 57 | super(Critic, self).__init__() 58 | self.seed = torch.manual_seed(seed) 59 | self.fcs1 = nn.Linear(state_size, fcs1_units) 60 | self.fc2 = nn.Linear(fcs1_units+action_size, fc2_units) 61 | self.fc3 = nn.Linear(fc2_units, 1) 62 | self.reset_parameters() 63 | 64 | def reset_parameters(self): 65 | self.fcs1.weight.data.uniform_(*hidden_init(self.fcs1)) 66 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 67 | self.fc3.weight.data.uniform_(-3e-3, 3e-3) 68 | 69 | def forward(self, state, action): 70 | """Build a critic (value) network that maps (state, action) pairs -> Q-values.""" 71 | xs = F.relu(self.fcs1(state)) 72 | x = torch.cat((xs, action), dim=1) 73 | x = F.relu(self.fc2(x)) 74 | return self.fc3(x) 75 | -------------------------------------------------------------------------------- /finance/text_images/4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/finance/text_images/4.jpeg -------------------------------------------------------------------------------- /finance/text_images/Actor-Critic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/finance/text_images/Actor-Critic.png -------------------------------------------------------------------------------- /finance/text_images/RL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/finance/text_images/RL.png -------------------------------------------------------------------------------- /finance/text_images/nvidia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/finance/text_images/nvidia.png -------------------------------------------------------------------------------- /finance/text_images/udacity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/finance/text_images/udacity.png -------------------------------------------------------------------------------- /hill-climbing/README.md: -------------------------------------------------------------------------------- 1 | [//]: # (Image References) 2 | 3 | [image1]: https://user-images.githubusercontent.com/10624937/42135683-dde5c6f0-7d13-11e8-90b1-8770df3e40cf.gif "Trained Agent" 4 | 5 | # Hill Climbing 6 | 7 | ### Instructions 8 | 9 | Open `Hill_Climbing.ipynb` to see an implementation of hill climbing with adaptive noise scaling with OpenAI Gym's Cartpole environment. 10 | 11 | Try to change the parameters in the notebook, to see if you can get the agent to train faster! 12 | 13 | ### Results 14 | 15 | ![Trained Agent][image1] -------------------------------------------------------------------------------- /lab-taxi/README.md: -------------------------------------------------------------------------------- 1 | # Taxi Problem 2 | 3 | ### Getting Started 4 | 5 | Read the description of the environment in subsection 3.1 of [this paper](https://arxiv.org/pdf/cs/9905014.pdf). You can verify that the description in the paper matches the OpenAI Gym environment by peeking at the code [here](https://github.com/openai/gym/blob/master/gym/envs/toy_text/taxi.py). 6 | 7 | 8 | ### Instructions 9 | 10 | The repository contains three files: 11 | - `agent.py`: Develop your reinforcement learning agent here. This is the only file that you should modify. 12 | - `monitor.py`: The `interact` function tests how well your agent learns from interaction with the environment. 13 | - `main.py`: Run this file in the terminal to check the performance of your agent. 14 | 15 | Begin by running the following command in the terminal: 16 | ``` 17 | python main.py 18 | ``` 19 | 20 | When you run `main.py`, the agent that you specify in `agent.py` interacts with the environment for 20,000 episodes. The details of the interaction are specified in `monitor.py`, which returns two variables: `avg_rewards` and `best_avg_reward`. 21 | - `avg_rewards` is a deque where `avg_rewards[i]` is the average (undiscounted) return collected by the agent from episodes `i+1` to episode `i+100`, inclusive. So, for instance, `avg_rewards[0]` is the average return collected by the agent over the first 100 episodes. 22 | - `best_avg_reward` is the largest entry in `avg_rewards`. This is the final score that you should use when determining how well your agent performed in the task. 23 | 24 | Your assignment is to modify the `agents.py` file to improve the agent's performance. 25 | - Use the `__init__()` method to define any needed instance variables. Currently, we define the number of actions available to the agent (`nA`) and initialize the action values (`Q`) to an empty dictionary of arrays. Feel free to add more instance variables; for example, you may find it useful to define the value of epsilon if the agent uses an epsilon-greedy policy for selecting actions. 26 | - The `select_action()` method accepts the environment state as input and returns the agent's choice of action. The default code that we have provided randomly selects an action. 27 | - The `step()` method accepts a (`state`, `action`, `reward`, `next_state`) tuple as input, along with the `done` variable, which is `True` if the episode has ended. The default code (which you should certainly change!) increments the action value of the previous state-action pair by 1. You should change this method to use the sampled tuple of experience to update the agent's knowledge of the problem. 28 | 29 | Once you have modified the function, you need only run `python main.py` to test your new agent. 30 | 31 | OpenAI Gym [defines "solving"](https://gym.openai.com/envs/Taxi-v1/) this task as getting average return of 9.7 over 100 consecutive trials. 32 | -------------------------------------------------------------------------------- /lab-taxi/agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | 4 | class Agent: 5 | 6 | def __init__(self, nA=6): 7 | """ Initialize agent. 8 | 9 | Params 10 | ====== 11 | - nA: number of actions available to the agent 12 | """ 13 | self.nA = nA 14 | self.Q = defaultdict(lambda: np.zeros(self.nA)) 15 | 16 | def select_action(self, state): 17 | """ Given the state, select an action. 18 | 19 | Params 20 | ====== 21 | - state: the current state of the environment 22 | 23 | Returns 24 | ======= 25 | - action: an integer, compatible with the task's action space 26 | """ 27 | return np.random.choice(self.nA) 28 | 29 | def step(self, state, action, reward, next_state, done): 30 | """ Update the agent's knowledge, using the most recently sampled tuple. 31 | 32 | Params 33 | ====== 34 | - state: the previous state of the environment 35 | - action: the agent's previous choice of action 36 | - reward: last reward received 37 | - next_state: the current state of the environment 38 | - done: whether the episode is complete (True or False) 39 | """ 40 | self.Q[state][action] += 1 -------------------------------------------------------------------------------- /lab-taxi/main.py: -------------------------------------------------------------------------------- 1 | from agent import Agent 2 | from monitor import interact 3 | import gym 4 | import numpy as np 5 | 6 | env = gym.make('Taxi-v2') 7 | agent = Agent() 8 | avg_rewards, best_avg_reward = interact(env, agent) -------------------------------------------------------------------------------- /lab-taxi/monitor.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import sys 3 | import math 4 | import numpy as np 5 | 6 | def interact(env, agent, num_episodes=20000, window=100): 7 | """ Monitor agent's performance. 8 | 9 | Params 10 | ====== 11 | - env: instance of OpenAI Gym's Taxi-v1 environment 12 | - agent: instance of class Agent (see Agent.py for details) 13 | - num_episodes: number of episodes of agent-environment interaction 14 | - window: number of episodes to consider when calculating average rewards 15 | 16 | Returns 17 | ======= 18 | - avg_rewards: deque containing average rewards 19 | - best_avg_reward: largest value in the avg_rewards deque 20 | """ 21 | # initialize average rewards 22 | avg_rewards = deque(maxlen=num_episodes) 23 | # initialize best average reward 24 | best_avg_reward = -math.inf 25 | # initialize monitor for most recent rewards 26 | samp_rewards = deque(maxlen=window) 27 | # for each episode 28 | for i_episode in range(1, num_episodes+1): 29 | # begin the episode 30 | state = env.reset() 31 | # initialize the sampled reward 32 | samp_reward = 0 33 | while True: 34 | # agent selects an action 35 | action = agent.select_action(state) 36 | # agent performs the selected action 37 | next_state, reward, done, _ = env.step(action) 38 | # agent performs internal updates based on sampled experience 39 | agent.step(state, action, reward, next_state, done) 40 | # update the sampled reward 41 | samp_reward += reward 42 | # update the state (s <- s') to next time step 43 | state = next_state 44 | if done: 45 | # save final sampled reward 46 | samp_rewards.append(samp_reward) 47 | break 48 | if (i_episode >= 100): 49 | # get average reward from last 100 episodes 50 | avg_reward = np.mean(samp_rewards) 51 | # append to deque 52 | avg_rewards.append(avg_reward) 53 | # update best average reward 54 | if avg_reward > best_avg_reward: 55 | best_avg_reward = avg_reward 56 | # monitor progress 57 | print("\rEpisode {}/{} || Best average reward {}".format(i_episode, num_episodes, best_avg_reward), end="") 58 | sys.stdout.flush() 59 | # check if task is solved (according to OpenAI Gym) 60 | if best_avg_reward >= 9.7: 61 | print('\nEnvironment solved in {} episodes.'.format(i_episode), end="") 62 | break 63 | if i_episode == num_episodes: print('\n') 64 | return avg_rewards, best_avg_reward -------------------------------------------------------------------------------- /monte-carlo/README.md: -------------------------------------------------------------------------------- 1 | # Monte Carlo Methods 2 | 3 | ### Instructions 4 | 5 | Follow the instructions in `Monte_Carlo.ipynb` to write your own implementations of many Monte Carlo methods! The corresponding solutions can be found in `Monte_Carlo_Solution.ipynb`. 6 | -------------------------------------------------------------------------------- /monte-carlo/images/optimal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/monte-carlo/images/optimal.png -------------------------------------------------------------------------------- /monte-carlo/plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mpl_toolkits.mplot3d import Axes3D 3 | import matplotlib.pyplot as plt 4 | from mpl_toolkits.axes_grid1 import make_axes_locatable 5 | 6 | def plot_blackjack_values(V): 7 | 8 | def get_Z(x, y, usable_ace): 9 | if (x,y,usable_ace) in V: 10 | return V[x,y,usable_ace] 11 | else: 12 | return 0 13 | 14 | def get_figure(usable_ace, ax): 15 | x_range = np.arange(11, 22) 16 | y_range = np.arange(1, 11) 17 | X, Y = np.meshgrid(x_range, y_range) 18 | 19 | Z = np.array([get_Z(x,y,usable_ace) for x,y in zip(np.ravel(X), np.ravel(Y))]).reshape(X.shape) 20 | 21 | surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=plt.cm.coolwarm, vmin=-1.0, vmax=1.0) 22 | ax.set_xlabel('Player\'s Current Sum') 23 | ax.set_ylabel('Dealer\'s Showing Card') 24 | ax.set_zlabel('State Value') 25 | ax.view_init(ax.elev, -120) 26 | 27 | fig = plt.figure(figsize=(20, 20)) 28 | ax = fig.add_subplot(211, projection='3d') 29 | ax.set_title('Usable Ace') 30 | get_figure(True, ax) 31 | ax = fig.add_subplot(212, projection='3d') 32 | ax.set_title('No Usable Ace') 33 | get_figure(False, ax) 34 | plt.show() 35 | 36 | def plot_policy(policy): 37 | 38 | def get_Z(x, y, usable_ace): 39 | if (x,y,usable_ace) in policy: 40 | return policy[x,y,usable_ace] 41 | else: 42 | return 1 43 | 44 | def get_figure(usable_ace, ax): 45 | x_range = np.arange(11, 22) 46 | y_range = np.arange(10, 0, -1) 47 | X, Y = np.meshgrid(x_range, y_range) 48 | Z = np.array([[get_Z(x,y,usable_ace) for x in x_range] for y in y_range]) 49 | surf = ax.imshow(Z, cmap=plt.get_cmap('Pastel2', 2), vmin=0, vmax=1, extent=[10.5, 21.5, 0.5, 10.5]) 50 | plt.xticks(x_range) 51 | plt.yticks(y_range) 52 | plt.gca().invert_yaxis() 53 | ax.set_xlabel('Player\'s Current Sum') 54 | ax.set_ylabel('Dealer\'s Showing Card') 55 | ax.grid(color='w', linestyle='-', linewidth=1) 56 | divider = make_axes_locatable(ax) 57 | cax = divider.append_axes("right", size="5%", pad=0.1) 58 | cbar = plt.colorbar(surf, ticks=[0,1], cax=cax) 59 | cbar.ax.set_yticklabels(['0 (STICK)','1 (HIT)']) 60 | 61 | fig = plt.figure(figsize=(15, 15)) 62 | ax = fig.add_subplot(121) 63 | ax.set_title('Usable Ace') 64 | get_figure(True, ax) 65 | ax = fig.add_subplot(122) 66 | ax.set_title('No Usable Ace') 67 | get_figure(False, ax) 68 | plt.show() -------------------------------------------------------------------------------- /p1_navigation/README.md: -------------------------------------------------------------------------------- 1 | [//]: # (Image References) 2 | 3 | [image1]: https://user-images.githubusercontent.com/10624937/42135619-d90f2f28-7d12-11e8-8823-82b970a54d7e.gif "Trained Agent" 4 | 5 | # Project 1: Navigation 6 | 7 | ### Introduction 8 | 9 | For this project, you will train an agent to navigate (and collect bananas!) in a large, square world. 10 | 11 | ![Trained Agent][image1] 12 | 13 | A reward of +1 is provided for collecting a yellow banana, and a reward of -1 is provided for collecting a blue banana. Thus, the goal of your agent is to collect as many yellow bananas as possible while avoiding blue bananas. 14 | 15 | The state space has 37 dimensions and contains the agent's velocity, along with ray-based perception of objects around agent's forward direction. Given this information, the agent has to learn how to best select actions. Four discrete actions are available, corresponding to: 16 | - **`0`** - move forward. 17 | - **`1`** - move backward. 18 | - **`2`** - turn left. 19 | - **`3`** - turn right. 20 | 21 | The task is episodic, and in order to solve the environment, your agent must get an average score of +13 over 100 consecutive episodes. 22 | 23 | ### Getting Started 24 | 25 | 1. Download the environment from one of the links below. You need only select the environment that matches your operating system: 26 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Linux.zip) 27 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana.app.zip) 28 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Windows_x86.zip) 29 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Windows_x86_64.zip) 30 | 31 | (_For Windows users_) Check out [this link](https://support.microsoft.com/en-us/help/827218/how-to-determine-whether-a-computer-is-running-a-32-bit-version-or-64) if you need help with determining if your computer is running a 32-bit version or 64-bit version of the Windows operating system. 32 | 33 | (_For AWS_) If you'd like to train the agent on AWS (and have not [enabled a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md)), then please use [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Linux_NoVis.zip) to obtain the environment. 34 | 35 | 2. Place the file in the DRLND GitHub repository, in the `p1_navigation/` folder, and unzip (or decompress) the file. 36 | 37 | ### Instructions 38 | 39 | Follow the instructions in `Navigation.ipynb` to get started with training your own agent! 40 | 41 | ### (Optional) Challenge: Learning from Pixels 42 | 43 | After you have successfully completed the project, if you're looking for an additional challenge, you have come to the right place! In the project, your agent learned from information such as its velocity, along with ray-based perception of objects around its forward direction. A more challenging task would be to learn directly from pixels! 44 | 45 | To solve this harder task, you'll need to download a new Unity environment. This environment is almost identical to the project environment, where the only difference is that the state is an 84 x 84 RGB image, corresponding to the agent's first-person view. (**Note**: Udacity students should not submit a project with this new environment.) 46 | 47 | You need only select the environment that matches your operating system: 48 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/VisualBanana_Linux.zip) 49 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/VisualBanana.app.zip) 50 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/VisualBanana_Windows_x86.zip) 51 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/VisualBanana_Windows_x86_64.zip) 52 | 53 | Then, place the file in the `p1_navigation/` folder in the DRLND GitHub repository, and unzip (or decompress) the file. Next, open `Navigation_Pixels.ipynb` and follow the instructions to learn how to use the Python API to control the agent. 54 | 55 | (_For AWS_) If you'd like to train the agent on AWS, you must follow the instructions to [set up X Server](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md), and then download the environment for the **Linux** operating system above. 56 | -------------------------------------------------------------------------------- /p2_continuous-control/README.md: -------------------------------------------------------------------------------- 1 | [//]: # (Image References) 2 | 3 | [image1]: https://user-images.githubusercontent.com/10624937/43851024-320ba930-9aff-11e8-8493-ee547c6af349.gif "Trained Agent" 4 | [image2]: https://user-images.githubusercontent.com/10624937/43851646-d899bf20-9b00-11e8-858c-29b5c2c94ccc.png "Crawler" 5 | 6 | 7 | # Project 2: Continuous Control 8 | 9 | ### Introduction 10 | 11 | For this project, you will work with the [Reacher](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Learning-Environment-Examples.md#reacher) environment. 12 | 13 | ![Trained Agent][image1] 14 | 15 | In this environment, a double-jointed arm can move to target locations. A reward of +0.1 is provided for each step that the agent's hand is in the goal location. Thus, the goal of your agent is to maintain its position at the target location for as many time steps as possible. 16 | 17 | The observation space consists of 33 variables corresponding to position, rotation, velocity, and angular velocities of the arm. Each action is a vector with four numbers, corresponding to torque applicable to two joints. Every entry in the action vector should be a number between -1 and 1. 18 | 19 | ### Distributed Training 20 | 21 | For this project, we will provide you with two separate versions of the Unity environment: 22 | - The first version contains a single agent. 23 | - The second version contains 20 identical agents, each with its own copy of the environment. 24 | 25 | The second version is useful for algorithms like [PPO](https://arxiv.org/pdf/1707.06347.pdf), [A3C](https://arxiv.org/pdf/1602.01783.pdf), and [D4PG](https://openreview.net/pdf?id=SyZipzbCb) that use multiple (non-interacting, parallel) copies of the same agent to distribute the task of gathering experience. 26 | 27 | ### Solving the Environment 28 | 29 | Note that your project submission need only solve one of the two versions of the environment. 30 | 31 | #### Option 1: Solve the First Version 32 | 33 | The task is episodic, and in order to solve the environment, your agent must get an average score of +30 over 100 consecutive episodes. 34 | 35 | #### Option 2: Solve the Second Version 36 | 37 | The barrier for solving the second version of the environment is slightly different, to take into account the presence of many agents. In particular, your agents must get an average score of +30 (over 100 consecutive episodes, and over all agents). Specifically, 38 | - After each episode, we add up the rewards that each agent received (without discounting), to get a score for each agent. This yields 20 (potentially different) scores. We then take the average of these 20 scores. 39 | - This yields an **average score** for each episode (where the average is over all 20 agents). 40 | 41 | The environment is considered solved, when the average (over 100 episodes) of those average scores is at least +30. 42 | 43 | ### Getting Started 44 | 45 | 1. Download the environment from one of the links below. You need only select the environment that matches your operating system: 46 | 47 | - **_Version 1: One (1) Agent_** 48 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/one_agent/Reacher_Linux.zip) 49 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/one_agent/Reacher.app.zip) 50 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/one_agent/Reacher_Windows_x86.zip) 51 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/one_agent/Reacher_Windows_x86_64.zip) 52 | 53 | - **_Version 2: Twenty (20) Agents_** 54 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher_Linux.zip) 55 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher.app.zip) 56 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher_Windows_x86.zip) 57 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher_Windows_x86_64.zip) 58 | 59 | (_For Windows users_) Check out [this link](https://support.microsoft.com/en-us/help/827218/how-to-determine-whether-a-computer-is-running-a-32-bit-version-or-64) if you need help with determining if your computer is running a 32-bit version or 64-bit version of the Windows operating system. 60 | 61 | (_For AWS_) If you'd like to train the agent on AWS (and have not [enabled a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md)), then please use [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/one_agent/Reacher_Linux_NoVis.zip) (version 1) or [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher_Linux_NoVis.zip) (version 2) to obtain the "headless" version of the environment. You will **not** be able to watch the agent without enabling a virtual screen, but you will be able to train the agent. (_To watch the agent, you should follow the instructions to [enable a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md), and then download the environment for the **Linux** operating system above._) 62 | 63 | 2. Place the file in the DRLND GitHub repository, in the `p2_continuous-control/` folder, and unzip (or decompress) the file. 64 | 65 | ### Instructions 66 | 67 | Follow the instructions in `Continuous_Control.ipynb` to get started with training your own agent! 68 | 69 | ### (Optional) Challenge: Crawler Environment 70 | 71 | After you have successfully completed the project, you might like to solve the more difficult **Crawler** environment. 72 | 73 | ![Crawler][image2] 74 | 75 | In this continuous control environment, the goal is to teach a creature with four legs to walk forward without falling. 76 | 77 | You can read more about this environment in the ML-Agents GitHub [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Learning-Environment-Examples.md#crawler). To solve this harder task, you'll need to download a new Unity environment. (**Note**: Udacity students should not submit a project with this new environment.) 78 | 79 | You need only select the environment that matches your operating system: 80 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Crawler/Crawler_Linux.zip) 81 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Crawler/Crawler.app.zip) 82 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Crawler/Crawler_Windows_x86.zip) 83 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Crawler/Crawler_Windows_x86_64.zip) 84 | 85 | Then, place the file in the `p2_continuous-control/` folder in the DRLND GitHub repository, and unzip (or decompress) the file. Next, open `Crawler.ipynb` and follow the instructions to learn how to use the Python API to control the agent. 86 | 87 | (_For AWS_) If you'd like to train the agent on AWS (and have not [enabled a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md)), then please use [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Crawler/Crawler_Linux_NoVis.zip) to obtain the "headless" version of the environment. You will **not** be able to watch the agent without enabling a virtual screen, but you will be able to train the agent. (_To watch the agent, you should follow the instructions to [enable a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md), and then download the environment for the **Linux** operating system above._) 88 | 89 | -------------------------------------------------------------------------------- /p3_collab-compet/README.md: -------------------------------------------------------------------------------- 1 | [//]: # (Image References) 2 | 3 | [image1]: https://user-images.githubusercontent.com/10624937/42135623-e770e354-7d12-11e8-998d-29fc74429ca2.gif "Trained Agent" 4 | [image2]: https://user-images.githubusercontent.com/10624937/42135622-e55fb586-7d12-11e8-8a54-3c31da15a90a.gif "Soccer" 5 | 6 | 7 | # Project 3: Collaboration and Competition 8 | 9 | ### Introduction 10 | 11 | For this project, you will work with the [Tennis](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Learning-Environment-Examples.md#tennis) environment. 12 | 13 | ![Trained Agent][image1] 14 | 15 | In this environment, two agents control rackets to bounce a ball over a net. If an agent hits the ball over the net, it receives a reward of +0.1. If an agent lets a ball hit the ground or hits the ball out of bounds, it receives a reward of -0.01. Thus, the goal of each agent is to keep the ball in play. 16 | 17 | The observation space consists of 8 variables corresponding to the position and velocity of the ball and racket. Each agent receives its own, local observation. Two continuous actions are available, corresponding to movement toward (or away from) the net, and jumping. 18 | 19 | The task is episodic, and in order to solve the environment, your agents must get an average score of +0.5 (over 100 consecutive episodes, after taking the maximum over both agents). Specifically, 20 | 21 | - After each episode, we add up the rewards that each agent received (without discounting), to get a score for each agent. This yields 2 (potentially different) scores. We then take the maximum of these 2 scores. 22 | - This yields a single **score** for each episode. 23 | 24 | The environment is considered solved, when the average (over 100 episodes) of those **scores** is at least +0.5. 25 | 26 | ### Getting Started 27 | 28 | 1. Download the environment from one of the links below. You need only select the environment that matches your operating system: 29 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis_Linux.zip) 30 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis.app.zip) 31 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis_Windows_x86.zip) 32 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis_Windows_x86_64.zip) 33 | 34 | (_For Windows users_) Check out [this link](https://support.microsoft.com/en-us/help/827218/how-to-determine-whether-a-computer-is-running-a-32-bit-version-or-64) if you need help with determining if your computer is running a 32-bit version or 64-bit version of the Windows operating system. 35 | 36 | (_For AWS_) If you'd like to train the agent on AWS (and have not [enabled a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md)), then please use [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis_Linux_NoVis.zip) to obtain the "headless" version of the environment. You will **not** be able to watch the agent without enabling a virtual screen, but you will be able to train the agent. (_To watch the agent, you should follow the instructions to [enable a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md), and then download the environment for the **Linux** operating system above._) 37 | 38 | 2. Place the file in the DRLND GitHub repository, in the `p3_collab-compet/` folder, and unzip (or decompress) the file. 39 | 40 | ### Instructions 41 | 42 | Follow the instructions in `Tennis.ipynb` to get started with training your own agent! 43 | 44 | ### (Optional) Challenge: Soccer Environment 45 | 46 | After you have successfully completed the project, you might like to solve the more difficult **Soccer** environment. 47 | 48 | ![Soccer][image2] 49 | 50 | In this environment, the goal is to train a team of agents to play soccer. 51 | 52 | You can read more about this environment in the ML-Agents GitHub [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Learning-Environment-Examples.md#soccer-twos). To solve this harder task, you'll need to download a new Unity environment. (**Note**: Udacity students should not submit a project with this new environment.) 53 | 54 | You need only select the environment that matches your operating system: 55 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Soccer/Soccer_Linux.zip) 56 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Soccer/Soccer.app.zip) 57 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Soccer/Soccer_Windows_x86.zip) 58 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Soccer/Soccer_Windows_x86_64.zip) 59 | 60 | Then, place the file in the `p3_collab-compet/` folder in the DRLND GitHub repository, and unzip (or decompress) the file. Next, open `Soccer.ipynb` and follow the instructions to learn how to use the Python API to control the agent. 61 | 62 | (_For AWS_) If you'd like to train the agents on AWS (and have not [enabled a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md)), then please use [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Soccer/Soccer_Linux_NoVis.zip) to obtain the "headless" version of the environment. You will **not** be able to watch the agents without enabling a virtual screen, but you will be able to train the agents. (_To watch the agents, you should follow the instructions to [enable a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md), and then download the environment for the **Linux** operating system above._) 63 | -------------------------------------------------------------------------------- /python/Basics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Unity ML-Agents Toolkit\n", 8 | "## Environment Basics\n", 9 | "This notebook contains a walkthrough of the basic functions of the Python API for the Unity ML-Agents toolkit. For instructions on building a Unity environment, see [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Getting-Started-with-Balance-Ball.md)." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "### 1. Set environment parameters\n", 17 | "\n", 18 | "Be sure to set `env_name` to the name of the Unity environment file you want to launch. Ensure that the environment build is in the `python/` directory." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": { 25 | "collapsed": true 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "env_name = \"3DBall\" # Name of the Unity environment binary to launch\n", 30 | "train_mode = True # Whether to run the environment in training or inference mode" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "### 2. Load dependencies\n", 38 | "\n", 39 | "The following loads the necessary dependencies and checks the Python version (at runtime). ML-Agents Toolkit (v0.3 onwards) requires Python 3." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "import matplotlib.pyplot as plt\n", 51 | "import numpy as np\n", 52 | "import sys\n", 53 | "\n", 54 | "from unityagents import UnityEnvironment\n", 55 | "\n", 56 | "%matplotlib inline\n", 57 | "\n", 58 | "print(\"Python version:\")\n", 59 | "print(sys.version)\n", 60 | "\n", 61 | "# check Python version\n", 62 | "if (sys.version_info[0] < 3):\n", 63 | " raise Exception(\"ERROR: ML-Agents Toolkit (v0.3 onwards) requires Python 3\")" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "### 3. Start the environment\n", 71 | "`UnityEnvironment` launches and begins communication with the environment when instantiated.\n", 72 | "\n", 73 | "Environments contain _brains_ which are responsible for deciding the actions of their associated _agents_. Here we check for the first brain available, and set it as the default brain we will be controlling from Python." 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "collapsed": true 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "env = UnityEnvironment(file_name=env_name)\n", 85 | "\n", 86 | "# Examine environment parameters\n", 87 | "print(str(env))\n", 88 | "\n", 89 | "# Set the default brain to work with\n", 90 | "default_brain = env.brain_names[0]\n", 91 | "brain = env.brains[default_brain]" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "### 4. Examine the observation and state spaces\n", 99 | "We can reset the environment to be provided with an initial set of observations and states for all the agents within the environment. In ML-Agents, _states_ refer to a vector of variables corresponding to relevant aspects of the environment for an agent. Likewise, _observations_ refer to a set of relevant pixel-wise visuals for an agent." 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "collapsed": true 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "# Reset the environment\n", 111 | "env_info = env.reset(train_mode=train_mode)[default_brain]\n", 112 | "\n", 113 | "# Examine the state space for the default brain\n", 114 | "print(\"Agent state looks like: \\n{}\".format(env_info.vector_observations[0]))\n", 115 | "\n", 116 | "# Examine the observation space for the default brain\n", 117 | "for observation in env_info.visual_observations:\n", 118 | " print(\"Agent observations look like:\")\n", 119 | " if observation.shape[3] == 3:\n", 120 | " plt.imshow(observation[0,:,:,:])\n", 121 | " else:\n", 122 | " plt.imshow(observation[0,:,:,0])" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "### 5. Take random actions in the environment\n", 130 | "Once we restart an environment, we can step the environment forward and provide actions to all of the agents within the environment. Here we simply choose random actions based on the `action_space_type` of the default brain. \n", 131 | "\n", 132 | "Once this cell is executed, 10 messages will be printed that detail how much reward will be accumulated for the next 10 episodes. The Unity environment will then pause, waiting for further signals telling it what to do next. Thus, not seeing any animation is expected when running this cell." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "collapsed": true 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "for episode in range(10):\n", 144 | " env_info = env.reset(train_mode=train_mode)[default_brain]\n", 145 | " done = False\n", 146 | " episode_rewards = 0\n", 147 | " while not done:\n", 148 | " if brain.vector_action_space_type == 'continuous':\n", 149 | " env_info = env.step(np.random.randn(len(env_info.agents), \n", 150 | " brain.vector_action_space_size))[default_brain]\n", 151 | " else:\n", 152 | " env_info = env.step(np.random.randint(0, brain.vector_action_space_size, \n", 153 | " size=(len(env_info.agents))))[default_brain]\n", 154 | " episode_rewards += env_info.rewards[0]\n", 155 | " done = env_info.local_done[0]\n", 156 | " print(\"Total reward this episode: {}\".format(episode_rewards))" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "### 6. Close the environment when finished\n", 164 | "When we are finished using an environment, we can close it with the function below." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "collapsed": true 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "env.close()" 176 | ] 177 | } 178 | ], 179 | "metadata": { 180 | "anaconda-cloud": {}, 181 | "kernelspec": { 182 | "display_name": "Python 3", 183 | "language": "python", 184 | "name": "python3" 185 | }, 186 | "language_info": { 187 | "codemirror_mode": { 188 | "name": "ipython", 189 | "version": 3 190 | }, 191 | "file_extension": ".py", 192 | "mimetype": "text/x-python", 193 | "name": "python", 194 | "nbconvert_exporter": "python", 195 | "pygments_lexer": "ipython3", 196 | "version": "3.6.3" 197 | } 198 | }, 199 | "nbformat": 4, 200 | "nbformat_minor": 1 201 | } 202 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | # Dependencies 2 | 3 | This is an amended version of the `python/` folder from the [ML-Agents repository](https://github.com/Unity-Technologies/ml-agents). It has been edited to include a few additional pip packages needed for the Deep Reinforcement Learning Nanodegree program. 4 | -------------------------------------------------------------------------------- /python/communicator_objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent_action_proto_pb2 import * 2 | from .agent_info_proto_pb2 import * 3 | from .brain_parameters_proto_pb2 import * 4 | from .brain_type_proto_pb2 import * 5 | from .command_proto_pb2 import * 6 | from .engine_configuration_proto_pb2 import * 7 | from .environment_parameters_proto_pb2 import * 8 | from .header_pb2 import * 9 | from .resolution_proto_pb2 import * 10 | from .space_type_proto_pb2 import * 11 | from .unity_input_pb2 import * 12 | from .unity_message_pb2 import * 13 | from .unity_output_pb2 import * 14 | from .unity_rl_initialization_input_pb2 import * 15 | from .unity_rl_initialization_output_pb2 import * 16 | from .unity_rl_input_pb2 import * 17 | from .unity_rl_output_pb2 import * 18 | from .unity_to_external_pb2 import * 19 | from .unity_to_external_pb2_grpc import * 20 | -------------------------------------------------------------------------------- /python/communicator_objects/agent_action_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/agent_action_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/agent_action_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n-communicator_objects/agent_action_proto.proto\x12\x14\x63ommunicator_objects\"R\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _AGENTACTIONPROTO = _descriptor.Descriptor( 29 | name='AgentActionProto', 30 | full_name='communicator_objects.AgentActionProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='vector_actions', full_name='communicator_objects.AgentActionProto.vector_actions', index=0, 37 | number=1, type=2, cpp_type=6, label=3, 38 | has_default_value=False, default_value=[], 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='text_actions', full_name='communicator_objects.AgentActionProto.text_actions', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='memories', full_name='communicator_objects.AgentActionProto.memories', index=2, 51 | number=3, type=2, cpp_type=6, label=3, 52 | has_default_value=False, default_value=[], 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto3', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=71, 69 | serialized_end=153, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['AgentActionProto'] = _AGENTACTIONPROTO 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | AgentActionProto = _reflection.GeneratedProtocolMessageType('AgentActionProto', (_message.Message,), dict( 76 | DESCRIPTOR = _AGENTACTIONPROTO, 77 | __module__ = 'communicator_objects.agent_action_proto_pb2' 78 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentActionProto) 79 | )) 80 | _sym_db.RegisterMessage(AgentActionProto) 81 | 82 | 83 | DESCRIPTOR.has_options = True 84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /python/communicator_objects/agent_info_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/agent_info_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/agent_info_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n+communicator_objects/agent_info_proto.proto\x12\x14\x63ommunicator_objects\"\xfd\x01\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x1b\n\x13visual_observations\x18\x02 \x03(\x0c\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x10\n\x08memories\x18\x06 \x03(\x02\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _AGENTINFOPROTO = _descriptor.Descriptor( 29 | name='AgentInfoProto', 30 | full_name='communicator_objects.AgentInfoProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='stacked_vector_observation', full_name='communicator_objects.AgentInfoProto.stacked_vector_observation', index=0, 37 | number=1, type=2, cpp_type=6, label=3, 38 | has_default_value=False, default_value=[], 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='visual_observations', full_name='communicator_objects.AgentInfoProto.visual_observations', index=1, 44 | number=2, type=12, cpp_type=9, label=3, 45 | has_default_value=False, default_value=[], 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='text_observation', full_name='communicator_objects.AgentInfoProto.text_observation', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='stored_vector_actions', full_name='communicator_objects.AgentInfoProto.stored_vector_actions', index=3, 58 | number=4, type=2, cpp_type=6, label=3, 59 | has_default_value=False, default_value=[], 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None, file=DESCRIPTOR), 63 | _descriptor.FieldDescriptor( 64 | name='stored_text_actions', full_name='communicator_objects.AgentInfoProto.stored_text_actions', index=4, 65 | number=5, type=9, cpp_type=9, label=1, 66 | has_default_value=False, default_value=_b("").decode('utf-8'), 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None, file=DESCRIPTOR), 70 | _descriptor.FieldDescriptor( 71 | name='memories', full_name='communicator_objects.AgentInfoProto.memories', index=5, 72 | number=6, type=2, cpp_type=6, label=3, 73 | has_default_value=False, default_value=[], 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None, file=DESCRIPTOR), 77 | _descriptor.FieldDescriptor( 78 | name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=6, 79 | number=7, type=2, cpp_type=6, label=1, 80 | has_default_value=False, default_value=float(0), 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None, file=DESCRIPTOR), 84 | _descriptor.FieldDescriptor( 85 | name='done', full_name='communicator_objects.AgentInfoProto.done', index=7, 86 | number=8, type=8, cpp_type=7, label=1, 87 | has_default_value=False, default_value=False, 88 | message_type=None, enum_type=None, containing_type=None, 89 | is_extension=False, extension_scope=None, 90 | options=None, file=DESCRIPTOR), 91 | _descriptor.FieldDescriptor( 92 | name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=8, 93 | number=9, type=8, cpp_type=7, label=1, 94 | has_default_value=False, default_value=False, 95 | message_type=None, enum_type=None, containing_type=None, 96 | is_extension=False, extension_scope=None, 97 | options=None, file=DESCRIPTOR), 98 | _descriptor.FieldDescriptor( 99 | name='id', full_name='communicator_objects.AgentInfoProto.id', index=9, 100 | number=10, type=5, cpp_type=1, label=1, 101 | has_default_value=False, default_value=0, 102 | message_type=None, enum_type=None, containing_type=None, 103 | is_extension=False, extension_scope=None, 104 | options=None, file=DESCRIPTOR), 105 | ], 106 | extensions=[ 107 | ], 108 | nested_types=[], 109 | enum_types=[ 110 | ], 111 | options=None, 112 | is_extendable=False, 113 | syntax='proto3', 114 | extension_ranges=[], 115 | oneofs=[ 116 | ], 117 | serialized_start=70, 118 | serialized_end=323, 119 | ) 120 | 121 | DESCRIPTOR.message_types_by_name['AgentInfoProto'] = _AGENTINFOPROTO 122 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 123 | 124 | AgentInfoProto = _reflection.GeneratedProtocolMessageType('AgentInfoProto', (_message.Message,), dict( 125 | DESCRIPTOR = _AGENTINFOPROTO, 126 | __module__ = 'communicator_objects.agent_info_proto_pb2' 127 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentInfoProto) 128 | )) 129 | _sym_db.RegisterMessage(AgentInfoProto) 130 | 131 | 132 | DESCRIPTOR.has_options = True 133 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 134 | # @@protoc_insertion_point(module_scope) 135 | -------------------------------------------------------------------------------- /python/communicator_objects/brain_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/brain_parameters_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 17 | from communicator_objects import brain_type_proto_pb2 as communicator__objects_dot_brain__type__proto__pb2 18 | from communicator_objects import space_type_proto_pb2 as communicator__objects_dot_space__type__proto__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='communicator_objects/brain_parameters_proto.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n1communicator_objects/brain_parameters_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto\x1a+communicator_objects/brain_type_proto.proto\x1a+communicator_objects/space_type_proto.proto\"\xc6\x03\n\x14\x42rainParametersProto\x12\x1f\n\x17vector_observation_size\x18\x01 \x01(\x05\x12\'\n\x1fnum_stacked_vector_observations\x18\x02 \x01(\x05\x12\x1a\n\x12vector_action_size\x18\x03 \x01(\x05\x12\x41\n\x12\x63\x61mera_resolutions\x18\x04 \x03(\x0b\x32%.communicator_objects.ResolutionProto\x12\"\n\x1avector_action_descriptions\x18\x05 \x03(\t\x12\x46\n\x18vector_action_space_type\x18\x06 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12K\n\x1dvector_observation_space_type\x18\x07 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12\x12\n\nbrain_name\x18\x08 \x01(\t\x12\x38\n\nbrain_type\x18\t \x01(\x0e\x32$.communicator_objects.BrainTypeProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,communicator__objects_dot_brain__type__proto__pb2.DESCRIPTOR,communicator__objects_dot_space__type__proto__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _BRAINPARAMETERSPROTO = _descriptor.Descriptor( 33 | name='BrainParametersProto', 34 | full_name='communicator_objects.BrainParametersProto', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='vector_observation_size', full_name='communicator_objects.BrainParametersProto.vector_observation_size', index=0, 41 | number=1, type=5, cpp_type=1, label=1, 42 | has_default_value=False, default_value=0, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='num_stacked_vector_observations', full_name='communicator_objects.BrainParametersProto.num_stacked_vector_observations', index=1, 48 | number=2, type=5, cpp_type=1, label=1, 49 | has_default_value=False, default_value=0, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='vector_action_size', full_name='communicator_objects.BrainParametersProto.vector_action_size', index=2, 55 | number=3, type=5, cpp_type=1, label=1, 56 | has_default_value=False, default_value=0, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None, file=DESCRIPTOR), 60 | _descriptor.FieldDescriptor( 61 | name='camera_resolutions', full_name='communicator_objects.BrainParametersProto.camera_resolutions', index=3, 62 | number=4, type=11, cpp_type=10, label=3, 63 | has_default_value=False, default_value=[], 64 | message_type=None, enum_type=None, containing_type=None, 65 | is_extension=False, extension_scope=None, 66 | options=None, file=DESCRIPTOR), 67 | _descriptor.FieldDescriptor( 68 | name='vector_action_descriptions', full_name='communicator_objects.BrainParametersProto.vector_action_descriptions', index=4, 69 | number=5, type=9, cpp_type=9, label=3, 70 | has_default_value=False, default_value=[], 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=None, file=DESCRIPTOR), 74 | _descriptor.FieldDescriptor( 75 | name='vector_action_space_type', full_name='communicator_objects.BrainParametersProto.vector_action_space_type', index=5, 76 | number=6, type=14, cpp_type=8, label=1, 77 | has_default_value=False, default_value=0, 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=None, file=DESCRIPTOR), 81 | _descriptor.FieldDescriptor( 82 | name='vector_observation_space_type', full_name='communicator_objects.BrainParametersProto.vector_observation_space_type', index=6, 83 | number=7, type=14, cpp_type=8, label=1, 84 | has_default_value=False, default_value=0, 85 | message_type=None, enum_type=None, containing_type=None, 86 | is_extension=False, extension_scope=None, 87 | options=None, file=DESCRIPTOR), 88 | _descriptor.FieldDescriptor( 89 | name='brain_name', full_name='communicator_objects.BrainParametersProto.brain_name', index=7, 90 | number=8, type=9, cpp_type=9, label=1, 91 | has_default_value=False, default_value=_b("").decode('utf-8'), 92 | message_type=None, enum_type=None, containing_type=None, 93 | is_extension=False, extension_scope=None, 94 | options=None, file=DESCRIPTOR), 95 | _descriptor.FieldDescriptor( 96 | name='brain_type', full_name='communicator_objects.BrainParametersProto.brain_type', index=8, 97 | number=9, type=14, cpp_type=8, label=1, 98 | has_default_value=False, default_value=0, 99 | message_type=None, enum_type=None, containing_type=None, 100 | is_extension=False, extension_scope=None, 101 | options=None, file=DESCRIPTOR), 102 | ], 103 | extensions=[ 104 | ], 105 | nested_types=[], 106 | enum_types=[ 107 | ], 108 | options=None, 109 | is_extendable=False, 110 | syntax='proto3', 111 | extension_ranges=[], 112 | oneofs=[ 113 | ], 114 | serialized_start=211, 115 | serialized_end=665, 116 | ) 117 | 118 | _BRAINPARAMETERSPROTO.fields_by_name['camera_resolutions'].message_type = communicator__objects_dot_resolution__proto__pb2._RESOLUTIONPROTO 119 | _BRAINPARAMETERSPROTO.fields_by_name['vector_action_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO 120 | _BRAINPARAMETERSPROTO.fields_by_name['vector_observation_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO 121 | _BRAINPARAMETERSPROTO.fields_by_name['brain_type'].enum_type = communicator__objects_dot_brain__type__proto__pb2._BRAINTYPEPROTO 122 | DESCRIPTOR.message_types_by_name['BrainParametersProto'] = _BRAINPARAMETERSPROTO 123 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 124 | 125 | BrainParametersProto = _reflection.GeneratedProtocolMessageType('BrainParametersProto', (_message.Message,), dict( 126 | DESCRIPTOR = _BRAINPARAMETERSPROTO, 127 | __module__ = 'communicator_objects.brain_parameters_proto_pb2' 128 | # @@protoc_insertion_point(class_scope:communicator_objects.BrainParametersProto) 129 | )) 130 | _sym_db.RegisterMessage(BrainParametersProto) 131 | 132 | 133 | DESCRIPTOR.has_options = True 134 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 135 | # @@protoc_insertion_point(module_scope) 136 | -------------------------------------------------------------------------------- /python/communicator_objects/brain_type_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/brain_type_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/brain_type_proto.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n+communicator_objects/brain_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*G\n\x0e\x42rainTypeProto\x12\n\n\x06Player\x10\x00\x12\r\n\tHeuristic\x10\x01\x12\x0c\n\x08\x45xternal\x10\x02\x12\x0c\n\x08Internal\x10\x03\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) 27 | 28 | _BRAINTYPEPROTO = _descriptor.EnumDescriptor( 29 | name='BrainTypeProto', 30 | full_name='communicator_objects.BrainTypeProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='Player', index=0, number=0, 36 | options=None, 37 | type=None), 38 | _descriptor.EnumValueDescriptor( 39 | name='Heuristic', index=1, number=1, 40 | options=None, 41 | type=None), 42 | _descriptor.EnumValueDescriptor( 43 | name='External', index=2, number=2, 44 | options=None, 45 | type=None), 46 | _descriptor.EnumValueDescriptor( 47 | name='Internal', index=3, number=3, 48 | options=None, 49 | type=None), 50 | ], 51 | containing_type=None, 52 | options=None, 53 | serialized_start=114, 54 | serialized_end=185, 55 | ) 56 | _sym_db.RegisterEnumDescriptor(_BRAINTYPEPROTO) 57 | 58 | BrainTypeProto = enum_type_wrapper.EnumTypeWrapper(_BRAINTYPEPROTO) 59 | Player = 0 60 | Heuristic = 1 61 | External = 2 62 | Internal = 3 63 | 64 | 65 | DESCRIPTOR.enum_types_by_name['BrainTypeProto'] = _BRAINTYPEPROTO 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | 69 | DESCRIPTOR.has_options = True 70 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 71 | # @@protoc_insertion_point(module_scope) 72 | -------------------------------------------------------------------------------- /python/communicator_objects/command_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/command_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/command_proto.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n(communicator_objects/command_proto.proto\x12\x14\x63ommunicator_objects*-\n\x0c\x43ommandProto\x12\x08\n\x04STEP\x10\x00\x12\t\n\x05RESET\x10\x01\x12\x08\n\x04QUIT\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | _COMMANDPROTO = _descriptor.EnumDescriptor( 27 | name='CommandProto', 28 | full_name='communicator_objects.CommandProto', 29 | filename=None, 30 | file=DESCRIPTOR, 31 | values=[ 32 | _descriptor.EnumValueDescriptor( 33 | name='STEP', index=0, number=0, 34 | options=None, 35 | type=None), 36 | _descriptor.EnumValueDescriptor( 37 | name='RESET', index=1, number=1, 38 | options=None, 39 | type=None), 40 | _descriptor.EnumValueDescriptor( 41 | name='QUIT', index=2, number=2, 42 | options=None, 43 | type=None), 44 | ], 45 | containing_type=None, 46 | options=None, 47 | serialized_start=66, 48 | serialized_end=111, 49 | ) 50 | _sym_db.RegisterEnumDescriptor(_COMMANDPROTO) 51 | 52 | CommandProto = enum_type_wrapper.EnumTypeWrapper(_COMMANDPROTO) 53 | STEP = 0 54 | RESET = 1 55 | QUIT = 2 56 | 57 | 58 | DESCRIPTOR.enum_types_by_name['CommandProto'] = _COMMANDPROTO 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | 62 | DESCRIPTOR.has_options = True 63 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /python/communicator_objects/engine_configuration_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/engine_configuration_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/engine_configuration_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n5communicator_objects/engine_configuration_proto.proto\x12\x14\x63ommunicator_objects\"\x95\x01\n\x18\x45ngineConfigurationProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x15\n\rquality_level\x18\x03 \x01(\x05\x12\x12\n\ntime_scale\x18\x04 \x01(\x02\x12\x19\n\x11target_frame_rate\x18\x05 \x01(\x05\x12\x14\n\x0cshow_monitor\x18\x06 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _ENGINECONFIGURATIONPROTO = _descriptor.Descriptor( 29 | name='EngineConfigurationProto', 30 | full_name='communicator_objects.EngineConfigurationProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='width', full_name='communicator_objects.EngineConfigurationProto.width', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='communicator_objects.EngineConfigurationProto.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='quality_level', full_name='communicator_objects.EngineConfigurationProto.quality_level', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=False, default_value=0, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='time_scale', full_name='communicator_objects.EngineConfigurationProto.time_scale', index=3, 58 | number=4, type=2, cpp_type=6, label=1, 59 | has_default_value=False, default_value=float(0), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None, file=DESCRIPTOR), 63 | _descriptor.FieldDescriptor( 64 | name='target_frame_rate', full_name='communicator_objects.EngineConfigurationProto.target_frame_rate', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=False, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None, file=DESCRIPTOR), 70 | _descriptor.FieldDescriptor( 71 | name='show_monitor', full_name='communicator_objects.EngineConfigurationProto.show_monitor', index=5, 72 | number=6, type=8, cpp_type=7, label=1, 73 | has_default_value=False, default_value=False, 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None, file=DESCRIPTOR), 77 | ], 78 | extensions=[ 79 | ], 80 | nested_types=[], 81 | enum_types=[ 82 | ], 83 | options=None, 84 | is_extendable=False, 85 | syntax='proto3', 86 | extension_ranges=[], 87 | oneofs=[ 88 | ], 89 | serialized_start=80, 90 | serialized_end=229, 91 | ) 92 | 93 | DESCRIPTOR.message_types_by_name['EngineConfigurationProto'] = _ENGINECONFIGURATIONPROTO 94 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 95 | 96 | EngineConfigurationProto = _reflection.GeneratedProtocolMessageType('EngineConfigurationProto', (_message.Message,), dict( 97 | DESCRIPTOR = _ENGINECONFIGURATIONPROTO, 98 | __module__ = 'communicator_objects.engine_configuration_proto_pb2' 99 | # @@protoc_insertion_point(class_scope:communicator_objects.EngineConfigurationProto) 100 | )) 101 | _sym_db.RegisterMessage(EngineConfigurationProto) 102 | 103 | 104 | DESCRIPTOR.has_options = True 105 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 106 | # @@protoc_insertion_point(module_scope) 107 | -------------------------------------------------------------------------------- /python/communicator_objects/environment_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/environment_parameters_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/environment_parameters_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n7communicator_objects/environment_parameters_proto.proto\x12\x14\x63ommunicator_objects\"\xb5\x01\n\x1a\x45nvironmentParametersProto\x12_\n\x10\x66loat_parameters\x18\x01 \x03(\x0b\x32\x45.communicator_objects.EnvironmentParametersProto.FloatParametersEntry\x1a\x36\n\x14\x46loatParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY = _descriptor.Descriptor( 29 | name='FloatParametersEntry', 30 | full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='key', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.key', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='value', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.value', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=False, default_value=float(0), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=209, 62 | serialized_end=263, 63 | ) 64 | 65 | _ENVIRONMENTPARAMETERSPROTO = _descriptor.Descriptor( 66 | name='EnvironmentParametersProto', 67 | full_name='communicator_objects.EnvironmentParametersProto', 68 | filename=None, 69 | file=DESCRIPTOR, 70 | containing_type=None, 71 | fields=[ 72 | _descriptor.FieldDescriptor( 73 | name='float_parameters', full_name='communicator_objects.EnvironmentParametersProto.float_parameters', index=0, 74 | number=1, type=11, cpp_type=10, label=3, 75 | has_default_value=False, default_value=[], 76 | message_type=None, enum_type=None, containing_type=None, 77 | is_extension=False, extension_scope=None, 78 | options=None, file=DESCRIPTOR), 79 | ], 80 | extensions=[ 81 | ], 82 | nested_types=[_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, ], 83 | enum_types=[ 84 | ], 85 | options=None, 86 | is_extendable=False, 87 | syntax='proto3', 88 | extension_ranges=[], 89 | oneofs=[ 90 | ], 91 | serialized_start=82, 92 | serialized_end=263, 93 | ) 94 | 95 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.containing_type = _ENVIRONMENTPARAMETERSPROTO 96 | _ENVIRONMENTPARAMETERSPROTO.fields_by_name['float_parameters'].message_type = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY 97 | DESCRIPTOR.message_types_by_name['EnvironmentParametersProto'] = _ENVIRONMENTPARAMETERSPROTO 98 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 99 | 100 | EnvironmentParametersProto = _reflection.GeneratedProtocolMessageType('EnvironmentParametersProto', (_message.Message,), dict( 101 | 102 | FloatParametersEntry = _reflection.GeneratedProtocolMessageType('FloatParametersEntry', (_message.Message,), dict( 103 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, 104 | __module__ = 'communicator_objects.environment_parameters_proto_pb2' 105 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto.FloatParametersEntry) 106 | )) 107 | , 108 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO, 109 | __module__ = 'communicator_objects.environment_parameters_proto_pb2' 110 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto) 111 | )) 112 | _sym_db.RegisterMessage(EnvironmentParametersProto) 113 | _sym_db.RegisterMessage(EnvironmentParametersProto.FloatParametersEntry) 114 | 115 | 116 | DESCRIPTOR.has_options = True 117 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 118 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.has_options = True 119 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 120 | # @@protoc_insertion_point(module_scope) 121 | -------------------------------------------------------------------------------- /python/communicator_objects/header_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/header.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/header.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n!communicator_objects/header.proto\x12\x14\x63ommunicator_objects\")\n\x06Header\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _HEADER = _descriptor.Descriptor( 29 | name='Header', 30 | full_name='communicator_objects.Header', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='status', full_name='communicator_objects.Header.status', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='message', full_name='communicator_objects.Header.message', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=None, 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=59, 62 | serialized_end=100, 63 | ) 64 | 65 | DESCRIPTOR.message_types_by_name['Header'] = _HEADER 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | Header = _reflection.GeneratedProtocolMessageType('Header', (_message.Message,), dict( 69 | DESCRIPTOR = _HEADER, 70 | __module__ = 'communicator_objects.header_pb2' 71 | # @@protoc_insertion_point(class_scope:communicator_objects.Header) 72 | )) 73 | _sym_db.RegisterMessage(Header) 74 | 75 | 76 | DESCRIPTOR.has_options = True 77 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 78 | # @@protoc_insertion_point(module_scope) 79 | -------------------------------------------------------------------------------- /python/communicator_objects/resolution_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/resolution_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/resolution_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n+communicator_objects/resolution_proto.proto\x12\x14\x63ommunicator_objects\"D\n\x0fResolutionProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x12\n\ngray_scale\x18\x03 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _RESOLUTIONPROTO = _descriptor.Descriptor( 29 | name='ResolutionProto', 30 | full_name='communicator_objects.ResolutionProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='width', full_name='communicator_objects.ResolutionProto.width', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='communicator_objects.ResolutionProto.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='gray_scale', full_name='communicator_objects.ResolutionProto.gray_scale', index=2, 51 | number=3, type=8, cpp_type=7, label=1, 52 | has_default_value=False, default_value=False, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto3', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=69, 69 | serialized_end=137, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['ResolutionProto'] = _RESOLUTIONPROTO 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | ResolutionProto = _reflection.GeneratedProtocolMessageType('ResolutionProto', (_message.Message,), dict( 76 | DESCRIPTOR = _RESOLUTIONPROTO, 77 | __module__ = 'communicator_objects.resolution_proto_pb2' 78 | # @@protoc_insertion_point(class_scope:communicator_objects.ResolutionProto) 79 | )) 80 | _sym_db.RegisterMessage(ResolutionProto) 81 | 82 | 83 | DESCRIPTOR.has_options = True 84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /python/communicator_objects/space_type_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/space_type_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/space_type_proto.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n+communicator_objects/space_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*.\n\x0eSpaceTypeProto\x12\x0c\n\x08\x64iscrete\x10\x00\x12\x0e\n\ncontinuous\x10\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) 27 | 28 | _SPACETYPEPROTO = _descriptor.EnumDescriptor( 29 | name='SpaceTypeProto', 30 | full_name='communicator_objects.SpaceTypeProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='discrete', index=0, number=0, 36 | options=None, 37 | type=None), 38 | _descriptor.EnumValueDescriptor( 39 | name='continuous', index=1, number=1, 40 | options=None, 41 | type=None), 42 | ], 43 | containing_type=None, 44 | options=None, 45 | serialized_start=114, 46 | serialized_end=160, 47 | ) 48 | _sym_db.RegisterEnumDescriptor(_SPACETYPEPROTO) 49 | 50 | SpaceTypeProto = enum_type_wrapper.EnumTypeWrapper(_SPACETYPEPROTO) 51 | discrete = 0 52 | continuous = 1 53 | 54 | 55 | DESCRIPTOR.enum_types_by_name['SpaceTypeProto'] = _SPACETYPEPROTO 56 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 57 | 58 | 59 | DESCRIPTOR.has_options = True 60 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 61 | # @@protoc_insertion_point(module_scope) 62 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_rl_input_pb2 as communicator__objects_dot_unity__rl__input__pb2 17 | from communicator_objects import unity_rl_initialization_input_pb2 as communicator__objects_dot_unity__rl__initialization__input__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_input.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n&communicator_objects/unity_input.proto\x12\x14\x63ommunicator_objects\x1a)communicator_objects/unity_rl_input.proto\x1a\x38\x63ommunicator_objects/unity_rl_initialization_input.proto\"\xb0\x01\n\nUnityInput\x12\x34\n\x08rl_input\x18\x01 \x01(\x0b\x32\".communicator_objects.UnityRLInput\x12Q\n\x17rl_initialization_input\x18\x02 \x01(\x0b\x32\x30.communicator_objects.UnityRLInitializationInput\x12\x19\n\x11\x63ustom_data_input\x18\x03 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_unity__rl__input__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__input__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYINPUT = _descriptor.Descriptor( 32 | name='UnityInput', 33 | full_name='communicator_objects.UnityInput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='rl_input', full_name='communicator_objects.UnityInput.rl_input', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='rl_initialization_input', full_name='communicator_objects.UnityInput.rl_initialization_input', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='custom_data_input', full_name='communicator_objects.UnityInput.custom_data_input', index=2, 54 | number=3, type=5, cpp_type=1, label=1, 55 | has_default_value=False, default_value=0, 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | ], 60 | extensions=[ 61 | ], 62 | nested_types=[], 63 | enum_types=[ 64 | ], 65 | options=None, 66 | is_extendable=False, 67 | syntax='proto3', 68 | extension_ranges=[], 69 | oneofs=[ 70 | ], 71 | serialized_start=166, 72 | serialized_end=342, 73 | ) 74 | 75 | _UNITYINPUT.fields_by_name['rl_input'].message_type = communicator__objects_dot_unity__rl__input__pb2._UNITYRLINPUT 76 | _UNITYINPUT.fields_by_name['rl_initialization_input'].message_type = communicator__objects_dot_unity__rl__initialization__input__pb2._UNITYRLINITIALIZATIONINPUT 77 | DESCRIPTOR.message_types_by_name['UnityInput'] = _UNITYINPUT 78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 79 | 80 | UnityInput = _reflection.GeneratedProtocolMessageType('UnityInput', (_message.Message,), dict( 81 | DESCRIPTOR = _UNITYINPUT, 82 | __module__ = 'communicator_objects.unity_input_pb2' 83 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityInput) 84 | )) 85 | _sym_db.RegisterMessage(UnityInput) 86 | 87 | 88 | DESCRIPTOR.has_options = True 89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_message_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_message.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_output_pb2 as communicator__objects_dot_unity__output__pb2 17 | from communicator_objects import unity_input_pb2 as communicator__objects_dot_unity__input__pb2 18 | from communicator_objects import header_pb2 as communicator__objects_dot_header__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='communicator_objects/unity_message.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n(communicator_objects/unity_message.proto\x12\x14\x63ommunicator_objects\x1a\'communicator_objects/unity_output.proto\x1a&communicator_objects/unity_input.proto\x1a!communicator_objects/header.proto\"\xac\x01\n\x0cUnityMessage\x12,\n\x06header\x18\x01 \x01(\x0b\x32\x1c.communicator_objects.Header\x12\x37\n\x0cunity_output\x18\x02 \x01(\x0b\x32!.communicator_objects.UnityOutput\x12\x35\n\x0bunity_input\x18\x03 \x01(\x0b\x32 .communicator_objects.UnityInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[communicator__objects_dot_unity__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__input__pb2.DESCRIPTOR,communicator__objects_dot_header__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _UNITYMESSAGE = _descriptor.Descriptor( 33 | name='UnityMessage', 34 | full_name='communicator_objects.UnityMessage', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='header', full_name='communicator_objects.UnityMessage.header', index=0, 41 | number=1, type=11, cpp_type=10, label=1, 42 | has_default_value=False, default_value=None, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='unity_output', full_name='communicator_objects.UnityMessage.unity_output', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='unity_input', full_name='communicator_objects.UnityMessage.unity_input', index=2, 55 | number=3, type=11, cpp_type=10, label=1, 56 | has_default_value=False, default_value=None, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None, file=DESCRIPTOR), 60 | ], 61 | extensions=[ 62 | ], 63 | nested_types=[], 64 | enum_types=[ 65 | ], 66 | options=None, 67 | is_extendable=False, 68 | syntax='proto3', 69 | extension_ranges=[], 70 | oneofs=[ 71 | ], 72 | serialized_start=183, 73 | serialized_end=355, 74 | ) 75 | 76 | _UNITYMESSAGE.fields_by_name['header'].message_type = communicator__objects_dot_header__pb2._HEADER 77 | _UNITYMESSAGE.fields_by_name['unity_output'].message_type = communicator__objects_dot_unity__output__pb2._UNITYOUTPUT 78 | _UNITYMESSAGE.fields_by_name['unity_input'].message_type = communicator__objects_dot_unity__input__pb2._UNITYINPUT 79 | DESCRIPTOR.message_types_by_name['UnityMessage'] = _UNITYMESSAGE 80 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 81 | 82 | UnityMessage = _reflection.GeneratedProtocolMessageType('UnityMessage', (_message.Message,), dict( 83 | DESCRIPTOR = _UNITYMESSAGE, 84 | __module__ = 'communicator_objects.unity_message_pb2' 85 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityMessage) 86 | )) 87 | _sym_db.RegisterMessage(UnityMessage) 88 | 89 | 90 | DESCRIPTOR.has_options = True 91 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 92 | # @@protoc_insertion_point(module_scope) 93 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_rl_output_pb2 as communicator__objects_dot_unity__rl__output__pb2 17 | from communicator_objects import unity_rl_initialization_output_pb2 as communicator__objects_dot_unity__rl__initialization__output__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n\'communicator_objects/unity_output.proto\x12\x14\x63ommunicator_objects\x1a*communicator_objects/unity_rl_output.proto\x1a\x39\x63ommunicator_objects/unity_rl_initialization_output.proto\"\xb6\x01\n\x0bUnityOutput\x12\x36\n\trl_output\x18\x01 \x01(\x0b\x32#.communicator_objects.UnityRLOutput\x12S\n\x18rl_initialization_output\x18\x02 \x01(\x0b\x32\x31.communicator_objects.UnityRLInitializationOutput\x12\x1a\n\x12\x63ustom_data_output\x18\x03 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_unity__rl__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__output__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYOUTPUT = _descriptor.Descriptor( 32 | name='UnityOutput', 33 | full_name='communicator_objects.UnityOutput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='rl_output', full_name='communicator_objects.UnityOutput.rl_output', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='rl_initialization_output', full_name='communicator_objects.UnityOutput.rl_initialization_output', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='custom_data_output', full_name='communicator_objects.UnityOutput.custom_data_output', index=2, 54 | number=3, type=9, cpp_type=9, label=1, 55 | has_default_value=False, default_value=_b("").decode('utf-8'), 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | ], 60 | extensions=[ 61 | ], 62 | nested_types=[], 63 | enum_types=[ 64 | ], 65 | options=None, 66 | is_extendable=False, 67 | syntax='proto3', 68 | extension_ranges=[], 69 | oneofs=[ 70 | ], 71 | serialized_start=169, 72 | serialized_end=351, 73 | ) 74 | 75 | _UNITYOUTPUT.fields_by_name['rl_output'].message_type = communicator__objects_dot_unity__rl__output__pb2._UNITYRLOUTPUT 76 | _UNITYOUTPUT.fields_by_name['rl_initialization_output'].message_type = communicator__objects_dot_unity__rl__initialization__output__pb2._UNITYRLINITIALIZATIONOUTPUT 77 | DESCRIPTOR.message_types_by_name['UnityOutput'] = _UNITYOUTPUT 78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 79 | 80 | UnityOutput = _reflection.GeneratedProtocolMessageType('UnityOutput', (_message.Message,), dict( 81 | DESCRIPTOR = _UNITYOUTPUT, 82 | __module__ = 'communicator_objects.unity_output_pb2' 83 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityOutput) 84 | )) 85 | _sym_db.RegisterMessage(UnityOutput) 86 | 87 | 88 | DESCRIPTOR.has_options = True 89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_rl_initialization_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_initialization_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/unity_rl_initialization_input.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n8communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\"*\n\x1aUnityRLInitializationInput\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _UNITYRLINITIALIZATIONINPUT = _descriptor.Descriptor( 29 | name='UnityRLInitializationInput', 30 | full_name='communicator_objects.UnityRLInitializationInput', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='seed', full_name='communicator_objects.UnityRLInitializationInput.seed', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | ], 43 | extensions=[ 44 | ], 45 | nested_types=[], 46 | enum_types=[ 47 | ], 48 | options=None, 49 | is_extendable=False, 50 | syntax='proto3', 51 | extension_ranges=[], 52 | oneofs=[ 53 | ], 54 | serialized_start=82, 55 | serialized_end=124, 56 | ) 57 | 58 | DESCRIPTOR.message_types_by_name['UnityRLInitializationInput'] = _UNITYRLINITIALIZATIONINPUT 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | UnityRLInitializationInput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationInput', (_message.Message,), dict( 62 | DESCRIPTOR = _UNITYRLINITIALIZATIONINPUT, 63 | __module__ = 'communicator_objects.unity_rl_initialization_input_pb2' 64 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationInput) 65 | )) 66 | _sym_db.RegisterMessage(UnityRLInitializationInput) 67 | 68 | 69 | DESCRIPTOR.has_options = True 70 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 71 | # @@protoc_insertion_point(module_scope) 72 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_rl_initialization_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_initialization_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import brain_parameters_proto_pb2 as communicator__objects_dot_brain__parameters__proto__pb2 17 | from communicator_objects import environment_parameters_proto_pb2 as communicator__objects_dot_environment__parameters__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_rl_initialization_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n9communicator_objects/unity_rl_initialization_output.proto\x12\x14\x63ommunicator_objects\x1a\x31\x63ommunicator_objects/brain_parameters_proto.proto\x1a\x37\x63ommunicator_objects/environment_parameters_proto.proto\"\xe6\x01\n\x1bUnityRLInitializationOutput\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x10\n\x08log_path\x18\x03 \x01(\t\x12\x44\n\x10\x62rain_parameters\x18\x05 \x03(\x0b\x32*.communicator_objects.BrainParametersProto\x12P\n\x16\x65nvironment_parameters\x18\x06 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_brain__parameters__proto__pb2.DESCRIPTOR,communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYRLINITIALIZATIONOUTPUT = _descriptor.Descriptor( 32 | name='UnityRLInitializationOutput', 33 | full_name='communicator_objects.UnityRLInitializationOutput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='name', full_name='communicator_objects.UnityRLInitializationOutput.name', index=0, 40 | number=1, type=9, cpp_type=9, label=1, 41 | has_default_value=False, default_value=_b("").decode('utf-8'), 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='version', full_name='communicator_objects.UnityRLInitializationOutput.version', index=1, 47 | number=2, type=9, cpp_type=9, label=1, 48 | has_default_value=False, default_value=_b("").decode('utf-8'), 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='log_path', full_name='communicator_objects.UnityRLInitializationOutput.log_path', index=2, 54 | number=3, type=9, cpp_type=9, label=1, 55 | has_default_value=False, default_value=_b("").decode('utf-8'), 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | _descriptor.FieldDescriptor( 60 | name='brain_parameters', full_name='communicator_objects.UnityRLInitializationOutput.brain_parameters', index=3, 61 | number=5, type=11, cpp_type=10, label=3, 62 | has_default_value=False, default_value=[], 63 | message_type=None, enum_type=None, containing_type=None, 64 | is_extension=False, extension_scope=None, 65 | options=None, file=DESCRIPTOR), 66 | _descriptor.FieldDescriptor( 67 | name='environment_parameters', full_name='communicator_objects.UnityRLInitializationOutput.environment_parameters', index=4, 68 | number=6, type=11, cpp_type=10, label=1, 69 | has_default_value=False, default_value=None, 70 | message_type=None, enum_type=None, containing_type=None, 71 | is_extension=False, extension_scope=None, 72 | options=None, file=DESCRIPTOR), 73 | ], 74 | extensions=[ 75 | ], 76 | nested_types=[], 77 | enum_types=[ 78 | ], 79 | options=None, 80 | is_extendable=False, 81 | syntax='proto3', 82 | extension_ranges=[], 83 | oneofs=[ 84 | ], 85 | serialized_start=192, 86 | serialized_end=422, 87 | ) 88 | 89 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['brain_parameters'].message_type = communicator__objects_dot_brain__parameters__proto__pb2._BRAINPARAMETERSPROTO 90 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['environment_parameters'].message_type = communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO 91 | DESCRIPTOR.message_types_by_name['UnityRLInitializationOutput'] = _UNITYRLINITIALIZATIONOUTPUT 92 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 93 | 94 | UnityRLInitializationOutput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationOutput', (_message.Message,), dict( 95 | DESCRIPTOR = _UNITYRLINITIALIZATIONOUTPUT, 96 | __module__ = 'communicator_objects.unity_rl_initialization_output_pb2' 97 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationOutput) 98 | )) 99 | _sym_db.RegisterMessage(UnityRLInitializationOutput) 100 | 101 | 102 | DESCRIPTOR.has_options = True 103 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 104 | # @@protoc_insertion_point(module_scope) 105 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_rl_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import agent_info_proto_pb2 as communicator__objects_dot_agent__info__proto__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/unity_rl_output.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n*communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/agent_info_proto.proto\"\xa3\x02\n\rUnityRLOutput\x12\x13\n\x0bglobal_done\x18\x01 \x01(\x08\x12G\n\nagentInfos\x18\x02 \x03(\x0b\x32\x33.communicator_objects.UnityRLOutput.AgentInfosEntry\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1ai\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x45\n\x05value\x18\x02 \x01(\x0b\x32\x36.communicator_objects.UnityRLOutput.ListAgentInfoProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | , 25 | dependencies=[communicator__objects_dot_agent__info__proto__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | 30 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO = _descriptor.Descriptor( 31 | name='ListAgentInfoProto', 32 | full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto', 33 | filename=None, 34 | file=DESCRIPTOR, 35 | containing_type=None, 36 | fields=[ 37 | _descriptor.FieldDescriptor( 38 | name='value', full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto.value', index=0, 39 | number=1, type=11, cpp_type=10, label=3, 40 | has_default_value=False, default_value=[], 41 | message_type=None, enum_type=None, containing_type=None, 42 | is_extension=False, extension_scope=None, 43 | options=None, file=DESCRIPTOR), 44 | ], 45 | extensions=[ 46 | ], 47 | nested_types=[], 48 | enum_types=[ 49 | ], 50 | options=None, 51 | is_extendable=False, 52 | syntax='proto3', 53 | extension_ranges=[], 54 | oneofs=[ 55 | ], 56 | serialized_start=225, 57 | serialized_end=298, 58 | ) 59 | 60 | _UNITYRLOUTPUT_AGENTINFOSENTRY = _descriptor.Descriptor( 61 | name='AgentInfosEntry', 62 | full_name='communicator_objects.UnityRLOutput.AgentInfosEntry', 63 | filename=None, 64 | file=DESCRIPTOR, 65 | containing_type=None, 66 | fields=[ 67 | _descriptor.FieldDescriptor( 68 | name='key', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.key', index=0, 69 | number=1, type=9, cpp_type=9, label=1, 70 | has_default_value=False, default_value=_b("").decode('utf-8'), 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=None, file=DESCRIPTOR), 74 | _descriptor.FieldDescriptor( 75 | name='value', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.value', index=1, 76 | number=2, type=11, cpp_type=10, label=1, 77 | has_default_value=False, default_value=None, 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=None, file=DESCRIPTOR), 81 | ], 82 | extensions=[ 83 | ], 84 | nested_types=[], 85 | enum_types=[ 86 | ], 87 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 88 | is_extendable=False, 89 | syntax='proto3', 90 | extension_ranges=[], 91 | oneofs=[ 92 | ], 93 | serialized_start=300, 94 | serialized_end=405, 95 | ) 96 | 97 | _UNITYRLOUTPUT = _descriptor.Descriptor( 98 | name='UnityRLOutput', 99 | full_name='communicator_objects.UnityRLOutput', 100 | filename=None, 101 | file=DESCRIPTOR, 102 | containing_type=None, 103 | fields=[ 104 | _descriptor.FieldDescriptor( 105 | name='global_done', full_name='communicator_objects.UnityRLOutput.global_done', index=0, 106 | number=1, type=8, cpp_type=7, label=1, 107 | has_default_value=False, default_value=False, 108 | message_type=None, enum_type=None, containing_type=None, 109 | is_extension=False, extension_scope=None, 110 | options=None, file=DESCRIPTOR), 111 | _descriptor.FieldDescriptor( 112 | name='agentInfos', full_name='communicator_objects.UnityRLOutput.agentInfos', index=1, 113 | number=2, type=11, cpp_type=10, label=3, 114 | has_default_value=False, default_value=[], 115 | message_type=None, enum_type=None, containing_type=None, 116 | is_extension=False, extension_scope=None, 117 | options=None, file=DESCRIPTOR), 118 | ], 119 | extensions=[ 120 | ], 121 | nested_types=[_UNITYRLOUTPUT_LISTAGENTINFOPROTO, _UNITYRLOUTPUT_AGENTINFOSENTRY, ], 122 | enum_types=[ 123 | ], 124 | options=None, 125 | is_extendable=False, 126 | syntax='proto3', 127 | extension_ranges=[], 128 | oneofs=[ 129 | ], 130 | serialized_start=114, 131 | serialized_end=405, 132 | ) 133 | 134 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.fields_by_name['value'].message_type = communicator__objects_dot_agent__info__proto__pb2._AGENTINFOPROTO 135 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.containing_type = _UNITYRLOUTPUT 136 | _UNITYRLOUTPUT_AGENTINFOSENTRY.fields_by_name['value'].message_type = _UNITYRLOUTPUT_LISTAGENTINFOPROTO 137 | _UNITYRLOUTPUT_AGENTINFOSENTRY.containing_type = _UNITYRLOUTPUT 138 | _UNITYRLOUTPUT.fields_by_name['agentInfos'].message_type = _UNITYRLOUTPUT_AGENTINFOSENTRY 139 | DESCRIPTOR.message_types_by_name['UnityRLOutput'] = _UNITYRLOUTPUT 140 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 141 | 142 | UnityRLOutput = _reflection.GeneratedProtocolMessageType('UnityRLOutput', (_message.Message,), dict( 143 | 144 | ListAgentInfoProto = _reflection.GeneratedProtocolMessageType('ListAgentInfoProto', (_message.Message,), dict( 145 | DESCRIPTOR = _UNITYRLOUTPUT_LISTAGENTINFOPROTO, 146 | __module__ = 'communicator_objects.unity_rl_output_pb2' 147 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.ListAgentInfoProto) 148 | )) 149 | , 150 | 151 | AgentInfosEntry = _reflection.GeneratedProtocolMessageType('AgentInfosEntry', (_message.Message,), dict( 152 | DESCRIPTOR = _UNITYRLOUTPUT_AGENTINFOSENTRY, 153 | __module__ = 'communicator_objects.unity_rl_output_pb2' 154 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.AgentInfosEntry) 155 | )) 156 | , 157 | DESCRIPTOR = _UNITYRLOUTPUT, 158 | __module__ = 'communicator_objects.unity_rl_output_pb2' 159 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput) 160 | )) 161 | _sym_db.RegisterMessage(UnityRLOutput) 162 | _sym_db.RegisterMessage(UnityRLOutput.ListAgentInfoProto) 163 | _sym_db.RegisterMessage(UnityRLOutput.AgentInfosEntry) 164 | 165 | 166 | DESCRIPTOR.has_options = True 167 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 168 | _UNITYRLOUTPUT_AGENTINFOSENTRY.has_options = True 169 | _UNITYRLOUTPUT_AGENTINFOSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 170 | # @@protoc_insertion_point(module_scope) 171 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_to_external_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_to_external.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/unity_to_external.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n,communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a(communicator_objects/unity_message.proto2g\n\x0fUnityToExternal\x12T\n\x08\x45xchange\x12\".communicator_objects.UnityMessage\x1a\".communicator_objects.UnityMessage\"\x00\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | , 25 | dependencies=[communicator__objects_dot_unity__message__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 30 | 31 | 32 | DESCRIPTOR.has_options = True 33 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 34 | 35 | _UNITYTOEXTERNAL = _descriptor.ServiceDescriptor( 36 | name='UnityToExternal', 37 | full_name='communicator_objects.UnityToExternal', 38 | file=DESCRIPTOR, 39 | index=0, 40 | options=None, 41 | serialized_start=112, 42 | serialized_end=215, 43 | methods=[ 44 | _descriptor.MethodDescriptor( 45 | name='Exchange', 46 | full_name='communicator_objects.UnityToExternal.Exchange', 47 | index=0, 48 | containing_service=None, 49 | input_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 50 | output_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 51 | options=None, 52 | ), 53 | ]) 54 | _sym_db.RegisterServiceDescriptor(_UNITYTOEXTERNAL) 55 | 56 | DESCRIPTOR.services_by_name['UnityToExternal'] = _UNITYTOEXTERNAL 57 | 58 | # @@protoc_insertion_point(module_scope) 59 | -------------------------------------------------------------------------------- /python/communicator_objects/unity_to_external_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | import grpc 3 | 4 | from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2 5 | 6 | 7 | class UnityToExternalStub(object): 8 | # missing associated documentation comment in .proto file 9 | pass 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Exchange = channel.unary_unary( 18 | '/communicator_objects.UnityToExternal/Exchange', 19 | request_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 20 | response_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 21 | ) 22 | 23 | 24 | class UnityToExternalServicer(object): 25 | # missing associated documentation comment in .proto file 26 | pass 27 | 28 | def Exchange(self, request, context): 29 | """Sends the academy parameters 30 | """ 31 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 32 | context.set_details('Method not implemented!') 33 | raise NotImplementedError('Method not implemented!') 34 | 35 | 36 | def add_UnityToExternalServicer_to_server(servicer, server): 37 | rpc_method_handlers = { 38 | 'Exchange': grpc.unary_unary_rpc_method_handler( 39 | servicer.Exchange, 40 | request_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 41 | response_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 42 | ), 43 | } 44 | generic_handler = grpc.method_handlers_generic_handler( 45 | 'communicator_objects.UnityToExternal', rpc_method_handlers) 46 | server.add_generic_rpc_handlers((generic_handler,)) 47 | -------------------------------------------------------------------------------- /python/curricula/push.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "reward", 3 | "thresholds" : [0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75], 4 | "min_lesson_length" : 2, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "goal_size" : [3.5, 3.25, 3.0, 2.75, 2.5, 2.25, 2.0, 1.75, 1.5, 1.25, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 9 | "block_size": [1.5, 1.4, 1.3, 1.2, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 10 | "x_variation":[1.5, 1.55, 1.6, 1.65, 1.7, 1.75, 1.8, 1.85, 1.9, 1.95, 2.0, 2.1, 2.2, 2.3, 2.4, 2.5] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /python/curricula/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "reward", 3 | "thresholds" : [10, 20, 50], 4 | "min_lesson_length" : 3, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "param1" : [0.7, 0.5, 0.3, 0.1], 9 | "param2" : [100, 50, 20, 15], 10 | "param3" : [0.2, 0.3, 0.7, 0.9] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /python/curricula/wall.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "progress", 3 | "thresholds" : [0.1, 0.3, 0.5], 4 | "min_lesson_length" : 2, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "big_wall_min_height" : [0.0, 4.0, 6.0, 8.0], 9 | "big_wall_max_height" : [4.0, 7.0, 8.0, 8.0], 10 | "small_wall_height" : [1.5, 2.0, 2.5, 4.0] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /python/learn.py: -------------------------------------------------------------------------------- 1 | # # Unity ML-Agents Toolkit 2 | # ## ML-Agent Learning 3 | 4 | import logging 5 | 6 | import os 7 | from docopt import docopt 8 | 9 | from unitytrainers.trainer_controller import TrainerController 10 | 11 | 12 | if __name__ == '__main__': 13 | print(''' 14 | 15 | ▄▄▄▓▓▓▓ 16 | ╓▓▓▓▓▓▓█▓▓▓▓▓ 17 | ,▄▄▄m▀▀▀' ,▓▓▓▀▓▓▄ ▓▓▓ ▓▓▌ 18 | ▄▓▓▓▀' ▄▓▓▀ ▓▓▓ ▄▄ ▄▄ ,▄▄ ▄▄▄▄ ,▄▄ ▄▓▓▌▄ ▄▄▄ ,▄▄ 19 | ▄▓▓▓▀ ▄▓▓▀ ▐▓▓▌ ▓▓▌ ▐▓▓ ▐▓▓▓▀▀▀▓▓▌ ▓▓▓ ▀▓▓▌▀ ^▓▓▌ ╒▓▓▌ 20 | ▄▓▓▓▓▓▄▄▄▄▄▄▄▄▓▓▓ ▓▀ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▄ ▓▓▌ 21 | ▀▓▓▓▓▀▀▀▀▀▀▀▀▀▀▓▓▄ ▓▓ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▐▓▓ 22 | ^█▓▓▓ ▀▓▓▄ ▐▓▓▌ ▓▓▓▓▄▓▓▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▓▄ ▓▓▓▓` 23 | '▀▓▓▓▄ ^▓▓▓ ▓▓▓ └▀▀▀▀ ▀▀ ^▀▀ `▀▀ `▀▀ '▀▀ ▐▓▓▌ 24 | ▀▀▀▀▓▄▄▄ ▓▓▓▓▓▓, ▓▓▓▓▀ 25 | `▀█▓▓▓▓▓▓▓▓▓▌ 26 | ¬`▀▀▀█▓ 27 | 28 | ''') 29 | 30 | logger = logging.getLogger("unityagents") 31 | _USAGE = ''' 32 | Usage: 33 | learn () [options] 34 | learn [options] 35 | learn --help 36 | 37 | Options: 38 | --curriculum= Curriculum json file for environment [default: None]. 39 | --keep-checkpoints= How many model checkpoints to keep [default: 5]. 40 | --lesson= Start learning from this lesson [default: 0]. 41 | --load Whether to load the model or randomly initialize [default: False]. 42 | --run-id= The sub-directory name for model and summary statistics [default: ppo]. 43 | --save-freq= Frequency at which to save model [default: 50000]. 44 | --seed= Random seed used for training [default: -1]. 45 | --slow Whether to run the game at training speed [default: False]. 46 | --train Whether to train model, or only run inference [default: False]. 47 | --worker-id= Number to add to communication port (5005). Used for multi-environment [default: 0]. 48 | --docker-target-name=
Docker Volume to store curriculum, executable and model files [default: Empty]. 49 | --no-graphics Whether to run the Unity simulator in no-graphics mode [default: False]. 50 | ''' 51 | 52 | options = docopt(_USAGE) 53 | logger.info(options) 54 | # Docker Parameters 55 | if options['--docker-target-name'] == 'Empty': 56 | docker_target_name = '' 57 | else: 58 | docker_target_name = options['--docker-target-name'] 59 | 60 | # General parameters 61 | run_id = options['--run-id'] 62 | seed = int(options['--seed']) 63 | load_model = options['--load'] 64 | train_model = options['--train'] 65 | save_freq = int(options['--save-freq']) 66 | env_path = options[''] 67 | keep_checkpoints = int(options['--keep-checkpoints']) 68 | worker_id = int(options['--worker-id']) 69 | curriculum_file = str(options['--curriculum']) 70 | if curriculum_file == "None": 71 | curriculum_file = None 72 | lesson = int(options['--lesson']) 73 | fast_simulation = not bool(options['--slow']) 74 | no_graphics = options['--no-graphics'] 75 | 76 | # Constants 77 | # Assumption that this yaml is present in same dir as this file 78 | base_path = os.path.dirname(__file__) 79 | TRAINER_CONFIG_PATH = os.path.abspath(os.path.join(base_path, "trainer_config.yaml")) 80 | 81 | tc = TrainerController(env_path, run_id, save_freq, curriculum_file, fast_simulation, load_model, train_model, 82 | worker_id, keep_checkpoints, lesson, seed, docker_target_name, TRAINER_CONFIG_PATH, 83 | no_graphics) 84 | tc.start_learning() 85 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.7.1 2 | Pillow>=4.2.1 3 | matplotlib 4 | numpy>=1.11.0 5 | jupyter 6 | pytest>=3.2.2 7 | docopt 8 | pyyaml 9 | protobuf==3.5.2 10 | grpcio==1.11.0 11 | torch==0.4.0 12 | pandas 13 | scipy 14 | ipykernel 15 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, Command, find_packages 4 | 5 | 6 | with open('requirements.txt') as f: 7 | required = f.read().splitlines() 8 | 9 | setup(name='unityagents', 10 | version='0.4.0', 11 | description='Unity Machine Learning Agents', 12 | license='Apache License 2.0', 13 | author='Unity Technologies', 14 | author_email='ML-Agents@unity3d.com', 15 | url='https://github.com/Unity-Technologies/ml-agents', 16 | packages=find_packages(), 17 | install_requires = required, 18 | long_description= ("Unity Machine Learning Agents allows researchers and developers " 19 | "to transform games and simulations created using the Unity Editor into environments " 20 | "where intelligent agents can be trained using reinforcement learning, evolutionary " 21 | "strategies, or other machine learning methods through a simple to use Python API.") 22 | ) 23 | -------------------------------------------------------------------------------- /python/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from unityagents import * 2 | from unitytrainers import * 3 | -------------------------------------------------------------------------------- /python/tests/mock_communicator.py: -------------------------------------------------------------------------------- 1 | 2 | from unityagents.communicator import Communicator 3 | from communicator_objects import UnityMessage, UnityOutput, UnityInput,\ 4 | ResolutionProto, BrainParametersProto, UnityRLInitializationOutput,\ 5 | AgentInfoProto, UnityRLOutput 6 | 7 | 8 | class MockCommunicator(Communicator): 9 | def __init__(self, discrete_action=False, visual_inputs=0): 10 | """ 11 | Python side of the grpc communication. Python is the client and Unity the server 12 | 13 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 14 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 15 | """ 16 | self.is_discrete = discrete_action 17 | self.steps = 0 18 | self.visual_inputs = visual_inputs 19 | self.has_been_closed = False 20 | 21 | def initialize(self, inputs: UnityInput) -> UnityOutput: 22 | resolutions = [ResolutionProto( 23 | width=30, 24 | height=40, 25 | gray_scale=False) for i in range(self.visual_inputs)] 26 | bp = BrainParametersProto( 27 | vector_observation_size=3, 28 | num_stacked_vector_observations=2, 29 | vector_action_size=2, 30 | camera_resolutions=resolutions, 31 | vector_action_descriptions=["", ""], 32 | vector_action_space_type=int(not self.is_discrete), 33 | vector_observation_space_type=1, 34 | brain_name="RealFakeBrain", 35 | brain_type=2 36 | ) 37 | rl_init = UnityRLInitializationOutput( 38 | name="RealFakeAcademy", 39 | version="API-4", 40 | log_path="", 41 | brain_parameters=[bp] 42 | ) 43 | return UnityOutput( 44 | rl_initialization_output=rl_init 45 | ) 46 | 47 | def exchange(self, inputs: UnityInput) -> UnityOutput: 48 | dict_agent_info = {} 49 | if self.is_discrete: 50 | vector_action = [1] 51 | else: 52 | vector_action = [1, 2] 53 | list_agent_info = [] 54 | for i in range(3): 55 | list_agent_info.append( 56 | AgentInfoProto( 57 | stacked_vector_observation=[1, 2, 3, 1, 2, 3], 58 | reward=1, 59 | stored_vector_actions=vector_action, 60 | stored_text_actions="", 61 | text_observation="", 62 | memories=[], 63 | done=(i == 2), 64 | max_step_reached=False, 65 | id=i 66 | )) 67 | dict_agent_info["RealFakeBrain"] = \ 68 | UnityRLOutput.ListAgentInfoProto(value=list_agent_info) 69 | global_done = False 70 | try: 71 | global_done = (inputs.rl_input.agent_actions["RealFakeBrain"].value[0].vector_actions[0] == -1) 72 | except: 73 | pass 74 | result = UnityRLOutput( 75 | global_done=global_done, 76 | agentInfos=dict_agent_info 77 | ) 78 | return UnityOutput( 79 | rl_output=result 80 | ) 81 | 82 | def close(self): 83 | """ 84 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 85 | """ 86 | self.has_been_closed = True 87 | -------------------------------------------------------------------------------- /python/tests/test_bc.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | import pytest 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from unitytrainers.bc.models import BehavioralCloningModel 8 | from unityagents import UnityEnvironment 9 | from .mock_communicator import MockCommunicator 10 | 11 | 12 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 13 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 14 | def test_cc_bc_model(mock_communicator, mock_launcher): 15 | tf.reset_default_graph() 16 | with tf.Session() as sess: 17 | with tf.variable_scope("FakeGraphScope"): 18 | mock_communicator.return_value = MockCommunicator( 19 | discrete_action=False, visual_inputs=0) 20 | env = UnityEnvironment(' ') 21 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 22 | init = tf.global_variables_initializer() 23 | sess.run(init) 24 | 25 | run_list = [model.sample_action, model.policy] 26 | feed_dict = {model.batch_size: 2, 27 | model.sequence_length: 1, 28 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 29 | [3, 4, 5, 3, 4, 5]])} 30 | sess.run(run_list, feed_dict=feed_dict) 31 | env.close() 32 | 33 | 34 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 35 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 36 | def test_dc_bc_model(mock_communicator, mock_launcher): 37 | tf.reset_default_graph() 38 | with tf.Session() as sess: 39 | with tf.variable_scope("FakeGraphScope"): 40 | mock_communicator.return_value = MockCommunicator( 41 | discrete_action=True, visual_inputs=0) 42 | env = UnityEnvironment(' ') 43 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 44 | init = tf.global_variables_initializer() 45 | sess.run(init) 46 | 47 | run_list = [model.sample_action, model.policy] 48 | feed_dict = {model.batch_size: 2, 49 | model.dropout_rate: 1.0, 50 | model.sequence_length: 1, 51 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 52 | [3, 4, 5, 3, 4, 5]])} 53 | sess.run(run_list, feed_dict=feed_dict) 54 | env.close() 55 | 56 | 57 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 58 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 59 | def test_visual_dc_bc_model(mock_communicator, mock_launcher): 60 | tf.reset_default_graph() 61 | with tf.Session() as sess: 62 | with tf.variable_scope("FakeGraphScope"): 63 | mock_communicator.return_value = MockCommunicator( 64 | discrete_action=True, visual_inputs=2) 65 | env = UnityEnvironment(' ') 66 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 67 | init = tf.global_variables_initializer() 68 | sess.run(init) 69 | 70 | run_list = [model.sample_action, model.policy] 71 | feed_dict = {model.batch_size: 2, 72 | model.dropout_rate: 1.0, 73 | model.sequence_length: 1, 74 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 75 | [3, 4, 5, 3, 4, 5]]), 76 | model.visual_in[0]: np.ones([2, 40, 30, 3]), 77 | model.visual_in[1]: np.ones([2, 40, 30, 3])} 78 | sess.run(run_list, feed_dict=feed_dict) 79 | env.close() 80 | 81 | 82 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 83 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 84 | def test_visual_cc_bc_model(mock_communicator, mock_launcher): 85 | tf.reset_default_graph() 86 | with tf.Session() as sess: 87 | with tf.variable_scope("FakeGraphScope"): 88 | mock_communicator.return_value = MockCommunicator( 89 | discrete_action=False, visual_inputs=2) 90 | env = UnityEnvironment(' ') 91 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 92 | init = tf.global_variables_initializer() 93 | sess.run(init) 94 | 95 | run_list = [model.sample_action, model.policy] 96 | feed_dict = {model.batch_size: 2, 97 | model.sequence_length: 1, 98 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 99 | [3, 4, 5, 3, 4, 5]]), 100 | model.visual_in[0]: np.ones([2, 40, 30, 3]), 101 | model.visual_in[1]: np.ones([2, 40, 30, 3])} 102 | sess.run(run_list, feed_dict=feed_dict) 103 | env.close() 104 | 105 | 106 | if __name__ == '__main__': 107 | pytest.main() 108 | -------------------------------------------------------------------------------- /python/tests/test_unityagents.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unittest.mock as mock 3 | import pytest 4 | import struct 5 | 6 | import numpy as np 7 | 8 | from unityagents import UnityEnvironment, UnityEnvironmentException, UnityActionException, \ 9 | BrainInfo, Curriculum 10 | from .mock_communicator import MockCommunicator 11 | 12 | 13 | dummy_curriculum = json.loads('''{ 14 | "measure" : "reward", 15 | "thresholds" : [10, 20, 50], 16 | "min_lesson_length" : 3, 17 | "signal_smoothing" : true, 18 | "parameters" : 19 | { 20 | "param1" : [0.7, 0.5, 0.3, 0.1], 21 | "param2" : [100, 50, 20, 15], 22 | "param3" : [0.2, 0.3, 0.7, 0.9] 23 | } 24 | }''') 25 | bad_curriculum = json.loads('''{ 26 | "measure" : "reward", 27 | "thresholds" : [10, 20, 50], 28 | "min_lesson_length" : 3, 29 | "signal_smoothing" : false, 30 | "parameters" : 31 | { 32 | "param1" : [0.7, 0.5, 0.3, 0.1], 33 | "param2" : [100, 50, 20], 34 | "param3" : [0.2, 0.3, 0.7, 0.9] 35 | } 36 | }''') 37 | 38 | 39 | def test_handles_bad_filename(): 40 | with pytest.raises(UnityEnvironmentException): 41 | UnityEnvironment(' ') 42 | 43 | 44 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 45 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 46 | def test_initialization(mock_communicator, mock_launcher): 47 | mock_communicator.return_value = MockCommunicator( 48 | discrete_action=False, visual_inputs=0) 49 | env = UnityEnvironment(' ') 50 | with pytest.raises(UnityActionException): 51 | env.step([0]) 52 | assert env.brain_names[0] == 'RealFakeBrain' 53 | env.close() 54 | 55 | 56 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 57 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 58 | def test_reset(mock_communicator, mock_launcher): 59 | mock_communicator.return_value = MockCommunicator( 60 | discrete_action=False, visual_inputs=0) 61 | env = UnityEnvironment(' ') 62 | brain = env.brains['RealFakeBrain'] 63 | brain_info = env.reset() 64 | env.close() 65 | assert not env.global_done 66 | assert isinstance(brain_info, dict) 67 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo) 68 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) 69 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) 70 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations 71 | assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ 72 | len(brain_info['RealFakeBrain'].agents) 73 | assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ 74 | brain.vector_observation_space_size * brain.num_stacked_vector_observations 75 | 76 | 77 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 78 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 79 | def test_step(mock_communicator, mock_launcher): 80 | mock_communicator.return_value = MockCommunicator( 81 | discrete_action=False, visual_inputs=0) 82 | env = UnityEnvironment(' ') 83 | brain = env.brains['RealFakeBrain'] 84 | brain_info = env.reset() 85 | brain_info = env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) 86 | with pytest.raises(UnityActionException): 87 | env.step([0]) 88 | brain_info = env.step([-1] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) 89 | with pytest.raises(UnityActionException): 90 | env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents)) 91 | env.close() 92 | assert env.global_done 93 | assert isinstance(brain_info, dict) 94 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo) 95 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) 96 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) 97 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations 98 | assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ 99 | len(brain_info['RealFakeBrain'].agents) 100 | assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ 101 | brain.vector_observation_space_size * brain.num_stacked_vector_observations 102 | 103 | print("\n\n\n\n\n\n\n" + str(brain_info['RealFakeBrain'].local_done)) 104 | assert not brain_info['RealFakeBrain'].local_done[0] 105 | assert brain_info['RealFakeBrain'].local_done[2] 106 | 107 | 108 | @mock.patch('unityagents.UnityEnvironment.executable_launcher') 109 | @mock.patch('unityagents.UnityEnvironment.get_communicator') 110 | def test_close(mock_communicator, mock_launcher): 111 | comm = MockCommunicator( 112 | discrete_action=False, visual_inputs=0) 113 | mock_communicator.return_value = comm 114 | env = UnityEnvironment(' ') 115 | assert env._loaded 116 | env.close() 117 | assert not env._loaded 118 | assert comm.has_been_closed 119 | 120 | 121 | def test_curriculum(): 122 | open_name = '%s.open' % __name__ 123 | with mock.patch('json.load') as mock_load: 124 | with mock.patch(open_name, create=True) as mock_open: 125 | mock_open.return_value = 0 126 | mock_load.return_value = bad_curriculum 127 | with pytest.raises(UnityEnvironmentException): 128 | Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1, "param3": 1}) 129 | mock_load.return_value = dummy_curriculum 130 | with pytest.raises(UnityEnvironmentException): 131 | Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1}) 132 | curriculum = Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1, "param3": 1}) 133 | assert curriculum.get_lesson_number == 0 134 | curriculum.set_lesson_number(1) 135 | assert curriculum.get_lesson_number == 1 136 | curriculum.increment_lesson(10) 137 | assert curriculum.get_lesson_number == 1 138 | curriculum.increment_lesson(30) 139 | curriculum.increment_lesson(30) 140 | assert curriculum.get_lesson_number == 1 141 | assert curriculum.lesson_length == 3 142 | curriculum.increment_lesson(30) 143 | assert curriculum.get_config() == {'param1': 0.3, 'param2': 20, 'param3': 0.7} 144 | assert curriculum.get_config(0) == {"param1": 0.7, "param2": 100, "param3": 0.2} 145 | assert curriculum.lesson_length == 0 146 | assert curriculum.get_lesson_number == 2 147 | 148 | 149 | if __name__ == '__main__': 150 | pytest.main() 151 | -------------------------------------------------------------------------------- /python/trainer_config.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | trainer: ppo 3 | batch_size: 1024 4 | beta: 5.0e-3 5 | buffer_size: 10240 6 | epsilon: 0.2 7 | gamma: 0.99 8 | hidden_units: 128 9 | lambd: 0.95 10 | learning_rate: 3.0e-4 11 | max_steps: 5.0e4 12 | memory_size: 256 13 | normalize: false 14 | num_epoch: 3 15 | num_layers: 2 16 | time_horizon: 64 17 | sequence_length: 64 18 | summary_freq: 1000 19 | use_recurrent: false 20 | use_curiosity: false 21 | curiosity_strength: 0.01 22 | curiosity_enc_size: 128 23 | 24 | BananaBrain: 25 | normalize: false 26 | batch_size: 1024 27 | beta: 5.0e-3 28 | buffer_size: 10240 29 | 30 | BouncerBrain: 31 | normalize: true 32 | max_steps: 5.0e5 33 | num_layers: 2 34 | hidden_units: 64 35 | 36 | PushBlockBrain: 37 | max_steps: 5.0e4 38 | batch_size: 128 39 | buffer_size: 2048 40 | beta: 1.0e-2 41 | hidden_units: 256 42 | summary_freq: 2000 43 | time_horizon: 64 44 | num_layers: 2 45 | 46 | SmallWallBrain: 47 | max_steps: 2.0e5 48 | batch_size: 128 49 | buffer_size: 2048 50 | beta: 5.0e-3 51 | hidden_units: 256 52 | summary_freq: 2000 53 | time_horizon: 128 54 | num_layers: 2 55 | normalize: false 56 | 57 | BigWallBrain: 58 | max_steps: 2.0e5 59 | batch_size: 128 60 | buffer_size: 2048 61 | beta: 5.0e-3 62 | hidden_units: 256 63 | summary_freq: 2000 64 | time_horizon: 128 65 | num_layers: 2 66 | normalize: false 67 | 68 | StrikerBrain: 69 | max_steps: 1.0e5 70 | batch_size: 128 71 | buffer_size: 2048 72 | beta: 5.0e-3 73 | hidden_units: 256 74 | summary_freq: 2000 75 | time_horizon: 128 76 | num_layers: 2 77 | normalize: false 78 | 79 | GoalieBrain: 80 | max_steps: 1.0e5 81 | batch_size: 128 82 | buffer_size: 2048 83 | beta: 5.0e-3 84 | hidden_units: 256 85 | summary_freq: 2000 86 | time_horizon: 128 87 | num_layers: 2 88 | normalize: false 89 | 90 | PyramidBrain: 91 | use_curiosity: true 92 | summary_freq: 2000 93 | curiosity_strength: 0.01 94 | curiosity_enc_size: 256 95 | time_horizon: 128 96 | batch_size: 128 97 | buffer_size: 2048 98 | hidden_units: 512 99 | num_layers: 2 100 | beta: 1.0e-2 101 | max_steps: 2.0e5 102 | num_epoch: 3 103 | 104 | VisualPyramidBrain: 105 | use_curiosity: true 106 | time_horizon: 128 107 | batch_size: 32 108 | buffer_size: 1024 109 | hidden_units: 256 110 | num_layers: 2 111 | beta: 1.0e-2 112 | max_steps: 5.0e5 113 | num_epoch: 3 114 | 115 | Ball3DBrain: 116 | normalize: true 117 | batch_size: 64 118 | buffer_size: 12000 119 | summary_freq: 1000 120 | time_horizon: 1000 121 | lambd: 0.99 122 | gamma: 0.995 123 | beta: 0.001 124 | 125 | Ball3DHardBrain: 126 | normalize: true 127 | batch_size: 1200 128 | buffer_size: 12000 129 | summary_freq: 1000 130 | time_horizon: 1000 131 | gamma: 0.995 132 | beta: 0.001 133 | 134 | TennisBrain: 135 | normalize: true 136 | 137 | CrawlerBrain: 138 | normalize: true 139 | num_epoch: 3 140 | time_horizon: 1000 141 | batch_size: 2024 142 | buffer_size: 20240 143 | gamma: 0.995 144 | max_steps: 1e6 145 | summary_freq: 3000 146 | num_layers: 3 147 | hidden_units: 512 148 | 149 | WalkerBrain: 150 | normalize: true 151 | num_epoch: 3 152 | time_horizon: 1000 153 | batch_size: 2048 154 | buffer_size: 20480 155 | gamma: 0.995 156 | max_steps: 2e6 157 | summary_freq: 3000 158 | num_layers: 3 159 | hidden_units: 512 160 | 161 | ReacherBrain: 162 | normalize: true 163 | num_epoch: 3 164 | time_horizon: 1000 165 | batch_size: 2024 166 | buffer_size: 20240 167 | gamma: 0.995 168 | max_steps: 1e6 169 | summary_freq: 3000 170 | 171 | HallwayBrain: 172 | use_recurrent: true 173 | sequence_length: 64 174 | num_layers: 2 175 | hidden_units: 128 176 | memory_size: 256 177 | beta: 1.0e-2 178 | gamma: 0.99 179 | num_epoch: 3 180 | buffer_size: 1024 181 | batch_size: 128 182 | max_steps: 5.0e5 183 | summary_freq: 1000 184 | time_horizon: 64 185 | 186 | GridWorldBrain: 187 | batch_size: 32 188 | normalize: false 189 | num_layers: 1 190 | hidden_units: 256 191 | beta: 5.0e-3 192 | gamma: 0.9 193 | buffer_size: 256 194 | max_steps: 5.0e5 195 | summary_freq: 2000 196 | time_horizon: 5 197 | 198 | BasicBrain: 199 | batch_size: 32 200 | normalize: false 201 | num_layers: 1 202 | hidden_units: 20 203 | beta: 5.0e-3 204 | gamma: 0.9 205 | buffer_size: 256 206 | max_steps: 5.0e5 207 | summary_freq: 2000 208 | time_horizon: 3 209 | 210 | StudentBrain: 211 | trainer: imitation 212 | max_steps: 10000 213 | summary_freq: 1000 214 | brain_to_imitate: TeacherBrain 215 | batch_size: 16 216 | batches_per_epoch: 5 217 | num_layers: 4 218 | hidden_units: 64 219 | sequence_length: 16 220 | buffer_size: 128 221 | 222 | StudentRecurrentBrain: 223 | trainer: imitation 224 | max_steps: 10000 225 | summary_freq: 1000 226 | brain_to_imitate: TeacherBrain 227 | batch_size: 16 228 | batches_per_epoch: 5 229 | num_layers: 4 230 | hidden_units: 64 231 | use_recurrent: true 232 | sequence_length: 32 233 | buffer_size: 128 234 | -------------------------------------------------------------------------------- /python/unityagents/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import * 2 | from .brain import * 3 | from .exception import * 4 | from .curriculum import * 5 | -------------------------------------------------------------------------------- /python/unityagents/brain.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | class BrainInfo: 5 | def __init__(self, visual_observation, vector_observation, text_observations, memory=None, 6 | reward=None, agents=None, local_done=None, 7 | vector_action=None, text_action=None, max_reached=None): 8 | """ 9 | Describes experience at current step of all agents linked to a brain. 10 | """ 11 | self.visual_observations = visual_observation 12 | self.vector_observations = vector_observation 13 | self.text_observations = text_observations 14 | self.memories = memory 15 | self.rewards = reward 16 | self.local_done = local_done 17 | self.max_reached = max_reached 18 | self.agents = agents 19 | self.previous_vector_actions = vector_action 20 | self.previous_text_actions = text_action 21 | 22 | 23 | AllBrainInfo = Dict[str, BrainInfo] 24 | 25 | 26 | class BrainParameters: 27 | def __init__(self, brain_name, brain_param): 28 | """ 29 | Contains all brain-specific parameters. 30 | :param brain_name: Name of brain. 31 | :param brain_param: Dictionary of brain parameters. 32 | """ 33 | self.brain_name = brain_name 34 | self.vector_observation_space_size = brain_param["vectorObservationSize"] 35 | self.num_stacked_vector_observations = brain_param["numStackedVectorObservations"] 36 | self.number_visual_observations = len(brain_param["cameraResolutions"]) 37 | self.camera_resolutions = brain_param["cameraResolutions"] 38 | self.vector_action_space_size = brain_param["vectorActionSize"] 39 | self.vector_action_descriptions = brain_param["vectorActionDescriptions"] 40 | self.vector_action_space_type = ["discrete", "continuous"][brain_param["vectorActionSpaceType"]] 41 | self.vector_observation_space_type = ["discrete", "continuous"][brain_param["vectorObservationSpaceType"]] 42 | 43 | def __str__(self): 44 | return '''Unity brain name: {0} 45 | Number of Visual Observations (per agent): {1} 46 | Vector Observation space type: {2} 47 | Vector Observation space size (per agent): {3} 48 | Number of stacked Vector Observation: {4} 49 | Vector Action space type: {5} 50 | Vector Action space size (per agent): {6} 51 | Vector Action descriptions: {7}'''.format(self.brain_name, 52 | str(self.number_visual_observations), 53 | self.vector_observation_space_type, 54 | str(self.vector_observation_space_size), 55 | str(self.num_stacked_vector_observations), 56 | self.vector_action_space_type, 57 | str(self.vector_action_space_size), 58 | ', '.join(self.vector_action_descriptions)) 59 | -------------------------------------------------------------------------------- /python/unityagents/communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from communicator_objects import UnityOutput, UnityInput 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | logger = logging.getLogger("unityagents") 7 | 8 | 9 | class Communicator(object): 10 | def __init__(self, worker_id=0, 11 | base_port=5005): 12 | """ 13 | Python side of the communication. Must be used in pair with the right Unity Communicator equivalent. 14 | 15 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 16 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 17 | """ 18 | 19 | def initialize(self, inputs: UnityInput) -> UnityOutput: 20 | """ 21 | Used to exchange initialization parameters between Python and the Environment 22 | :param inputs: The initialization input that will be sent to the environment. 23 | :return: UnityOutput: The initialization output sent by Unity 24 | """ 25 | 26 | def exchange(self, inputs: UnityInput) -> UnityOutput: 27 | """ 28 | Used to send an input and receive an output from the Environment 29 | :param inputs: The UnityInput that needs to be sent the Environment 30 | :return: The UnityOutputs generated by the Environment 31 | """ 32 | 33 | def close(self): 34 | """ 35 | Sends a shutdown signal to the unity environment, and closes the connection. 36 | """ 37 | 38 | -------------------------------------------------------------------------------- /python/unityagents/curriculum.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .exception import UnityEnvironmentException 4 | 5 | import logging 6 | 7 | logger = logging.getLogger("unityagents") 8 | 9 | 10 | class Curriculum(object): 11 | def __init__(self, location, default_reset_parameters): 12 | """ 13 | Initializes a Curriculum object. 14 | :param location: Path to JSON defining curriculum. 15 | :param default_reset_parameters: Set of reset parameters for environment. 16 | """ 17 | self.lesson_length = 0 18 | self.max_lesson_number = 0 19 | self.measure_type = None 20 | if location is None: 21 | self.data = None 22 | else: 23 | try: 24 | with open(location) as data_file: 25 | self.data = json.load(data_file) 26 | except IOError: 27 | raise UnityEnvironmentException( 28 | "The file {0} could not be found.".format(location)) 29 | except UnicodeDecodeError: 30 | raise UnityEnvironmentException("There was an error decoding {}".format(location)) 31 | self.smoothing_value = 0 32 | for key in ['parameters', 'measure', 'thresholds', 33 | 'min_lesson_length', 'signal_smoothing']: 34 | if key not in self.data: 35 | raise UnityEnvironmentException("{0} does not contain a " 36 | "{1} field.".format(location, key)) 37 | parameters = self.data['parameters'] 38 | self.measure_type = self.data['measure'] 39 | self.max_lesson_number = len(self.data['thresholds']) 40 | for key in parameters: 41 | if key not in default_reset_parameters: 42 | raise UnityEnvironmentException( 43 | "The parameter {0} in Curriculum {1} is not present in " 44 | "the Environment".format(key, location)) 45 | for key in parameters: 46 | if len(parameters[key]) != self.max_lesson_number + 1: 47 | raise UnityEnvironmentException( 48 | "The parameter {0} in Curriculum {1} must have {2} values " 49 | "but {3} were found".format(key, location, 50 | self.max_lesson_number + 1, len(parameters[key]))) 51 | self.set_lesson_number(0) 52 | 53 | @property 54 | def measure(self): 55 | return self.measure_type 56 | 57 | @property 58 | def get_lesson_number(self): 59 | return self.lesson_number 60 | 61 | def set_lesson_number(self, value): 62 | self.lesson_length = 0 63 | self.lesson_number = max(0, min(value, self.max_lesson_number)) 64 | 65 | def increment_lesson(self, progress): 66 | """ 67 | Increments the lesson number depending on the progree given. 68 | :param progress: Measure of progress (either reward or percentage steps completed). 69 | """ 70 | if self.data is None or progress is None: 71 | return 72 | if self.data["signal_smoothing"]: 73 | progress = self.smoothing_value * 0.25 + 0.75 * progress 74 | self.smoothing_value = progress 75 | self.lesson_length += 1 76 | if self.lesson_number < self.max_lesson_number: 77 | if ((progress > self.data['thresholds'][self.lesson_number]) and 78 | (self.lesson_length > self.data['min_lesson_length'])): 79 | self.lesson_length = 0 80 | self.lesson_number += 1 81 | config = {} 82 | parameters = self.data["parameters"] 83 | for key in parameters: 84 | config[key] = parameters[key][self.lesson_number] 85 | logger.info("\nLesson changed. Now in Lesson {0} : \t{1}" 86 | .format(self.lesson_number, 87 | ', '.join([str(x) + ' -> ' + str(config[x]) for x in config]))) 88 | 89 | def get_config(self, lesson=None): 90 | """ 91 | Returns reset parameters which correspond to the lesson. 92 | :param lesson: The lesson you want to get the config of. If None, the current lesson is returned. 93 | :return: The configuration of the reset parameters. 94 | """ 95 | if self.data is None: 96 | return {} 97 | if lesson is None: 98 | lesson = self.lesson_number 99 | lesson = max(0, min(lesson, self.max_lesson_number)) 100 | config = {} 101 | parameters = self.data["parameters"] 102 | for key in parameters: 103 | config[key] = parameters[key][lesson] 104 | return config 105 | -------------------------------------------------------------------------------- /python/unityagents/exception.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger("unityagents") 3 | 4 | class UnityException(Exception): 5 | """ 6 | Any error related to ml-agents environment. 7 | """ 8 | pass 9 | 10 | class UnityEnvironmentException(UnityException): 11 | """ 12 | Related to errors starting and closing environment. 13 | """ 14 | pass 15 | 16 | 17 | class UnityActionException(UnityException): 18 | """ 19 | Related to errors with sending actions. 20 | """ 21 | pass 22 | 23 | class UnityTimeOutException(UnityException): 24 | """ 25 | Related to errors with communication timeouts. 26 | """ 27 | def __init__(self, message, log_file_path = None): 28 | if log_file_path is not None: 29 | try: 30 | with open(log_file_path, "r") as f: 31 | printing = False 32 | unity_error = '\n' 33 | for l in f: 34 | l=l.strip() 35 | if (l == 'Exception') or (l=='Error'): 36 | printing = True 37 | unity_error += '----------------------\n' 38 | if (l == ''): 39 | printing = False 40 | if printing: 41 | unity_error += l + '\n' 42 | logger.info(unity_error) 43 | logger.error("An error might have occured in the environment. " 44 | "You can check the logfile for more information at {}".format(log_file_path)) 45 | except: 46 | logger.error("An error might have occured in the environment. " 47 | "No unity-environment.log file could be found.") 48 | super(UnityTimeOutException, self).__init__(message) 49 | 50 | -------------------------------------------------------------------------------- /python/unityagents/rpc_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import grpc 3 | 4 | from multiprocessing import Pipe 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | from .communicator import Communicator 8 | from communicator_objects import UnityToExternalServicer, add_UnityToExternalServicer_to_server 9 | from communicator_objects import UnityMessage, UnityInput, UnityOutput 10 | from .exception import UnityTimeOutException 11 | 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger("unityagents") 15 | 16 | 17 | class UnityToExternalServicerImplementation(UnityToExternalServicer): 18 | parent_conn, child_conn = Pipe() 19 | 20 | def Initialize(self, request, context): 21 | self.child_conn.send(request) 22 | return self.child_conn.recv() 23 | 24 | def Exchange(self, request, context): 25 | self.child_conn.send(request) 26 | return self.child_conn.recv() 27 | 28 | 29 | class RpcCommunicator(Communicator): 30 | def __init__(self, worker_id=0, 31 | base_port=5005): 32 | """ 33 | Python side of the grpc communication. Python is the server and Unity the client 34 | 35 | 36 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 37 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 38 | """ 39 | self.port = base_port + worker_id 40 | self.worker_id = worker_id 41 | self.server = None 42 | self.unity_to_external = None 43 | self.is_open = False 44 | 45 | def initialize(self, inputs: UnityInput) -> UnityOutput: 46 | try: 47 | # Establish communication grpc 48 | self.server = grpc.server(ThreadPoolExecutor(max_workers=10)) 49 | self.unity_to_external = UnityToExternalServicerImplementation() 50 | add_UnityToExternalServicer_to_server(self.unity_to_external, self.server) 51 | self.server.add_insecure_port('[::]:'+str(self.port)) 52 | self.server.start() 53 | except : 54 | raise UnityTimeOutException( 55 | "Couldn't start socket communication because worker number {} is still in use. " 56 | "You may need to manually close a previously opened environment " 57 | "or use a different worker number.".format(str(self.worker_id))) 58 | if not self.unity_to_external.parent_conn.poll(30): 59 | raise UnityTimeOutException( 60 | "The Unity environment took too long to respond. Make sure that :\n" 61 | "\t The environment does not need user interaction to launch\n" 62 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" 63 | "\t The environment and the Python interface have compatible versions.") 64 | aca_param = self.unity_to_external.parent_conn.recv().unity_output 65 | self.is_open = True 66 | message = UnityMessage() 67 | message.header.status = 200 68 | message.unity_input.CopyFrom(inputs) 69 | self.unity_to_external.parent_conn.send(message) 70 | self.unity_to_external.parent_conn.recv() 71 | return aca_param 72 | 73 | def exchange(self, inputs: UnityInput) -> UnityOutput: 74 | message = UnityMessage() 75 | message.header.status = 200 76 | message.unity_input.CopyFrom(inputs) 77 | self.unity_to_external.parent_conn.send(message) 78 | output = self.unity_to_external.parent_conn.recv() 79 | if output.header.status != 200: 80 | return None 81 | return output.unity_output 82 | 83 | def close(self): 84 | """ 85 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 86 | """ 87 | if self.is_open: 88 | message_input = UnityMessage() 89 | message_input.header.status = 400 90 | self.unity_to_external.parent_conn.send(message_input) 91 | self.unity_to_external.parent_conn.close() 92 | self.server.stop(False) 93 | self.is_open = False 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /python/unityagents/socket_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import socket 3 | import struct 4 | 5 | from .communicator import Communicator 6 | from communicator_objects import UnityMessage, UnityOutput, UnityInput 7 | from .exception import UnityTimeOutException 8 | 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | logger = logging.getLogger("unityagents") 12 | 13 | 14 | class SocketCommunicator(Communicator): 15 | def __init__(self, worker_id=0, 16 | base_port=5005): 17 | """ 18 | Python side of the socket communication 19 | 20 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 21 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 22 | """ 23 | 24 | self.port = base_port + worker_id 25 | self._buffer_size = 12000 26 | self.worker_id = worker_id 27 | self._socket = None 28 | self._conn = None 29 | 30 | def initialize(self, inputs: UnityInput) -> UnityOutput: 31 | try: 32 | # Establish communication socket 33 | self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 34 | self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 35 | self._socket.bind(("localhost", self.port)) 36 | except: 37 | raise UnityTimeOutException("Couldn't start socket communication because worker number {} is still in use. " 38 | "You may need to manually close a previously opened environment " 39 | "or use a different worker number.".format(str(self.worker_id))) 40 | try: 41 | self._socket.settimeout(30) 42 | self._socket.listen(1) 43 | self._conn, _ = self._socket.accept() 44 | self._conn.settimeout(30) 45 | except : 46 | raise UnityTimeOutException( 47 | "The Unity environment took too long to respond. Make sure that :\n" 48 | "\t The environment does not need user interaction to launch\n" 49 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" 50 | "\t The environment and the Python interface have compatible versions.") 51 | message = UnityMessage() 52 | message.header.status = 200 53 | message.unity_input.CopyFrom(inputs) 54 | self._communicator_send(message.SerializeToString()) 55 | initialization_output = UnityMessage() 56 | initialization_output.ParseFromString(self._communicator_receive()) 57 | return initialization_output.unity_output 58 | 59 | def _communicator_receive(self): 60 | try: 61 | s = self._conn.recv(self._buffer_size) 62 | message_length = struct.unpack("I", bytearray(s[:4]))[0] 63 | s = s[4:] 64 | while len(s) != message_length: 65 | s += self._conn.recv(self._buffer_size) 66 | except socket.timeout as e: 67 | raise UnityTimeOutException("The environment took too long to respond.") 68 | return s 69 | 70 | def _communicator_send(self, message): 71 | self._conn.send(struct.pack("I", len(message)) + message) 72 | 73 | def exchange(self, inputs: UnityInput) -> UnityOutput: 74 | message = UnityMessage() 75 | message.header.status = 200 76 | message.unity_input.CopyFrom(inputs) 77 | self._communicator_send(message.SerializeToString()) 78 | outputs = UnityMessage() 79 | outputs.ParseFromString(self._communicator_receive()) 80 | if outputs.header.status != 200: 81 | return None 82 | return outputs.unity_output 83 | 84 | def close(self): 85 | """ 86 | Sends a shutdown signal to the unity environment, and closes the socket connection. 87 | """ 88 | if self._socket is not None and self._conn is not None: 89 | message_input = UnityMessage() 90 | message_input.header.status = 400 91 | self._communicator_send(message_input.SerializeToString()) 92 | if self._socket is not None: 93 | self._socket.close() 94 | self._socket = None 95 | if self._socket is not None: 96 | self._conn.close() 97 | self._conn = None 98 | 99 | -------------------------------------------------------------------------------- /python/unitytrainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .buffer import * 2 | from .models import * 3 | from .trainer_controller import * 4 | from .bc.models import * 5 | from .bc.trainer import * 6 | from .ppo.models import * 7 | from .ppo.trainer import * 8 | -------------------------------------------------------------------------------- /python/unitytrainers/bc/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .trainer import * 3 | -------------------------------------------------------------------------------- /python/unitytrainers/bc/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as c_layers 3 | from unitytrainers.models import LearningModel 4 | 5 | 6 | class BehavioralCloningModel(LearningModel): 7 | def __init__(self, brain, h_size=128, lr=1e-4, n_layers=2, m_size=128, 8 | normalize=False, use_recurrent=False): 9 | LearningModel.__init__(self, m_size, normalize, use_recurrent, brain) 10 | 11 | num_streams = 1 12 | hidden_streams = self.create_observation_streams(num_streams, h_size, n_layers) 13 | hidden = hidden_streams[0] 14 | self.dropout_rate = tf.placeholder(dtype=tf.float32, shape=[], name="dropout_rate") 15 | hidden_reg = tf.layers.dropout(hidden, self.dropout_rate) 16 | if self.use_recurrent: 17 | tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32) 18 | self.memory_in = tf.placeholder(shape=[None, self.m_size], dtype=tf.float32, name='recurrent_in') 19 | hidden_reg, self.memory_out = self.create_recurrent_encoder(hidden_reg, self.memory_in, 20 | self.sequence_length) 21 | self.memory_out = tf.identity(self.memory_out, name='recurrent_out') 22 | self.policy = tf.layers.dense(hidden_reg, self.a_size, activation=None, use_bias=False, name='pre_action', 23 | kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01)) 24 | 25 | if brain.vector_action_space_type == "discrete": 26 | self.action_probs = tf.nn.softmax(self.policy) 27 | self.sample_action_float = tf.multinomial(self.policy, 1) 28 | self.sample_action_float = tf.identity(self.sample_action_float, name="action") 29 | self.sample_action = tf.cast(self.sample_action_float, tf.int32) 30 | self.true_action = tf.placeholder(shape=[None], dtype=tf.int32, name="teacher_action") 31 | self.action_oh = tf.one_hot(self.true_action, self.a_size) 32 | self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh) 33 | self.action_percent = tf.reduce_mean(tf.cast( 34 | tf.equal(tf.cast(tf.argmax(self.action_probs, axis=1), tf.int32), self.sample_action), tf.float32)) 35 | else: 36 | self.clipped_sample_action = tf.clip_by_value(self.policy, -1, 1) 37 | self.sample_action = tf.identity(self.clipped_sample_action, name="action") 38 | self.true_action = tf.placeholder(shape=[None, self.a_size], dtype=tf.float32, name="teacher_action") 39 | self.clipped_true_action = tf.clip_by_value(self.true_action, -1, 1) 40 | self.loss = tf.reduce_sum(tf.squared_difference(self.clipped_true_action, self.sample_action)) 41 | 42 | optimizer = tf.train.AdamOptimizer(learning_rate=lr) 43 | self.update = optimizer.minimize(self.loss) 44 | -------------------------------------------------------------------------------- /python/unitytrainers/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .trainer import * 3 | -------------------------------------------------------------------------------- /python/unitytrainers/trainer.py: -------------------------------------------------------------------------------- 1 | # # Unity ML-Agents Toolkit 2 | import logging 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | from unityagents import UnityException, AllBrainInfo 8 | 9 | logger = logging.getLogger("unityagents") 10 | 11 | 12 | class UnityTrainerException(UnityException): 13 | """ 14 | Related to errors with the Trainer. 15 | """ 16 | pass 17 | 18 | 19 | class Trainer(object): 20 | """This class is the abstract class for the unitytrainers""" 21 | 22 | def __init__(self, sess, env, brain_name, trainer_parameters, training): 23 | """ 24 | Responsible for collecting experiences and training a neural network model. 25 | :param sess: Tensorflow session. 26 | :param env: The UnityEnvironment. 27 | :param trainer_parameters: The parameters for the trainer (dictionary). 28 | :param training: Whether the trainer is set for training. 29 | """ 30 | self.brain_name = brain_name 31 | self.brain = env.brains[self.brain_name] 32 | self.trainer_parameters = trainer_parameters 33 | self.is_training = training 34 | self.sess = sess 35 | self.stats = {} 36 | self.summary_writer = None 37 | 38 | def __str__(self): 39 | return '''Empty Trainer''' 40 | 41 | @property 42 | def parameters(self): 43 | """ 44 | Returns the trainer parameters of the trainer. 45 | """ 46 | raise UnityTrainerException("The parameters property was not implemented.") 47 | 48 | @property 49 | def graph_scope(self): 50 | """ 51 | Returns the graph scope of the trainer. 52 | """ 53 | raise UnityTrainerException("The graph_scope property was not implemented.") 54 | 55 | @property 56 | def get_max_steps(self): 57 | """ 58 | Returns the maximum number of steps. Is used to know when the trainer should be stopped. 59 | :return: The maximum number of steps of the trainer 60 | """ 61 | raise UnityTrainerException("The get_max_steps property was not implemented.") 62 | 63 | @property 64 | def get_step(self): 65 | """ 66 | Returns the number of steps the trainer has performed 67 | :return: the step count of the trainer 68 | """ 69 | raise UnityTrainerException("The get_step property was not implemented.") 70 | 71 | @property 72 | def get_last_reward(self): 73 | """ 74 | Returns the last reward the trainer has had 75 | :return: the new last reward 76 | """ 77 | raise UnityTrainerException("The get_last_reward property was not implemented.") 78 | 79 | def increment_step_and_update_last_reward(self): 80 | """ 81 | Increment the step count of the trainer and updates the last reward 82 | """ 83 | raise UnityTrainerException("The increment_step_and_update_last_reward method was not implemented.") 84 | 85 | def take_action(self, all_brain_info: AllBrainInfo): 86 | """ 87 | Decides actions given state/observation information, and takes them in environment. 88 | :param all_brain_info: A dictionary of brain names and BrainInfo from environment. 89 | :return: a tuple containing action, memories, values and an object 90 | to be passed to add experiences 91 | """ 92 | raise UnityTrainerException("The take_action method was not implemented.") 93 | 94 | def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take_action_outputs): 95 | """ 96 | Adds experiences to each agent's experience history. 97 | :param curr_info: Current AllBrainInfo. 98 | :param next_info: Next AllBrainInfo. 99 | :param take_action_outputs: The outputs of the take action method. 100 | """ 101 | raise UnityTrainerException("The add_experiences method was not implemented.") 102 | 103 | def process_experiences(self, current_info: AllBrainInfo, next_info: AllBrainInfo): 104 | """ 105 | Checks agent histories for processing condition, and processes them as necessary. 106 | Processing involves calculating value and advantage targets for model updating step. 107 | :param current_info: Dictionary of all current-step brains and corresponding BrainInfo. 108 | :param next_info: Dictionary of all next-step brains and corresponding BrainInfo. 109 | """ 110 | raise UnityTrainerException("The process_experiences method was not implemented.") 111 | 112 | def end_episode(self): 113 | """ 114 | A signal that the Episode has ended. The buffer must be reset. 115 | Get only called when the academy resets. 116 | """ 117 | raise UnityTrainerException("The end_episode method was not implemented.") 118 | 119 | def is_ready_update(self): 120 | """ 121 | Returns whether or not the trainer has enough elements to run update model 122 | :return: A boolean corresponding to wether or not update_model() can be run 123 | """ 124 | raise UnityTrainerException("The is_ready_update method was not implemented.") 125 | 126 | def update_model(self): 127 | """ 128 | Uses training_buffer to update model. 129 | """ 130 | raise UnityTrainerException("The update_model method was not implemented.") 131 | 132 | def write_summary(self, lesson_number): 133 | """ 134 | Saves training statistics to Tensorboard. 135 | :param lesson_number: The lesson the trainer is at. 136 | """ 137 | if (self.get_step % self.trainer_parameters['summary_freq'] == 0 and self.get_step != 0 and 138 | self.is_training and self.get_step <= self.get_max_steps): 139 | if len(self.stats['cumulative_reward']) > 0: 140 | mean_reward = np.mean(self.stats['cumulative_reward']) 141 | logger.info(" {}: Step: {}. Mean Reward: {:0.3f}. Std of Reward: {:0.3f}." 142 | .format(self.brain_name, self.get_step, 143 | mean_reward, np.std(self.stats['cumulative_reward']))) 144 | else: 145 | logger.info(" {}: Step: {}. No episode was completed since last summary." 146 | .format(self.brain_name, self.get_step)) 147 | summary = tf.Summary() 148 | for key in self.stats: 149 | if len(self.stats[key]) > 0: 150 | stat_mean = float(np.mean(self.stats[key])) 151 | summary.value.add(tag='Info/{}'.format(key), simple_value=stat_mean) 152 | self.stats[key] = [] 153 | summary.value.add(tag='Info/Lesson', simple_value=lesson_number) 154 | self.summary_writer.add_summary(summary, self.get_step) 155 | self.summary_writer.flush() 156 | 157 | def write_tensorboard_text(self, key, input_dict): 158 | """ 159 | Saves text to Tensorboard. 160 | Note: Only works on tensorflow r1.2 or above. 161 | :param key: The name of the text. 162 | :param input_dict: A dictionary that will be displayed in a table on Tensorboard. 163 | """ 164 | try: 165 | s_op = tf.summary.text(key, tf.convert_to_tensor(([[str(x), str(input_dict[x])] for x in input_dict]))) 166 | s = self.sess.run(s_op) 167 | self.summary_writer.add_summary(s, self.get_step) 168 | except: 169 | logger.info("Cannot write text summary for Tensorboard. Tensorflow version must be r1.2 or above.") 170 | pass 171 | -------------------------------------------------------------------------------- /reinforce/README.md: -------------------------------------------------------------------------------- 1 | [//]: # (Image References) 2 | 3 | [image1]: https://user-images.githubusercontent.com/10624937/42135683-dde5c6f0-7d13-11e8-90b1-8770df3e40cf.gif "Trained Agent" 4 | 5 | # REINFORCE 6 | 7 | ### Instructions 8 | 9 | Open `REINFORCE.ipynb` to see an implementation of REINFORCE (also known as Monte Carlo Policy Gradients) with OpenAI Gym's Cartpole environment. 10 | 11 | Try to change the parameters in the notebook, to see if you can get the agent to train faster! 12 | 13 | ### Results 14 | 15 | ![Trained Agent][image1] 16 | -------------------------------------------------------------------------------- /temporal-difference/README.md: -------------------------------------------------------------------------------- 1 | # Temporal-Difference Methods 2 | 3 | ### Instructions 4 | 5 | Follow the instructions in `Temporal_Difference.ipynb` to write your own implementations of many temporal-difference methods! The corresponding solutions can be found in `Temporal_Difference_Solution.ipynb`. 6 | -------------------------------------------------------------------------------- /temporal-difference/check_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from IPython.display import Markdown, display 3 | import numpy as np 4 | 5 | def printmd(string): 6 | display(Markdown(string)) 7 | 8 | V_opt = np.zeros((4,12)) 9 | V_opt[0:13][0] = -np.arange(3, 15)[::-1] 10 | V_opt[0:13][1] = -np.arange(3, 15)[::-1] + 1 11 | V_opt[0:13][2] = -np.arange(3, 15)[::-1] + 2 12 | V_opt[3][0] = -13 13 | 14 | pol_opt = np.hstack((np.ones(11), 2, 0)) 15 | 16 | V_true = np.zeros((4,12)) 17 | for i in range(3): 18 | V_true[0:13][i] = -np.arange(3, 15)[::-1] - i 19 | V_true[1][11] = -2 20 | V_true[2][11] = -1 21 | V_true[3][0] = -17 22 | 23 | def get_long_path(V): 24 | return np.array(np.hstack((V[0:13][0], V[1][0], V[1][11], V[2][0], V[2][11], V[3][0], V[3][11]))) 25 | 26 | def get_optimal_path(policy): 27 | return np.array(np.hstack((policy[2][:], policy[3][0]))) 28 | 29 | class Tests(unittest.TestCase): 30 | 31 | def td_prediction_check(self, V): 32 | to_check = get_long_path(V) 33 | soln = get_long_path(V_true) 34 | np.testing.assert_array_almost_equal(soln, to_check) 35 | 36 | def td_control_check(self, policy): 37 | to_check = get_optimal_path(policy) 38 | np.testing.assert_equal(pol_opt, to_check) 39 | 40 | check = Tests() 41 | 42 | def run_check(check_name, func): 43 | try: 44 | getattr(check, check_name)(func) 45 | except check.failureException as e: 46 | printmd('**PLEASE TRY AGAIN**') 47 | return 48 | printmd('**PASSED**') -------------------------------------------------------------------------------- /temporal-difference/plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | sns.set_style("white") 5 | 6 | def plot_values(V): 7 | # reshape the state-value function 8 | V = np.reshape(V, (4,12)) 9 | # plot the state-value function 10 | fig = plt.figure(figsize=(15,5)) 11 | ax = fig.add_subplot(111) 12 | im = ax.imshow(V, cmap='cool') 13 | for (j,i),label in np.ndenumerate(V): 14 | ax.text(i, j, np.round(label,3), ha='center', va='center', fontsize=14) 15 | plt.tick_params(bottom='off', left='off', labelbottom='off', labelleft='off') 16 | plt.title('State-Value Function') 17 | plt.show() -------------------------------------------------------------------------------- /tile-coding/README.md: -------------------------------------------------------------------------------- 1 | # Tile Coding 2 | 3 | ### Instructions 4 | 5 | Follow the instructions in `Tile_Coding.ipynb` to learn how to discretize continuous state spaces, to use tabular solution methods to solve complex tasks. The corresponding solutions can be found in `Tile_Coding_Solution.ipynb`. 6 | --------------------------------------------------------------------------------