├── .github
└── workflows
│ └── manual.yml
├── .gitignore
├── CODEOWNERS
├── LICENSE
├── README.md
├── cheatsheet
├── LICENSE.txt
├── README.md
├── cheatsheet.pdf
├── cheatsheet.tex
└── udacity-logo.png
├── cross-entropy
├── CEM.ipynb
├── README.md
└── checkpoint.pth
├── ddpg-bipedal
├── DDPG.ipynb
├── README.md
├── ddpg_agent.py
└── model.py
├── ddpg-pendulum
├── DDPG.ipynb
├── README.md
├── checkpoint_actor.pth
├── checkpoint_critic.pth
├── ddpg_agent.py
└── model.py
├── discretization
├── Discretization.ipynb
├── Discretization_Solution.ipynb
└── README.md
├── dqn
├── README.md
├── exercise
│ ├── Deep_Q_Network.ipynb
│ ├── dqn_agent.py
│ └── model.py
└── solution
│ ├── Deep_Q_Network_Solution.ipynb
│ ├── checkpoint.pth
│ ├── dqn_agent.py
│ └── model.py
├── dynamic-programming
├── Dynamic_Programming.ipynb
├── Dynamic_Programming_Solution.ipynb
├── README.md
├── check_test.py
├── frozenlake.py
└── plot_utils.py
├── finance
├── DRL.ipynb
├── ddpg_agent.py
├── model.py
├── syntheticChrissAlmgren.py
├── text_images
│ ├── 4.jpeg
│ ├── Actor-Critic.png
│ ├── RL.png
│ ├── nvidia.png
│ └── udacity.png
└── utils.py
├── hill-climbing
├── Hill_Climbing.ipynb
└── README.md
├── lab-taxi
├── README.md
├── agent.py
├── main.py
└── monitor.py
├── monte-carlo
├── Monte_Carlo.ipynb
├── Monte_Carlo_Solution.ipynb
├── README.md
├── images
│ └── optimal.png
└── plot_utils.py
├── p1_navigation
├── Navigation.ipynb
├── Navigation_Pixels.ipynb
└── README.md
├── p2_continuous-control
├── Continuous_Control.ipynb
├── Crawler.ipynb
└── README.md
├── p3_collab-compet
├── README.md
├── Soccer.ipynb
└── Tennis.ipynb
├── python
├── Basics.ipynb
├── README.md
├── communicator_objects
│ ├── __init__.py
│ ├── agent_action_proto_pb2.py
│ ├── agent_info_proto_pb2.py
│ ├── brain_parameters_proto_pb2.py
│ ├── brain_type_proto_pb2.py
│ ├── command_proto_pb2.py
│ ├── engine_configuration_proto_pb2.py
│ ├── environment_parameters_proto_pb2.py
│ ├── header_pb2.py
│ ├── resolution_proto_pb2.py
│ ├── space_type_proto_pb2.py
│ ├── unity_input_pb2.py
│ ├── unity_message_pb2.py
│ ├── unity_output_pb2.py
│ ├── unity_rl_initialization_input_pb2.py
│ ├── unity_rl_initialization_output_pb2.py
│ ├── unity_rl_input_pb2.py
│ ├── unity_rl_output_pb2.py
│ ├── unity_to_external_pb2.py
│ └── unity_to_external_pb2_grpc.py
├── curricula
│ ├── push.json
│ ├── test.json
│ └── wall.json
├── learn.py
├── requirements.txt
├── setup.py
├── tests
│ ├── __init__.py
│ ├── mock_communicator.py
│ ├── test_bc.py
│ ├── test_ppo.py
│ ├── test_unityagents.py
│ └── test_unitytrainers.py
├── trainer_config.yaml
├── unityagents
│ ├── __init__.py
│ ├── brain.py
│ ├── communicator.py
│ ├── curriculum.py
│ ├── environment.py
│ ├── exception.py
│ ├── rpc_communicator.py
│ └── socket_communicator.py
└── unitytrainers
│ ├── __init__.py
│ ├── bc
│ ├── __init__.py
│ ├── models.py
│ └── trainer.py
│ ├── buffer.py
│ ├── models.py
│ ├── ppo
│ ├── __init__.py
│ ├── models.py
│ └── trainer.py
│ ├── trainer.py
│ └── trainer_controller.py
├── reinforce
├── README.md
└── REINFORCE.ipynb
├── temporal-difference
├── README.md
├── Temporal_Difference.ipynb
├── Temporal_Difference_Solution.ipynb
├── check_test.py
└── plot_utils.py
└── tile-coding
├── README.md
├── Tile_Coding.ipynb
└── Tile_Coding_Solution.ipynb
/.github/workflows/manual.yml:
--------------------------------------------------------------------------------
1 | # Workflow to ensure whenever a Github PR is submitted,
2 | # a JIRA ticket gets created automatically.
3 | name: Manual Workflow
4 |
5 | # Controls when the action will run.
6 | on:
7 | # Triggers the workflow on pull request events but only for the master branch
8 | pull_request_target:
9 | types: [opened, reopened]
10 |
11 | # Allows you to run this workflow manually from the Actions tab
12 | workflow_dispatch:
13 |
14 | jobs:
15 | test-transition-issue:
16 | name: Convert Github Issue to Jira Issue
17 | runs-on: ubuntu-latest
18 | steps:
19 | - name: Checkout
20 | uses: actions/checkout@master
21 |
22 | - name: Login
23 | uses: atlassian/gajira-login@master
24 | env:
25 | JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
26 | JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
27 | JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
28 |
29 | - name: Create NEW JIRA ticket
30 | id: create
31 | uses: atlassian/gajira-create@master
32 | with:
33 | project: CONUPDATE
34 | issuetype: Task
35 | summary: |
36 | Github PR [Assign the ND component] | Repo: ${{ github.repository }} | PR# ${{github.event.number}}
37 | description: |
38 | Repo link: https://github.com/${{ github.repository }}
39 | PR no. ${{ github.event.pull_request.number }}
40 | PR title: ${{ github.event.pull_request.title }}
41 | PR description: ${{ github.event.pull_request.description }}
42 | In addition, please resolve other issues, if any.
43 | fields: '{"components": [{"name":"Github PR"}], "customfield_16449":"https://classroom.udacity.com/", "customfield_16450":"Resolve the PR", "labels": ["github"], "priority":{"id": "4"}}'
44 |
45 | - name: Log created issue
46 | run: echo "Issue ${{ steps.create.outputs.issue }} was created"
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Mac OS
2 | .DS_Store
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | env/
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # dotenv
86 | .env
87 |
88 | # virtualenv
89 | .venv
90 | venv/
91 | ENV/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @udacity/active-public-content
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Udacity
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/cheatsheet/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Udacity, Inc.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/cheatsheet/README.md:
--------------------------------------------------------------------------------
1 | # Cheatsheet
2 |
3 | You are encouraged to use the [PDF file](https://github.com/udacity/deep-reinforcement-learning/tree/master/cheatsheet/cheatsheet.pdf) in the repository to guide your study of RL.
--------------------------------------------------------------------------------
/cheatsheet/cheatsheet.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/cheatsheet/cheatsheet.pdf
--------------------------------------------------------------------------------
/cheatsheet/udacity-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/cheatsheet/udacity-logo.png
--------------------------------------------------------------------------------
/cross-entropy/README.md:
--------------------------------------------------------------------------------
1 | [//]: # (Image References)
2 |
3 | [image1]: https://user-images.githubusercontent.com/10624937/42135605-ba0e5f2c-7d12-11e8-9578-86d74e0976f8.gif "Trained Agent"
4 |
5 | # Cross-Entropy Method
6 |
7 | ### Instructions
8 |
9 | Open `CEM.ipynb` to see an implementation of the cross-entropy method with OpenAI Gym's MountainCarContinuous environment.
10 |
11 | Try to change the parameters in the notebook, to see if you can get the agent to train faster!
12 |
13 | ### Results
14 |
15 | ![Trained Agent][image1]
--------------------------------------------------------------------------------
/cross-entropy/checkpoint.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/cross-entropy/checkpoint.pth
--------------------------------------------------------------------------------
/ddpg-bipedal/README.md:
--------------------------------------------------------------------------------
1 | [//]: # (Image References)
2 |
3 | [image1]: https://user-images.githubusercontent.com/10624937/42135608-be87357e-7d12-11e8-8eca-e6d5fabdba6b.gif "Trained Agent"
4 |
5 |
6 | # Actor-Critic Methods
7 |
8 | ### Instructions
9 |
10 | Open `DDPG.ipynb` to see an implementation of DDPG with OpenAI Gym's BipedalWalker environment.
11 |
12 | ### Results
13 |
14 | ![Trained Agent][image1]
15 |
--------------------------------------------------------------------------------
/ddpg-bipedal/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | def hidden_init(layer):
8 | fan_in = layer.weight.data.size()[0]
9 | lim = 1. / np.sqrt(fan_in)
10 | return (-lim, lim)
11 |
12 | class Actor(nn.Module):
13 | """Actor (Policy) Model."""
14 |
15 | def __init__(self, state_size, action_size, seed, fc_units=256):
16 | """Initialize parameters and build model.
17 | Params
18 | ======
19 | state_size (int): Dimension of each state
20 | action_size (int): Dimension of each action
21 | seed (int): Random seed
22 | fc1_units (int): Number of nodes in first hidden layer
23 | fc2_units (int): Number of nodes in second hidden layer
24 | """
25 | super(Actor, self).__init__()
26 | self.seed = torch.manual_seed(seed)
27 | self.fc1 = nn.Linear(state_size, fc_units)
28 | self.fc2 = nn.Linear(fc_units, action_size)
29 | self.reset_parameters()
30 |
31 | def reset_parameters(self):
32 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1))
33 | self.fc2.weight.data.uniform_(-3e-3, 3e-3)
34 |
35 | def forward(self, state):
36 | """Build an actor (policy) network that maps states -> actions."""
37 | x = F.relu(self.fc1(state))
38 | return F.tanh(self.fc2(x))
39 |
40 |
41 | class Critic(nn.Module):
42 | """Critic (Value) Model."""
43 |
44 | def __init__(self, state_size, action_size, seed, fcs1_units=256, fc2_units=256, fc3_units=128):
45 | """Initialize parameters and build model.
46 | Params
47 | ======
48 | state_size (int): Dimension of each state
49 | action_size (int): Dimension of each action
50 | seed (int): Random seed
51 | fcs1_units (int): Number of nodes in the first hidden layer
52 | fc2_units (int): Number of nodes in the second hidden layer
53 | """
54 | super(Critic, self).__init__()
55 | self.seed = torch.manual_seed(seed)
56 | self.fcs1 = nn.Linear(state_size, fcs1_units)
57 | self.fc2 = nn.Linear(fcs1_units+action_size, fc2_units)
58 | self.fc3 = nn.Linear(fc2_units, fc3_units)
59 | self.fc4 = nn.Linear(fc3_units, 1)
60 | self.reset_parameters()
61 |
62 | def reset_parameters(self):
63 | self.fcs1.weight.data.uniform_(*hidden_init(self.fcs1))
64 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
65 | self.fc3.weight.data.uniform_(*hidden_init(self.fc3))
66 | self.fc4.weight.data.uniform_(-3e-3, 3e-3)
67 |
68 | def forward(self, state, action):
69 | """Build a critic (value) network that maps (state, action) pairs -> Q-values."""
70 | xs = F.leaky_relu(self.fcs1(state))
71 | x = torch.cat((xs, action), dim=1)
72 | x = F.leaky_relu(self.fc2(x))
73 | x = F.leaky_relu(self.fc3(x))
74 | return self.fc4(x)
75 |
--------------------------------------------------------------------------------
/ddpg-pendulum/README.md:
--------------------------------------------------------------------------------
1 | [//]: # (Image References)
2 |
3 | [image1]: https://user-images.githubusercontent.com/10624937/42135610-c37e0292-7d12-11e8-8228-4d3585f8c026.gif "Trained Agent"
4 |
5 | # Actor-Critic Methods
6 |
7 | ### Instructions
8 |
9 | Open `DDPG.ipynb` to see an implementation of DDPG with OpenAI Gym's Pendulum environment.
10 |
11 | ### Results
12 |
13 | ![Trained Agent][image1]
--------------------------------------------------------------------------------
/ddpg-pendulum/checkpoint_actor.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/ddpg-pendulum/checkpoint_actor.pth
--------------------------------------------------------------------------------
/ddpg-pendulum/checkpoint_critic.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/ddpg-pendulum/checkpoint_critic.pth
--------------------------------------------------------------------------------
/ddpg-pendulum/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | def hidden_init(layer):
8 | fan_in = layer.weight.data.size()[0]
9 | lim = 1. / np.sqrt(fan_in)
10 | return (-lim, lim)
11 |
12 | class Actor(nn.Module):
13 | """Actor (Policy) Model."""
14 |
15 | def __init__(self, state_size, action_size, seed, fc1_units=400, fc2_units=300):
16 | """Initialize parameters and build model.
17 | Params
18 | ======
19 | state_size (int): Dimension of each state
20 | action_size (int): Dimension of each action
21 | seed (int): Random seed
22 | fc1_units (int): Number of nodes in first hidden layer
23 | fc2_units (int): Number of nodes in second hidden layer
24 | """
25 | super(Actor, self).__init__()
26 | self.seed = torch.manual_seed(seed)
27 | self.fc1 = nn.Linear(state_size, fc1_units)
28 | self.fc2 = nn.Linear(fc1_units, fc2_units)
29 | self.fc3 = nn.Linear(fc2_units, action_size)
30 | self.reset_parameters()
31 |
32 | def reset_parameters(self):
33 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1))
34 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
35 | self.fc3.weight.data.uniform_(-3e-3, 3e-3)
36 |
37 | def forward(self, state):
38 | """Build an actor (policy) network that maps states -> actions."""
39 | x = F.relu(self.fc1(state))
40 | x = F.relu(self.fc2(x))
41 | return F.tanh(self.fc3(x))
42 |
43 |
44 | class Critic(nn.Module):
45 | """Critic (Value) Model."""
46 |
47 | def __init__(self, state_size, action_size, seed, fcs1_units=400, fc2_units=300):
48 | """Initialize parameters and build model.
49 | Params
50 | ======
51 | state_size (int): Dimension of each state
52 | action_size (int): Dimension of each action
53 | seed (int): Random seed
54 | fcs1_units (int): Number of nodes in the first hidden layer
55 | fc2_units (int): Number of nodes in the second hidden layer
56 | """
57 | super(Critic, self).__init__()
58 | self.seed = torch.manual_seed(seed)
59 | self.fcs1 = nn.Linear(state_size, fcs1_units)
60 | self.fc2 = nn.Linear(fcs1_units+action_size, fc2_units)
61 | self.fc3 = nn.Linear(fc2_units, 1)
62 | self.reset_parameters()
63 |
64 | def reset_parameters(self):
65 | self.fcs1.weight.data.uniform_(*hidden_init(self.fcs1))
66 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
67 | self.fc3.weight.data.uniform_(-3e-3, 3e-3)
68 |
69 | def forward(self, state, action):
70 | """Build a critic (value) network that maps (state, action) pairs -> Q-values."""
71 | xs = F.relu(self.fcs1(state))
72 | x = torch.cat((xs, action), dim=1)
73 | x = F.relu(self.fc2(x))
74 | return self.fc3(x)
75 |
--------------------------------------------------------------------------------
/discretization/README.md:
--------------------------------------------------------------------------------
1 | [//]: # (Image References)
2 |
3 | [image1]: https://user-images.githubusercontent.com/10624937/42135605-ba0e5f2c-7d12-11e8-9578-86d74e0976f8.gif "Trained Agent"
4 |
5 | # Discretization
6 |
7 | ### Instructions
8 |
9 | Follow the instructions in `Discretization.ipynb` to learn how to discretize continuous state spaces, to use tabular solution methods to solve complex tasks. The corresponding solutions can be found in `Discretization_Solution.ipynb`.
10 |
11 | ### Results
12 |
13 | ![Trained Agent][image1]
14 |
15 | ### Resources
16 |
17 | To learn about more advanced discretization approaches, refer to the following:
18 |
19 | - Uther, W., and Veloso, M., 1998. [Tree Based Discretization for Continuous State Space Reinforcement Learning](http://www.cs.cmu.edu/~mmv/papers/will-aaai98.pdf). In _Proceedings of AAAI, 1998_, pp. 769-774.
20 | - Munos, R. and Moore, A., 2002. [Variable Resolution Discretization in Optimal Control](https://link.springer.com/content/pdf/10.1023%2FA%3A1017992615625.pdf). In _Machine Learning_, 49(2), pp. 291-323.
21 |
--------------------------------------------------------------------------------
/dqn/README.md:
--------------------------------------------------------------------------------
1 | [//]: # (Image References)
2 |
3 | [image1]: https://user-images.githubusercontent.com/10624937/42135612-cbff24aa-7d12-11e8-9b6c-2b41e64b3bb0.gif "Trained Agent"
4 |
5 | # Deep Q-Network (DQN)
6 |
7 | ### Instructions
8 |
9 | In this exercise, you will implement Deep Q-Learning to solve OpenAI Gym's LunarLander environment. To begin, navigate to the `exercise/` folder, and follow the instructions in `Deep_Q_Network.ipynb`.
10 |
11 | (_Alternatively, if you'd prefer to explore a complete implementation, enter the `solution/` folder, and run the code in `Deep_Q_Network_Solution.ipynb`._)
12 |
13 | After you are able to get the code working, try to change the parameters in the notebook, to see if you can get the agent to train faster! You may also like to implement prioritized experience replay, or use it as a starting point to implement a Double DQN or Dueling DQN!
14 |
15 | ### Results
16 |
17 | ![Trained Agent][image1]
18 |
19 | ### Resources
20 |
21 | - [Human-Level Control through Deep Reinforcement Learning](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
22 | - [Deep Reinforcement Learning with Double Q-Learning](https://arxiv.org/abs/1509.06461)
23 | - [Dueling Network Architectures for Deep Reinforcement Learning](https://arxiv.org/abs/1511.06581)
24 | - [Prioritized Experience Replay](https://arxiv.org/abs/1511.05952)
25 |
--------------------------------------------------------------------------------
/dqn/exercise/dqn_agent.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | from collections import namedtuple, deque
4 |
5 | from model import QNetwork
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 |
11 | BUFFER_SIZE = int(1e5) # replay buffer size
12 | BATCH_SIZE = 64 # minibatch size
13 | GAMMA = 0.99 # discount factor
14 | TAU = 1e-3 # for soft update of target parameters
15 | LR = 5e-4 # learning rate
16 | UPDATE_EVERY = 4 # how often to update the network
17 |
18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19 |
20 | class Agent():
21 | """Interacts with and learns from the environment."""
22 |
23 | def __init__(self, state_size, action_size, seed):
24 | """Initialize an Agent object.
25 |
26 | Params
27 | ======
28 | state_size (int): dimension of each state
29 | action_size (int): dimension of each action
30 | seed (int): random seed
31 | """
32 | self.state_size = state_size
33 | self.action_size = action_size
34 | self.seed = random.seed(seed)
35 |
36 | # Q-Network
37 | self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
38 | self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device)
39 | self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
40 |
41 | # Replay memory
42 | self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
43 | # Initialize time step (for updating every UPDATE_EVERY steps)
44 | self.t_step = 0
45 |
46 | def step(self, state, action, reward, next_state, done):
47 | # Save experience in replay memory
48 | self.memory.add(state, action, reward, next_state, done)
49 |
50 | # Learn every UPDATE_EVERY time steps.
51 | self.t_step = (self.t_step + 1) % UPDATE_EVERY
52 | if self.t_step == 0:
53 | # If enough samples are available in memory, get random subset and learn
54 | if len(self.memory) > BATCH_SIZE:
55 | experiences = self.memory.sample()
56 | self.learn(experiences, GAMMA)
57 |
58 | def act(self, state, eps=0.):
59 | """Returns actions for given state as per current policy.
60 |
61 | Params
62 | ======
63 | state (array_like): current state
64 | eps (float): epsilon, for epsilon-greedy action selection
65 | """
66 | state = torch.from_numpy(state).float().unsqueeze(0).to(device)
67 | self.qnetwork_local.eval()
68 | with torch.no_grad():
69 | action_values = self.qnetwork_local(state)
70 | self.qnetwork_local.train()
71 |
72 | # Epsilon-greedy action selection
73 | if random.random() > eps:
74 | return np.argmax(action_values.cpu().data.numpy())
75 | else:
76 | return random.choice(np.arange(self.action_size))
77 |
78 | def learn(self, experiences, gamma):
79 | """Update value parameters using given batch of experience tuples.
80 |
81 | Params
82 | ======
83 | experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
84 | gamma (float): discount factor
85 | """
86 | states, actions, rewards, next_states, dones = experiences
87 |
88 | ## TODO: compute and minimize the loss
89 | "*** YOUR CODE HERE ***"
90 |
91 | # ------------------- update target network ------------------- #
92 | self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
93 |
94 | def soft_update(self, local_model, target_model, tau):
95 | """Soft update model parameters.
96 | θ_target = τ*θ_local + (1 - τ)*θ_target
97 |
98 | Params
99 | ======
100 | local_model (PyTorch model): weights will be copied from
101 | target_model (PyTorch model): weights will be copied to
102 | tau (float): interpolation parameter
103 | """
104 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
105 | target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
106 |
107 |
108 | class ReplayBuffer:
109 | """Fixed-size buffer to store experience tuples."""
110 |
111 | def __init__(self, action_size, buffer_size, batch_size, seed):
112 | """Initialize a ReplayBuffer object.
113 |
114 | Params
115 | ======
116 | action_size (int): dimension of each action
117 | buffer_size (int): maximum size of buffer
118 | batch_size (int): size of each training batch
119 | seed (int): random seed
120 | """
121 | self.action_size = action_size
122 | self.memory = deque(maxlen=buffer_size)
123 | self.batch_size = batch_size
124 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
125 | self.seed = random.seed(seed)
126 |
127 | def add(self, state, action, reward, next_state, done):
128 | """Add a new experience to memory."""
129 | e = self.experience(state, action, reward, next_state, done)
130 | self.memory.append(e)
131 |
132 | def sample(self):
133 | """Randomly sample a batch of experiences from memory."""
134 | experiences = random.sample(self.memory, k=self.batch_size)
135 |
136 | states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
137 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
138 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
139 | next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
140 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)
141 |
142 | return (states, actions, rewards, next_states, dones)
143 |
144 | def __len__(self):
145 | """Return the current size of internal memory."""
146 | return len(self.memory)
--------------------------------------------------------------------------------
/dqn/exercise/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class QNetwork(nn.Module):
6 | """Actor (Policy) Model."""
7 |
8 | def __init__(self, state_size, action_size, seed):
9 | """Initialize parameters and build model.
10 | Params
11 | ======
12 | state_size (int): Dimension of each state
13 | action_size (int): Dimension of each action
14 | seed (int): Random seed
15 | """
16 | super(QNetwork, self).__init__()
17 | self.seed = torch.manual_seed(seed)
18 | "*** YOUR CODE HERE ***"
19 |
20 | def forward(self, state):
21 | """Build a network that maps state -> action values."""
22 | pass
23 |
--------------------------------------------------------------------------------
/dqn/solution/checkpoint.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/dqn/solution/checkpoint.pth
--------------------------------------------------------------------------------
/dqn/solution/dqn_agent.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | from collections import namedtuple, deque
4 |
5 | from model import QNetwork
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 |
11 | BUFFER_SIZE = int(1e5) # replay buffer size
12 | BATCH_SIZE = 64 # minibatch size
13 | GAMMA = 0.99 # discount factor
14 | TAU = 1e-3 # for soft update of target parameters
15 | LR = 5e-4 # learning rate
16 | UPDATE_EVERY = 4 # how often to update the network
17 |
18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19 |
20 | class Agent():
21 | """Interacts with and learns from the environment."""
22 |
23 | def __init__(self, state_size, action_size, seed):
24 | """Initialize an Agent object.
25 |
26 | Params
27 | ======
28 | state_size (int): dimension of each state
29 | action_size (int): dimension of each action
30 | seed (int): random seed
31 | """
32 | self.state_size = state_size
33 | self.action_size = action_size
34 | self.seed = random.seed(seed)
35 |
36 | # Q-Network
37 | self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
38 | self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device)
39 | self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
40 |
41 | # Replay memory
42 | self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
43 | # Initialize time step (for updating every UPDATE_EVERY steps)
44 | self.t_step = 0
45 |
46 | def step(self, state, action, reward, next_state, done):
47 | # Save experience in replay memory
48 | self.memory.add(state, action, reward, next_state, done)
49 |
50 | # Learn every UPDATE_EVERY time steps.
51 | self.t_step = (self.t_step + 1) % UPDATE_EVERY
52 | if self.t_step == 0:
53 | # If enough samples are available in memory, get random subset and learn
54 | if len(self.memory) > BATCH_SIZE:
55 | experiences = self.memory.sample()
56 | self.learn(experiences, GAMMA)
57 |
58 | def act(self, state, eps=0.):
59 | """Returns actions for given state as per current policy.
60 |
61 | Params
62 | ======
63 | state (array_like): current state
64 | eps (float): epsilon, for epsilon-greedy action selection
65 | """
66 | state = torch.from_numpy(state).float().unsqueeze(0).to(device)
67 | self.qnetwork_local.eval()
68 | with torch.no_grad():
69 | action_values = self.qnetwork_local(state)
70 | self.qnetwork_local.train()
71 |
72 | # Epsilon-greedy action selection
73 | if random.random() > eps:
74 | return np.argmax(action_values.cpu().data.numpy())
75 | else:
76 | return random.choice(np.arange(self.action_size))
77 |
78 | def learn(self, experiences, gamma):
79 | """Update value parameters using given batch of experience tuples.
80 |
81 | Params
82 | ======
83 | experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
84 | gamma (float): discount factor
85 | """
86 | states, actions, rewards, next_states, dones = experiences
87 |
88 | # Get max predicted Q values (for next states) from target model
89 | Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
90 | # Compute Q targets for current states
91 | Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
92 |
93 | # Get expected Q values from local model
94 | Q_expected = self.qnetwork_local(states).gather(1, actions)
95 |
96 | # Compute loss
97 | loss = F.mse_loss(Q_expected, Q_targets)
98 | # Minimize the loss
99 | self.optimizer.zero_grad()
100 | loss.backward()
101 | self.optimizer.step()
102 |
103 | # ------------------- update target network ------------------- #
104 | self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
105 |
106 | def soft_update(self, local_model, target_model, tau):
107 | """Soft update model parameters.
108 | θ_target = τ*θ_local + (1 - τ)*θ_target
109 |
110 | Params
111 | ======
112 | local_model (PyTorch model): weights will be copied from
113 | target_model (PyTorch model): weights will be copied to
114 | tau (float): interpolation parameter
115 | """
116 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
117 | target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
118 |
119 |
120 | class ReplayBuffer:
121 | """Fixed-size buffer to store experience tuples."""
122 |
123 | def __init__(self, action_size, buffer_size, batch_size, seed):
124 | """Initialize a ReplayBuffer object.
125 |
126 | Params
127 | ======
128 | action_size (int): dimension of each action
129 | buffer_size (int): maximum size of buffer
130 | batch_size (int): size of each training batch
131 | seed (int): random seed
132 | """
133 | self.action_size = action_size
134 | self.memory = deque(maxlen=buffer_size)
135 | self.batch_size = batch_size
136 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
137 | self.seed = random.seed(seed)
138 |
139 | def add(self, state, action, reward, next_state, done):
140 | """Add a new experience to memory."""
141 | e = self.experience(state, action, reward, next_state, done)
142 | self.memory.append(e)
143 |
144 | def sample(self):
145 | """Randomly sample a batch of experiences from memory."""
146 | experiences = random.sample(self.memory, k=self.batch_size)
147 |
148 | states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
149 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
150 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
151 | next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
152 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)
153 |
154 | return (states, actions, rewards, next_states, dones)
155 |
156 | def __len__(self):
157 | """Return the current size of internal memory."""
158 | return len(self.memory)
--------------------------------------------------------------------------------
/dqn/solution/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class QNetwork(nn.Module):
6 | """Actor (Policy) Model."""
7 |
8 | def __init__(self, state_size, action_size, seed, fc1_units=64, fc2_units=64):
9 | """Initialize parameters and build model.
10 | Params
11 | ======
12 | state_size (int): Dimension of each state
13 | action_size (int): Dimension of each action
14 | seed (int): Random seed
15 | fc1_units (int): Number of nodes in first hidden layer
16 | fc2_units (int): Number of nodes in second hidden layer
17 | """
18 | super(QNetwork, self).__init__()
19 | self.seed = torch.manual_seed(seed)
20 | self.fc1 = nn.Linear(state_size, fc1_units)
21 | self.fc2 = nn.Linear(fc1_units, fc2_units)
22 | self.fc3 = nn.Linear(fc2_units, action_size)
23 |
24 | def forward(self, state):
25 | """Build a network that maps state -> action values."""
26 | x = F.relu(self.fc1(state))
27 | x = F.relu(self.fc2(x))
28 | return self.fc3(x)
29 |
--------------------------------------------------------------------------------
/dynamic-programming/README.md:
--------------------------------------------------------------------------------
1 | # Dynamic Programming
2 |
3 | ### Instructions
4 |
5 | Follow the instructions in `Dynamic_Programming.ipynb` to write your own implementations of many dynamic programming algorithms! The corresponding solutions can be found in `Dynamic_Programming_Solution.ipynb`.
6 |
--------------------------------------------------------------------------------
/dynamic-programming/check_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import copy
3 | from IPython.display import Markdown, display
4 | import numpy as np
5 | from frozenlake import FrozenLakeEnv
6 |
7 | def printmd(string):
8 | display(Markdown(string))
9 |
10 | def policy_evaluation_soln(env, policy, gamma=1, theta=1e-8):
11 | V = np.zeros(env.nS)
12 | while True:
13 | delta = 0
14 | for s in range(env.nS):
15 | Vs = 0
16 | for a, action_prob in enumerate(policy[s]):
17 | for prob, next_state, reward, done in env.P[s][a]:
18 | Vs += action_prob * prob * (reward + gamma * V[next_state])
19 | delta = max(delta, np.abs(V[s]-Vs))
20 | V[s] = Vs
21 | if delta < theta:
22 | break
23 | return V
24 |
25 | def q_from_v_soln(env, V, s, gamma=1):
26 | q = np.zeros(env.nA)
27 | for a in range(env.nA):
28 | for prob, next_state, reward, done in env.P[s][a]:
29 | q[a] += prob * (reward + gamma * V[next_state])
30 | return q
31 |
32 | def policy_improvement_soln(env, V, gamma=1):
33 | policy = np.zeros([env.nS, env.nA]) / env.nA
34 | for s in range(env.nS):
35 | q = q_from_v_soln(env, V, s, gamma)
36 | best_a = np.argwhere(q==np.max(q)).flatten()
37 | policy[s] = np.sum([np.eye(env.nA)[i] for i in best_a], axis=0)/len(best_a)
38 | return policy
39 |
40 | def policy_iteration_soln(env, gamma=1, theta=1e-8):
41 | policy = np.ones([env.nS, env.nA]) / env.nA
42 | while True:
43 | V = policy_evaluation_soln(env, policy, gamma, theta)
44 | new_policy = policy_improvement_soln(env, V)
45 | if (new_policy == policy).all():
46 | break;
47 | policy = copy.copy(new_policy)
48 | return policy, V
49 |
50 | env = FrozenLakeEnv()
51 | random_policy = np.ones([env.nS, env.nA]) / env.nA
52 |
53 | class Tests(unittest.TestCase):
54 |
55 | def policy_evaluation_check(self, policy_evaluation):
56 | soln = policy_evaluation_soln(env, random_policy)
57 | to_check = policy_evaluation(env, random_policy)
58 | np.testing.assert_array_almost_equal(soln, to_check)
59 |
60 | def q_from_v_check(self, q_from_v):
61 | V = policy_evaluation_soln(env, random_policy)
62 | soln = np.zeros([env.nS, env.nA])
63 | to_check = np.zeros([env.nS, env.nA])
64 | for s in range(env.nS):
65 | soln[s] = q_from_v_soln(env, V, s)
66 | to_check[s] = q_from_v(env, V, s)
67 | np.testing.assert_array_almost_equal(soln, to_check)
68 |
69 | def policy_improvement_check(self, policy_improvement):
70 | V = policy_evaluation_soln(env, random_policy)
71 | new_policy = policy_improvement(env, V)
72 | new_V = policy_evaluation_soln(env, new_policy)
73 | self.assertTrue(np.all(new_V >= V))
74 |
75 | def policy_iteration_check(self, policy_iteration):
76 | policy_soln, _ = policy_iteration_soln(env)
77 | policy_to_check, _ = policy_iteration(env)
78 | soln = policy_evaluation_soln(env, policy_soln)
79 | to_check = policy_evaluation_soln(env, policy_to_check)
80 | np.testing.assert_array_almost_equal(soln, to_check)
81 |
82 | def truncated_policy_iteration_check(self, truncated_policy_iteration):
83 | self.policy_iteration_check(truncated_policy_iteration)
84 |
85 | def value_iteration_check(self, value_iteration):
86 | self.policy_iteration_check(value_iteration)
87 |
88 | check = Tests()
89 |
90 | def run_check(check_name, func):
91 | try:
92 | getattr(check, check_name)(func)
93 | except check.failureException as e:
94 | printmd('**PLEASE TRY AGAIN**')
95 | return
96 | printmd('**PASSED**')
--------------------------------------------------------------------------------
/dynamic-programming/frozenlake.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 | from six import StringIO, b
4 |
5 | from gym import utils
6 | from gym.envs.toy_text import discrete
7 |
8 | LEFT = 0
9 | DOWN = 1
10 | RIGHT = 2
11 | UP = 3
12 |
13 | MAPS = {
14 | "4x4": [
15 | "SFFF",
16 | "FHFH",
17 | "FFFH",
18 | "HFFG"
19 | ],
20 | "8x8": [
21 | "SFFFFFFF",
22 | "FFFFFFFF",
23 | "FFFHFFFF",
24 | "FFFFFHFF",
25 | "FFFHFFFF",
26 | "FHHFFFHF",
27 | "FHFFHFHF",
28 | "FFFHFFFG"
29 | ],
30 | }
31 |
32 | class FrozenLakeEnv(discrete.DiscreteEnv):
33 | """
34 | Winter is here. You and your friends were tossing around a frisbee at the park
35 | when you made a wild throw that left the frisbee out in the middle of the lake.
36 | The water is mostly frozen, but there are a few holes where the ice has melted.
37 | If you step into one of those holes, you'll fall into the freezing water.
38 | At this time, there's an international frisbee shortage, so it's absolutely imperative that
39 | you navigate across the lake and retrieve the disc.
40 | However, the ice is slippery, so you won't always move in the direction you intend.
41 | The surface is described using a grid like the following
42 |
43 | SFFF
44 | FHFH
45 | FFFH
46 | HFFG
47 |
48 | S : starting point, safe
49 | F : frozen surface, safe
50 | H : hole, fall to your doom
51 | G : goal, where the frisbee is located
52 |
53 | The episode ends when you reach the goal or fall in a hole.
54 | You receive a reward of 1 if you reach the goal, and zero otherwise.
55 |
56 | """
57 |
58 | metadata = {'render.modes': ['human', 'ansi']}
59 |
60 | def __init__(self, desc=None, map_name="4x4",is_slippery=True):
61 | if desc is None and map_name is None:
62 | raise ValueError('Must provide either desc or map_name')
63 | elif desc is None:
64 | desc = MAPS[map_name]
65 | self.desc = desc = np.asarray(desc,dtype='c')
66 | self.nrow, self.ncol = nrow, ncol = desc.shape
67 |
68 | nA = 4
69 | nS = nrow * ncol
70 |
71 | isd = np.array(desc == b'S').astype('float64').ravel()
72 | isd /= isd.sum()
73 |
74 | P = {s : {a : [] for a in range(nA)} for s in range(nS)}
75 |
76 | def to_s(row, col):
77 | return row*ncol + col
78 | def inc(row, col, a):
79 | if a==0: # left
80 | col = max(col-1,0)
81 | elif a==1: # down
82 | row = min(row+1,nrow-1)
83 | elif a==2: # right
84 | col = min(col+1,ncol-1)
85 | elif a==3: # up
86 | row = max(row-1,0)
87 | return (row, col)
88 |
89 | for row in range(nrow):
90 | for col in range(ncol):
91 | s = to_s(row, col)
92 | for a in range(4):
93 | li = P[s][a]
94 | letter = desc[row, col]
95 | if letter in b'GH':
96 | li.append((1.0, s, 0, True))
97 | else:
98 | if is_slippery:
99 | for b in [(a-1)%4, a, (a+1)%4]:
100 | newrow, newcol = inc(row, col, b)
101 | newstate = to_s(newrow, newcol)
102 | newletter = desc[newrow, newcol]
103 | done = bytes(newletter) in b'GH'
104 | rew = float(newletter == b'G')
105 | li.append((1.0/3.0, newstate, rew, done))
106 | else:
107 | newrow, newcol = inc(row, col, a)
108 | newstate = to_s(newrow, newcol)
109 | newletter = desc[newrow, newcol]
110 | done = bytes(newletter) in b'GH'
111 | rew = float(newletter == b'G')
112 | li.append((1.0, newstate, rew, done))
113 |
114 | # obtain one-step dynamics for dynamic programming setting
115 | self.P = P
116 |
117 | super(FrozenLakeEnv, self).__init__(nS, nA, P, isd)
118 |
119 | def _render(self, mode='human', close=False):
120 | if close:
121 | return
122 | outfile = StringIO() if mode == 'ansi' else sys.stdout
123 |
124 | row, col = self.s // self.ncol, self.s % self.ncol
125 | desc = self.desc.tolist()
126 | desc = [[c.decode('utf-8') for c in line] for line in desc]
127 | desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True)
128 | if self.lastaction is not None:
129 | outfile.write(" ({})\n".format(["Left","Down","Right","Up"][self.lastaction]))
130 | else:
131 | outfile.write("\n")
132 | outfile.write("\n".join(''.join(line) for line in desc)+"\n")
133 |
134 | if mode != 'human':
135 | return outfile
136 |
--------------------------------------------------------------------------------
/dynamic-programming/plot_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | def plot_values(V):
5 | # reshape value function
6 | V_sq = np.reshape(V, (4,4))
7 |
8 | # plot the state-value function
9 | fig = plt.figure(figsize=(6, 6))
10 | ax = fig.add_subplot(111)
11 | im = ax.imshow(V_sq, cmap='cool')
12 | for (j,i),label in np.ndenumerate(V_sq):
13 | ax.text(i, j, np.round(label, 5), ha='center', va='center', fontsize=14)
14 | plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
15 | plt.title('State-Value Function')
16 | plt.show()
17 |
--------------------------------------------------------------------------------
/finance/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | def hidden_init(layer):
8 | fan_in = layer.weight.data.size()[0]
9 | lim = 1. / np.sqrt(fan_in)
10 | return (-lim, lim)
11 |
12 | class Actor(nn.Module):
13 | """Actor (Policy) Model."""
14 |
15 | def __init__(self, state_size, action_size, seed, fc1_units=24, fc2_units=48):
16 | """Initialize parameters and build model.
17 | Params
18 | ======
19 | state_size (int): Dimension of each state
20 | action_size (int): Dimension of each action
21 | seed (int): Random seed
22 | fc1_units (int): Number of nodes in first hidden layer
23 | fc2_units (int): Number of nodes in second hidden layer
24 | """
25 | super(Actor, self).__init__()
26 | self.seed = torch.manual_seed(seed)
27 | self.fc1 = nn.Linear(state_size, fc1_units)
28 | self.fc2 = nn.Linear(fc1_units, fc2_units)
29 | self.fc3 = nn.Linear(fc2_units, action_size)
30 | self.reset_parameters()
31 |
32 | def reset_parameters(self):
33 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1))
34 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
35 | self.fc3.weight.data.uniform_(-3e-3, 3e-3)
36 |
37 | def forward(self, state):
38 | """Build an actor (policy) network that maps states -> actions."""
39 | x = F.relu(self.fc1(state))
40 | x = F.relu(self.fc2(x))
41 | return F.tanh(self.fc3(x))
42 |
43 |
44 | class Critic(nn.Module):
45 | """Critic (Value) Model."""
46 |
47 | def __init__(self, state_size, action_size, seed, fcs1_units=24, fc2_units=48):
48 | """Initialize parameters and build model.
49 | Params
50 | ======
51 | state_size (int): Dimension of each state
52 | action_size (int): Dimension of each action
53 | seed (int): Random seed
54 | fcs1_units (int): Number of nodes in the first hidden layer
55 | fc2_units (int): Number of nodes in the second hidden layer
56 | """
57 | super(Critic, self).__init__()
58 | self.seed = torch.manual_seed(seed)
59 | self.fcs1 = nn.Linear(state_size, fcs1_units)
60 | self.fc2 = nn.Linear(fcs1_units+action_size, fc2_units)
61 | self.fc3 = nn.Linear(fc2_units, 1)
62 | self.reset_parameters()
63 |
64 | def reset_parameters(self):
65 | self.fcs1.weight.data.uniform_(*hidden_init(self.fcs1))
66 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
67 | self.fc3.weight.data.uniform_(-3e-3, 3e-3)
68 |
69 | def forward(self, state, action):
70 | """Build a critic (value) network that maps (state, action) pairs -> Q-values."""
71 | xs = F.relu(self.fcs1(state))
72 | x = torch.cat((xs, action), dim=1)
73 | x = F.relu(self.fc2(x))
74 | return self.fc3(x)
75 |
--------------------------------------------------------------------------------
/finance/text_images/4.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/finance/text_images/4.jpeg
--------------------------------------------------------------------------------
/finance/text_images/Actor-Critic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/finance/text_images/Actor-Critic.png
--------------------------------------------------------------------------------
/finance/text_images/RL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/finance/text_images/RL.png
--------------------------------------------------------------------------------
/finance/text_images/nvidia.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/finance/text_images/nvidia.png
--------------------------------------------------------------------------------
/finance/text_images/udacity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/finance/text_images/udacity.png
--------------------------------------------------------------------------------
/hill-climbing/README.md:
--------------------------------------------------------------------------------
1 | [//]: # (Image References)
2 |
3 | [image1]: https://user-images.githubusercontent.com/10624937/42135683-dde5c6f0-7d13-11e8-90b1-8770df3e40cf.gif "Trained Agent"
4 |
5 | # Hill Climbing
6 |
7 | ### Instructions
8 |
9 | Open `Hill_Climbing.ipynb` to see an implementation of hill climbing with adaptive noise scaling with OpenAI Gym's Cartpole environment.
10 |
11 | Try to change the parameters in the notebook, to see if you can get the agent to train faster!
12 |
13 | ### Results
14 |
15 | ![Trained Agent][image1]
--------------------------------------------------------------------------------
/lab-taxi/README.md:
--------------------------------------------------------------------------------
1 | # Taxi Problem
2 |
3 | ### Getting Started
4 |
5 | Read the description of the environment in subsection 3.1 of [this paper](https://arxiv.org/pdf/cs/9905014.pdf). You can verify that the description in the paper matches the OpenAI Gym environment by peeking at the code [here](https://github.com/openai/gym/blob/master/gym/envs/toy_text/taxi.py).
6 |
7 |
8 | ### Instructions
9 |
10 | The repository contains three files:
11 | - `agent.py`: Develop your reinforcement learning agent here. This is the only file that you should modify.
12 | - `monitor.py`: The `interact` function tests how well your agent learns from interaction with the environment.
13 | - `main.py`: Run this file in the terminal to check the performance of your agent.
14 |
15 | Begin by running the following command in the terminal:
16 | ```
17 | python main.py
18 | ```
19 |
20 | When you run `main.py`, the agent that you specify in `agent.py` interacts with the environment for 20,000 episodes. The details of the interaction are specified in `monitor.py`, which returns two variables: `avg_rewards` and `best_avg_reward`.
21 | - `avg_rewards` is a deque where `avg_rewards[i]` is the average (undiscounted) return collected by the agent from episodes `i+1` to episode `i+100`, inclusive. So, for instance, `avg_rewards[0]` is the average return collected by the agent over the first 100 episodes.
22 | - `best_avg_reward` is the largest entry in `avg_rewards`. This is the final score that you should use when determining how well your agent performed in the task.
23 |
24 | Your assignment is to modify the `agents.py` file to improve the agent's performance.
25 | - Use the `__init__()` method to define any needed instance variables. Currently, we define the number of actions available to the agent (`nA`) and initialize the action values (`Q`) to an empty dictionary of arrays. Feel free to add more instance variables; for example, you may find it useful to define the value of epsilon if the agent uses an epsilon-greedy policy for selecting actions.
26 | - The `select_action()` method accepts the environment state as input and returns the agent's choice of action. The default code that we have provided randomly selects an action.
27 | - The `step()` method accepts a (`state`, `action`, `reward`, `next_state`) tuple as input, along with the `done` variable, which is `True` if the episode has ended. The default code (which you should certainly change!) increments the action value of the previous state-action pair by 1. You should change this method to use the sampled tuple of experience to update the agent's knowledge of the problem.
28 |
29 | Once you have modified the function, you need only run `python main.py` to test your new agent.
30 |
31 | OpenAI Gym [defines "solving"](https://gym.openai.com/envs/Taxi-v1/) this task as getting average return of 9.7 over 100 consecutive trials.
32 |
--------------------------------------------------------------------------------
/lab-taxi/agent.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import defaultdict
3 |
4 | class Agent:
5 |
6 | def __init__(self, nA=6):
7 | """ Initialize agent.
8 |
9 | Params
10 | ======
11 | - nA: number of actions available to the agent
12 | """
13 | self.nA = nA
14 | self.Q = defaultdict(lambda: np.zeros(self.nA))
15 |
16 | def select_action(self, state):
17 | """ Given the state, select an action.
18 |
19 | Params
20 | ======
21 | - state: the current state of the environment
22 |
23 | Returns
24 | =======
25 | - action: an integer, compatible with the task's action space
26 | """
27 | return np.random.choice(self.nA)
28 |
29 | def step(self, state, action, reward, next_state, done):
30 | """ Update the agent's knowledge, using the most recently sampled tuple.
31 |
32 | Params
33 | ======
34 | - state: the previous state of the environment
35 | - action: the agent's previous choice of action
36 | - reward: last reward received
37 | - next_state: the current state of the environment
38 | - done: whether the episode is complete (True or False)
39 | """
40 | self.Q[state][action] += 1
--------------------------------------------------------------------------------
/lab-taxi/main.py:
--------------------------------------------------------------------------------
1 | from agent import Agent
2 | from monitor import interact
3 | import gym
4 | import numpy as np
5 |
6 | env = gym.make('Taxi-v2')
7 | agent = Agent()
8 | avg_rewards, best_avg_reward = interact(env, agent)
--------------------------------------------------------------------------------
/lab-taxi/monitor.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | import sys
3 | import math
4 | import numpy as np
5 |
6 | def interact(env, agent, num_episodes=20000, window=100):
7 | """ Monitor agent's performance.
8 |
9 | Params
10 | ======
11 | - env: instance of OpenAI Gym's Taxi-v1 environment
12 | - agent: instance of class Agent (see Agent.py for details)
13 | - num_episodes: number of episodes of agent-environment interaction
14 | - window: number of episodes to consider when calculating average rewards
15 |
16 | Returns
17 | =======
18 | - avg_rewards: deque containing average rewards
19 | - best_avg_reward: largest value in the avg_rewards deque
20 | """
21 | # initialize average rewards
22 | avg_rewards = deque(maxlen=num_episodes)
23 | # initialize best average reward
24 | best_avg_reward = -math.inf
25 | # initialize monitor for most recent rewards
26 | samp_rewards = deque(maxlen=window)
27 | # for each episode
28 | for i_episode in range(1, num_episodes+1):
29 | # begin the episode
30 | state = env.reset()
31 | # initialize the sampled reward
32 | samp_reward = 0
33 | while True:
34 | # agent selects an action
35 | action = agent.select_action(state)
36 | # agent performs the selected action
37 | next_state, reward, done, _ = env.step(action)
38 | # agent performs internal updates based on sampled experience
39 | agent.step(state, action, reward, next_state, done)
40 | # update the sampled reward
41 | samp_reward += reward
42 | # update the state (s <- s') to next time step
43 | state = next_state
44 | if done:
45 | # save final sampled reward
46 | samp_rewards.append(samp_reward)
47 | break
48 | if (i_episode >= 100):
49 | # get average reward from last 100 episodes
50 | avg_reward = np.mean(samp_rewards)
51 | # append to deque
52 | avg_rewards.append(avg_reward)
53 | # update best average reward
54 | if avg_reward > best_avg_reward:
55 | best_avg_reward = avg_reward
56 | # monitor progress
57 | print("\rEpisode {}/{} || Best average reward {}".format(i_episode, num_episodes, best_avg_reward), end="")
58 | sys.stdout.flush()
59 | # check if task is solved (according to OpenAI Gym)
60 | if best_avg_reward >= 9.7:
61 | print('\nEnvironment solved in {} episodes.'.format(i_episode), end="")
62 | break
63 | if i_episode == num_episodes: print('\n')
64 | return avg_rewards, best_avg_reward
--------------------------------------------------------------------------------
/monte-carlo/README.md:
--------------------------------------------------------------------------------
1 | # Monte Carlo Methods
2 |
3 | ### Instructions
4 |
5 | Follow the instructions in `Monte_Carlo.ipynb` to write your own implementations of many Monte Carlo methods! The corresponding solutions can be found in `Monte_Carlo_Solution.ipynb`.
6 |
--------------------------------------------------------------------------------
/monte-carlo/images/optimal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/udacity/deep-reinforcement-learning/561eec3ae8678a23a4557f1a15414a9b076fdfff/monte-carlo/images/optimal.png
--------------------------------------------------------------------------------
/monte-carlo/plot_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mpl_toolkits.mplot3d import Axes3D
3 | import matplotlib.pyplot as plt
4 | from mpl_toolkits.axes_grid1 import make_axes_locatable
5 |
6 | def plot_blackjack_values(V):
7 |
8 | def get_Z(x, y, usable_ace):
9 | if (x,y,usable_ace) in V:
10 | return V[x,y,usable_ace]
11 | else:
12 | return 0
13 |
14 | def get_figure(usable_ace, ax):
15 | x_range = np.arange(11, 22)
16 | y_range = np.arange(1, 11)
17 | X, Y = np.meshgrid(x_range, y_range)
18 |
19 | Z = np.array([get_Z(x,y,usable_ace) for x,y in zip(np.ravel(X), np.ravel(Y))]).reshape(X.shape)
20 |
21 | surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=plt.cm.coolwarm, vmin=-1.0, vmax=1.0)
22 | ax.set_xlabel('Player\'s Current Sum')
23 | ax.set_ylabel('Dealer\'s Showing Card')
24 | ax.set_zlabel('State Value')
25 | ax.view_init(ax.elev, -120)
26 |
27 | fig = plt.figure(figsize=(20, 20))
28 | ax = fig.add_subplot(211, projection='3d')
29 | ax.set_title('Usable Ace')
30 | get_figure(True, ax)
31 | ax = fig.add_subplot(212, projection='3d')
32 | ax.set_title('No Usable Ace')
33 | get_figure(False, ax)
34 | plt.show()
35 |
36 | def plot_policy(policy):
37 |
38 | def get_Z(x, y, usable_ace):
39 | if (x,y,usable_ace) in policy:
40 | return policy[x,y,usable_ace]
41 | else:
42 | return 1
43 |
44 | def get_figure(usable_ace, ax):
45 | x_range = np.arange(11, 22)
46 | y_range = np.arange(10, 0, -1)
47 | X, Y = np.meshgrid(x_range, y_range)
48 | Z = np.array([[get_Z(x,y,usable_ace) for x in x_range] for y in y_range])
49 | surf = ax.imshow(Z, cmap=plt.get_cmap('Pastel2', 2), vmin=0, vmax=1, extent=[10.5, 21.5, 0.5, 10.5])
50 | plt.xticks(x_range)
51 | plt.yticks(y_range)
52 | plt.gca().invert_yaxis()
53 | ax.set_xlabel('Player\'s Current Sum')
54 | ax.set_ylabel('Dealer\'s Showing Card')
55 | ax.grid(color='w', linestyle='-', linewidth=1)
56 | divider = make_axes_locatable(ax)
57 | cax = divider.append_axes("right", size="5%", pad=0.1)
58 | cbar = plt.colorbar(surf, ticks=[0,1], cax=cax)
59 | cbar.ax.set_yticklabels(['0 (STICK)','1 (HIT)'])
60 |
61 | fig = plt.figure(figsize=(15, 15))
62 | ax = fig.add_subplot(121)
63 | ax.set_title('Usable Ace')
64 | get_figure(True, ax)
65 | ax = fig.add_subplot(122)
66 | ax.set_title('No Usable Ace')
67 | get_figure(False, ax)
68 | plt.show()
--------------------------------------------------------------------------------
/p1_navigation/README.md:
--------------------------------------------------------------------------------
1 | [//]: # (Image References)
2 |
3 | [image1]: https://user-images.githubusercontent.com/10624937/42135619-d90f2f28-7d12-11e8-8823-82b970a54d7e.gif "Trained Agent"
4 |
5 | # Project 1: Navigation
6 |
7 | ### Introduction
8 |
9 | For this project, you will train an agent to navigate (and collect bananas!) in a large, square world.
10 |
11 | ![Trained Agent][image1]
12 |
13 | A reward of +1 is provided for collecting a yellow banana, and a reward of -1 is provided for collecting a blue banana. Thus, the goal of your agent is to collect as many yellow bananas as possible while avoiding blue bananas.
14 |
15 | The state space has 37 dimensions and contains the agent's velocity, along with ray-based perception of objects around agent's forward direction. Given this information, the agent has to learn how to best select actions. Four discrete actions are available, corresponding to:
16 | - **`0`** - move forward.
17 | - **`1`** - move backward.
18 | - **`2`** - turn left.
19 | - **`3`** - turn right.
20 |
21 | The task is episodic, and in order to solve the environment, your agent must get an average score of +13 over 100 consecutive episodes.
22 |
23 | ### Getting Started
24 |
25 | 1. Download the environment from one of the links below. You need only select the environment that matches your operating system:
26 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Linux.zip)
27 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana.app.zip)
28 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Windows_x86.zip)
29 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Windows_x86_64.zip)
30 |
31 | (_For Windows users_) Check out [this link](https://support.microsoft.com/en-us/help/827218/how-to-determine-whether-a-computer-is-running-a-32-bit-version-or-64) if you need help with determining if your computer is running a 32-bit version or 64-bit version of the Windows operating system.
32 |
33 | (_For AWS_) If you'd like to train the agent on AWS (and have not [enabled a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md)), then please use [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Linux_NoVis.zip) to obtain the environment.
34 |
35 | 2. Place the file in the DRLND GitHub repository, in the `p1_navigation/` folder, and unzip (or decompress) the file.
36 |
37 | ### Instructions
38 |
39 | Follow the instructions in `Navigation.ipynb` to get started with training your own agent!
40 |
41 | ### (Optional) Challenge: Learning from Pixels
42 |
43 | After you have successfully completed the project, if you're looking for an additional challenge, you have come to the right place! In the project, your agent learned from information such as its velocity, along with ray-based perception of objects around its forward direction. A more challenging task would be to learn directly from pixels!
44 |
45 | To solve this harder task, you'll need to download a new Unity environment. This environment is almost identical to the project environment, where the only difference is that the state is an 84 x 84 RGB image, corresponding to the agent's first-person view. (**Note**: Udacity students should not submit a project with this new environment.)
46 |
47 | You need only select the environment that matches your operating system:
48 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/VisualBanana_Linux.zip)
49 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/VisualBanana.app.zip)
50 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/VisualBanana_Windows_x86.zip)
51 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/VisualBanana_Windows_x86_64.zip)
52 |
53 | Then, place the file in the `p1_navigation/` folder in the DRLND GitHub repository, and unzip (or decompress) the file. Next, open `Navigation_Pixels.ipynb` and follow the instructions to learn how to use the Python API to control the agent.
54 |
55 | (_For AWS_) If you'd like to train the agent on AWS, you must follow the instructions to [set up X Server](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md), and then download the environment for the **Linux** operating system above.
56 |
--------------------------------------------------------------------------------
/p2_continuous-control/README.md:
--------------------------------------------------------------------------------
1 | [//]: # (Image References)
2 |
3 | [image1]: https://user-images.githubusercontent.com/10624937/43851024-320ba930-9aff-11e8-8493-ee547c6af349.gif "Trained Agent"
4 | [image2]: https://user-images.githubusercontent.com/10624937/43851646-d899bf20-9b00-11e8-858c-29b5c2c94ccc.png "Crawler"
5 |
6 |
7 | # Project 2: Continuous Control
8 |
9 | ### Introduction
10 |
11 | For this project, you will work with the [Reacher](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Learning-Environment-Examples.md#reacher) environment.
12 |
13 | ![Trained Agent][image1]
14 |
15 | In this environment, a double-jointed arm can move to target locations. A reward of +0.1 is provided for each step that the agent's hand is in the goal location. Thus, the goal of your agent is to maintain its position at the target location for as many time steps as possible.
16 |
17 | The observation space consists of 33 variables corresponding to position, rotation, velocity, and angular velocities of the arm. Each action is a vector with four numbers, corresponding to torque applicable to two joints. Every entry in the action vector should be a number between -1 and 1.
18 |
19 | ### Distributed Training
20 |
21 | For this project, we will provide you with two separate versions of the Unity environment:
22 | - The first version contains a single agent.
23 | - The second version contains 20 identical agents, each with its own copy of the environment.
24 |
25 | The second version is useful for algorithms like [PPO](https://arxiv.org/pdf/1707.06347.pdf), [A3C](https://arxiv.org/pdf/1602.01783.pdf), and [D4PG](https://openreview.net/pdf?id=SyZipzbCb) that use multiple (non-interacting, parallel) copies of the same agent to distribute the task of gathering experience.
26 |
27 | ### Solving the Environment
28 |
29 | Note that your project submission need only solve one of the two versions of the environment.
30 |
31 | #### Option 1: Solve the First Version
32 |
33 | The task is episodic, and in order to solve the environment, your agent must get an average score of +30 over 100 consecutive episodes.
34 |
35 | #### Option 2: Solve the Second Version
36 |
37 | The barrier for solving the second version of the environment is slightly different, to take into account the presence of many agents. In particular, your agents must get an average score of +30 (over 100 consecutive episodes, and over all agents). Specifically,
38 | - After each episode, we add up the rewards that each agent received (without discounting), to get a score for each agent. This yields 20 (potentially different) scores. We then take the average of these 20 scores.
39 | - This yields an **average score** for each episode (where the average is over all 20 agents).
40 |
41 | The environment is considered solved, when the average (over 100 episodes) of those average scores is at least +30.
42 |
43 | ### Getting Started
44 |
45 | 1. Download the environment from one of the links below. You need only select the environment that matches your operating system:
46 |
47 | - **_Version 1: One (1) Agent_**
48 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/one_agent/Reacher_Linux.zip)
49 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/one_agent/Reacher.app.zip)
50 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/one_agent/Reacher_Windows_x86.zip)
51 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/one_agent/Reacher_Windows_x86_64.zip)
52 |
53 | - **_Version 2: Twenty (20) Agents_**
54 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher_Linux.zip)
55 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher.app.zip)
56 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher_Windows_x86.zip)
57 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher_Windows_x86_64.zip)
58 |
59 | (_For Windows users_) Check out [this link](https://support.microsoft.com/en-us/help/827218/how-to-determine-whether-a-computer-is-running-a-32-bit-version-or-64) if you need help with determining if your computer is running a 32-bit version or 64-bit version of the Windows operating system.
60 |
61 | (_For AWS_) If you'd like to train the agent on AWS (and have not [enabled a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md)), then please use [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/one_agent/Reacher_Linux_NoVis.zip) (version 1) or [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Reacher/Reacher_Linux_NoVis.zip) (version 2) to obtain the "headless" version of the environment. You will **not** be able to watch the agent without enabling a virtual screen, but you will be able to train the agent. (_To watch the agent, you should follow the instructions to [enable a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md), and then download the environment for the **Linux** operating system above._)
62 |
63 | 2. Place the file in the DRLND GitHub repository, in the `p2_continuous-control/` folder, and unzip (or decompress) the file.
64 |
65 | ### Instructions
66 |
67 | Follow the instructions in `Continuous_Control.ipynb` to get started with training your own agent!
68 |
69 | ### (Optional) Challenge: Crawler Environment
70 |
71 | After you have successfully completed the project, you might like to solve the more difficult **Crawler** environment.
72 |
73 | ![Crawler][image2]
74 |
75 | In this continuous control environment, the goal is to teach a creature with four legs to walk forward without falling.
76 |
77 | You can read more about this environment in the ML-Agents GitHub [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Learning-Environment-Examples.md#crawler). To solve this harder task, you'll need to download a new Unity environment. (**Note**: Udacity students should not submit a project with this new environment.)
78 |
79 | You need only select the environment that matches your operating system:
80 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Crawler/Crawler_Linux.zip)
81 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Crawler/Crawler.app.zip)
82 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Crawler/Crawler_Windows_x86.zip)
83 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Crawler/Crawler_Windows_x86_64.zip)
84 |
85 | Then, place the file in the `p2_continuous-control/` folder in the DRLND GitHub repository, and unzip (or decompress) the file. Next, open `Crawler.ipynb` and follow the instructions to learn how to use the Python API to control the agent.
86 |
87 | (_For AWS_) If you'd like to train the agent on AWS (and have not [enabled a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md)), then please use [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P2/Crawler/Crawler_Linux_NoVis.zip) to obtain the "headless" version of the environment. You will **not** be able to watch the agent without enabling a virtual screen, but you will be able to train the agent. (_To watch the agent, you should follow the instructions to [enable a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md), and then download the environment for the **Linux** operating system above._)
88 |
89 |
--------------------------------------------------------------------------------
/p3_collab-compet/README.md:
--------------------------------------------------------------------------------
1 | [//]: # (Image References)
2 |
3 | [image1]: https://user-images.githubusercontent.com/10624937/42135623-e770e354-7d12-11e8-998d-29fc74429ca2.gif "Trained Agent"
4 | [image2]: https://user-images.githubusercontent.com/10624937/42135622-e55fb586-7d12-11e8-8a54-3c31da15a90a.gif "Soccer"
5 |
6 |
7 | # Project 3: Collaboration and Competition
8 |
9 | ### Introduction
10 |
11 | For this project, you will work with the [Tennis](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Learning-Environment-Examples.md#tennis) environment.
12 |
13 | ![Trained Agent][image1]
14 |
15 | In this environment, two agents control rackets to bounce a ball over a net. If an agent hits the ball over the net, it receives a reward of +0.1. If an agent lets a ball hit the ground or hits the ball out of bounds, it receives a reward of -0.01. Thus, the goal of each agent is to keep the ball in play.
16 |
17 | The observation space consists of 8 variables corresponding to the position and velocity of the ball and racket. Each agent receives its own, local observation. Two continuous actions are available, corresponding to movement toward (or away from) the net, and jumping.
18 |
19 | The task is episodic, and in order to solve the environment, your agents must get an average score of +0.5 (over 100 consecutive episodes, after taking the maximum over both agents). Specifically,
20 |
21 | - After each episode, we add up the rewards that each agent received (without discounting), to get a score for each agent. This yields 2 (potentially different) scores. We then take the maximum of these 2 scores.
22 | - This yields a single **score** for each episode.
23 |
24 | The environment is considered solved, when the average (over 100 episodes) of those **scores** is at least +0.5.
25 |
26 | ### Getting Started
27 |
28 | 1. Download the environment from one of the links below. You need only select the environment that matches your operating system:
29 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis_Linux.zip)
30 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis.app.zip)
31 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis_Windows_x86.zip)
32 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis_Windows_x86_64.zip)
33 |
34 | (_For Windows users_) Check out [this link](https://support.microsoft.com/en-us/help/827218/how-to-determine-whether-a-computer-is-running-a-32-bit-version-or-64) if you need help with determining if your computer is running a 32-bit version or 64-bit version of the Windows operating system.
35 |
36 | (_For AWS_) If you'd like to train the agent on AWS (and have not [enabled a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md)), then please use [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis_Linux_NoVis.zip) to obtain the "headless" version of the environment. You will **not** be able to watch the agent without enabling a virtual screen, but you will be able to train the agent. (_To watch the agent, you should follow the instructions to [enable a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md), and then download the environment for the **Linux** operating system above._)
37 |
38 | 2. Place the file in the DRLND GitHub repository, in the `p3_collab-compet/` folder, and unzip (or decompress) the file.
39 |
40 | ### Instructions
41 |
42 | Follow the instructions in `Tennis.ipynb` to get started with training your own agent!
43 |
44 | ### (Optional) Challenge: Soccer Environment
45 |
46 | After you have successfully completed the project, you might like to solve the more difficult **Soccer** environment.
47 |
48 | ![Soccer][image2]
49 |
50 | In this environment, the goal is to train a team of agents to play soccer.
51 |
52 | You can read more about this environment in the ML-Agents GitHub [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Learning-Environment-Examples.md#soccer-twos). To solve this harder task, you'll need to download a new Unity environment. (**Note**: Udacity students should not submit a project with this new environment.)
53 |
54 | You need only select the environment that matches your operating system:
55 | - Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Soccer/Soccer_Linux.zip)
56 | - Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Soccer/Soccer.app.zip)
57 | - Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Soccer/Soccer_Windows_x86.zip)
58 | - Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Soccer/Soccer_Windows_x86_64.zip)
59 |
60 | Then, place the file in the `p3_collab-compet/` folder in the DRLND GitHub repository, and unzip (or decompress) the file. Next, open `Soccer.ipynb` and follow the instructions to learn how to use the Python API to control the agent.
61 |
62 | (_For AWS_) If you'd like to train the agents on AWS (and have not [enabled a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md)), then please use [this link](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Soccer/Soccer_Linux_NoVis.zip) to obtain the "headless" version of the environment. You will **not** be able to watch the agents without enabling a virtual screen, but you will be able to train the agents. (_To watch the agents, you should follow the instructions to [enable a virtual screen](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-on-Amazon-Web-Service.md), and then download the environment for the **Linux** operating system above._)
63 |
--------------------------------------------------------------------------------
/python/Basics.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Unity ML-Agents Toolkit\n",
8 | "## Environment Basics\n",
9 | "This notebook contains a walkthrough of the basic functions of the Python API for the Unity ML-Agents toolkit. For instructions on building a Unity environment, see [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Getting-Started-with-Balance-Ball.md)."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "### 1. Set environment parameters\n",
17 | "\n",
18 | "Be sure to set `env_name` to the name of the Unity environment file you want to launch. Ensure that the environment build is in the `python/` directory."
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {
25 | "collapsed": true
26 | },
27 | "outputs": [],
28 | "source": [
29 | "env_name = \"3DBall\" # Name of the Unity environment binary to launch\n",
30 | "train_mode = True # Whether to run the environment in training or inference mode"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "### 2. Load dependencies\n",
38 | "\n",
39 | "The following loads the necessary dependencies and checks the Python version (at runtime). ML-Agents Toolkit (v0.3 onwards) requires Python 3."
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {
46 | "collapsed": true
47 | },
48 | "outputs": [],
49 | "source": [
50 | "import matplotlib.pyplot as plt\n",
51 | "import numpy as np\n",
52 | "import sys\n",
53 | "\n",
54 | "from unityagents import UnityEnvironment\n",
55 | "\n",
56 | "%matplotlib inline\n",
57 | "\n",
58 | "print(\"Python version:\")\n",
59 | "print(sys.version)\n",
60 | "\n",
61 | "# check Python version\n",
62 | "if (sys.version_info[0] < 3):\n",
63 | " raise Exception(\"ERROR: ML-Agents Toolkit (v0.3 onwards) requires Python 3\")"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {},
69 | "source": [
70 | "### 3. Start the environment\n",
71 | "`UnityEnvironment` launches and begins communication with the environment when instantiated.\n",
72 | "\n",
73 | "Environments contain _brains_ which are responsible for deciding the actions of their associated _agents_. Here we check for the first brain available, and set it as the default brain we will be controlling from Python."
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {
80 | "collapsed": true
81 | },
82 | "outputs": [],
83 | "source": [
84 | "env = UnityEnvironment(file_name=env_name)\n",
85 | "\n",
86 | "# Examine environment parameters\n",
87 | "print(str(env))\n",
88 | "\n",
89 | "# Set the default brain to work with\n",
90 | "default_brain = env.brain_names[0]\n",
91 | "brain = env.brains[default_brain]"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "metadata": {},
97 | "source": [
98 | "### 4. Examine the observation and state spaces\n",
99 | "We can reset the environment to be provided with an initial set of observations and states for all the agents within the environment. In ML-Agents, _states_ refer to a vector of variables corresponding to relevant aspects of the environment for an agent. Likewise, _observations_ refer to a set of relevant pixel-wise visuals for an agent."
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "metadata": {
106 | "collapsed": true
107 | },
108 | "outputs": [],
109 | "source": [
110 | "# Reset the environment\n",
111 | "env_info = env.reset(train_mode=train_mode)[default_brain]\n",
112 | "\n",
113 | "# Examine the state space for the default brain\n",
114 | "print(\"Agent state looks like: \\n{}\".format(env_info.vector_observations[0]))\n",
115 | "\n",
116 | "# Examine the observation space for the default brain\n",
117 | "for observation in env_info.visual_observations:\n",
118 | " print(\"Agent observations look like:\")\n",
119 | " if observation.shape[3] == 3:\n",
120 | " plt.imshow(observation[0,:,:,:])\n",
121 | " else:\n",
122 | " plt.imshow(observation[0,:,:,0])"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "metadata": {},
128 | "source": [
129 | "### 5. Take random actions in the environment\n",
130 | "Once we restart an environment, we can step the environment forward and provide actions to all of the agents within the environment. Here we simply choose random actions based on the `action_space_type` of the default brain. \n",
131 | "\n",
132 | "Once this cell is executed, 10 messages will be printed that detail how much reward will be accumulated for the next 10 episodes. The Unity environment will then pause, waiting for further signals telling it what to do next. Thus, not seeing any animation is expected when running this cell."
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {
139 | "collapsed": true
140 | },
141 | "outputs": [],
142 | "source": [
143 | "for episode in range(10):\n",
144 | " env_info = env.reset(train_mode=train_mode)[default_brain]\n",
145 | " done = False\n",
146 | " episode_rewards = 0\n",
147 | " while not done:\n",
148 | " if brain.vector_action_space_type == 'continuous':\n",
149 | " env_info = env.step(np.random.randn(len(env_info.agents), \n",
150 | " brain.vector_action_space_size))[default_brain]\n",
151 | " else:\n",
152 | " env_info = env.step(np.random.randint(0, brain.vector_action_space_size, \n",
153 | " size=(len(env_info.agents))))[default_brain]\n",
154 | " episode_rewards += env_info.rewards[0]\n",
155 | " done = env_info.local_done[0]\n",
156 | " print(\"Total reward this episode: {}\".format(episode_rewards))"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "### 6. Close the environment when finished\n",
164 | "When we are finished using an environment, we can close it with the function below."
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "metadata": {
171 | "collapsed": true
172 | },
173 | "outputs": [],
174 | "source": [
175 | "env.close()"
176 | ]
177 | }
178 | ],
179 | "metadata": {
180 | "anaconda-cloud": {},
181 | "kernelspec": {
182 | "display_name": "Python 3",
183 | "language": "python",
184 | "name": "python3"
185 | },
186 | "language_info": {
187 | "codemirror_mode": {
188 | "name": "ipython",
189 | "version": 3
190 | },
191 | "file_extension": ".py",
192 | "mimetype": "text/x-python",
193 | "name": "python",
194 | "nbconvert_exporter": "python",
195 | "pygments_lexer": "ipython3",
196 | "version": "3.6.3"
197 | }
198 | },
199 | "nbformat": 4,
200 | "nbformat_minor": 1
201 | }
202 |
--------------------------------------------------------------------------------
/python/README.md:
--------------------------------------------------------------------------------
1 | # Dependencies
2 |
3 | This is an amended version of the `python/` folder from the [ML-Agents repository](https://github.com/Unity-Technologies/ml-agents). It has been edited to include a few additional pip packages needed for the Deep Reinforcement Learning Nanodegree program.
4 |
--------------------------------------------------------------------------------
/python/communicator_objects/__init__.py:
--------------------------------------------------------------------------------
1 | from .agent_action_proto_pb2 import *
2 | from .agent_info_proto_pb2 import *
3 | from .brain_parameters_proto_pb2 import *
4 | from .brain_type_proto_pb2 import *
5 | from .command_proto_pb2 import *
6 | from .engine_configuration_proto_pb2 import *
7 | from .environment_parameters_proto_pb2 import *
8 | from .header_pb2 import *
9 | from .resolution_proto_pb2 import *
10 | from .space_type_proto_pb2 import *
11 | from .unity_input_pb2 import *
12 | from .unity_message_pb2 import *
13 | from .unity_output_pb2 import *
14 | from .unity_rl_initialization_input_pb2 import *
15 | from .unity_rl_initialization_output_pb2 import *
16 | from .unity_rl_input_pb2 import *
17 | from .unity_rl_output_pb2 import *
18 | from .unity_to_external_pb2 import *
19 | from .unity_to_external_pb2_grpc import *
20 |
--------------------------------------------------------------------------------
/python/communicator_objects/agent_action_proto_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/agent_action_proto.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='communicator_objects/agent_action_proto.proto',
20 | package='communicator_objects',
21 | syntax='proto3',
22 | serialized_pb=_b('\n-communicator_objects/agent_action_proto.proto\x12\x14\x63ommunicator_objects\"R\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _AGENTACTIONPROTO = _descriptor.Descriptor(
29 | name='AgentActionProto',
30 | full_name='communicator_objects.AgentActionProto',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='vector_actions', full_name='communicator_objects.AgentActionProto.vector_actions', index=0,
37 | number=1, type=2, cpp_type=6, label=3,
38 | has_default_value=False, default_value=[],
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None, file=DESCRIPTOR),
42 | _descriptor.FieldDescriptor(
43 | name='text_actions', full_name='communicator_objects.AgentActionProto.text_actions', index=1,
44 | number=2, type=9, cpp_type=9, label=1,
45 | has_default_value=False, default_value=_b("").decode('utf-8'),
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None, file=DESCRIPTOR),
49 | _descriptor.FieldDescriptor(
50 | name='memories', full_name='communicator_objects.AgentActionProto.memories', index=2,
51 | number=3, type=2, cpp_type=6, label=3,
52 | has_default_value=False, default_value=[],
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None, file=DESCRIPTOR),
56 | ],
57 | extensions=[
58 | ],
59 | nested_types=[],
60 | enum_types=[
61 | ],
62 | options=None,
63 | is_extendable=False,
64 | syntax='proto3',
65 | extension_ranges=[],
66 | oneofs=[
67 | ],
68 | serialized_start=71,
69 | serialized_end=153,
70 | )
71 |
72 | DESCRIPTOR.message_types_by_name['AgentActionProto'] = _AGENTACTIONPROTO
73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
74 |
75 | AgentActionProto = _reflection.GeneratedProtocolMessageType('AgentActionProto', (_message.Message,), dict(
76 | DESCRIPTOR = _AGENTACTIONPROTO,
77 | __module__ = 'communicator_objects.agent_action_proto_pb2'
78 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentActionProto)
79 | ))
80 | _sym_db.RegisterMessage(AgentActionProto)
81 |
82 |
83 | DESCRIPTOR.has_options = True
84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
85 | # @@protoc_insertion_point(module_scope)
86 |
--------------------------------------------------------------------------------
/python/communicator_objects/agent_info_proto_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/agent_info_proto.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='communicator_objects/agent_info_proto.proto',
20 | package='communicator_objects',
21 | syntax='proto3',
22 | serialized_pb=_b('\n+communicator_objects/agent_info_proto.proto\x12\x14\x63ommunicator_objects\"\xfd\x01\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x1b\n\x13visual_observations\x18\x02 \x03(\x0c\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x10\n\x08memories\x18\x06 \x03(\x02\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _AGENTINFOPROTO = _descriptor.Descriptor(
29 | name='AgentInfoProto',
30 | full_name='communicator_objects.AgentInfoProto',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='stacked_vector_observation', full_name='communicator_objects.AgentInfoProto.stacked_vector_observation', index=0,
37 | number=1, type=2, cpp_type=6, label=3,
38 | has_default_value=False, default_value=[],
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None, file=DESCRIPTOR),
42 | _descriptor.FieldDescriptor(
43 | name='visual_observations', full_name='communicator_objects.AgentInfoProto.visual_observations', index=1,
44 | number=2, type=12, cpp_type=9, label=3,
45 | has_default_value=False, default_value=[],
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None, file=DESCRIPTOR),
49 | _descriptor.FieldDescriptor(
50 | name='text_observation', full_name='communicator_objects.AgentInfoProto.text_observation', index=2,
51 | number=3, type=9, cpp_type=9, label=1,
52 | has_default_value=False, default_value=_b("").decode('utf-8'),
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None, file=DESCRIPTOR),
56 | _descriptor.FieldDescriptor(
57 | name='stored_vector_actions', full_name='communicator_objects.AgentInfoProto.stored_vector_actions', index=3,
58 | number=4, type=2, cpp_type=6, label=3,
59 | has_default_value=False, default_value=[],
60 | message_type=None, enum_type=None, containing_type=None,
61 | is_extension=False, extension_scope=None,
62 | options=None, file=DESCRIPTOR),
63 | _descriptor.FieldDescriptor(
64 | name='stored_text_actions', full_name='communicator_objects.AgentInfoProto.stored_text_actions', index=4,
65 | number=5, type=9, cpp_type=9, label=1,
66 | has_default_value=False, default_value=_b("").decode('utf-8'),
67 | message_type=None, enum_type=None, containing_type=None,
68 | is_extension=False, extension_scope=None,
69 | options=None, file=DESCRIPTOR),
70 | _descriptor.FieldDescriptor(
71 | name='memories', full_name='communicator_objects.AgentInfoProto.memories', index=5,
72 | number=6, type=2, cpp_type=6, label=3,
73 | has_default_value=False, default_value=[],
74 | message_type=None, enum_type=None, containing_type=None,
75 | is_extension=False, extension_scope=None,
76 | options=None, file=DESCRIPTOR),
77 | _descriptor.FieldDescriptor(
78 | name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=6,
79 | number=7, type=2, cpp_type=6, label=1,
80 | has_default_value=False, default_value=float(0),
81 | message_type=None, enum_type=None, containing_type=None,
82 | is_extension=False, extension_scope=None,
83 | options=None, file=DESCRIPTOR),
84 | _descriptor.FieldDescriptor(
85 | name='done', full_name='communicator_objects.AgentInfoProto.done', index=7,
86 | number=8, type=8, cpp_type=7, label=1,
87 | has_default_value=False, default_value=False,
88 | message_type=None, enum_type=None, containing_type=None,
89 | is_extension=False, extension_scope=None,
90 | options=None, file=DESCRIPTOR),
91 | _descriptor.FieldDescriptor(
92 | name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=8,
93 | number=9, type=8, cpp_type=7, label=1,
94 | has_default_value=False, default_value=False,
95 | message_type=None, enum_type=None, containing_type=None,
96 | is_extension=False, extension_scope=None,
97 | options=None, file=DESCRIPTOR),
98 | _descriptor.FieldDescriptor(
99 | name='id', full_name='communicator_objects.AgentInfoProto.id', index=9,
100 | number=10, type=5, cpp_type=1, label=1,
101 | has_default_value=False, default_value=0,
102 | message_type=None, enum_type=None, containing_type=None,
103 | is_extension=False, extension_scope=None,
104 | options=None, file=DESCRIPTOR),
105 | ],
106 | extensions=[
107 | ],
108 | nested_types=[],
109 | enum_types=[
110 | ],
111 | options=None,
112 | is_extendable=False,
113 | syntax='proto3',
114 | extension_ranges=[],
115 | oneofs=[
116 | ],
117 | serialized_start=70,
118 | serialized_end=323,
119 | )
120 |
121 | DESCRIPTOR.message_types_by_name['AgentInfoProto'] = _AGENTINFOPROTO
122 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
123 |
124 | AgentInfoProto = _reflection.GeneratedProtocolMessageType('AgentInfoProto', (_message.Message,), dict(
125 | DESCRIPTOR = _AGENTINFOPROTO,
126 | __module__ = 'communicator_objects.agent_info_proto_pb2'
127 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentInfoProto)
128 | ))
129 | _sym_db.RegisterMessage(AgentInfoProto)
130 |
131 |
132 | DESCRIPTOR.has_options = True
133 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
134 | # @@protoc_insertion_point(module_scope)
135 |
--------------------------------------------------------------------------------
/python/communicator_objects/brain_parameters_proto_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/brain_parameters_proto.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2
17 | from communicator_objects import brain_type_proto_pb2 as communicator__objects_dot_brain__type__proto__pb2
18 | from communicator_objects import space_type_proto_pb2 as communicator__objects_dot_space__type__proto__pb2
19 |
20 |
21 | DESCRIPTOR = _descriptor.FileDescriptor(
22 | name='communicator_objects/brain_parameters_proto.proto',
23 | package='communicator_objects',
24 | syntax='proto3',
25 | serialized_pb=_b('\n1communicator_objects/brain_parameters_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto\x1a+communicator_objects/brain_type_proto.proto\x1a+communicator_objects/space_type_proto.proto\"\xc6\x03\n\x14\x42rainParametersProto\x12\x1f\n\x17vector_observation_size\x18\x01 \x01(\x05\x12\'\n\x1fnum_stacked_vector_observations\x18\x02 \x01(\x05\x12\x1a\n\x12vector_action_size\x18\x03 \x01(\x05\x12\x41\n\x12\x63\x61mera_resolutions\x18\x04 \x03(\x0b\x32%.communicator_objects.ResolutionProto\x12\"\n\x1avector_action_descriptions\x18\x05 \x03(\t\x12\x46\n\x18vector_action_space_type\x18\x06 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12K\n\x1dvector_observation_space_type\x18\x07 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12\x12\n\nbrain_name\x18\x08 \x01(\t\x12\x38\n\nbrain_type\x18\t \x01(\x0e\x32$.communicator_objects.BrainTypeProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
26 | ,
27 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,communicator__objects_dot_brain__type__proto__pb2.DESCRIPTOR,communicator__objects_dot_space__type__proto__pb2.DESCRIPTOR,])
28 |
29 |
30 |
31 |
32 | _BRAINPARAMETERSPROTO = _descriptor.Descriptor(
33 | name='BrainParametersProto',
34 | full_name='communicator_objects.BrainParametersProto',
35 | filename=None,
36 | file=DESCRIPTOR,
37 | containing_type=None,
38 | fields=[
39 | _descriptor.FieldDescriptor(
40 | name='vector_observation_size', full_name='communicator_objects.BrainParametersProto.vector_observation_size', index=0,
41 | number=1, type=5, cpp_type=1, label=1,
42 | has_default_value=False, default_value=0,
43 | message_type=None, enum_type=None, containing_type=None,
44 | is_extension=False, extension_scope=None,
45 | options=None, file=DESCRIPTOR),
46 | _descriptor.FieldDescriptor(
47 | name='num_stacked_vector_observations', full_name='communicator_objects.BrainParametersProto.num_stacked_vector_observations', index=1,
48 | number=2, type=5, cpp_type=1, label=1,
49 | has_default_value=False, default_value=0,
50 | message_type=None, enum_type=None, containing_type=None,
51 | is_extension=False, extension_scope=None,
52 | options=None, file=DESCRIPTOR),
53 | _descriptor.FieldDescriptor(
54 | name='vector_action_size', full_name='communicator_objects.BrainParametersProto.vector_action_size', index=2,
55 | number=3, type=5, cpp_type=1, label=1,
56 | has_default_value=False, default_value=0,
57 | message_type=None, enum_type=None, containing_type=None,
58 | is_extension=False, extension_scope=None,
59 | options=None, file=DESCRIPTOR),
60 | _descriptor.FieldDescriptor(
61 | name='camera_resolutions', full_name='communicator_objects.BrainParametersProto.camera_resolutions', index=3,
62 | number=4, type=11, cpp_type=10, label=3,
63 | has_default_value=False, default_value=[],
64 | message_type=None, enum_type=None, containing_type=None,
65 | is_extension=False, extension_scope=None,
66 | options=None, file=DESCRIPTOR),
67 | _descriptor.FieldDescriptor(
68 | name='vector_action_descriptions', full_name='communicator_objects.BrainParametersProto.vector_action_descriptions', index=4,
69 | number=5, type=9, cpp_type=9, label=3,
70 | has_default_value=False, default_value=[],
71 | message_type=None, enum_type=None, containing_type=None,
72 | is_extension=False, extension_scope=None,
73 | options=None, file=DESCRIPTOR),
74 | _descriptor.FieldDescriptor(
75 | name='vector_action_space_type', full_name='communicator_objects.BrainParametersProto.vector_action_space_type', index=5,
76 | number=6, type=14, cpp_type=8, label=1,
77 | has_default_value=False, default_value=0,
78 | message_type=None, enum_type=None, containing_type=None,
79 | is_extension=False, extension_scope=None,
80 | options=None, file=DESCRIPTOR),
81 | _descriptor.FieldDescriptor(
82 | name='vector_observation_space_type', full_name='communicator_objects.BrainParametersProto.vector_observation_space_type', index=6,
83 | number=7, type=14, cpp_type=8, label=1,
84 | has_default_value=False, default_value=0,
85 | message_type=None, enum_type=None, containing_type=None,
86 | is_extension=False, extension_scope=None,
87 | options=None, file=DESCRIPTOR),
88 | _descriptor.FieldDescriptor(
89 | name='brain_name', full_name='communicator_objects.BrainParametersProto.brain_name', index=7,
90 | number=8, type=9, cpp_type=9, label=1,
91 | has_default_value=False, default_value=_b("").decode('utf-8'),
92 | message_type=None, enum_type=None, containing_type=None,
93 | is_extension=False, extension_scope=None,
94 | options=None, file=DESCRIPTOR),
95 | _descriptor.FieldDescriptor(
96 | name='brain_type', full_name='communicator_objects.BrainParametersProto.brain_type', index=8,
97 | number=9, type=14, cpp_type=8, label=1,
98 | has_default_value=False, default_value=0,
99 | message_type=None, enum_type=None, containing_type=None,
100 | is_extension=False, extension_scope=None,
101 | options=None, file=DESCRIPTOR),
102 | ],
103 | extensions=[
104 | ],
105 | nested_types=[],
106 | enum_types=[
107 | ],
108 | options=None,
109 | is_extendable=False,
110 | syntax='proto3',
111 | extension_ranges=[],
112 | oneofs=[
113 | ],
114 | serialized_start=211,
115 | serialized_end=665,
116 | )
117 |
118 | _BRAINPARAMETERSPROTO.fields_by_name['camera_resolutions'].message_type = communicator__objects_dot_resolution__proto__pb2._RESOLUTIONPROTO
119 | _BRAINPARAMETERSPROTO.fields_by_name['vector_action_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO
120 | _BRAINPARAMETERSPROTO.fields_by_name['vector_observation_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO
121 | _BRAINPARAMETERSPROTO.fields_by_name['brain_type'].enum_type = communicator__objects_dot_brain__type__proto__pb2._BRAINTYPEPROTO
122 | DESCRIPTOR.message_types_by_name['BrainParametersProto'] = _BRAINPARAMETERSPROTO
123 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
124 |
125 | BrainParametersProto = _reflection.GeneratedProtocolMessageType('BrainParametersProto', (_message.Message,), dict(
126 | DESCRIPTOR = _BRAINPARAMETERSPROTO,
127 | __module__ = 'communicator_objects.brain_parameters_proto_pb2'
128 | # @@protoc_insertion_point(class_scope:communicator_objects.BrainParametersProto)
129 | ))
130 | _sym_db.RegisterMessage(BrainParametersProto)
131 |
132 |
133 | DESCRIPTOR.has_options = True
134 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
135 | # @@protoc_insertion_point(module_scope)
136 |
--------------------------------------------------------------------------------
/python/communicator_objects/brain_type_proto_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/brain_type_proto.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf.internal import enum_type_wrapper
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import message as _message
9 | from google.protobuf import reflection as _reflection
10 | from google.protobuf import symbol_database as _symbol_database
11 | from google.protobuf import descriptor_pb2
12 | # @@protoc_insertion_point(imports)
13 |
14 | _sym_db = _symbol_database.Default()
15 |
16 |
17 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2
18 |
19 |
20 | DESCRIPTOR = _descriptor.FileDescriptor(
21 | name='communicator_objects/brain_type_proto.proto',
22 | package='communicator_objects',
23 | syntax='proto3',
24 | serialized_pb=_b('\n+communicator_objects/brain_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*G\n\x0e\x42rainTypeProto\x12\n\n\x06Player\x10\x00\x12\r\n\tHeuristic\x10\x01\x12\x0c\n\x08\x45xternal\x10\x02\x12\x0c\n\x08Internal\x10\x03\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
25 | ,
26 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,])
27 |
28 | _BRAINTYPEPROTO = _descriptor.EnumDescriptor(
29 | name='BrainTypeProto',
30 | full_name='communicator_objects.BrainTypeProto',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | values=[
34 | _descriptor.EnumValueDescriptor(
35 | name='Player', index=0, number=0,
36 | options=None,
37 | type=None),
38 | _descriptor.EnumValueDescriptor(
39 | name='Heuristic', index=1, number=1,
40 | options=None,
41 | type=None),
42 | _descriptor.EnumValueDescriptor(
43 | name='External', index=2, number=2,
44 | options=None,
45 | type=None),
46 | _descriptor.EnumValueDescriptor(
47 | name='Internal', index=3, number=3,
48 | options=None,
49 | type=None),
50 | ],
51 | containing_type=None,
52 | options=None,
53 | serialized_start=114,
54 | serialized_end=185,
55 | )
56 | _sym_db.RegisterEnumDescriptor(_BRAINTYPEPROTO)
57 |
58 | BrainTypeProto = enum_type_wrapper.EnumTypeWrapper(_BRAINTYPEPROTO)
59 | Player = 0
60 | Heuristic = 1
61 | External = 2
62 | Internal = 3
63 |
64 |
65 | DESCRIPTOR.enum_types_by_name['BrainTypeProto'] = _BRAINTYPEPROTO
66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
67 |
68 |
69 | DESCRIPTOR.has_options = True
70 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
71 | # @@protoc_insertion_point(module_scope)
72 |
--------------------------------------------------------------------------------
/python/communicator_objects/command_proto_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/command_proto.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf.internal import enum_type_wrapper
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import message as _message
9 | from google.protobuf import reflection as _reflection
10 | from google.protobuf import symbol_database as _symbol_database
11 | from google.protobuf import descriptor_pb2
12 | # @@protoc_insertion_point(imports)
13 |
14 | _sym_db = _symbol_database.Default()
15 |
16 |
17 |
18 |
19 | DESCRIPTOR = _descriptor.FileDescriptor(
20 | name='communicator_objects/command_proto.proto',
21 | package='communicator_objects',
22 | syntax='proto3',
23 | serialized_pb=_b('\n(communicator_objects/command_proto.proto\x12\x14\x63ommunicator_objects*-\n\x0c\x43ommandProto\x12\x08\n\x04STEP\x10\x00\x12\t\n\x05RESET\x10\x01\x12\x08\n\x04QUIT\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
24 | )
25 |
26 | _COMMANDPROTO = _descriptor.EnumDescriptor(
27 | name='CommandProto',
28 | full_name='communicator_objects.CommandProto',
29 | filename=None,
30 | file=DESCRIPTOR,
31 | values=[
32 | _descriptor.EnumValueDescriptor(
33 | name='STEP', index=0, number=0,
34 | options=None,
35 | type=None),
36 | _descriptor.EnumValueDescriptor(
37 | name='RESET', index=1, number=1,
38 | options=None,
39 | type=None),
40 | _descriptor.EnumValueDescriptor(
41 | name='QUIT', index=2, number=2,
42 | options=None,
43 | type=None),
44 | ],
45 | containing_type=None,
46 | options=None,
47 | serialized_start=66,
48 | serialized_end=111,
49 | )
50 | _sym_db.RegisterEnumDescriptor(_COMMANDPROTO)
51 |
52 | CommandProto = enum_type_wrapper.EnumTypeWrapper(_COMMANDPROTO)
53 | STEP = 0
54 | RESET = 1
55 | QUIT = 2
56 |
57 |
58 | DESCRIPTOR.enum_types_by_name['CommandProto'] = _COMMANDPROTO
59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
60 |
61 |
62 | DESCRIPTOR.has_options = True
63 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
64 | # @@protoc_insertion_point(module_scope)
65 |
--------------------------------------------------------------------------------
/python/communicator_objects/engine_configuration_proto_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/engine_configuration_proto.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='communicator_objects/engine_configuration_proto.proto',
20 | package='communicator_objects',
21 | syntax='proto3',
22 | serialized_pb=_b('\n5communicator_objects/engine_configuration_proto.proto\x12\x14\x63ommunicator_objects\"\x95\x01\n\x18\x45ngineConfigurationProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x15\n\rquality_level\x18\x03 \x01(\x05\x12\x12\n\ntime_scale\x18\x04 \x01(\x02\x12\x19\n\x11target_frame_rate\x18\x05 \x01(\x05\x12\x14\n\x0cshow_monitor\x18\x06 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _ENGINECONFIGURATIONPROTO = _descriptor.Descriptor(
29 | name='EngineConfigurationProto',
30 | full_name='communicator_objects.EngineConfigurationProto',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='width', full_name='communicator_objects.EngineConfigurationProto.width', index=0,
37 | number=1, type=5, cpp_type=1, label=1,
38 | has_default_value=False, default_value=0,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None, file=DESCRIPTOR),
42 | _descriptor.FieldDescriptor(
43 | name='height', full_name='communicator_objects.EngineConfigurationProto.height', index=1,
44 | number=2, type=5, cpp_type=1, label=1,
45 | has_default_value=False, default_value=0,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None, file=DESCRIPTOR),
49 | _descriptor.FieldDescriptor(
50 | name='quality_level', full_name='communicator_objects.EngineConfigurationProto.quality_level', index=2,
51 | number=3, type=5, cpp_type=1, label=1,
52 | has_default_value=False, default_value=0,
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None, file=DESCRIPTOR),
56 | _descriptor.FieldDescriptor(
57 | name='time_scale', full_name='communicator_objects.EngineConfigurationProto.time_scale', index=3,
58 | number=4, type=2, cpp_type=6, label=1,
59 | has_default_value=False, default_value=float(0),
60 | message_type=None, enum_type=None, containing_type=None,
61 | is_extension=False, extension_scope=None,
62 | options=None, file=DESCRIPTOR),
63 | _descriptor.FieldDescriptor(
64 | name='target_frame_rate', full_name='communicator_objects.EngineConfigurationProto.target_frame_rate', index=4,
65 | number=5, type=5, cpp_type=1, label=1,
66 | has_default_value=False, default_value=0,
67 | message_type=None, enum_type=None, containing_type=None,
68 | is_extension=False, extension_scope=None,
69 | options=None, file=DESCRIPTOR),
70 | _descriptor.FieldDescriptor(
71 | name='show_monitor', full_name='communicator_objects.EngineConfigurationProto.show_monitor', index=5,
72 | number=6, type=8, cpp_type=7, label=1,
73 | has_default_value=False, default_value=False,
74 | message_type=None, enum_type=None, containing_type=None,
75 | is_extension=False, extension_scope=None,
76 | options=None, file=DESCRIPTOR),
77 | ],
78 | extensions=[
79 | ],
80 | nested_types=[],
81 | enum_types=[
82 | ],
83 | options=None,
84 | is_extendable=False,
85 | syntax='proto3',
86 | extension_ranges=[],
87 | oneofs=[
88 | ],
89 | serialized_start=80,
90 | serialized_end=229,
91 | )
92 |
93 | DESCRIPTOR.message_types_by_name['EngineConfigurationProto'] = _ENGINECONFIGURATIONPROTO
94 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
95 |
96 | EngineConfigurationProto = _reflection.GeneratedProtocolMessageType('EngineConfigurationProto', (_message.Message,), dict(
97 | DESCRIPTOR = _ENGINECONFIGURATIONPROTO,
98 | __module__ = 'communicator_objects.engine_configuration_proto_pb2'
99 | # @@protoc_insertion_point(class_scope:communicator_objects.EngineConfigurationProto)
100 | ))
101 | _sym_db.RegisterMessage(EngineConfigurationProto)
102 |
103 |
104 | DESCRIPTOR.has_options = True
105 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
106 | # @@protoc_insertion_point(module_scope)
107 |
--------------------------------------------------------------------------------
/python/communicator_objects/environment_parameters_proto_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/environment_parameters_proto.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='communicator_objects/environment_parameters_proto.proto',
20 | package='communicator_objects',
21 | syntax='proto3',
22 | serialized_pb=_b('\n7communicator_objects/environment_parameters_proto.proto\x12\x14\x63ommunicator_objects\"\xb5\x01\n\x1a\x45nvironmentParametersProto\x12_\n\x10\x66loat_parameters\x18\x01 \x03(\x0b\x32\x45.communicator_objects.EnvironmentParametersProto.FloatParametersEntry\x1a\x36\n\x14\x46loatParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY = _descriptor.Descriptor(
29 | name='FloatParametersEntry',
30 | full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='key', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.key', index=0,
37 | number=1, type=9, cpp_type=9, label=1,
38 | has_default_value=False, default_value=_b("").decode('utf-8'),
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None, file=DESCRIPTOR),
42 | _descriptor.FieldDescriptor(
43 | name='value', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.value', index=1,
44 | number=2, type=2, cpp_type=6, label=1,
45 | has_default_value=False, default_value=float(0),
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None, file=DESCRIPTOR),
49 | ],
50 | extensions=[
51 | ],
52 | nested_types=[],
53 | enum_types=[
54 | ],
55 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')),
56 | is_extendable=False,
57 | syntax='proto3',
58 | extension_ranges=[],
59 | oneofs=[
60 | ],
61 | serialized_start=209,
62 | serialized_end=263,
63 | )
64 |
65 | _ENVIRONMENTPARAMETERSPROTO = _descriptor.Descriptor(
66 | name='EnvironmentParametersProto',
67 | full_name='communicator_objects.EnvironmentParametersProto',
68 | filename=None,
69 | file=DESCRIPTOR,
70 | containing_type=None,
71 | fields=[
72 | _descriptor.FieldDescriptor(
73 | name='float_parameters', full_name='communicator_objects.EnvironmentParametersProto.float_parameters', index=0,
74 | number=1, type=11, cpp_type=10, label=3,
75 | has_default_value=False, default_value=[],
76 | message_type=None, enum_type=None, containing_type=None,
77 | is_extension=False, extension_scope=None,
78 | options=None, file=DESCRIPTOR),
79 | ],
80 | extensions=[
81 | ],
82 | nested_types=[_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, ],
83 | enum_types=[
84 | ],
85 | options=None,
86 | is_extendable=False,
87 | syntax='proto3',
88 | extension_ranges=[],
89 | oneofs=[
90 | ],
91 | serialized_start=82,
92 | serialized_end=263,
93 | )
94 |
95 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.containing_type = _ENVIRONMENTPARAMETERSPROTO
96 | _ENVIRONMENTPARAMETERSPROTO.fields_by_name['float_parameters'].message_type = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY
97 | DESCRIPTOR.message_types_by_name['EnvironmentParametersProto'] = _ENVIRONMENTPARAMETERSPROTO
98 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
99 |
100 | EnvironmentParametersProto = _reflection.GeneratedProtocolMessageType('EnvironmentParametersProto', (_message.Message,), dict(
101 |
102 | FloatParametersEntry = _reflection.GeneratedProtocolMessageType('FloatParametersEntry', (_message.Message,), dict(
103 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY,
104 | __module__ = 'communicator_objects.environment_parameters_proto_pb2'
105 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto.FloatParametersEntry)
106 | ))
107 | ,
108 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO,
109 | __module__ = 'communicator_objects.environment_parameters_proto_pb2'
110 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto)
111 | ))
112 | _sym_db.RegisterMessage(EnvironmentParametersProto)
113 | _sym_db.RegisterMessage(EnvironmentParametersProto.FloatParametersEntry)
114 |
115 |
116 | DESCRIPTOR.has_options = True
117 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
118 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.has_options = True
119 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001'))
120 | # @@protoc_insertion_point(module_scope)
121 |
--------------------------------------------------------------------------------
/python/communicator_objects/header_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/header.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='communicator_objects/header.proto',
20 | package='communicator_objects',
21 | syntax='proto3',
22 | serialized_pb=_b('\n!communicator_objects/header.proto\x12\x14\x63ommunicator_objects\")\n\x06Header\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _HEADER = _descriptor.Descriptor(
29 | name='Header',
30 | full_name='communicator_objects.Header',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='status', full_name='communicator_objects.Header.status', index=0,
37 | number=1, type=5, cpp_type=1, label=1,
38 | has_default_value=False, default_value=0,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None, file=DESCRIPTOR),
42 | _descriptor.FieldDescriptor(
43 | name='message', full_name='communicator_objects.Header.message', index=1,
44 | number=2, type=9, cpp_type=9, label=1,
45 | has_default_value=False, default_value=_b("").decode('utf-8'),
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None, file=DESCRIPTOR),
49 | ],
50 | extensions=[
51 | ],
52 | nested_types=[],
53 | enum_types=[
54 | ],
55 | options=None,
56 | is_extendable=False,
57 | syntax='proto3',
58 | extension_ranges=[],
59 | oneofs=[
60 | ],
61 | serialized_start=59,
62 | serialized_end=100,
63 | )
64 |
65 | DESCRIPTOR.message_types_by_name['Header'] = _HEADER
66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
67 |
68 | Header = _reflection.GeneratedProtocolMessageType('Header', (_message.Message,), dict(
69 | DESCRIPTOR = _HEADER,
70 | __module__ = 'communicator_objects.header_pb2'
71 | # @@protoc_insertion_point(class_scope:communicator_objects.Header)
72 | ))
73 | _sym_db.RegisterMessage(Header)
74 |
75 |
76 | DESCRIPTOR.has_options = True
77 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
78 | # @@protoc_insertion_point(module_scope)
79 |
--------------------------------------------------------------------------------
/python/communicator_objects/resolution_proto_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/resolution_proto.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='communicator_objects/resolution_proto.proto',
20 | package='communicator_objects',
21 | syntax='proto3',
22 | serialized_pb=_b('\n+communicator_objects/resolution_proto.proto\x12\x14\x63ommunicator_objects\"D\n\x0fResolutionProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x12\n\ngray_scale\x18\x03 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _RESOLUTIONPROTO = _descriptor.Descriptor(
29 | name='ResolutionProto',
30 | full_name='communicator_objects.ResolutionProto',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='width', full_name='communicator_objects.ResolutionProto.width', index=0,
37 | number=1, type=5, cpp_type=1, label=1,
38 | has_default_value=False, default_value=0,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None, file=DESCRIPTOR),
42 | _descriptor.FieldDescriptor(
43 | name='height', full_name='communicator_objects.ResolutionProto.height', index=1,
44 | number=2, type=5, cpp_type=1, label=1,
45 | has_default_value=False, default_value=0,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None, file=DESCRIPTOR),
49 | _descriptor.FieldDescriptor(
50 | name='gray_scale', full_name='communicator_objects.ResolutionProto.gray_scale', index=2,
51 | number=3, type=8, cpp_type=7, label=1,
52 | has_default_value=False, default_value=False,
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None, file=DESCRIPTOR),
56 | ],
57 | extensions=[
58 | ],
59 | nested_types=[],
60 | enum_types=[
61 | ],
62 | options=None,
63 | is_extendable=False,
64 | syntax='proto3',
65 | extension_ranges=[],
66 | oneofs=[
67 | ],
68 | serialized_start=69,
69 | serialized_end=137,
70 | )
71 |
72 | DESCRIPTOR.message_types_by_name['ResolutionProto'] = _RESOLUTIONPROTO
73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
74 |
75 | ResolutionProto = _reflection.GeneratedProtocolMessageType('ResolutionProto', (_message.Message,), dict(
76 | DESCRIPTOR = _RESOLUTIONPROTO,
77 | __module__ = 'communicator_objects.resolution_proto_pb2'
78 | # @@protoc_insertion_point(class_scope:communicator_objects.ResolutionProto)
79 | ))
80 | _sym_db.RegisterMessage(ResolutionProto)
81 |
82 |
83 | DESCRIPTOR.has_options = True
84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
85 | # @@protoc_insertion_point(module_scope)
86 |
--------------------------------------------------------------------------------
/python/communicator_objects/space_type_proto_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/space_type_proto.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf.internal import enum_type_wrapper
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import message as _message
9 | from google.protobuf import reflection as _reflection
10 | from google.protobuf import symbol_database as _symbol_database
11 | from google.protobuf import descriptor_pb2
12 | # @@protoc_insertion_point(imports)
13 |
14 | _sym_db = _symbol_database.Default()
15 |
16 |
17 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2
18 |
19 |
20 | DESCRIPTOR = _descriptor.FileDescriptor(
21 | name='communicator_objects/space_type_proto.proto',
22 | package='communicator_objects',
23 | syntax='proto3',
24 | serialized_pb=_b('\n+communicator_objects/space_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*.\n\x0eSpaceTypeProto\x12\x0c\n\x08\x64iscrete\x10\x00\x12\x0e\n\ncontinuous\x10\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
25 | ,
26 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,])
27 |
28 | _SPACETYPEPROTO = _descriptor.EnumDescriptor(
29 | name='SpaceTypeProto',
30 | full_name='communicator_objects.SpaceTypeProto',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | values=[
34 | _descriptor.EnumValueDescriptor(
35 | name='discrete', index=0, number=0,
36 | options=None,
37 | type=None),
38 | _descriptor.EnumValueDescriptor(
39 | name='continuous', index=1, number=1,
40 | options=None,
41 | type=None),
42 | ],
43 | containing_type=None,
44 | options=None,
45 | serialized_start=114,
46 | serialized_end=160,
47 | )
48 | _sym_db.RegisterEnumDescriptor(_SPACETYPEPROTO)
49 |
50 | SpaceTypeProto = enum_type_wrapper.EnumTypeWrapper(_SPACETYPEPROTO)
51 | discrete = 0
52 | continuous = 1
53 |
54 |
55 | DESCRIPTOR.enum_types_by_name['SpaceTypeProto'] = _SPACETYPEPROTO
56 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
57 |
58 |
59 | DESCRIPTOR.has_options = True
60 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
61 | # @@protoc_insertion_point(module_scope)
62 |
--------------------------------------------------------------------------------
/python/communicator_objects/unity_input_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/unity_input.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from communicator_objects import unity_rl_input_pb2 as communicator__objects_dot_unity__rl__input__pb2
17 | from communicator_objects import unity_rl_initialization_input_pb2 as communicator__objects_dot_unity__rl__initialization__input__pb2
18 |
19 |
20 | DESCRIPTOR = _descriptor.FileDescriptor(
21 | name='communicator_objects/unity_input.proto',
22 | package='communicator_objects',
23 | syntax='proto3',
24 | serialized_pb=_b('\n&communicator_objects/unity_input.proto\x12\x14\x63ommunicator_objects\x1a)communicator_objects/unity_rl_input.proto\x1a\x38\x63ommunicator_objects/unity_rl_initialization_input.proto\"\xb0\x01\n\nUnityInput\x12\x34\n\x08rl_input\x18\x01 \x01(\x0b\x32\".communicator_objects.UnityRLInput\x12Q\n\x17rl_initialization_input\x18\x02 \x01(\x0b\x32\x30.communicator_objects.UnityRLInitializationInput\x12\x19\n\x11\x63ustom_data_input\x18\x03 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
25 | ,
26 | dependencies=[communicator__objects_dot_unity__rl__input__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__input__pb2.DESCRIPTOR,])
27 |
28 |
29 |
30 |
31 | _UNITYINPUT = _descriptor.Descriptor(
32 | name='UnityInput',
33 | full_name='communicator_objects.UnityInput',
34 | filename=None,
35 | file=DESCRIPTOR,
36 | containing_type=None,
37 | fields=[
38 | _descriptor.FieldDescriptor(
39 | name='rl_input', full_name='communicator_objects.UnityInput.rl_input', index=0,
40 | number=1, type=11, cpp_type=10, label=1,
41 | has_default_value=False, default_value=None,
42 | message_type=None, enum_type=None, containing_type=None,
43 | is_extension=False, extension_scope=None,
44 | options=None, file=DESCRIPTOR),
45 | _descriptor.FieldDescriptor(
46 | name='rl_initialization_input', full_name='communicator_objects.UnityInput.rl_initialization_input', index=1,
47 | number=2, type=11, cpp_type=10, label=1,
48 | has_default_value=False, default_value=None,
49 | message_type=None, enum_type=None, containing_type=None,
50 | is_extension=False, extension_scope=None,
51 | options=None, file=DESCRIPTOR),
52 | _descriptor.FieldDescriptor(
53 | name='custom_data_input', full_name='communicator_objects.UnityInput.custom_data_input', index=2,
54 | number=3, type=5, cpp_type=1, label=1,
55 | has_default_value=False, default_value=0,
56 | message_type=None, enum_type=None, containing_type=None,
57 | is_extension=False, extension_scope=None,
58 | options=None, file=DESCRIPTOR),
59 | ],
60 | extensions=[
61 | ],
62 | nested_types=[],
63 | enum_types=[
64 | ],
65 | options=None,
66 | is_extendable=False,
67 | syntax='proto3',
68 | extension_ranges=[],
69 | oneofs=[
70 | ],
71 | serialized_start=166,
72 | serialized_end=342,
73 | )
74 |
75 | _UNITYINPUT.fields_by_name['rl_input'].message_type = communicator__objects_dot_unity__rl__input__pb2._UNITYRLINPUT
76 | _UNITYINPUT.fields_by_name['rl_initialization_input'].message_type = communicator__objects_dot_unity__rl__initialization__input__pb2._UNITYRLINITIALIZATIONINPUT
77 | DESCRIPTOR.message_types_by_name['UnityInput'] = _UNITYINPUT
78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
79 |
80 | UnityInput = _reflection.GeneratedProtocolMessageType('UnityInput', (_message.Message,), dict(
81 | DESCRIPTOR = _UNITYINPUT,
82 | __module__ = 'communicator_objects.unity_input_pb2'
83 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityInput)
84 | ))
85 | _sym_db.RegisterMessage(UnityInput)
86 |
87 |
88 | DESCRIPTOR.has_options = True
89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
90 | # @@protoc_insertion_point(module_scope)
91 |
--------------------------------------------------------------------------------
/python/communicator_objects/unity_message_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/unity_message.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from communicator_objects import unity_output_pb2 as communicator__objects_dot_unity__output__pb2
17 | from communicator_objects import unity_input_pb2 as communicator__objects_dot_unity__input__pb2
18 | from communicator_objects import header_pb2 as communicator__objects_dot_header__pb2
19 |
20 |
21 | DESCRIPTOR = _descriptor.FileDescriptor(
22 | name='communicator_objects/unity_message.proto',
23 | package='communicator_objects',
24 | syntax='proto3',
25 | serialized_pb=_b('\n(communicator_objects/unity_message.proto\x12\x14\x63ommunicator_objects\x1a\'communicator_objects/unity_output.proto\x1a&communicator_objects/unity_input.proto\x1a!communicator_objects/header.proto\"\xac\x01\n\x0cUnityMessage\x12,\n\x06header\x18\x01 \x01(\x0b\x32\x1c.communicator_objects.Header\x12\x37\n\x0cunity_output\x18\x02 \x01(\x0b\x32!.communicator_objects.UnityOutput\x12\x35\n\x0bunity_input\x18\x03 \x01(\x0b\x32 .communicator_objects.UnityInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
26 | ,
27 | dependencies=[communicator__objects_dot_unity__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__input__pb2.DESCRIPTOR,communicator__objects_dot_header__pb2.DESCRIPTOR,])
28 |
29 |
30 |
31 |
32 | _UNITYMESSAGE = _descriptor.Descriptor(
33 | name='UnityMessage',
34 | full_name='communicator_objects.UnityMessage',
35 | filename=None,
36 | file=DESCRIPTOR,
37 | containing_type=None,
38 | fields=[
39 | _descriptor.FieldDescriptor(
40 | name='header', full_name='communicator_objects.UnityMessage.header', index=0,
41 | number=1, type=11, cpp_type=10, label=1,
42 | has_default_value=False, default_value=None,
43 | message_type=None, enum_type=None, containing_type=None,
44 | is_extension=False, extension_scope=None,
45 | options=None, file=DESCRIPTOR),
46 | _descriptor.FieldDescriptor(
47 | name='unity_output', full_name='communicator_objects.UnityMessage.unity_output', index=1,
48 | number=2, type=11, cpp_type=10, label=1,
49 | has_default_value=False, default_value=None,
50 | message_type=None, enum_type=None, containing_type=None,
51 | is_extension=False, extension_scope=None,
52 | options=None, file=DESCRIPTOR),
53 | _descriptor.FieldDescriptor(
54 | name='unity_input', full_name='communicator_objects.UnityMessage.unity_input', index=2,
55 | number=3, type=11, cpp_type=10, label=1,
56 | has_default_value=False, default_value=None,
57 | message_type=None, enum_type=None, containing_type=None,
58 | is_extension=False, extension_scope=None,
59 | options=None, file=DESCRIPTOR),
60 | ],
61 | extensions=[
62 | ],
63 | nested_types=[],
64 | enum_types=[
65 | ],
66 | options=None,
67 | is_extendable=False,
68 | syntax='proto3',
69 | extension_ranges=[],
70 | oneofs=[
71 | ],
72 | serialized_start=183,
73 | serialized_end=355,
74 | )
75 |
76 | _UNITYMESSAGE.fields_by_name['header'].message_type = communicator__objects_dot_header__pb2._HEADER
77 | _UNITYMESSAGE.fields_by_name['unity_output'].message_type = communicator__objects_dot_unity__output__pb2._UNITYOUTPUT
78 | _UNITYMESSAGE.fields_by_name['unity_input'].message_type = communicator__objects_dot_unity__input__pb2._UNITYINPUT
79 | DESCRIPTOR.message_types_by_name['UnityMessage'] = _UNITYMESSAGE
80 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
81 |
82 | UnityMessage = _reflection.GeneratedProtocolMessageType('UnityMessage', (_message.Message,), dict(
83 | DESCRIPTOR = _UNITYMESSAGE,
84 | __module__ = 'communicator_objects.unity_message_pb2'
85 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityMessage)
86 | ))
87 | _sym_db.RegisterMessage(UnityMessage)
88 |
89 |
90 | DESCRIPTOR.has_options = True
91 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
92 | # @@protoc_insertion_point(module_scope)
93 |
--------------------------------------------------------------------------------
/python/communicator_objects/unity_output_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/unity_output.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from communicator_objects import unity_rl_output_pb2 as communicator__objects_dot_unity__rl__output__pb2
17 | from communicator_objects import unity_rl_initialization_output_pb2 as communicator__objects_dot_unity__rl__initialization__output__pb2
18 |
19 |
20 | DESCRIPTOR = _descriptor.FileDescriptor(
21 | name='communicator_objects/unity_output.proto',
22 | package='communicator_objects',
23 | syntax='proto3',
24 | serialized_pb=_b('\n\'communicator_objects/unity_output.proto\x12\x14\x63ommunicator_objects\x1a*communicator_objects/unity_rl_output.proto\x1a\x39\x63ommunicator_objects/unity_rl_initialization_output.proto\"\xb6\x01\n\x0bUnityOutput\x12\x36\n\trl_output\x18\x01 \x01(\x0b\x32#.communicator_objects.UnityRLOutput\x12S\n\x18rl_initialization_output\x18\x02 \x01(\x0b\x32\x31.communicator_objects.UnityRLInitializationOutput\x12\x1a\n\x12\x63ustom_data_output\x18\x03 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
25 | ,
26 | dependencies=[communicator__objects_dot_unity__rl__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__output__pb2.DESCRIPTOR,])
27 |
28 |
29 |
30 |
31 | _UNITYOUTPUT = _descriptor.Descriptor(
32 | name='UnityOutput',
33 | full_name='communicator_objects.UnityOutput',
34 | filename=None,
35 | file=DESCRIPTOR,
36 | containing_type=None,
37 | fields=[
38 | _descriptor.FieldDescriptor(
39 | name='rl_output', full_name='communicator_objects.UnityOutput.rl_output', index=0,
40 | number=1, type=11, cpp_type=10, label=1,
41 | has_default_value=False, default_value=None,
42 | message_type=None, enum_type=None, containing_type=None,
43 | is_extension=False, extension_scope=None,
44 | options=None, file=DESCRIPTOR),
45 | _descriptor.FieldDescriptor(
46 | name='rl_initialization_output', full_name='communicator_objects.UnityOutput.rl_initialization_output', index=1,
47 | number=2, type=11, cpp_type=10, label=1,
48 | has_default_value=False, default_value=None,
49 | message_type=None, enum_type=None, containing_type=None,
50 | is_extension=False, extension_scope=None,
51 | options=None, file=DESCRIPTOR),
52 | _descriptor.FieldDescriptor(
53 | name='custom_data_output', full_name='communicator_objects.UnityOutput.custom_data_output', index=2,
54 | number=3, type=9, cpp_type=9, label=1,
55 | has_default_value=False, default_value=_b("").decode('utf-8'),
56 | message_type=None, enum_type=None, containing_type=None,
57 | is_extension=False, extension_scope=None,
58 | options=None, file=DESCRIPTOR),
59 | ],
60 | extensions=[
61 | ],
62 | nested_types=[],
63 | enum_types=[
64 | ],
65 | options=None,
66 | is_extendable=False,
67 | syntax='proto3',
68 | extension_ranges=[],
69 | oneofs=[
70 | ],
71 | serialized_start=169,
72 | serialized_end=351,
73 | )
74 |
75 | _UNITYOUTPUT.fields_by_name['rl_output'].message_type = communicator__objects_dot_unity__rl__output__pb2._UNITYRLOUTPUT
76 | _UNITYOUTPUT.fields_by_name['rl_initialization_output'].message_type = communicator__objects_dot_unity__rl__initialization__output__pb2._UNITYRLINITIALIZATIONOUTPUT
77 | DESCRIPTOR.message_types_by_name['UnityOutput'] = _UNITYOUTPUT
78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
79 |
80 | UnityOutput = _reflection.GeneratedProtocolMessageType('UnityOutput', (_message.Message,), dict(
81 | DESCRIPTOR = _UNITYOUTPUT,
82 | __module__ = 'communicator_objects.unity_output_pb2'
83 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityOutput)
84 | ))
85 | _sym_db.RegisterMessage(UnityOutput)
86 |
87 |
88 | DESCRIPTOR.has_options = True
89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
90 | # @@protoc_insertion_point(module_scope)
91 |
--------------------------------------------------------------------------------
/python/communicator_objects/unity_rl_initialization_input_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/unity_rl_initialization_input.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='communicator_objects/unity_rl_initialization_input.proto',
20 | package='communicator_objects',
21 | syntax='proto3',
22 | serialized_pb=_b('\n8communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\"*\n\x1aUnityRLInitializationInput\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _UNITYRLINITIALIZATIONINPUT = _descriptor.Descriptor(
29 | name='UnityRLInitializationInput',
30 | full_name='communicator_objects.UnityRLInitializationInput',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='seed', full_name='communicator_objects.UnityRLInitializationInput.seed', index=0,
37 | number=1, type=5, cpp_type=1, label=1,
38 | has_default_value=False, default_value=0,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None, file=DESCRIPTOR),
42 | ],
43 | extensions=[
44 | ],
45 | nested_types=[],
46 | enum_types=[
47 | ],
48 | options=None,
49 | is_extendable=False,
50 | syntax='proto3',
51 | extension_ranges=[],
52 | oneofs=[
53 | ],
54 | serialized_start=82,
55 | serialized_end=124,
56 | )
57 |
58 | DESCRIPTOR.message_types_by_name['UnityRLInitializationInput'] = _UNITYRLINITIALIZATIONINPUT
59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
60 |
61 | UnityRLInitializationInput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationInput', (_message.Message,), dict(
62 | DESCRIPTOR = _UNITYRLINITIALIZATIONINPUT,
63 | __module__ = 'communicator_objects.unity_rl_initialization_input_pb2'
64 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationInput)
65 | ))
66 | _sym_db.RegisterMessage(UnityRLInitializationInput)
67 |
68 |
69 | DESCRIPTOR.has_options = True
70 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
71 | # @@protoc_insertion_point(module_scope)
72 |
--------------------------------------------------------------------------------
/python/communicator_objects/unity_rl_initialization_output_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/unity_rl_initialization_output.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from communicator_objects import brain_parameters_proto_pb2 as communicator__objects_dot_brain__parameters__proto__pb2
17 | from communicator_objects import environment_parameters_proto_pb2 as communicator__objects_dot_environment__parameters__proto__pb2
18 |
19 |
20 | DESCRIPTOR = _descriptor.FileDescriptor(
21 | name='communicator_objects/unity_rl_initialization_output.proto',
22 | package='communicator_objects',
23 | syntax='proto3',
24 | serialized_pb=_b('\n9communicator_objects/unity_rl_initialization_output.proto\x12\x14\x63ommunicator_objects\x1a\x31\x63ommunicator_objects/brain_parameters_proto.proto\x1a\x37\x63ommunicator_objects/environment_parameters_proto.proto\"\xe6\x01\n\x1bUnityRLInitializationOutput\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x10\n\x08log_path\x18\x03 \x01(\t\x12\x44\n\x10\x62rain_parameters\x18\x05 \x03(\x0b\x32*.communicator_objects.BrainParametersProto\x12P\n\x16\x65nvironment_parameters\x18\x06 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
25 | ,
26 | dependencies=[communicator__objects_dot_brain__parameters__proto__pb2.DESCRIPTOR,communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,])
27 |
28 |
29 |
30 |
31 | _UNITYRLINITIALIZATIONOUTPUT = _descriptor.Descriptor(
32 | name='UnityRLInitializationOutput',
33 | full_name='communicator_objects.UnityRLInitializationOutput',
34 | filename=None,
35 | file=DESCRIPTOR,
36 | containing_type=None,
37 | fields=[
38 | _descriptor.FieldDescriptor(
39 | name='name', full_name='communicator_objects.UnityRLInitializationOutput.name', index=0,
40 | number=1, type=9, cpp_type=9, label=1,
41 | has_default_value=False, default_value=_b("").decode('utf-8'),
42 | message_type=None, enum_type=None, containing_type=None,
43 | is_extension=False, extension_scope=None,
44 | options=None, file=DESCRIPTOR),
45 | _descriptor.FieldDescriptor(
46 | name='version', full_name='communicator_objects.UnityRLInitializationOutput.version', index=1,
47 | number=2, type=9, cpp_type=9, label=1,
48 | has_default_value=False, default_value=_b("").decode('utf-8'),
49 | message_type=None, enum_type=None, containing_type=None,
50 | is_extension=False, extension_scope=None,
51 | options=None, file=DESCRIPTOR),
52 | _descriptor.FieldDescriptor(
53 | name='log_path', full_name='communicator_objects.UnityRLInitializationOutput.log_path', index=2,
54 | number=3, type=9, cpp_type=9, label=1,
55 | has_default_value=False, default_value=_b("").decode('utf-8'),
56 | message_type=None, enum_type=None, containing_type=None,
57 | is_extension=False, extension_scope=None,
58 | options=None, file=DESCRIPTOR),
59 | _descriptor.FieldDescriptor(
60 | name='brain_parameters', full_name='communicator_objects.UnityRLInitializationOutput.brain_parameters', index=3,
61 | number=5, type=11, cpp_type=10, label=3,
62 | has_default_value=False, default_value=[],
63 | message_type=None, enum_type=None, containing_type=None,
64 | is_extension=False, extension_scope=None,
65 | options=None, file=DESCRIPTOR),
66 | _descriptor.FieldDescriptor(
67 | name='environment_parameters', full_name='communicator_objects.UnityRLInitializationOutput.environment_parameters', index=4,
68 | number=6, type=11, cpp_type=10, label=1,
69 | has_default_value=False, default_value=None,
70 | message_type=None, enum_type=None, containing_type=None,
71 | is_extension=False, extension_scope=None,
72 | options=None, file=DESCRIPTOR),
73 | ],
74 | extensions=[
75 | ],
76 | nested_types=[],
77 | enum_types=[
78 | ],
79 | options=None,
80 | is_extendable=False,
81 | syntax='proto3',
82 | extension_ranges=[],
83 | oneofs=[
84 | ],
85 | serialized_start=192,
86 | serialized_end=422,
87 | )
88 |
89 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['brain_parameters'].message_type = communicator__objects_dot_brain__parameters__proto__pb2._BRAINPARAMETERSPROTO
90 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['environment_parameters'].message_type = communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO
91 | DESCRIPTOR.message_types_by_name['UnityRLInitializationOutput'] = _UNITYRLINITIALIZATIONOUTPUT
92 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
93 |
94 | UnityRLInitializationOutput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationOutput', (_message.Message,), dict(
95 | DESCRIPTOR = _UNITYRLINITIALIZATIONOUTPUT,
96 | __module__ = 'communicator_objects.unity_rl_initialization_output_pb2'
97 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationOutput)
98 | ))
99 | _sym_db.RegisterMessage(UnityRLInitializationOutput)
100 |
101 |
102 | DESCRIPTOR.has_options = True
103 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
104 | # @@protoc_insertion_point(module_scope)
105 |
--------------------------------------------------------------------------------
/python/communicator_objects/unity_rl_output_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/unity_rl_output.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from communicator_objects import agent_info_proto_pb2 as communicator__objects_dot_agent__info__proto__pb2
17 |
18 |
19 | DESCRIPTOR = _descriptor.FileDescriptor(
20 | name='communicator_objects/unity_rl_output.proto',
21 | package='communicator_objects',
22 | syntax='proto3',
23 | serialized_pb=_b('\n*communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/agent_info_proto.proto\"\xa3\x02\n\rUnityRLOutput\x12\x13\n\x0bglobal_done\x18\x01 \x01(\x08\x12G\n\nagentInfos\x18\x02 \x03(\x0b\x32\x33.communicator_objects.UnityRLOutput.AgentInfosEntry\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1ai\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x45\n\x05value\x18\x02 \x01(\x0b\x32\x36.communicator_objects.UnityRLOutput.ListAgentInfoProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
24 | ,
25 | dependencies=[communicator__objects_dot_agent__info__proto__pb2.DESCRIPTOR,])
26 |
27 |
28 |
29 |
30 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO = _descriptor.Descriptor(
31 | name='ListAgentInfoProto',
32 | full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto',
33 | filename=None,
34 | file=DESCRIPTOR,
35 | containing_type=None,
36 | fields=[
37 | _descriptor.FieldDescriptor(
38 | name='value', full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto.value', index=0,
39 | number=1, type=11, cpp_type=10, label=3,
40 | has_default_value=False, default_value=[],
41 | message_type=None, enum_type=None, containing_type=None,
42 | is_extension=False, extension_scope=None,
43 | options=None, file=DESCRIPTOR),
44 | ],
45 | extensions=[
46 | ],
47 | nested_types=[],
48 | enum_types=[
49 | ],
50 | options=None,
51 | is_extendable=False,
52 | syntax='proto3',
53 | extension_ranges=[],
54 | oneofs=[
55 | ],
56 | serialized_start=225,
57 | serialized_end=298,
58 | )
59 |
60 | _UNITYRLOUTPUT_AGENTINFOSENTRY = _descriptor.Descriptor(
61 | name='AgentInfosEntry',
62 | full_name='communicator_objects.UnityRLOutput.AgentInfosEntry',
63 | filename=None,
64 | file=DESCRIPTOR,
65 | containing_type=None,
66 | fields=[
67 | _descriptor.FieldDescriptor(
68 | name='key', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.key', index=0,
69 | number=1, type=9, cpp_type=9, label=1,
70 | has_default_value=False, default_value=_b("").decode('utf-8'),
71 | message_type=None, enum_type=None, containing_type=None,
72 | is_extension=False, extension_scope=None,
73 | options=None, file=DESCRIPTOR),
74 | _descriptor.FieldDescriptor(
75 | name='value', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.value', index=1,
76 | number=2, type=11, cpp_type=10, label=1,
77 | has_default_value=False, default_value=None,
78 | message_type=None, enum_type=None, containing_type=None,
79 | is_extension=False, extension_scope=None,
80 | options=None, file=DESCRIPTOR),
81 | ],
82 | extensions=[
83 | ],
84 | nested_types=[],
85 | enum_types=[
86 | ],
87 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')),
88 | is_extendable=False,
89 | syntax='proto3',
90 | extension_ranges=[],
91 | oneofs=[
92 | ],
93 | serialized_start=300,
94 | serialized_end=405,
95 | )
96 |
97 | _UNITYRLOUTPUT = _descriptor.Descriptor(
98 | name='UnityRLOutput',
99 | full_name='communicator_objects.UnityRLOutput',
100 | filename=None,
101 | file=DESCRIPTOR,
102 | containing_type=None,
103 | fields=[
104 | _descriptor.FieldDescriptor(
105 | name='global_done', full_name='communicator_objects.UnityRLOutput.global_done', index=0,
106 | number=1, type=8, cpp_type=7, label=1,
107 | has_default_value=False, default_value=False,
108 | message_type=None, enum_type=None, containing_type=None,
109 | is_extension=False, extension_scope=None,
110 | options=None, file=DESCRIPTOR),
111 | _descriptor.FieldDescriptor(
112 | name='agentInfos', full_name='communicator_objects.UnityRLOutput.agentInfos', index=1,
113 | number=2, type=11, cpp_type=10, label=3,
114 | has_default_value=False, default_value=[],
115 | message_type=None, enum_type=None, containing_type=None,
116 | is_extension=False, extension_scope=None,
117 | options=None, file=DESCRIPTOR),
118 | ],
119 | extensions=[
120 | ],
121 | nested_types=[_UNITYRLOUTPUT_LISTAGENTINFOPROTO, _UNITYRLOUTPUT_AGENTINFOSENTRY, ],
122 | enum_types=[
123 | ],
124 | options=None,
125 | is_extendable=False,
126 | syntax='proto3',
127 | extension_ranges=[],
128 | oneofs=[
129 | ],
130 | serialized_start=114,
131 | serialized_end=405,
132 | )
133 |
134 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.fields_by_name['value'].message_type = communicator__objects_dot_agent__info__proto__pb2._AGENTINFOPROTO
135 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.containing_type = _UNITYRLOUTPUT
136 | _UNITYRLOUTPUT_AGENTINFOSENTRY.fields_by_name['value'].message_type = _UNITYRLOUTPUT_LISTAGENTINFOPROTO
137 | _UNITYRLOUTPUT_AGENTINFOSENTRY.containing_type = _UNITYRLOUTPUT
138 | _UNITYRLOUTPUT.fields_by_name['agentInfos'].message_type = _UNITYRLOUTPUT_AGENTINFOSENTRY
139 | DESCRIPTOR.message_types_by_name['UnityRLOutput'] = _UNITYRLOUTPUT
140 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
141 |
142 | UnityRLOutput = _reflection.GeneratedProtocolMessageType('UnityRLOutput', (_message.Message,), dict(
143 |
144 | ListAgentInfoProto = _reflection.GeneratedProtocolMessageType('ListAgentInfoProto', (_message.Message,), dict(
145 | DESCRIPTOR = _UNITYRLOUTPUT_LISTAGENTINFOPROTO,
146 | __module__ = 'communicator_objects.unity_rl_output_pb2'
147 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.ListAgentInfoProto)
148 | ))
149 | ,
150 |
151 | AgentInfosEntry = _reflection.GeneratedProtocolMessageType('AgentInfosEntry', (_message.Message,), dict(
152 | DESCRIPTOR = _UNITYRLOUTPUT_AGENTINFOSENTRY,
153 | __module__ = 'communicator_objects.unity_rl_output_pb2'
154 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.AgentInfosEntry)
155 | ))
156 | ,
157 | DESCRIPTOR = _UNITYRLOUTPUT,
158 | __module__ = 'communicator_objects.unity_rl_output_pb2'
159 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput)
160 | ))
161 | _sym_db.RegisterMessage(UnityRLOutput)
162 | _sym_db.RegisterMessage(UnityRLOutput.ListAgentInfoProto)
163 | _sym_db.RegisterMessage(UnityRLOutput.AgentInfosEntry)
164 |
165 |
166 | DESCRIPTOR.has_options = True
167 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
168 | _UNITYRLOUTPUT_AGENTINFOSENTRY.has_options = True
169 | _UNITYRLOUTPUT_AGENTINFOSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001'))
170 | # @@protoc_insertion_point(module_scope)
171 |
--------------------------------------------------------------------------------
/python/communicator_objects/unity_to_external_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: communicator_objects/unity_to_external.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2
17 |
18 |
19 | DESCRIPTOR = _descriptor.FileDescriptor(
20 | name='communicator_objects/unity_to_external.proto',
21 | package='communicator_objects',
22 | syntax='proto3',
23 | serialized_pb=_b('\n,communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a(communicator_objects/unity_message.proto2g\n\x0fUnityToExternal\x12T\n\x08\x45xchange\x12\".communicator_objects.UnityMessage\x1a\".communicator_objects.UnityMessage\"\x00\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
24 | ,
25 | dependencies=[communicator__objects_dot_unity__message__pb2.DESCRIPTOR,])
26 |
27 |
28 |
29 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
30 |
31 |
32 | DESCRIPTOR.has_options = True
33 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
34 |
35 | _UNITYTOEXTERNAL = _descriptor.ServiceDescriptor(
36 | name='UnityToExternal',
37 | full_name='communicator_objects.UnityToExternal',
38 | file=DESCRIPTOR,
39 | index=0,
40 | options=None,
41 | serialized_start=112,
42 | serialized_end=215,
43 | methods=[
44 | _descriptor.MethodDescriptor(
45 | name='Exchange',
46 | full_name='communicator_objects.UnityToExternal.Exchange',
47 | index=0,
48 | containing_service=None,
49 | input_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE,
50 | output_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE,
51 | options=None,
52 | ),
53 | ])
54 | _sym_db.RegisterServiceDescriptor(_UNITYTOEXTERNAL)
55 |
56 | DESCRIPTOR.services_by_name['UnityToExternal'] = _UNITYTOEXTERNAL
57 |
58 | # @@protoc_insertion_point(module_scope)
59 |
--------------------------------------------------------------------------------
/python/communicator_objects/unity_to_external_pb2_grpc.py:
--------------------------------------------------------------------------------
1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2 | import grpc
3 |
4 | from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2
5 |
6 |
7 | class UnityToExternalStub(object):
8 | # missing associated documentation comment in .proto file
9 | pass
10 |
11 | def __init__(self, channel):
12 | """Constructor.
13 |
14 | Args:
15 | channel: A grpc.Channel.
16 | """
17 | self.Exchange = channel.unary_unary(
18 | '/communicator_objects.UnityToExternal/Exchange',
19 | request_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString,
20 | response_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString,
21 | )
22 |
23 |
24 | class UnityToExternalServicer(object):
25 | # missing associated documentation comment in .proto file
26 | pass
27 |
28 | def Exchange(self, request, context):
29 | """Sends the academy parameters
30 | """
31 | context.set_code(grpc.StatusCode.UNIMPLEMENTED)
32 | context.set_details('Method not implemented!')
33 | raise NotImplementedError('Method not implemented!')
34 |
35 |
36 | def add_UnityToExternalServicer_to_server(servicer, server):
37 | rpc_method_handlers = {
38 | 'Exchange': grpc.unary_unary_rpc_method_handler(
39 | servicer.Exchange,
40 | request_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString,
41 | response_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString,
42 | ),
43 | }
44 | generic_handler = grpc.method_handlers_generic_handler(
45 | 'communicator_objects.UnityToExternal', rpc_method_handlers)
46 | server.add_generic_rpc_handlers((generic_handler,))
47 |
--------------------------------------------------------------------------------
/python/curricula/push.json:
--------------------------------------------------------------------------------
1 | {
2 | "measure" : "reward",
3 | "thresholds" : [0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75],
4 | "min_lesson_length" : 2,
5 | "signal_smoothing" : true,
6 | "parameters" :
7 | {
8 | "goal_size" : [3.5, 3.25, 3.0, 2.75, 2.5, 2.25, 2.0, 1.75, 1.5, 1.25, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
9 | "block_size": [1.5, 1.4, 1.3, 1.2, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
10 | "x_variation":[1.5, 1.55, 1.6, 1.65, 1.7, 1.75, 1.8, 1.85, 1.9, 1.95, 2.0, 2.1, 2.2, 2.3, 2.4, 2.5]
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/python/curricula/test.json:
--------------------------------------------------------------------------------
1 | {
2 | "measure" : "reward",
3 | "thresholds" : [10, 20, 50],
4 | "min_lesson_length" : 3,
5 | "signal_smoothing" : true,
6 | "parameters" :
7 | {
8 | "param1" : [0.7, 0.5, 0.3, 0.1],
9 | "param2" : [100, 50, 20, 15],
10 | "param3" : [0.2, 0.3, 0.7, 0.9]
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/python/curricula/wall.json:
--------------------------------------------------------------------------------
1 | {
2 | "measure" : "progress",
3 | "thresholds" : [0.1, 0.3, 0.5],
4 | "min_lesson_length" : 2,
5 | "signal_smoothing" : true,
6 | "parameters" :
7 | {
8 | "big_wall_min_height" : [0.0, 4.0, 6.0, 8.0],
9 | "big_wall_max_height" : [4.0, 7.0, 8.0, 8.0],
10 | "small_wall_height" : [1.5, 2.0, 2.5, 4.0]
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/python/learn.py:
--------------------------------------------------------------------------------
1 | # # Unity ML-Agents Toolkit
2 | # ## ML-Agent Learning
3 |
4 | import logging
5 |
6 | import os
7 | from docopt import docopt
8 |
9 | from unitytrainers.trainer_controller import TrainerController
10 |
11 |
12 | if __name__ == '__main__':
13 | print('''
14 |
15 | ▄▄▄▓▓▓▓
16 | ╓▓▓▓▓▓▓█▓▓▓▓▓
17 | ,▄▄▄m▀▀▀' ,▓▓▓▀▓▓▄ ▓▓▓ ▓▓▌
18 | ▄▓▓▓▀' ▄▓▓▀ ▓▓▓ ▄▄ ▄▄ ,▄▄ ▄▄▄▄ ,▄▄ ▄▓▓▌▄ ▄▄▄ ,▄▄
19 | ▄▓▓▓▀ ▄▓▓▀ ▐▓▓▌ ▓▓▌ ▐▓▓ ▐▓▓▓▀▀▀▓▓▌ ▓▓▓ ▀▓▓▌▀ ^▓▓▌ ╒▓▓▌
20 | ▄▓▓▓▓▓▄▄▄▄▄▄▄▄▓▓▓ ▓▀ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▄ ▓▓▌
21 | ▀▓▓▓▓▀▀▀▀▀▀▀▀▀▀▓▓▄ ▓▓ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▐▓▓
22 | ^█▓▓▓ ▀▓▓▄ ▐▓▓▌ ▓▓▓▓▄▓▓▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▓▄ ▓▓▓▓`
23 | '▀▓▓▓▄ ^▓▓▓ ▓▓▓ └▀▀▀▀ ▀▀ ^▀▀ `▀▀ `▀▀ '▀▀ ▐▓▓▌
24 | ▀▀▀▀▓▄▄▄ ▓▓▓▓▓▓, ▓▓▓▓▀
25 | `▀█▓▓▓▓▓▓▓▓▓▌
26 | ¬`▀▀▀█▓
27 |
28 | ''')
29 |
30 | logger = logging.getLogger("unityagents")
31 | _USAGE = '''
32 | Usage:
33 | learn () [options]
34 | learn [options]
35 | learn --help
36 |
37 | Options:
38 | --curriculum= Curriculum json file for environment [default: None].
39 | --keep-checkpoints= How many model checkpoints to keep [default: 5].
40 | --lesson= Start learning from this lesson [default: 0].
41 | --load Whether to load the model or randomly initialize [default: False].
42 | --run-id= The sub-directory name for model and summary statistics [default: ppo].
43 | --save-freq= Frequency at which to save model [default: 50000].
44 | --seed= Random seed used for training [default: -1].
45 | --slow Whether to run the game at training speed [default: False].
46 | --train Whether to train model, or only run inference [default: False].
47 | --worker-id= Number to add to communication port (5005). Used for multi-environment [default: 0].
48 | --docker-target-name= Docker Volume to store curriculum, executable and model files [default: Empty].
49 | --no-graphics Whether to run the Unity simulator in no-graphics mode [default: False].
50 | '''
51 |
52 | options = docopt(_USAGE)
53 | logger.info(options)
54 | # Docker Parameters
55 | if options['--docker-target-name'] == 'Empty':
56 | docker_target_name = ''
57 | else:
58 | docker_target_name = options['--docker-target-name']
59 |
60 | # General parameters
61 | run_id = options['--run-id']
62 | seed = int(options['--seed'])
63 | load_model = options['--load']
64 | train_model = options['--train']
65 | save_freq = int(options['--save-freq'])
66 | env_path = options['']
67 | keep_checkpoints = int(options['--keep-checkpoints'])
68 | worker_id = int(options['--worker-id'])
69 | curriculum_file = str(options['--curriculum'])
70 | if curriculum_file == "None":
71 | curriculum_file = None
72 | lesson = int(options['--lesson'])
73 | fast_simulation = not bool(options['--slow'])
74 | no_graphics = options['--no-graphics']
75 |
76 | # Constants
77 | # Assumption that this yaml is present in same dir as this file
78 | base_path = os.path.dirname(__file__)
79 | TRAINER_CONFIG_PATH = os.path.abspath(os.path.join(base_path, "trainer_config.yaml"))
80 |
81 | tc = TrainerController(env_path, run_id, save_freq, curriculum_file, fast_simulation, load_model, train_model,
82 | worker_id, keep_checkpoints, lesson, seed, docker_target_name, TRAINER_CONFIG_PATH,
83 | no_graphics)
84 | tc.start_learning()
85 |
--------------------------------------------------------------------------------
/python/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==1.7.1
2 | Pillow>=4.2.1
3 | matplotlib
4 | numpy>=1.11.0
5 | jupyter
6 | pytest>=3.2.2
7 | docopt
8 | pyyaml
9 | protobuf==3.5.2
10 | grpcio==1.11.0
11 | torch==0.4.0
12 | pandas
13 | scipy
14 | ipykernel
15 |
--------------------------------------------------------------------------------
/python/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup, Command, find_packages
4 |
5 |
6 | with open('requirements.txt') as f:
7 | required = f.read().splitlines()
8 |
9 | setup(name='unityagents',
10 | version='0.4.0',
11 | description='Unity Machine Learning Agents',
12 | license='Apache License 2.0',
13 | author='Unity Technologies',
14 | author_email='ML-Agents@unity3d.com',
15 | url='https://github.com/Unity-Technologies/ml-agents',
16 | packages=find_packages(),
17 | install_requires = required,
18 | long_description= ("Unity Machine Learning Agents allows researchers and developers "
19 | "to transform games and simulations created using the Unity Editor into environments "
20 | "where intelligent agents can be trained using reinforcement learning, evolutionary "
21 | "strategies, or other machine learning methods through a simple to use Python API.")
22 | )
23 |
--------------------------------------------------------------------------------
/python/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from unityagents import *
2 | from unitytrainers import *
3 |
--------------------------------------------------------------------------------
/python/tests/mock_communicator.py:
--------------------------------------------------------------------------------
1 |
2 | from unityagents.communicator import Communicator
3 | from communicator_objects import UnityMessage, UnityOutput, UnityInput,\
4 | ResolutionProto, BrainParametersProto, UnityRLInitializationOutput,\
5 | AgentInfoProto, UnityRLOutput
6 |
7 |
8 | class MockCommunicator(Communicator):
9 | def __init__(self, discrete_action=False, visual_inputs=0):
10 | """
11 | Python side of the grpc communication. Python is the client and Unity the server
12 |
13 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
14 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
15 | """
16 | self.is_discrete = discrete_action
17 | self.steps = 0
18 | self.visual_inputs = visual_inputs
19 | self.has_been_closed = False
20 |
21 | def initialize(self, inputs: UnityInput) -> UnityOutput:
22 | resolutions = [ResolutionProto(
23 | width=30,
24 | height=40,
25 | gray_scale=False) for i in range(self.visual_inputs)]
26 | bp = BrainParametersProto(
27 | vector_observation_size=3,
28 | num_stacked_vector_observations=2,
29 | vector_action_size=2,
30 | camera_resolutions=resolutions,
31 | vector_action_descriptions=["", ""],
32 | vector_action_space_type=int(not self.is_discrete),
33 | vector_observation_space_type=1,
34 | brain_name="RealFakeBrain",
35 | brain_type=2
36 | )
37 | rl_init = UnityRLInitializationOutput(
38 | name="RealFakeAcademy",
39 | version="API-4",
40 | log_path="",
41 | brain_parameters=[bp]
42 | )
43 | return UnityOutput(
44 | rl_initialization_output=rl_init
45 | )
46 |
47 | def exchange(self, inputs: UnityInput) -> UnityOutput:
48 | dict_agent_info = {}
49 | if self.is_discrete:
50 | vector_action = [1]
51 | else:
52 | vector_action = [1, 2]
53 | list_agent_info = []
54 | for i in range(3):
55 | list_agent_info.append(
56 | AgentInfoProto(
57 | stacked_vector_observation=[1, 2, 3, 1, 2, 3],
58 | reward=1,
59 | stored_vector_actions=vector_action,
60 | stored_text_actions="",
61 | text_observation="",
62 | memories=[],
63 | done=(i == 2),
64 | max_step_reached=False,
65 | id=i
66 | ))
67 | dict_agent_info["RealFakeBrain"] = \
68 | UnityRLOutput.ListAgentInfoProto(value=list_agent_info)
69 | global_done = False
70 | try:
71 | global_done = (inputs.rl_input.agent_actions["RealFakeBrain"].value[0].vector_actions[0] == -1)
72 | except:
73 | pass
74 | result = UnityRLOutput(
75 | global_done=global_done,
76 | agentInfos=dict_agent_info
77 | )
78 | return UnityOutput(
79 | rl_output=result
80 | )
81 |
82 | def close(self):
83 | """
84 | Sends a shutdown signal to the unity environment, and closes the grpc connection.
85 | """
86 | self.has_been_closed = True
87 |
--------------------------------------------------------------------------------
/python/tests/test_bc.py:
--------------------------------------------------------------------------------
1 | import unittest.mock as mock
2 | import pytest
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from unitytrainers.bc.models import BehavioralCloningModel
8 | from unityagents import UnityEnvironment
9 | from .mock_communicator import MockCommunicator
10 |
11 |
12 | @mock.patch('unityagents.UnityEnvironment.executable_launcher')
13 | @mock.patch('unityagents.UnityEnvironment.get_communicator')
14 | def test_cc_bc_model(mock_communicator, mock_launcher):
15 | tf.reset_default_graph()
16 | with tf.Session() as sess:
17 | with tf.variable_scope("FakeGraphScope"):
18 | mock_communicator.return_value = MockCommunicator(
19 | discrete_action=False, visual_inputs=0)
20 | env = UnityEnvironment(' ')
21 | model = BehavioralCloningModel(env.brains["RealFakeBrain"])
22 | init = tf.global_variables_initializer()
23 | sess.run(init)
24 |
25 | run_list = [model.sample_action, model.policy]
26 | feed_dict = {model.batch_size: 2,
27 | model.sequence_length: 1,
28 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
29 | [3, 4, 5, 3, 4, 5]])}
30 | sess.run(run_list, feed_dict=feed_dict)
31 | env.close()
32 |
33 |
34 | @mock.patch('unityagents.UnityEnvironment.executable_launcher')
35 | @mock.patch('unityagents.UnityEnvironment.get_communicator')
36 | def test_dc_bc_model(mock_communicator, mock_launcher):
37 | tf.reset_default_graph()
38 | with tf.Session() as sess:
39 | with tf.variable_scope("FakeGraphScope"):
40 | mock_communicator.return_value = MockCommunicator(
41 | discrete_action=True, visual_inputs=0)
42 | env = UnityEnvironment(' ')
43 | model = BehavioralCloningModel(env.brains["RealFakeBrain"])
44 | init = tf.global_variables_initializer()
45 | sess.run(init)
46 |
47 | run_list = [model.sample_action, model.policy]
48 | feed_dict = {model.batch_size: 2,
49 | model.dropout_rate: 1.0,
50 | model.sequence_length: 1,
51 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
52 | [3, 4, 5, 3, 4, 5]])}
53 | sess.run(run_list, feed_dict=feed_dict)
54 | env.close()
55 |
56 |
57 | @mock.patch('unityagents.UnityEnvironment.executable_launcher')
58 | @mock.patch('unityagents.UnityEnvironment.get_communicator')
59 | def test_visual_dc_bc_model(mock_communicator, mock_launcher):
60 | tf.reset_default_graph()
61 | with tf.Session() as sess:
62 | with tf.variable_scope("FakeGraphScope"):
63 | mock_communicator.return_value = MockCommunicator(
64 | discrete_action=True, visual_inputs=2)
65 | env = UnityEnvironment(' ')
66 | model = BehavioralCloningModel(env.brains["RealFakeBrain"])
67 | init = tf.global_variables_initializer()
68 | sess.run(init)
69 |
70 | run_list = [model.sample_action, model.policy]
71 | feed_dict = {model.batch_size: 2,
72 | model.dropout_rate: 1.0,
73 | model.sequence_length: 1,
74 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
75 | [3, 4, 5, 3, 4, 5]]),
76 | model.visual_in[0]: np.ones([2, 40, 30, 3]),
77 | model.visual_in[1]: np.ones([2, 40, 30, 3])}
78 | sess.run(run_list, feed_dict=feed_dict)
79 | env.close()
80 |
81 |
82 | @mock.patch('unityagents.UnityEnvironment.executable_launcher')
83 | @mock.patch('unityagents.UnityEnvironment.get_communicator')
84 | def test_visual_cc_bc_model(mock_communicator, mock_launcher):
85 | tf.reset_default_graph()
86 | with tf.Session() as sess:
87 | with tf.variable_scope("FakeGraphScope"):
88 | mock_communicator.return_value = MockCommunicator(
89 | discrete_action=False, visual_inputs=2)
90 | env = UnityEnvironment(' ')
91 | model = BehavioralCloningModel(env.brains["RealFakeBrain"])
92 | init = tf.global_variables_initializer()
93 | sess.run(init)
94 |
95 | run_list = [model.sample_action, model.policy]
96 | feed_dict = {model.batch_size: 2,
97 | model.sequence_length: 1,
98 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
99 | [3, 4, 5, 3, 4, 5]]),
100 | model.visual_in[0]: np.ones([2, 40, 30, 3]),
101 | model.visual_in[1]: np.ones([2, 40, 30, 3])}
102 | sess.run(run_list, feed_dict=feed_dict)
103 | env.close()
104 |
105 |
106 | if __name__ == '__main__':
107 | pytest.main()
108 |
--------------------------------------------------------------------------------
/python/tests/test_unityagents.py:
--------------------------------------------------------------------------------
1 | import json
2 | import unittest.mock as mock
3 | import pytest
4 | import struct
5 |
6 | import numpy as np
7 |
8 | from unityagents import UnityEnvironment, UnityEnvironmentException, UnityActionException, \
9 | BrainInfo, Curriculum
10 | from .mock_communicator import MockCommunicator
11 |
12 |
13 | dummy_curriculum = json.loads('''{
14 | "measure" : "reward",
15 | "thresholds" : [10, 20, 50],
16 | "min_lesson_length" : 3,
17 | "signal_smoothing" : true,
18 | "parameters" :
19 | {
20 | "param1" : [0.7, 0.5, 0.3, 0.1],
21 | "param2" : [100, 50, 20, 15],
22 | "param3" : [0.2, 0.3, 0.7, 0.9]
23 | }
24 | }''')
25 | bad_curriculum = json.loads('''{
26 | "measure" : "reward",
27 | "thresholds" : [10, 20, 50],
28 | "min_lesson_length" : 3,
29 | "signal_smoothing" : false,
30 | "parameters" :
31 | {
32 | "param1" : [0.7, 0.5, 0.3, 0.1],
33 | "param2" : [100, 50, 20],
34 | "param3" : [0.2, 0.3, 0.7, 0.9]
35 | }
36 | }''')
37 |
38 |
39 | def test_handles_bad_filename():
40 | with pytest.raises(UnityEnvironmentException):
41 | UnityEnvironment(' ')
42 |
43 |
44 | @mock.patch('unityagents.UnityEnvironment.executable_launcher')
45 | @mock.patch('unityagents.UnityEnvironment.get_communicator')
46 | def test_initialization(mock_communicator, mock_launcher):
47 | mock_communicator.return_value = MockCommunicator(
48 | discrete_action=False, visual_inputs=0)
49 | env = UnityEnvironment(' ')
50 | with pytest.raises(UnityActionException):
51 | env.step([0])
52 | assert env.brain_names[0] == 'RealFakeBrain'
53 | env.close()
54 |
55 |
56 | @mock.patch('unityagents.UnityEnvironment.executable_launcher')
57 | @mock.patch('unityagents.UnityEnvironment.get_communicator')
58 | def test_reset(mock_communicator, mock_launcher):
59 | mock_communicator.return_value = MockCommunicator(
60 | discrete_action=False, visual_inputs=0)
61 | env = UnityEnvironment(' ')
62 | brain = env.brains['RealFakeBrain']
63 | brain_info = env.reset()
64 | env.close()
65 | assert not env.global_done
66 | assert isinstance(brain_info, dict)
67 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo)
68 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list)
69 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray)
70 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations
71 | assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \
72 | len(brain_info['RealFakeBrain'].agents)
73 | assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \
74 | brain.vector_observation_space_size * brain.num_stacked_vector_observations
75 |
76 |
77 | @mock.patch('unityagents.UnityEnvironment.executable_launcher')
78 | @mock.patch('unityagents.UnityEnvironment.get_communicator')
79 | def test_step(mock_communicator, mock_launcher):
80 | mock_communicator.return_value = MockCommunicator(
81 | discrete_action=False, visual_inputs=0)
82 | env = UnityEnvironment(' ')
83 | brain = env.brains['RealFakeBrain']
84 | brain_info = env.reset()
85 | brain_info = env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents))
86 | with pytest.raises(UnityActionException):
87 | env.step([0])
88 | brain_info = env.step([-1] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents))
89 | with pytest.raises(UnityActionException):
90 | env.step([0] * brain.vector_action_space_size * len(brain_info['RealFakeBrain'].agents))
91 | env.close()
92 | assert env.global_done
93 | assert isinstance(brain_info, dict)
94 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo)
95 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list)
96 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray)
97 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations
98 | assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \
99 | len(brain_info['RealFakeBrain'].agents)
100 | assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \
101 | brain.vector_observation_space_size * brain.num_stacked_vector_observations
102 |
103 | print("\n\n\n\n\n\n\n" + str(brain_info['RealFakeBrain'].local_done))
104 | assert not brain_info['RealFakeBrain'].local_done[0]
105 | assert brain_info['RealFakeBrain'].local_done[2]
106 |
107 |
108 | @mock.patch('unityagents.UnityEnvironment.executable_launcher')
109 | @mock.patch('unityagents.UnityEnvironment.get_communicator')
110 | def test_close(mock_communicator, mock_launcher):
111 | comm = MockCommunicator(
112 | discrete_action=False, visual_inputs=0)
113 | mock_communicator.return_value = comm
114 | env = UnityEnvironment(' ')
115 | assert env._loaded
116 | env.close()
117 | assert not env._loaded
118 | assert comm.has_been_closed
119 |
120 |
121 | def test_curriculum():
122 | open_name = '%s.open' % __name__
123 | with mock.patch('json.load') as mock_load:
124 | with mock.patch(open_name, create=True) as mock_open:
125 | mock_open.return_value = 0
126 | mock_load.return_value = bad_curriculum
127 | with pytest.raises(UnityEnvironmentException):
128 | Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1, "param3": 1})
129 | mock_load.return_value = dummy_curriculum
130 | with pytest.raises(UnityEnvironmentException):
131 | Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1})
132 | curriculum = Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1, "param3": 1})
133 | assert curriculum.get_lesson_number == 0
134 | curriculum.set_lesson_number(1)
135 | assert curriculum.get_lesson_number == 1
136 | curriculum.increment_lesson(10)
137 | assert curriculum.get_lesson_number == 1
138 | curriculum.increment_lesson(30)
139 | curriculum.increment_lesson(30)
140 | assert curriculum.get_lesson_number == 1
141 | assert curriculum.lesson_length == 3
142 | curriculum.increment_lesson(30)
143 | assert curriculum.get_config() == {'param1': 0.3, 'param2': 20, 'param3': 0.7}
144 | assert curriculum.get_config(0) == {"param1": 0.7, "param2": 100, "param3": 0.2}
145 | assert curriculum.lesson_length == 0
146 | assert curriculum.get_lesson_number == 2
147 |
148 |
149 | if __name__ == '__main__':
150 | pytest.main()
151 |
--------------------------------------------------------------------------------
/python/trainer_config.yaml:
--------------------------------------------------------------------------------
1 | default:
2 | trainer: ppo
3 | batch_size: 1024
4 | beta: 5.0e-3
5 | buffer_size: 10240
6 | epsilon: 0.2
7 | gamma: 0.99
8 | hidden_units: 128
9 | lambd: 0.95
10 | learning_rate: 3.0e-4
11 | max_steps: 5.0e4
12 | memory_size: 256
13 | normalize: false
14 | num_epoch: 3
15 | num_layers: 2
16 | time_horizon: 64
17 | sequence_length: 64
18 | summary_freq: 1000
19 | use_recurrent: false
20 | use_curiosity: false
21 | curiosity_strength: 0.01
22 | curiosity_enc_size: 128
23 |
24 | BananaBrain:
25 | normalize: false
26 | batch_size: 1024
27 | beta: 5.0e-3
28 | buffer_size: 10240
29 |
30 | BouncerBrain:
31 | normalize: true
32 | max_steps: 5.0e5
33 | num_layers: 2
34 | hidden_units: 64
35 |
36 | PushBlockBrain:
37 | max_steps: 5.0e4
38 | batch_size: 128
39 | buffer_size: 2048
40 | beta: 1.0e-2
41 | hidden_units: 256
42 | summary_freq: 2000
43 | time_horizon: 64
44 | num_layers: 2
45 |
46 | SmallWallBrain:
47 | max_steps: 2.0e5
48 | batch_size: 128
49 | buffer_size: 2048
50 | beta: 5.0e-3
51 | hidden_units: 256
52 | summary_freq: 2000
53 | time_horizon: 128
54 | num_layers: 2
55 | normalize: false
56 |
57 | BigWallBrain:
58 | max_steps: 2.0e5
59 | batch_size: 128
60 | buffer_size: 2048
61 | beta: 5.0e-3
62 | hidden_units: 256
63 | summary_freq: 2000
64 | time_horizon: 128
65 | num_layers: 2
66 | normalize: false
67 |
68 | StrikerBrain:
69 | max_steps: 1.0e5
70 | batch_size: 128
71 | buffer_size: 2048
72 | beta: 5.0e-3
73 | hidden_units: 256
74 | summary_freq: 2000
75 | time_horizon: 128
76 | num_layers: 2
77 | normalize: false
78 |
79 | GoalieBrain:
80 | max_steps: 1.0e5
81 | batch_size: 128
82 | buffer_size: 2048
83 | beta: 5.0e-3
84 | hidden_units: 256
85 | summary_freq: 2000
86 | time_horizon: 128
87 | num_layers: 2
88 | normalize: false
89 |
90 | PyramidBrain:
91 | use_curiosity: true
92 | summary_freq: 2000
93 | curiosity_strength: 0.01
94 | curiosity_enc_size: 256
95 | time_horizon: 128
96 | batch_size: 128
97 | buffer_size: 2048
98 | hidden_units: 512
99 | num_layers: 2
100 | beta: 1.0e-2
101 | max_steps: 2.0e5
102 | num_epoch: 3
103 |
104 | VisualPyramidBrain:
105 | use_curiosity: true
106 | time_horizon: 128
107 | batch_size: 32
108 | buffer_size: 1024
109 | hidden_units: 256
110 | num_layers: 2
111 | beta: 1.0e-2
112 | max_steps: 5.0e5
113 | num_epoch: 3
114 |
115 | Ball3DBrain:
116 | normalize: true
117 | batch_size: 64
118 | buffer_size: 12000
119 | summary_freq: 1000
120 | time_horizon: 1000
121 | lambd: 0.99
122 | gamma: 0.995
123 | beta: 0.001
124 |
125 | Ball3DHardBrain:
126 | normalize: true
127 | batch_size: 1200
128 | buffer_size: 12000
129 | summary_freq: 1000
130 | time_horizon: 1000
131 | gamma: 0.995
132 | beta: 0.001
133 |
134 | TennisBrain:
135 | normalize: true
136 |
137 | CrawlerBrain:
138 | normalize: true
139 | num_epoch: 3
140 | time_horizon: 1000
141 | batch_size: 2024
142 | buffer_size: 20240
143 | gamma: 0.995
144 | max_steps: 1e6
145 | summary_freq: 3000
146 | num_layers: 3
147 | hidden_units: 512
148 |
149 | WalkerBrain:
150 | normalize: true
151 | num_epoch: 3
152 | time_horizon: 1000
153 | batch_size: 2048
154 | buffer_size: 20480
155 | gamma: 0.995
156 | max_steps: 2e6
157 | summary_freq: 3000
158 | num_layers: 3
159 | hidden_units: 512
160 |
161 | ReacherBrain:
162 | normalize: true
163 | num_epoch: 3
164 | time_horizon: 1000
165 | batch_size: 2024
166 | buffer_size: 20240
167 | gamma: 0.995
168 | max_steps: 1e6
169 | summary_freq: 3000
170 |
171 | HallwayBrain:
172 | use_recurrent: true
173 | sequence_length: 64
174 | num_layers: 2
175 | hidden_units: 128
176 | memory_size: 256
177 | beta: 1.0e-2
178 | gamma: 0.99
179 | num_epoch: 3
180 | buffer_size: 1024
181 | batch_size: 128
182 | max_steps: 5.0e5
183 | summary_freq: 1000
184 | time_horizon: 64
185 |
186 | GridWorldBrain:
187 | batch_size: 32
188 | normalize: false
189 | num_layers: 1
190 | hidden_units: 256
191 | beta: 5.0e-3
192 | gamma: 0.9
193 | buffer_size: 256
194 | max_steps: 5.0e5
195 | summary_freq: 2000
196 | time_horizon: 5
197 |
198 | BasicBrain:
199 | batch_size: 32
200 | normalize: false
201 | num_layers: 1
202 | hidden_units: 20
203 | beta: 5.0e-3
204 | gamma: 0.9
205 | buffer_size: 256
206 | max_steps: 5.0e5
207 | summary_freq: 2000
208 | time_horizon: 3
209 |
210 | StudentBrain:
211 | trainer: imitation
212 | max_steps: 10000
213 | summary_freq: 1000
214 | brain_to_imitate: TeacherBrain
215 | batch_size: 16
216 | batches_per_epoch: 5
217 | num_layers: 4
218 | hidden_units: 64
219 | sequence_length: 16
220 | buffer_size: 128
221 |
222 | StudentRecurrentBrain:
223 | trainer: imitation
224 | max_steps: 10000
225 | summary_freq: 1000
226 | brain_to_imitate: TeacherBrain
227 | batch_size: 16
228 | batches_per_epoch: 5
229 | num_layers: 4
230 | hidden_units: 64
231 | use_recurrent: true
232 | sequence_length: 32
233 | buffer_size: 128
234 |
--------------------------------------------------------------------------------
/python/unityagents/__init__.py:
--------------------------------------------------------------------------------
1 | from .environment import *
2 | from .brain import *
3 | from .exception import *
4 | from .curriculum import *
5 |
--------------------------------------------------------------------------------
/python/unityagents/brain.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 |
4 | class BrainInfo:
5 | def __init__(self, visual_observation, vector_observation, text_observations, memory=None,
6 | reward=None, agents=None, local_done=None,
7 | vector_action=None, text_action=None, max_reached=None):
8 | """
9 | Describes experience at current step of all agents linked to a brain.
10 | """
11 | self.visual_observations = visual_observation
12 | self.vector_observations = vector_observation
13 | self.text_observations = text_observations
14 | self.memories = memory
15 | self.rewards = reward
16 | self.local_done = local_done
17 | self.max_reached = max_reached
18 | self.agents = agents
19 | self.previous_vector_actions = vector_action
20 | self.previous_text_actions = text_action
21 |
22 |
23 | AllBrainInfo = Dict[str, BrainInfo]
24 |
25 |
26 | class BrainParameters:
27 | def __init__(self, brain_name, brain_param):
28 | """
29 | Contains all brain-specific parameters.
30 | :param brain_name: Name of brain.
31 | :param brain_param: Dictionary of brain parameters.
32 | """
33 | self.brain_name = brain_name
34 | self.vector_observation_space_size = brain_param["vectorObservationSize"]
35 | self.num_stacked_vector_observations = brain_param["numStackedVectorObservations"]
36 | self.number_visual_observations = len(brain_param["cameraResolutions"])
37 | self.camera_resolutions = brain_param["cameraResolutions"]
38 | self.vector_action_space_size = brain_param["vectorActionSize"]
39 | self.vector_action_descriptions = brain_param["vectorActionDescriptions"]
40 | self.vector_action_space_type = ["discrete", "continuous"][brain_param["vectorActionSpaceType"]]
41 | self.vector_observation_space_type = ["discrete", "continuous"][brain_param["vectorObservationSpaceType"]]
42 |
43 | def __str__(self):
44 | return '''Unity brain name: {0}
45 | Number of Visual Observations (per agent): {1}
46 | Vector Observation space type: {2}
47 | Vector Observation space size (per agent): {3}
48 | Number of stacked Vector Observation: {4}
49 | Vector Action space type: {5}
50 | Vector Action space size (per agent): {6}
51 | Vector Action descriptions: {7}'''.format(self.brain_name,
52 | str(self.number_visual_observations),
53 | self.vector_observation_space_type,
54 | str(self.vector_observation_space_size),
55 | str(self.num_stacked_vector_observations),
56 | self.vector_action_space_type,
57 | str(self.vector_action_space_size),
58 | ', '.join(self.vector_action_descriptions))
59 |
--------------------------------------------------------------------------------
/python/unityagents/communicator.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from communicator_objects import UnityOutput, UnityInput
4 |
5 | logging.basicConfig(level=logging.INFO)
6 | logger = logging.getLogger("unityagents")
7 |
8 |
9 | class Communicator(object):
10 | def __init__(self, worker_id=0,
11 | base_port=5005):
12 | """
13 | Python side of the communication. Must be used in pair with the right Unity Communicator equivalent.
14 |
15 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
16 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
17 | """
18 |
19 | def initialize(self, inputs: UnityInput) -> UnityOutput:
20 | """
21 | Used to exchange initialization parameters between Python and the Environment
22 | :param inputs: The initialization input that will be sent to the environment.
23 | :return: UnityOutput: The initialization output sent by Unity
24 | """
25 |
26 | def exchange(self, inputs: UnityInput) -> UnityOutput:
27 | """
28 | Used to send an input and receive an output from the Environment
29 | :param inputs: The UnityInput that needs to be sent the Environment
30 | :return: The UnityOutputs generated by the Environment
31 | """
32 |
33 | def close(self):
34 | """
35 | Sends a shutdown signal to the unity environment, and closes the connection.
36 | """
37 |
38 |
--------------------------------------------------------------------------------
/python/unityagents/curriculum.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from .exception import UnityEnvironmentException
4 |
5 | import logging
6 |
7 | logger = logging.getLogger("unityagents")
8 |
9 |
10 | class Curriculum(object):
11 | def __init__(self, location, default_reset_parameters):
12 | """
13 | Initializes a Curriculum object.
14 | :param location: Path to JSON defining curriculum.
15 | :param default_reset_parameters: Set of reset parameters for environment.
16 | """
17 | self.lesson_length = 0
18 | self.max_lesson_number = 0
19 | self.measure_type = None
20 | if location is None:
21 | self.data = None
22 | else:
23 | try:
24 | with open(location) as data_file:
25 | self.data = json.load(data_file)
26 | except IOError:
27 | raise UnityEnvironmentException(
28 | "The file {0} could not be found.".format(location))
29 | except UnicodeDecodeError:
30 | raise UnityEnvironmentException("There was an error decoding {}".format(location))
31 | self.smoothing_value = 0
32 | for key in ['parameters', 'measure', 'thresholds',
33 | 'min_lesson_length', 'signal_smoothing']:
34 | if key not in self.data:
35 | raise UnityEnvironmentException("{0} does not contain a "
36 | "{1} field.".format(location, key))
37 | parameters = self.data['parameters']
38 | self.measure_type = self.data['measure']
39 | self.max_lesson_number = len(self.data['thresholds'])
40 | for key in parameters:
41 | if key not in default_reset_parameters:
42 | raise UnityEnvironmentException(
43 | "The parameter {0} in Curriculum {1} is not present in "
44 | "the Environment".format(key, location))
45 | for key in parameters:
46 | if len(parameters[key]) != self.max_lesson_number + 1:
47 | raise UnityEnvironmentException(
48 | "The parameter {0} in Curriculum {1} must have {2} values "
49 | "but {3} were found".format(key, location,
50 | self.max_lesson_number + 1, len(parameters[key])))
51 | self.set_lesson_number(0)
52 |
53 | @property
54 | def measure(self):
55 | return self.measure_type
56 |
57 | @property
58 | def get_lesson_number(self):
59 | return self.lesson_number
60 |
61 | def set_lesson_number(self, value):
62 | self.lesson_length = 0
63 | self.lesson_number = max(0, min(value, self.max_lesson_number))
64 |
65 | def increment_lesson(self, progress):
66 | """
67 | Increments the lesson number depending on the progree given.
68 | :param progress: Measure of progress (either reward or percentage steps completed).
69 | """
70 | if self.data is None or progress is None:
71 | return
72 | if self.data["signal_smoothing"]:
73 | progress = self.smoothing_value * 0.25 + 0.75 * progress
74 | self.smoothing_value = progress
75 | self.lesson_length += 1
76 | if self.lesson_number < self.max_lesson_number:
77 | if ((progress > self.data['thresholds'][self.lesson_number]) and
78 | (self.lesson_length > self.data['min_lesson_length'])):
79 | self.lesson_length = 0
80 | self.lesson_number += 1
81 | config = {}
82 | parameters = self.data["parameters"]
83 | for key in parameters:
84 | config[key] = parameters[key][self.lesson_number]
85 | logger.info("\nLesson changed. Now in Lesson {0} : \t{1}"
86 | .format(self.lesson_number,
87 | ', '.join([str(x) + ' -> ' + str(config[x]) for x in config])))
88 |
89 | def get_config(self, lesson=None):
90 | """
91 | Returns reset parameters which correspond to the lesson.
92 | :param lesson: The lesson you want to get the config of. If None, the current lesson is returned.
93 | :return: The configuration of the reset parameters.
94 | """
95 | if self.data is None:
96 | return {}
97 | if lesson is None:
98 | lesson = self.lesson_number
99 | lesson = max(0, min(lesson, self.max_lesson_number))
100 | config = {}
101 | parameters = self.data["parameters"]
102 | for key in parameters:
103 | config[key] = parameters[key][lesson]
104 | return config
105 |
--------------------------------------------------------------------------------
/python/unityagents/exception.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger("unityagents")
3 |
4 | class UnityException(Exception):
5 | """
6 | Any error related to ml-agents environment.
7 | """
8 | pass
9 |
10 | class UnityEnvironmentException(UnityException):
11 | """
12 | Related to errors starting and closing environment.
13 | """
14 | pass
15 |
16 |
17 | class UnityActionException(UnityException):
18 | """
19 | Related to errors with sending actions.
20 | """
21 | pass
22 |
23 | class UnityTimeOutException(UnityException):
24 | """
25 | Related to errors with communication timeouts.
26 | """
27 | def __init__(self, message, log_file_path = None):
28 | if log_file_path is not None:
29 | try:
30 | with open(log_file_path, "r") as f:
31 | printing = False
32 | unity_error = '\n'
33 | for l in f:
34 | l=l.strip()
35 | if (l == 'Exception') or (l=='Error'):
36 | printing = True
37 | unity_error += '----------------------\n'
38 | if (l == ''):
39 | printing = False
40 | if printing:
41 | unity_error += l + '\n'
42 | logger.info(unity_error)
43 | logger.error("An error might have occured in the environment. "
44 | "You can check the logfile for more information at {}".format(log_file_path))
45 | except:
46 | logger.error("An error might have occured in the environment. "
47 | "No unity-environment.log file could be found.")
48 | super(UnityTimeOutException, self).__init__(message)
49 |
50 |
--------------------------------------------------------------------------------
/python/unityagents/rpc_communicator.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import grpc
3 |
4 | from multiprocessing import Pipe
5 | from concurrent.futures import ThreadPoolExecutor
6 |
7 | from .communicator import Communicator
8 | from communicator_objects import UnityToExternalServicer, add_UnityToExternalServicer_to_server
9 | from communicator_objects import UnityMessage, UnityInput, UnityOutput
10 | from .exception import UnityTimeOutException
11 |
12 |
13 | logging.basicConfig(level=logging.INFO)
14 | logger = logging.getLogger("unityagents")
15 |
16 |
17 | class UnityToExternalServicerImplementation(UnityToExternalServicer):
18 | parent_conn, child_conn = Pipe()
19 |
20 | def Initialize(self, request, context):
21 | self.child_conn.send(request)
22 | return self.child_conn.recv()
23 |
24 | def Exchange(self, request, context):
25 | self.child_conn.send(request)
26 | return self.child_conn.recv()
27 |
28 |
29 | class RpcCommunicator(Communicator):
30 | def __init__(self, worker_id=0,
31 | base_port=5005):
32 | """
33 | Python side of the grpc communication. Python is the server and Unity the client
34 |
35 |
36 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
37 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
38 | """
39 | self.port = base_port + worker_id
40 | self.worker_id = worker_id
41 | self.server = None
42 | self.unity_to_external = None
43 | self.is_open = False
44 |
45 | def initialize(self, inputs: UnityInput) -> UnityOutput:
46 | try:
47 | # Establish communication grpc
48 | self.server = grpc.server(ThreadPoolExecutor(max_workers=10))
49 | self.unity_to_external = UnityToExternalServicerImplementation()
50 | add_UnityToExternalServicer_to_server(self.unity_to_external, self.server)
51 | self.server.add_insecure_port('[::]:'+str(self.port))
52 | self.server.start()
53 | except :
54 | raise UnityTimeOutException(
55 | "Couldn't start socket communication because worker number {} is still in use. "
56 | "You may need to manually close a previously opened environment "
57 | "or use a different worker number.".format(str(self.worker_id)))
58 | if not self.unity_to_external.parent_conn.poll(30):
59 | raise UnityTimeOutException(
60 | "The Unity environment took too long to respond. Make sure that :\n"
61 | "\t The environment does not need user interaction to launch\n"
62 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n"
63 | "\t The environment and the Python interface have compatible versions.")
64 | aca_param = self.unity_to_external.parent_conn.recv().unity_output
65 | self.is_open = True
66 | message = UnityMessage()
67 | message.header.status = 200
68 | message.unity_input.CopyFrom(inputs)
69 | self.unity_to_external.parent_conn.send(message)
70 | self.unity_to_external.parent_conn.recv()
71 | return aca_param
72 |
73 | def exchange(self, inputs: UnityInput) -> UnityOutput:
74 | message = UnityMessage()
75 | message.header.status = 200
76 | message.unity_input.CopyFrom(inputs)
77 | self.unity_to_external.parent_conn.send(message)
78 | output = self.unity_to_external.parent_conn.recv()
79 | if output.header.status != 200:
80 | return None
81 | return output.unity_output
82 |
83 | def close(self):
84 | """
85 | Sends a shutdown signal to the unity environment, and closes the grpc connection.
86 | """
87 | if self.is_open:
88 | message_input = UnityMessage()
89 | message_input.header.status = 400
90 | self.unity_to_external.parent_conn.send(message_input)
91 | self.unity_to_external.parent_conn.close()
92 | self.server.stop(False)
93 | self.is_open = False
94 |
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/python/unityagents/socket_communicator.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import socket
3 | import struct
4 |
5 | from .communicator import Communicator
6 | from communicator_objects import UnityMessage, UnityOutput, UnityInput
7 | from .exception import UnityTimeOutException
8 |
9 |
10 | logging.basicConfig(level=logging.INFO)
11 | logger = logging.getLogger("unityagents")
12 |
13 |
14 | class SocketCommunicator(Communicator):
15 | def __init__(self, worker_id=0,
16 | base_port=5005):
17 | """
18 | Python side of the socket communication
19 |
20 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
21 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
22 | """
23 |
24 | self.port = base_port + worker_id
25 | self._buffer_size = 12000
26 | self.worker_id = worker_id
27 | self._socket = None
28 | self._conn = None
29 |
30 | def initialize(self, inputs: UnityInput) -> UnityOutput:
31 | try:
32 | # Establish communication socket
33 | self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
34 | self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
35 | self._socket.bind(("localhost", self.port))
36 | except:
37 | raise UnityTimeOutException("Couldn't start socket communication because worker number {} is still in use. "
38 | "You may need to manually close a previously opened environment "
39 | "or use a different worker number.".format(str(self.worker_id)))
40 | try:
41 | self._socket.settimeout(30)
42 | self._socket.listen(1)
43 | self._conn, _ = self._socket.accept()
44 | self._conn.settimeout(30)
45 | except :
46 | raise UnityTimeOutException(
47 | "The Unity environment took too long to respond. Make sure that :\n"
48 | "\t The environment does not need user interaction to launch\n"
49 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n"
50 | "\t The environment and the Python interface have compatible versions.")
51 | message = UnityMessage()
52 | message.header.status = 200
53 | message.unity_input.CopyFrom(inputs)
54 | self._communicator_send(message.SerializeToString())
55 | initialization_output = UnityMessage()
56 | initialization_output.ParseFromString(self._communicator_receive())
57 | return initialization_output.unity_output
58 |
59 | def _communicator_receive(self):
60 | try:
61 | s = self._conn.recv(self._buffer_size)
62 | message_length = struct.unpack("I", bytearray(s[:4]))[0]
63 | s = s[4:]
64 | while len(s) != message_length:
65 | s += self._conn.recv(self._buffer_size)
66 | except socket.timeout as e:
67 | raise UnityTimeOutException("The environment took too long to respond.")
68 | return s
69 |
70 | def _communicator_send(self, message):
71 | self._conn.send(struct.pack("I", len(message)) + message)
72 |
73 | def exchange(self, inputs: UnityInput) -> UnityOutput:
74 | message = UnityMessage()
75 | message.header.status = 200
76 | message.unity_input.CopyFrom(inputs)
77 | self._communicator_send(message.SerializeToString())
78 | outputs = UnityMessage()
79 | outputs.ParseFromString(self._communicator_receive())
80 | if outputs.header.status != 200:
81 | return None
82 | return outputs.unity_output
83 |
84 | def close(self):
85 | """
86 | Sends a shutdown signal to the unity environment, and closes the socket connection.
87 | """
88 | if self._socket is not None and self._conn is not None:
89 | message_input = UnityMessage()
90 | message_input.header.status = 400
91 | self._communicator_send(message_input.SerializeToString())
92 | if self._socket is not None:
93 | self._socket.close()
94 | self._socket = None
95 | if self._socket is not None:
96 | self._conn.close()
97 | self._conn = None
98 |
99 |
--------------------------------------------------------------------------------
/python/unitytrainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .buffer import *
2 | from .models import *
3 | from .trainer_controller import *
4 | from .bc.models import *
5 | from .bc.trainer import *
6 | from .ppo.models import *
7 | from .ppo.trainer import *
8 |
--------------------------------------------------------------------------------
/python/unitytrainers/bc/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import *
2 | from .trainer import *
3 |
--------------------------------------------------------------------------------
/python/unitytrainers/bc/models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib.layers as c_layers
3 | from unitytrainers.models import LearningModel
4 |
5 |
6 | class BehavioralCloningModel(LearningModel):
7 | def __init__(self, brain, h_size=128, lr=1e-4, n_layers=2, m_size=128,
8 | normalize=False, use_recurrent=False):
9 | LearningModel.__init__(self, m_size, normalize, use_recurrent, brain)
10 |
11 | num_streams = 1
12 | hidden_streams = self.create_observation_streams(num_streams, h_size, n_layers)
13 | hidden = hidden_streams[0]
14 | self.dropout_rate = tf.placeholder(dtype=tf.float32, shape=[], name="dropout_rate")
15 | hidden_reg = tf.layers.dropout(hidden, self.dropout_rate)
16 | if self.use_recurrent:
17 | tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32)
18 | self.memory_in = tf.placeholder(shape=[None, self.m_size], dtype=tf.float32, name='recurrent_in')
19 | hidden_reg, self.memory_out = self.create_recurrent_encoder(hidden_reg, self.memory_in,
20 | self.sequence_length)
21 | self.memory_out = tf.identity(self.memory_out, name='recurrent_out')
22 | self.policy = tf.layers.dense(hidden_reg, self.a_size, activation=None, use_bias=False, name='pre_action',
23 | kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01))
24 |
25 | if brain.vector_action_space_type == "discrete":
26 | self.action_probs = tf.nn.softmax(self.policy)
27 | self.sample_action_float = tf.multinomial(self.policy, 1)
28 | self.sample_action_float = tf.identity(self.sample_action_float, name="action")
29 | self.sample_action = tf.cast(self.sample_action_float, tf.int32)
30 | self.true_action = tf.placeholder(shape=[None], dtype=tf.int32, name="teacher_action")
31 | self.action_oh = tf.one_hot(self.true_action, self.a_size)
32 | self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh)
33 | self.action_percent = tf.reduce_mean(tf.cast(
34 | tf.equal(tf.cast(tf.argmax(self.action_probs, axis=1), tf.int32), self.sample_action), tf.float32))
35 | else:
36 | self.clipped_sample_action = tf.clip_by_value(self.policy, -1, 1)
37 | self.sample_action = tf.identity(self.clipped_sample_action, name="action")
38 | self.true_action = tf.placeholder(shape=[None, self.a_size], dtype=tf.float32, name="teacher_action")
39 | self.clipped_true_action = tf.clip_by_value(self.true_action, -1, 1)
40 | self.loss = tf.reduce_sum(tf.squared_difference(self.clipped_true_action, self.sample_action))
41 |
42 | optimizer = tf.train.AdamOptimizer(learning_rate=lr)
43 | self.update = optimizer.minimize(self.loss)
44 |
--------------------------------------------------------------------------------
/python/unitytrainers/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import *
2 | from .trainer import *
3 |
--------------------------------------------------------------------------------
/python/unitytrainers/trainer.py:
--------------------------------------------------------------------------------
1 | # # Unity ML-Agents Toolkit
2 | import logging
3 |
4 | import tensorflow as tf
5 | import numpy as np
6 |
7 | from unityagents import UnityException, AllBrainInfo
8 |
9 | logger = logging.getLogger("unityagents")
10 |
11 |
12 | class UnityTrainerException(UnityException):
13 | """
14 | Related to errors with the Trainer.
15 | """
16 | pass
17 |
18 |
19 | class Trainer(object):
20 | """This class is the abstract class for the unitytrainers"""
21 |
22 | def __init__(self, sess, env, brain_name, trainer_parameters, training):
23 | """
24 | Responsible for collecting experiences and training a neural network model.
25 | :param sess: Tensorflow session.
26 | :param env: The UnityEnvironment.
27 | :param trainer_parameters: The parameters for the trainer (dictionary).
28 | :param training: Whether the trainer is set for training.
29 | """
30 | self.brain_name = brain_name
31 | self.brain = env.brains[self.brain_name]
32 | self.trainer_parameters = trainer_parameters
33 | self.is_training = training
34 | self.sess = sess
35 | self.stats = {}
36 | self.summary_writer = None
37 |
38 | def __str__(self):
39 | return '''Empty Trainer'''
40 |
41 | @property
42 | def parameters(self):
43 | """
44 | Returns the trainer parameters of the trainer.
45 | """
46 | raise UnityTrainerException("The parameters property was not implemented.")
47 |
48 | @property
49 | def graph_scope(self):
50 | """
51 | Returns the graph scope of the trainer.
52 | """
53 | raise UnityTrainerException("The graph_scope property was not implemented.")
54 |
55 | @property
56 | def get_max_steps(self):
57 | """
58 | Returns the maximum number of steps. Is used to know when the trainer should be stopped.
59 | :return: The maximum number of steps of the trainer
60 | """
61 | raise UnityTrainerException("The get_max_steps property was not implemented.")
62 |
63 | @property
64 | def get_step(self):
65 | """
66 | Returns the number of steps the trainer has performed
67 | :return: the step count of the trainer
68 | """
69 | raise UnityTrainerException("The get_step property was not implemented.")
70 |
71 | @property
72 | def get_last_reward(self):
73 | """
74 | Returns the last reward the trainer has had
75 | :return: the new last reward
76 | """
77 | raise UnityTrainerException("The get_last_reward property was not implemented.")
78 |
79 | def increment_step_and_update_last_reward(self):
80 | """
81 | Increment the step count of the trainer and updates the last reward
82 | """
83 | raise UnityTrainerException("The increment_step_and_update_last_reward method was not implemented.")
84 |
85 | def take_action(self, all_brain_info: AllBrainInfo):
86 | """
87 | Decides actions given state/observation information, and takes them in environment.
88 | :param all_brain_info: A dictionary of brain names and BrainInfo from environment.
89 | :return: a tuple containing action, memories, values and an object
90 | to be passed to add experiences
91 | """
92 | raise UnityTrainerException("The take_action method was not implemented.")
93 |
94 | def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take_action_outputs):
95 | """
96 | Adds experiences to each agent's experience history.
97 | :param curr_info: Current AllBrainInfo.
98 | :param next_info: Next AllBrainInfo.
99 | :param take_action_outputs: The outputs of the take action method.
100 | """
101 | raise UnityTrainerException("The add_experiences method was not implemented.")
102 |
103 | def process_experiences(self, current_info: AllBrainInfo, next_info: AllBrainInfo):
104 | """
105 | Checks agent histories for processing condition, and processes them as necessary.
106 | Processing involves calculating value and advantage targets for model updating step.
107 | :param current_info: Dictionary of all current-step brains and corresponding BrainInfo.
108 | :param next_info: Dictionary of all next-step brains and corresponding BrainInfo.
109 | """
110 | raise UnityTrainerException("The process_experiences method was not implemented.")
111 |
112 | def end_episode(self):
113 | """
114 | A signal that the Episode has ended. The buffer must be reset.
115 | Get only called when the academy resets.
116 | """
117 | raise UnityTrainerException("The end_episode method was not implemented.")
118 |
119 | def is_ready_update(self):
120 | """
121 | Returns whether or not the trainer has enough elements to run update model
122 | :return: A boolean corresponding to wether or not update_model() can be run
123 | """
124 | raise UnityTrainerException("The is_ready_update method was not implemented.")
125 |
126 | def update_model(self):
127 | """
128 | Uses training_buffer to update model.
129 | """
130 | raise UnityTrainerException("The update_model method was not implemented.")
131 |
132 | def write_summary(self, lesson_number):
133 | """
134 | Saves training statistics to Tensorboard.
135 | :param lesson_number: The lesson the trainer is at.
136 | """
137 | if (self.get_step % self.trainer_parameters['summary_freq'] == 0 and self.get_step != 0 and
138 | self.is_training and self.get_step <= self.get_max_steps):
139 | if len(self.stats['cumulative_reward']) > 0:
140 | mean_reward = np.mean(self.stats['cumulative_reward'])
141 | logger.info(" {}: Step: {}. Mean Reward: {:0.3f}. Std of Reward: {:0.3f}."
142 | .format(self.brain_name, self.get_step,
143 | mean_reward, np.std(self.stats['cumulative_reward'])))
144 | else:
145 | logger.info(" {}: Step: {}. No episode was completed since last summary."
146 | .format(self.brain_name, self.get_step))
147 | summary = tf.Summary()
148 | for key in self.stats:
149 | if len(self.stats[key]) > 0:
150 | stat_mean = float(np.mean(self.stats[key]))
151 | summary.value.add(tag='Info/{}'.format(key), simple_value=stat_mean)
152 | self.stats[key] = []
153 | summary.value.add(tag='Info/Lesson', simple_value=lesson_number)
154 | self.summary_writer.add_summary(summary, self.get_step)
155 | self.summary_writer.flush()
156 |
157 | def write_tensorboard_text(self, key, input_dict):
158 | """
159 | Saves text to Tensorboard.
160 | Note: Only works on tensorflow r1.2 or above.
161 | :param key: The name of the text.
162 | :param input_dict: A dictionary that will be displayed in a table on Tensorboard.
163 | """
164 | try:
165 | s_op = tf.summary.text(key, tf.convert_to_tensor(([[str(x), str(input_dict[x])] for x in input_dict])))
166 | s = self.sess.run(s_op)
167 | self.summary_writer.add_summary(s, self.get_step)
168 | except:
169 | logger.info("Cannot write text summary for Tensorboard. Tensorflow version must be r1.2 or above.")
170 | pass
171 |
--------------------------------------------------------------------------------
/reinforce/README.md:
--------------------------------------------------------------------------------
1 | [//]: # (Image References)
2 |
3 | [image1]: https://user-images.githubusercontent.com/10624937/42135683-dde5c6f0-7d13-11e8-90b1-8770df3e40cf.gif "Trained Agent"
4 |
5 | # REINFORCE
6 |
7 | ### Instructions
8 |
9 | Open `REINFORCE.ipynb` to see an implementation of REINFORCE (also known as Monte Carlo Policy Gradients) with OpenAI Gym's Cartpole environment.
10 |
11 | Try to change the parameters in the notebook, to see if you can get the agent to train faster!
12 |
13 | ### Results
14 |
15 | ![Trained Agent][image1]
16 |
--------------------------------------------------------------------------------
/temporal-difference/README.md:
--------------------------------------------------------------------------------
1 | # Temporal-Difference Methods
2 |
3 | ### Instructions
4 |
5 | Follow the instructions in `Temporal_Difference.ipynb` to write your own implementations of many temporal-difference methods! The corresponding solutions can be found in `Temporal_Difference_Solution.ipynb`.
6 |
--------------------------------------------------------------------------------
/temporal-difference/check_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from IPython.display import Markdown, display
3 | import numpy as np
4 |
5 | def printmd(string):
6 | display(Markdown(string))
7 |
8 | V_opt = np.zeros((4,12))
9 | V_opt[0:13][0] = -np.arange(3, 15)[::-1]
10 | V_opt[0:13][1] = -np.arange(3, 15)[::-1] + 1
11 | V_opt[0:13][2] = -np.arange(3, 15)[::-1] + 2
12 | V_opt[3][0] = -13
13 |
14 | pol_opt = np.hstack((np.ones(11), 2, 0))
15 |
16 | V_true = np.zeros((4,12))
17 | for i in range(3):
18 | V_true[0:13][i] = -np.arange(3, 15)[::-1] - i
19 | V_true[1][11] = -2
20 | V_true[2][11] = -1
21 | V_true[3][0] = -17
22 |
23 | def get_long_path(V):
24 | return np.array(np.hstack((V[0:13][0], V[1][0], V[1][11], V[2][0], V[2][11], V[3][0], V[3][11])))
25 |
26 | def get_optimal_path(policy):
27 | return np.array(np.hstack((policy[2][:], policy[3][0])))
28 |
29 | class Tests(unittest.TestCase):
30 |
31 | def td_prediction_check(self, V):
32 | to_check = get_long_path(V)
33 | soln = get_long_path(V_true)
34 | np.testing.assert_array_almost_equal(soln, to_check)
35 |
36 | def td_control_check(self, policy):
37 | to_check = get_optimal_path(policy)
38 | np.testing.assert_equal(pol_opt, to_check)
39 |
40 | check = Tests()
41 |
42 | def run_check(check_name, func):
43 | try:
44 | getattr(check, check_name)(func)
45 | except check.failureException as e:
46 | printmd('**PLEASE TRY AGAIN**')
47 | return
48 | printmd('**PASSED**')
--------------------------------------------------------------------------------
/temporal-difference/plot_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 | sns.set_style("white")
5 |
6 | def plot_values(V):
7 | # reshape the state-value function
8 | V = np.reshape(V, (4,12))
9 | # plot the state-value function
10 | fig = plt.figure(figsize=(15,5))
11 | ax = fig.add_subplot(111)
12 | im = ax.imshow(V, cmap='cool')
13 | for (j,i),label in np.ndenumerate(V):
14 | ax.text(i, j, np.round(label,3), ha='center', va='center', fontsize=14)
15 | plt.tick_params(bottom='off', left='off', labelbottom='off', labelleft='off')
16 | plt.title('State-Value Function')
17 | plt.show()
--------------------------------------------------------------------------------
/tile-coding/README.md:
--------------------------------------------------------------------------------
1 | # Tile Coding
2 |
3 | ### Instructions
4 |
5 | Follow the instructions in `Tile_Coding.ipynb` to learn how to discretize continuous state spaces, to use tabular solution methods to solve complex tasks. The corresponding solutions can be found in `Tile_Coding_Solution.ipynb`.
6 |
--------------------------------------------------------------------------------