├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── onpolicy.iml
└── vcs.xml
├── LICENSE
├── README.md
├── environment.yaml
├── onpolicy
├── __init__.py
├── algorithms
│ ├── __init__.py
│ ├── r_mappo
│ │ ├── __init__.py
│ │ ├── algorithm
│ │ │ ├── rMAPPOPolicy.py
│ │ │ └── r_actor_critic.py
│ │ └── r_mappo.py
│ └── utils
│ │ ├── act.py
│ │ ├── attention.py
│ │ ├── cnn.py
│ │ ├── distributions.py
│ │ ├── invariant.py
│ │ ├── mix.py
│ │ ├── mlp.py
│ │ ├── rnn.py
│ │ ├── util.py
│ │ └── vit.py
├── config.py
├── docs
│ ├── Makefile
│ ├── make.bat
│ └── source
│ │ ├── _templates
│ │ ├── module.rst_t
│ │ ├── package.rst_t
│ │ └── toc.rst_t
│ │ ├── conf.py
│ │ ├── full.gif
│ │ ├── index.rst
│ │ ├── quickstart.rst
│ │ └── setup.rst
├── envs
│ ├── __init__.py
│ ├── env_wrappers.py
│ └── gridworld
│ │ ├── .travis.yml
│ │ ├── GridWorld_Env.py
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── benchmark.py
│ │ ├── figures
│ │ ├── BlockedUnlockPickup.png
│ │ ├── DistShift1.png
│ │ ├── DistShift2.png
│ │ ├── KeyCorridorS3R1.png
│ │ ├── KeyCorridorS3R2.png
│ │ ├── KeyCorridorS3R3.png
│ │ ├── KeyCorridorS4R3.png
│ │ ├── KeyCorridorS5R3.png
│ │ ├── KeyCorridorS6R3.png
│ │ ├── LavaCrossingS11N5.png
│ │ ├── LavaCrossingS9N1.png
│ │ ├── LavaCrossingS9N2.png
│ │ ├── LavaCrossingS9N3.png
│ │ ├── LavaGapS6.png
│ │ ├── ObstructedMaze-1Dl.png
│ │ ├── ObstructedMaze-1Dlh.png
│ │ ├── ObstructedMaze-1Dlhb.png
│ │ ├── ObstructedMaze-1Q.png
│ │ ├── ObstructedMaze-2Dl.png
│ │ ├── ObstructedMaze-2Dlh.png
│ │ ├── ObstructedMaze-2Dlhb.png
│ │ ├── ObstructedMaze-2Q.png
│ │ ├── ObstructedMaze-4Q.png
│ │ ├── SimpleCrossingS11N5.png
│ │ ├── SimpleCrossingS9N1.png
│ │ ├── SimpleCrossingS9N2.png
│ │ ├── SimpleCrossingS9N3.png
│ │ ├── Unlock.png
│ │ ├── UnlockPickup.png
│ │ ├── door-key-curriculum.gif
│ │ ├── door-key-env.png
│ │ ├── dynamic_obstacles.gif
│ │ ├── empty-env.png
│ │ ├── fetch-env.png
│ │ ├── four-rooms-env.png
│ │ ├── gotodoor-6x6.mp4
│ │ ├── gotodoor-6x6.png
│ │ └── multi-room.gif
│ │ ├── frontier
│ │ ├── apf.py
│ │ ├── nearest.py
│ │ ├── rrt.py
│ │ ├── utility.py
│ │ ├── utils.py
│ │ └── voronoi.py
│ │ ├── gym_minigrid
│ │ ├── __init__.py
│ │ ├── envs
│ │ │ ├── __init__.py
│ │ │ ├── blockedunlockpickup.py
│ │ │ ├── crossing.py
│ │ │ ├── distshift.py
│ │ │ ├── doorkey.py
│ │ │ ├── dynamicobstacles.py
│ │ │ ├── empty.py
│ │ │ ├── fetch.py
│ │ │ ├── fourrooms.py
│ │ │ ├── gotodoor.py
│ │ │ ├── gotoobject.py
│ │ │ ├── human.py
│ │ │ ├── irregular_room.py
│ │ │ ├── keycorridor.py
│ │ │ ├── lavagap.py
│ │ │ ├── lockedroom.py
│ │ │ ├── memory.py
│ │ │ ├── multiexploration.py
│ │ │ ├── multiroom.py
│ │ │ ├── obstructedmaze.py
│ │ │ ├── playground_v0.py
│ │ │ ├── putnear.py
│ │ │ ├── redbluedoors.py
│ │ │ ├── unlock.py
│ │ │ └── unlockpickup.py
│ │ ├── minigrid.py
│ │ ├── register.py
│ │ ├── rendering.py
│ │ ├── roomgrid.py
│ │ ├── window.py
│ │ └── wrappers.py
│ │ ├── manual_control.py
│ │ ├── run_tests.py
│ │ └── setup.py
├── runner
│ └── shared
│ │ ├── base_runner.py
│ │ └── gridworld_runner.py
├── scripts
│ ├── __init__.py
│ ├── render
│ │ ├── __init__.py
│ │ ├── render_gridworld.py
│ │ └── render_gridworld_ft.py
│ ├── render_gridworld.sh
│ ├── render_gridworld_ft.sh
│ ├── train
│ │ ├── __init__.py
│ │ └── train_gridworld.py
│ └── train_gridworld.sh
└── utils
│ ├── RRT
│ ├── __init__.py
│ ├── rrt.py
│ ├── rrt_with_pathsmoothing.py
│ ├── rrt_with_sobol_sampler.py
│ └── sobol
│ │ ├── __init__.py
│ │ └── sobol.py
│ ├── __init__.py
│ ├── multi_discrete.py
│ ├── shared_buffer.py
│ ├── util.py
│ └── valuenorm.py
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | results
3 | *.pyc
4 | /gifs/
5 | build
6 | win_rate
7 | *.so
8 | *.csv
9 | RODE
10 | *.egg-info
11 | wandb
12 | .vscode
13 | api
14 | logs
15 | .nfs*
16 | *.so
17 | *.out
18 | png
19 | *.log
20 |
21 | docker
22 | dataset
23 | data
24 | pretrained_models
25 |
26 | onpolicy/scripts/gifs/
27 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Default ignored files
3 | /workspace.xml
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/onpolicy.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Yang Xinyi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Asynchronous Multi-Agent Reinforcement Learning for Efficient Real-Time Multi-Robot Cooperative Exploration
2 |
3 | This is a PyTorch implementation of the paper: [Asynchronous Multi-Agent Reinforcement Learning for Efficient Real-Time Multi-Robot Cooperative Exploration](https://arxiv.org/abs/2301.03398)
4 |
5 | Project Website: https://sites.google.com/view/ace-aamas
6 |
7 | ## Training
8 |
9 | You could start training with by running `sh train_gridworld.sh` in directory [onpolicy/scripts](onpolicy/scripts).
10 |
11 | ## Evaluation
12 |
13 | Similar to training, you could run `sh render_gridworld.sh` in directory [onpolicy/scripts](onpolicy/scripts) to start evaluation. Remember to set up your path to the cooresponding model, correct hyperparameters and related evaluation parameters.
14 |
15 | We also provide our implementations of planning-based baselines. You could run `sh render_gridworld_ft.sh` to evaluate the planning-based methods. Note that `algorithm_name` determines the method to make global planning. It can be set to one of `mappo`, `ft_rrt`, `ft_apf`, `ft_nearest` and `ft_utility`.
16 |
17 | You could also visualize the result and generate gifs by adding `--use_render` and `--save_gifs` to the scripts.
18 |
19 | ## Citation
20 | If you find this repository useful, please cite our [paper](https://arxiv.org/abs/2301.03398):
21 | ```
22 | @misc{yu2023asynchronous,
23 | title={Asynchronous Multi-Agent Reinforcement Learning for Efficient Real-Time Multi-Robot Cooperative Exploration},
24 | author={Chao Yu and Xinyi Yang and Jiaxuan Gao and Jiayu Chen and Yunfei Li and Jijia Liu and Yunfei Xiang and Ruixin Huang and Huazhong Yang and Yi Wu and Yu Wang},
25 | year={2023},
26 | eprint={2301.03398},
27 | archivePrefix={arXiv},
28 | primaryClass={cs.RO}
29 | }
30 | ```
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: marl
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _tflow_select=2.1.0=gpu
7 | - absl-py=0.9.0=py36_0
8 | - astor=0.8.0=py36_0
9 | - blas=1.0=mkl
10 | - c-ares=1.15.0=h7b6447c_1001
11 | - ca-certificates=2020.1.1=0
12 | - certifi=2020.4.5.2=py36_0
13 | - cudatoolkit=10.0.130=0
14 | - cudnn=7.6.5=cuda10.0_0
15 | - cupti=10.0.130=0
16 | - gast=0.2.2=py36_0
17 | - google-pasta=0.2.0=py_0
18 | - grpcio=1.14.1=py36h9ba97e2_0
19 | - h5py=2.10.0=py36h7918eee_0
20 | - hdf5=1.10.4=hb1b8bf9_0
21 | - intel-openmp=2020.1=217
22 | - keras-applications=1.0.8=py_0
23 | - keras-preprocessing=1.1.0=py_1
24 | - libedit=3.1=heed3624_0
25 | - libffi=3.2.1=hd88cf55_4
26 | - libgcc-ng=9.1.0=hdf63c60_0
27 | - libgfortran-ng=7.3.0=hdf63c60_0
28 | - libprotobuf=3.12.3=hd408876_0
29 | - libstdcxx-ng=9.1.0=hdf63c60_0
30 | - markdown=3.1.1=py36_0
31 | - mkl=2020.1=217
32 | - mkl-service=2.3.0=py36he904b0f_0
33 | - mkl_fft=1.1.0=py36h23d657b_0
34 | - mkl_random=1.1.1=py36h0573a6f_0
35 | - ncurses=6.0=h9df7e31_2
36 | - numpy=1.18.1=py36h4f9e942_0
37 | - numpy-base=1.18.1=py36hde5b4d6_1
38 | - openssl=1.0.2u=h7b6447c_0
39 | - opt_einsum=3.1.0=py_0
40 | - pip=20.1.1=py36_1
41 | - protobuf=3.12.3=py36he6710b0_0
42 | - python=3.6.2=hca45abc_19
43 | - readline=7.0=ha6073c6_4
44 | - scipy=1.4.1=py36h0b6359f_0
45 | - setuptools=47.3.0=py36_0
46 | - six=1.15.0=py_0
47 | - sqlite=3.23.1=he433501_0
48 | - tensorboard=2.0.0=pyhb38c66f_1
49 | - tensorflow=2.0.0=gpu_py36h6b29c10_0
50 | - tensorflow-base=2.0.0=gpu_py36h0ec5d1f_0
51 | - tensorflow-estimator=2.0.0=pyh2649769_0
52 | - tensorflow-gpu=2.0.0=h0d30ee6_0
53 | - termcolor=1.1.0=py36_1
54 | - tk=8.6.8=hbc83047_0
55 | - werkzeug=0.16.1=py_0
56 | - wheel=0.34.2=py36_0
57 | - wrapt=1.12.1=py36h7b6447c_1
58 | - xz=5.2.5=h7b6447c_0
59 | - zlib=1.2.11=h7b6447c_3
60 | - pip:
61 | - aiohttp==3.6.2
62 | - aioredis==1.3.1
63 | - astunparse==1.6.3
64 | - async-timeout==3.0.1
65 | - atari-py==0.2.6
66 | - atomicwrites==1.2.1
67 | - attrs==18.2.0
68 | - beautifulsoup4==4.9.1
69 | - blessings==1.7
70 | - cachetools==4.1.1
71 | - cffi==1.14.1
72 | - chardet==3.0.4
73 | - click==7.1.2
74 | - cloudpickle==1.3.0
75 | - colorama==0.4.3
76 | - colorful==0.5.4
77 | - configparser==5.0.1
78 | - contextvars==2.4
79 | - cycler==0.10.0
80 | - cython==0.29.21
81 | - deepdiff==4.3.2
82 | - dill==0.3.2
83 | - docker-pycreds==0.4.0
84 | - docopt==0.6.2
85 | - fasteners==0.15
86 | - filelock==3.0.12
87 | - funcsigs==1.0.2
88 | - future==0.16.0
89 | - gin==0.1.6
90 | - gin-config==0.3.0
91 | - gitdb==4.0.5
92 | - gitpython==3.1.9
93 | - glfw==1.12.0
94 | - google==3.0.0
95 | - google-api-core==1.22.1
96 | - google-auth==1.21.0
97 | - google-auth-oauthlib==0.4.1
98 | - googleapis-common-protos==1.52.0
99 | - gpustat==0.6.0
100 | - gql==0.2.0
101 | - graphql-core==1.1
102 | - gym==0.17.2
103 | - hiredis==1.1.0
104 | - idna==2.7
105 | - idna-ssl==1.1.0
106 | - imageio==2.4.1
107 | - immutables==0.14
108 | - importlib-metadata==1.7.0
109 | - joblib==0.16.0
110 | - jsonnet==0.16.0
111 | - jsonpickle==0.9.6
112 | - jsonschema==3.2.0
113 | - kiwisolver==1.0.1
114 | - lockfile==0.12.2
115 | - mappo==0.0.1
116 | - matplotlib==3.0.0
117 | - mock==2.0.0
118 | - monotonic==1.5
119 | - more-itertools==4.3.0
120 | - mpi4py==3.0.3
121 | - mpyq==0.2.5
122 | - msgpack==1.0.0
123 | - mujoco-py==2.0.2.13
124 | - mujoco-worldgen==0.0.0
125 | - multidict==4.7.6
126 | - munch==2.3.2
127 | - nvidia-ml-py3==7.352.0
128 | - oauthlib==3.1.0
129 | - opencensus==0.7.10
130 | - opencensus-context==0.1.1
131 | - opencv-python==4.2.0.34
132 | - ordered-set==4.0.2
133 | - packaging==20.4
134 | - pandas==1.1.1
135 | - pathlib2==2.3.2
136 | - pathtools==0.1.2
137 | - pbr==4.3.0
138 | - pillow==5.3.0
139 | - pluggy==0.7.1
140 | - portpicker==1.2.0
141 | - probscale==0.2.3
142 | - progressbar2==3.53.1
143 | - prometheus-client==0.8.0
144 | - promise==2.3
145 | - psutil==5.7.2
146 | - py==1.6.0
147 | - py-spy==0.3.3
148 | - pyasn1==0.4.8
149 | - pyasn1-modules==0.2.8
150 | - pycparser==2.20
151 | - pygame==1.9.4
152 | - pyglet==1.5.0
153 | - pyopengl==3.1.5
154 | - pyopengl-accelerate==3.1.5
155 | - pyparsing==2.2.2
156 | - pyrsistent==0.16.0
157 | - pysc2==3.0.0
158 | - pytest==3.8.2
159 | - python-dateutil==2.7.3
160 | - python-utils==2.4.0
161 | - pytz==2020.1
162 | - pyyaml==3.13
163 | - pyzmq==19.0.2
164 | - ray==0.8.0
165 | - redis==3.4.1
166 | - requests==2.24.0
167 | - requests-oauthlib==1.3.0
168 | - rsa==4.6
169 | - s2clientprotocol==4.10.1.75800.0
170 | - s2protocol==4.11.4.78285.0
171 | - sacred==0.7.2
172 | - seaborn==0.10.1
173 | - sentry-sdk==0.18.0
174 | - shortuuid==1.0.1
175 | - sk-video==1.1.10
176 | - smmap==3.0.4
177 | - snakeviz==1.0.0
178 | - soupsieve==2.0.1
179 | - subprocess32==3.5.4
180 | - tabulate==0.8.7
181 | - tensorboard-logger==0.1.0
182 | - tensorboard-plugin-wit==1.7.0
183 | - tensorboardx==2.0
184 | - torch==1.5.1+cu101
185 | - torchvision==0.6.1+cu101
186 | - tornado==5.1.1
187 | - tqdm==4.48.2
188 | - typing-extensions==3.7.4.3
189 | - urllib3==1.23
190 | - wandb==0.10.5
191 | - watchdog==0.10.3
192 | - websocket-client==0.53.0
193 | - whichcraft==0.5.2
194 | - xmltodict==0.12.0
195 | - yarl==1.5.1
196 | - zipp==3.1.0
197 | - zmq==0.0.0
198 | prefix: /home/yuchao/anaconda3/envs/marl
199 |
--------------------------------------------------------------------------------
/onpolicy/__init__.py:
--------------------------------------------------------------------------------
1 | from onpolicy import algorithms, envs, runner, scripts, utils, config
2 |
3 |
4 | __version__ = "0.1.0"
5 |
6 | __all__ = [
7 | "algorithms",
8 | "envs",
9 | "runner",
10 | "scripts",
11 | "utils",
12 | "config",
13 | ]
--------------------------------------------------------------------------------
/onpolicy/algorithms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/algorithms/__init__.py
--------------------------------------------------------------------------------
/onpolicy/algorithms/r_mappo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/algorithms/r_mappo/__init__.py
--------------------------------------------------------------------------------
/onpolicy/algorithms/r_mappo/algorithm/rMAPPOPolicy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from onpolicy.algorithms.r_mappo.algorithm.r_actor_critic import R_Actor, R_Critic
4 | from onpolicy.utils.util import update_linear_schedule
5 |
6 |
7 | class R_MAPPOPolicy:
8 | def __init__(self, args, obs_space, share_obs_space, act_space, device=torch.device("cpu")):
9 |
10 | self.device = device
11 | self.lr = args.lr
12 | self.critic_lr = args.critic_lr
13 | self.opti_eps = args.opti_eps
14 | self.weight_decay = args.weight_decay
15 |
16 | self.obs_space = obs_space
17 | self.share_obs_space = share_obs_space
18 | self.act_space = act_space
19 |
20 | self.actor = R_Actor(args, self.obs_space, self.act_space, self.device)
21 | self.critic = R_Critic(args, self.share_obs_space, self.device)
22 |
23 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.lr, eps=self.opti_eps, weight_decay=self.weight_decay)
24 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr, eps=self.opti_eps, weight_decay=self.weight_decay)
25 |
26 | def lr_decay(self, episode, episodes):
27 | update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr)
28 | update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr)
29 |
30 | def get_actions(self, share_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None, deterministic=False):
31 | actions, action_log_probs, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic)
32 | values, rnn_states_critic = self.critic(share_obs, rnn_states_critic, masks)
33 | return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic
34 |
35 | def get_values(self, share_obs, rnn_states_critic, masks):
36 | values, _ = self.critic(share_obs, rnn_states_critic, masks)
37 | return values
38 |
39 | def evaluate_actions(self, share_obs, obs, rnn_states_actor, rnn_states_critic, action, masks, available_actions=None, active_masks=None):
40 | action_log_probs, dist_entropy, policy_values = self.actor.evaluate_actions(obs, rnn_states_actor, action, masks, available_actions, active_masks)
41 | values, _ = self.critic(share_obs, rnn_states_critic, masks)
42 | return values, action_log_probs, dist_entropy, policy_values
43 |
44 | def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False):
45 | actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic)
46 | return actions, rnn_states_actor
47 |
--------------------------------------------------------------------------------
/onpolicy/algorithms/r_mappo/algorithm/r_actor_critic.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from onpolicy.algorithms.utils.util import init, check
9 | from onpolicy.algorithms.utils.cnn import CNNBase
10 | from onpolicy.algorithms.utils.mlp import MLPBase, MLPLayer
11 | from onpolicy.algorithms.utils.mix import MIXBase
12 | from onpolicy.algorithms.utils.rnn import RNNLayer
13 | from onpolicy.algorithms.utils.act import ACTLayer
14 | from onpolicy.utils.util import get_shape_from_obs_space
15 |
16 | class R_Actor(nn.Module):
17 | def __init__(self, args, obs_space, action_space, device=torch.device("cpu")):
18 | super(R_Actor, self).__init__()
19 | self.hidden_size = args.hidden_size
20 |
21 | self._gain = args.gain
22 | self._use_orthogonal = args.use_orthogonal
23 | self._activation_id = args.activation_id
24 | self._use_policy_active_masks = args.use_policy_active_masks
25 | self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
26 | self._use_recurrent_policy = args.use_recurrent_policy
27 | self._use_policy_vhead = args.use_policy_vhead
28 | self._recurrent_N = args.recurrent_N
29 | self._grid_goal = args.grid_goal
30 | self.tpdv = dict(dtype=torch.float32, device=device)
31 |
32 | obs_shape = get_shape_from_obs_space(obs_space)
33 |
34 | if 'Dict' in obs_shape.__class__.__name__:
35 | self._mixed_obs = True
36 | self.base = MIXBase(args, obs_shape, cnn_layers_params=args.cnn_layers_params)
37 | else:
38 | self._mixed_obs = False
39 | self.base = CNNBase(args, obs_shape) if len(obs_shape)==3 else MLPBase(args, obs_shape, use_attn_internal=args.use_attn_internal, use_cat_self=True)
40 |
41 | input_size = self.base.output_size
42 |
43 | if self._use_naive_recurrent_policy or self._use_recurrent_policy:
44 | self.rnn = RNNLayer(input_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)
45 | input_size = self.hidden_size
46 |
47 | self.act = ACTLayer(action_space, input_size, self._use_orthogonal, self._gain, args=args)
48 |
49 | if self._use_policy_vhead:
50 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal]
51 | def init_(m):
52 | return init(m, init_method, lambda x: nn.init.constant_(x, 0))
53 | self.v_out = init_(nn.Linear(input_size, 1))
54 |
55 | self.to(device)
56 |
57 | def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False):
58 | if self._mixed_obs:
59 | for key in obs.keys():
60 | obs[key] = check(obs[key]).to(**self.tpdv)
61 | else:
62 | obs = check(obs).to(**self.tpdv)
63 | rnn_states = check(rnn_states).to(**self.tpdv)
64 | masks = check(masks).to(**self.tpdv)
65 |
66 | if available_actions is not None:
67 | available_actions = check(available_actions).to(**self.tpdv)
68 |
69 | actor_features = self.base(obs)
70 |
71 | if self._use_naive_recurrent_policy or self._use_recurrent_policy:
72 | if self._grid_goal:
73 | actor_features[0], rnn_states = self.rnn(actor_features[0], rnn_states, masks)
74 | else:
75 | actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)
76 |
77 | actions, action_log_probs = self.act(actor_features, available_actions, deterministic)
78 |
79 | return actions, action_log_probs, rnn_states
80 |
81 | def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None):
82 | if self._mixed_obs:
83 | for key in obs.keys():
84 | obs[key] = check(obs[key]).to(**self.tpdv)
85 | else:
86 | obs = check(obs).to(**self.tpdv)
87 |
88 | rnn_states = check(rnn_states).to(**self.tpdv)
89 | action = check(action).to(**self.tpdv)
90 | masks = check(masks).to(**self.tpdv)
91 |
92 | if available_actions is not None:
93 | available_actions = check(available_actions).to(**self.tpdv)
94 |
95 | if active_masks is not None:
96 | active_masks = check(active_masks).to(**self.tpdv)
97 |
98 | actor_features = self.base(obs)
99 |
100 | if self._use_naive_recurrent_policy or self._use_recurrent_policy:
101 | if self._grid_goal:
102 | actor_features[0], rnn_states = self.rnn(actor_features[0], rnn_states, masks)
103 | else:
104 | actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)
105 |
106 | action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features, action, available_actions, active_masks = active_masks if self._use_policy_active_masks else None)
107 |
108 | values = self.v_out(actor_features) if self._use_policy_vhead else None
109 |
110 | return action_log_probs, dist_entropy, values
111 |
112 | def get_policy_values(self, obs, rnn_states, masks):
113 | if self._mixed_obs:
114 | for key in obs.keys():
115 | obs[key] = check(obs[key]).to(**self.tpdv)
116 | else:
117 | obs = check(obs).to(**self.tpdv)
118 | rnn_states = check(rnn_states).to(**self.tpdv)
119 | masks = check(masks).to(**self.tpdv)
120 |
121 | actor_features = self.base(obs)
122 |
123 | if self._use_naive_recurrent_policy or self._use_recurrent_policy:
124 | actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)
125 |
126 | values = self.v_out(actor_features)
127 |
128 | return values
129 |
130 | class R_Critic(nn.Module):
131 | def __init__(self, args, share_obs_space, device=torch.device("cpu")):
132 | super(R_Critic, self).__init__()
133 | self.hidden_size = args.hidden_size
134 | self._use_orthogonal = args.use_orthogonal
135 | self._activation_id = args.activation_id
136 | self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
137 | self._use_recurrent_policy = args.use_recurrent_policy
138 | self._recurrent_N = args.recurrent_N
139 | self._grid_goal = args.grid_goal
140 | self.tpdv = dict(dtype=torch.float32, device=device)
141 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal]
142 |
143 | share_obs_shape = get_shape_from_obs_space(share_obs_space)
144 |
145 | if 'Dict' in share_obs_shape.__class__.__name__:
146 | self._mixed_obs = True
147 | self.base = MIXBase(args, share_obs_shape, cnn_layers_params=args.cnn_layers_params)
148 | else:
149 | self._mixed_obs = False
150 | self.base = CNNBase(args, share_obs_shape) if len(share_obs_shape)==3 else MLPBase(args, share_obs_shape, use_attn_internal=True, use_cat_self=args.use_cat_self)
151 |
152 | input_size = self.base.output_size
153 |
154 | if self._use_naive_recurrent_policy or self._use_recurrent_policy:
155 | self.rnn = RNNLayer(input_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)
156 | input_size = self.hidden_size
157 |
158 | def init_(m):
159 | return init(m, init_method, lambda x: nn.init.constant_(x, 0))
160 |
161 | self.v_out = init_(nn.Linear(input_size, 1))
162 |
163 | self.to(device)
164 |
165 | def forward(self, share_obs, rnn_states, masks):
166 | if self._mixed_obs:
167 | for key in share_obs.keys():
168 | share_obs[key] = check(share_obs[key]).to(**self.tpdv)
169 | else:
170 | share_obs = check(share_obs).to(**self.tpdv)
171 | rnn_states = check(rnn_states).to(**self.tpdv)
172 | masks = check(masks).to(**self.tpdv)
173 |
174 | critic_features = self.base(share_obs)
175 |
176 | if self._grid_goal:
177 | critic_features = critic_features[0]
178 |
179 | if self._use_naive_recurrent_policy or self._use_recurrent_policy:
180 | critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks)
181 |
182 | values = self.v_out(critic_features)
183 |
184 | return values, rnn_states
185 |
--------------------------------------------------------------------------------
/onpolicy/algorithms/utils/cnn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .util import init
8 |
9 | class Flatten(nn.Module):
10 | def forward(self, x):
11 | return x.view(x.size(0), -1)
12 |
13 | class CNNLayer(nn.Module):
14 | def __init__(self, obs_shape, hidden_size, use_orthogonal, activation_id, kernel_size=3, stride=1):
15 | super(CNNLayer, self).__init__()
16 |
17 | active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id]
18 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
19 | gain = nn.init.calculate_gain(['tanh', 'relu', 'leaky_relu', 'leaky_relu'][activation_id])
20 |
21 | def init_(m):
22 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)
23 |
24 | input_channel = obs_shape[0]
25 | input_width = obs_shape[1]
26 | input_height = obs_shape[2]
27 |
28 | self.cnn = nn.Sequential(
29 | init_(nn.Conv2d(in_channels=input_channel, out_channels=hidden_size//2, kernel_size=kernel_size, stride=stride)), active_func,
30 | Flatten(),
31 | init_(nn.Linear(hidden_size//2 * (input_width-kernel_size+stride) * (input_height-kernel_size+stride), hidden_size)), active_func,
32 | init_(nn.Linear(hidden_size, hidden_size)), active_func)
33 |
34 | def forward(self, x):
35 | x = x / 255.0
36 | x = self.cnn(x)
37 |
38 | return x
39 |
40 | class CNNBase(nn.Module):
41 | def __init__(self, args, obs_shape):
42 | super(CNNBase, self).__init__()
43 |
44 | self._use_orthogonal = args.use_orthogonal
45 | self._activation_id = args.activation_id
46 | self.hidden_size = args.hidden_size
47 |
48 | self.cnn = CNNLayer(obs_shape, self.hidden_size, self._use_orthogonal, self._activation_id)
49 |
50 | def forward(self, x):
51 | x = self.cnn(x)
52 | return x
53 |
54 | @property
55 | def output_size(self):
56 | return self.hidden_size
57 |
--------------------------------------------------------------------------------
/onpolicy/algorithms/utils/distributions.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .util import init
8 |
9 | """
10 | Modify standard PyTorch distributions so they are compatible with this code.
11 | """
12 |
13 | #
14 | # Standardize distribution interfaces
15 | #
16 |
17 | # Categorical
18 | class FixedCategorical(torch.distributions.Categorical):
19 | def sample(self):
20 | return super().sample().unsqueeze(-1)
21 |
22 | def log_probs(self, actions):
23 | return (
24 | super()
25 | .log_prob(actions.squeeze(-1))
26 | .view(actions.size(0), -1)
27 | .sum(-1)
28 | .unsqueeze(-1)
29 | )
30 |
31 | def mode(self):
32 | return self.probs.argmax(dim=-1, keepdim=True)
33 |
34 |
35 | # Normal
36 | class FixedNormal(torch.distributions.Normal):
37 | def log_probs(self, actions):
38 | return super().log_prob(actions).sum(-1, keepdim=True)
39 |
40 | def entrop(self):
41 | return super.entropy().sum(-1)
42 |
43 | def mode(self):
44 | return self.mean
45 |
46 |
47 | # Bernoulli
48 | class FixedBernoulli(torch.distributions.Bernoulli):
49 | def log_probs(self, actions):
50 | return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1)
51 |
52 | def entropy(self):
53 | return super().entropy().sum(-1)
54 |
55 | def mode(self):
56 | return torch.gt(self.probs, 0.5).float()
57 |
58 |
59 | class Categorical(nn.Module):
60 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
61 | super(Categorical, self).__init__()
62 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
63 | def init_(m):
64 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
65 |
66 | self.linear = init_(nn.Linear(num_inputs, num_outputs))
67 |
68 | def forward(self, x, available_actions=None, trans=True):
69 | if trans:
70 | x = self.linear(x)
71 | if available_actions is not None:
72 | x[available_actions == 0] = -1e10
73 | return FixedCategorical(logits=x)
74 |
75 |
76 | class DiagGaussian(nn.Module):
77 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
78 | super(DiagGaussian, self).__init__()
79 |
80 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
81 | def init_(m):
82 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
83 |
84 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
85 | self.logstd = AddBias(torch.zeros(num_outputs))
86 |
87 | def forward(self, x, trans=True):
88 | if trans:
89 | action_mean = self.fc_mean(x)
90 | else:
91 | action_mean = x
92 |
93 | # An ugly hack for my KFAC implementation.
94 | zeros = torch.zeros(action_mean.size())
95 | if x.is_cuda:
96 | zeros = zeros.cuda()
97 |
98 | action_logstd = self.logstd(zeros)
99 | return FixedNormal(action_mean, action_logstd.exp())
100 |
101 |
102 | class Bernoulli(nn.Module):
103 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
104 | super(Bernoulli, self).__init__()
105 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
106 | def init_(m):
107 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
108 |
109 | self.linear = init_(nn.Linear(num_inputs, num_outputs))
110 |
111 | def forward(self, x):
112 | x = self.linear(x)
113 | return FixedBernoulli(logits=x)
114 |
115 | class AddBias(nn.Module):
116 | def __init__(self, bias):
117 | super(AddBias, self).__init__()
118 | self._bias = nn.Parameter(bias.unsqueeze(1))
119 |
120 | def forward(self, x):
121 | if x.dim() == 2:
122 | bias = self._bias.t().view(1, -1)
123 | else:
124 | bias = self._bias.t().view(1, -1, 1, 1)
125 |
126 | return x + bias
127 |
--------------------------------------------------------------------------------
/onpolicy/algorithms/utils/invariant.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from onpolicy.algorithms.utils.vit import ViT, Attention, PreNorm, Transformer, CrossAttention, FeedForward
6 | from einops.layers.torch import Rearrange
7 | from einops import rearrange, repeat
8 | import random
9 |
10 | def get_position_embedding(pos, hidden_dim, device = torch.device("cpu")):
11 | scaled_time = 2 * torch.arange(hidden_dim / 2) / hidden_dim
12 | scaled_time = 10000 ** scaled_time
13 | scaled_time = pos / scaled_time
14 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=0).to(device)
15 |
16 | def get_explicit_position_embedding(n_embed_input=10, n_embed_output=8 , device = torch.device("cpu")):
17 |
18 | return nn.Embedding(n_embed_input, n_embed_output).to(device)
19 |
20 | class AlterEncoder(nn.Module):
21 | def __init__(self, num_grids, input_dim, depth = 2, hidden_dim = 128, heads = 4, dim_head = 32, mlp_dim = 128, dropout = 0.):
22 | super().__init__()
23 | self.num_grids = num_grids
24 | self.hidden_dim = hidden_dim
25 | self.depth = depth
26 | self.encode_actor = nn.Linear(input_dim, hidden_dim)
27 | self.encode_other = nn.Linear(input_dim, hidden_dim)
28 | self.last_cross_attn = nn.ModuleList([
29 | nn.LayerNorm(hidden_dim),
30 | nn.Linear(hidden_dim, 2 * heads * dim_head, bias = False),
31 | CrossAttention(hidden_dim, heads = heads, dim_head = dim_head, dropout = dropout),
32 | PreNorm(hidden_dim, FeedForward(hidden_dim, mlp_dim, dropout = dropout))
33 | ])
34 |
35 | def forward(self, data):
36 | x, others = data
37 | B = x.shape[0]
38 | # print("alter_attn", x.shape)
39 | x = self.encode_actor(x)
40 | all = [x,]
41 | for i, y in enumerate(others):
42 | y = self.encode_other(y)
43 | all.append(y)
44 | num_agents = len(all)
45 | out = torch.stack(all, dim = 1) # B x num_agents x 64 x D
46 | out = rearrange(out, "b n g d -> (b g) n d", b = B, n = num_agents, g = self.num_grids)
47 | norm, to_kv, cross_attn, ff= self.last_cross_attn
48 | out = norm(out)
49 | x = out[:, :1, :] # 64B x 1 x D
50 | others = out[:, 1:, :] # 64B x (n-1) x D
51 | if num_agents > 1:
52 | k, v = to_kv(others).chunk(2, dim=-1)
53 | out = cross_attn(x, k, v) + x # # 64B x 1 x D
54 | else:
55 | out = x
56 | out = ff(out) + out
57 | out = rearrange(out, " (b g) n d -> n b g d", b = B, n = 1, g = self.num_grids)[0]
58 | return out
59 |
--------------------------------------------------------------------------------
/onpolicy/algorithms/utils/mlp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .util import init, get_clones
8 | from .attention import Encoder
9 |
10 | class MLPLayer(nn.Module):
11 | def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, activation_id):
12 | super(MLPLayer, self).__init__()
13 | self._layer_N = layer_N
14 |
15 | active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id]
16 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
17 | gain = nn.init.calculate_gain(['tanh', 'relu', 'leaky_relu', 'leaky_relu'][activation_id])
18 |
19 | def init_(m):
20 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)
21 |
22 | self.fc1 = nn.Sequential(
23 | init_(nn.Linear(input_dim, hidden_size)), active_func, nn.LayerNorm(hidden_size))
24 | self.fc_h = nn.Sequential(init_(
25 | nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size))
26 | self.fc2 = get_clones(self.fc_h, self._layer_N)
27 |
28 | def forward(self, x):
29 | x = self.fc1(x)
30 | for i in range(self._layer_N):
31 | x = self.fc2[i](x)
32 | return x
33 |
34 | class CONVLayer(nn.Module):
35 | def __init__(self, input_dim, hidden_size, use_orthogonal, activation_id):
36 | super(CONVLayer, self).__init__()
37 |
38 | active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id]
39 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
40 | gain = nn.init.calculate_gain(['tanh', 'relu', 'leaky_relu', 'leaky_relu'][activation_id])
41 |
42 | def init_(m):
43 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)
44 |
45 | self.conv = nn.Sequential(
46 | init_(nn.Conv1d(in_channels=input_dim, out_channels=hidden_size//4, kernel_size=3, stride=2, padding=0)), active_func, #nn.BatchNorm1d(hidden_size//4),
47 | init_(nn.Conv1d(in_channels=hidden_size//4, out_channels=hidden_size//2, kernel_size=3, stride=1, padding=1)), active_func, #nn.BatchNorm1d(hidden_size//2),
48 | init_(nn.Conv1d(in_channels=hidden_size//2, out_channels=hidden_size, kernel_size=3, stride=1, padding=1)), active_func) #, nn.BatchNorm1d(hidden_size))
49 |
50 | def forward(self, x):
51 | x = self.conv(x)
52 | return x
53 |
54 |
55 | class MLPBase(nn.Module):
56 | def __init__(self, args, obs_shape, use_attn_internal=False, use_cat_self=True):
57 | super(MLPBase, self).__init__()
58 |
59 | self._use_feature_normalization = args.use_feature_normalization
60 | self._use_orthogonal = args.use_orthogonal
61 | self._activation_id = args.activation_id
62 | self._use_attn = args.use_attn
63 | self._use_attn_internal = use_attn_internal
64 | self._use_average_pool = args.use_average_pool
65 | self._use_conv1d = args.use_conv1d
66 | self._stacked_frames = args.stacked_frames
67 | self._layer_N = 0 if args.use_single_network else args.layer_N
68 | self._attn_size = args.attn_size
69 | self.hidden_size = args.hidden_size
70 |
71 | obs_dim = obs_shape[0]
72 |
73 | if self._use_feature_normalization:
74 | self.feature_norm = nn.LayerNorm(obs_dim)
75 |
76 | if self._use_attn and self._use_attn_internal:
77 |
78 | if self._use_average_pool:
79 | if use_cat_self:
80 | inputs_dim = self._attn_size + obs_shape[-1][1]
81 | else:
82 | inputs_dim = self._attn_size
83 | else:
84 | split_inputs_dim = 0
85 | split_shape = obs_shape[1:]
86 | for i in range(len(split_shape)):
87 | split_inputs_dim += split_shape[i][0]
88 | inputs_dim = split_inputs_dim * self._attn_size
89 | self.attn = Encoder(args, obs_shape, use_cat_self)
90 | self.attn_norm = nn.LayerNorm(inputs_dim)
91 | else:
92 | inputs_dim = obs_dim
93 |
94 | if self._use_conv1d:
95 | self.conv = CONVLayer(self._stacked_frames, self.hidden_size, self._use_orthogonal, self._activation_id)
96 | random_x = torch.FloatTensor(1, self._stacked_frames, inputs_dim//self._stacked_frames)
97 | random_out = self.conv(random_x)
98 | assert len(random_out.shape)==3
99 | inputs_dim = random_out.size(-1) * random_out.size(-2)
100 |
101 | self.mlp = MLPLayer(inputs_dim, self.hidden_size,
102 | self._layer_N, self._use_orthogonal, self._activation_id)
103 |
104 | def forward(self, x):
105 | if self._use_feature_normalization:
106 | x = self.feature_norm(x)
107 |
108 | if self._use_attn and self._use_attn_internal:
109 | x = self.attn(x, self_idx=-1)
110 | x = self.attn_norm(x)
111 |
112 | if self._use_conv1d:
113 | batch_size = x.size(0)
114 | x = x.view(batch_size, self._stacked_frames, -1)
115 | x = self.conv(x)
116 | x = x.view(batch_size, -1)
117 |
118 | x = self.mlp(x)
119 |
120 | return x
121 |
122 | @property
123 | def output_size(self):
124 | return self.hidden_size
--------------------------------------------------------------------------------
/onpolicy/algorithms/utils/rnn.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | class RNNLayer(nn.Module):
9 | def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal):
10 | super(RNNLayer, self).__init__()
11 | self._recurrent_N = recurrent_N
12 | self._use_orthogonal = use_orthogonal
13 |
14 | self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N)
15 | for name, param in self.rnn.named_parameters():
16 | if 'bias' in name:
17 | nn.init.constant_(param, 0)
18 | elif 'weight' in name:
19 | if self._use_orthogonal:
20 | nn.init.orthogonal_(param)
21 | else:
22 | nn.init.xavier_uniform_(param)
23 | self.norm = nn.LayerNorm(outputs_dim)
24 |
25 | def forward(self, x, hxs, masks):
26 | if x.size(0) == hxs.size(0):
27 | x, hxs = self.rnn(x.unsqueeze(0), (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous())
28 | #x= self.gru(x.unsqueeze(0))
29 | x = x.squeeze(0)
30 | hxs = hxs.transpose(0, 1)
31 | else:
32 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
33 | N = hxs.size(0)
34 | T = int(x.size(0) / N)
35 |
36 | # unflatten
37 | x = x.view(T, N, x.size(1))
38 |
39 | # Same deal with masks
40 | masks = masks.view(T, N)
41 |
42 | # Let's figure out which steps in the sequence have a zero for any agent
43 | # We will always assume t=0 has a zero in it as that makes the logic cleaner
44 | has_zeros = ((masks[1:] == 0.0)
45 | .any(dim=-1)
46 | .nonzero()
47 | .squeeze()
48 | .cpu())
49 |
50 | # +1 to correct the masks[1:]
51 | if has_zeros.dim() == 0:
52 | # Deal with scalar
53 | has_zeros = [has_zeros.item() + 1]
54 | else:
55 | has_zeros = (has_zeros + 1).numpy().tolist()
56 |
57 | # add t=0 and t=T to the list
58 | has_zeros = [0] + has_zeros + [T]
59 |
60 | hxs = hxs.transpose(0, 1)
61 |
62 | outputs = []
63 | for i in range(len(has_zeros) - 1):
64 | # We can now process steps that don't have any zeros in masks together!
65 | # This is much faster
66 | start_idx = has_zeros[i]
67 | end_idx = has_zeros[i + 1]
68 | temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous()
69 | rnn_scores, hxs = self.rnn(x[start_idx:end_idx], temp)
70 | outputs.append(rnn_scores)
71 |
72 | # assert len(outputs) == T
73 | # x is a (T, N, -1) tensor
74 | x = torch.cat(outputs, dim=0)
75 |
76 | # flatten
77 | x = x.reshape(T * N, -1)
78 | hxs = hxs.transpose(0, 1)
79 |
80 | x = self.norm(x)
81 | return x, hxs
82 |
--------------------------------------------------------------------------------
/onpolicy/algorithms/utils/util.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import copy
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | def init(module, weight_init, bias_init, gain=1):
10 | weight_init(module.weight.data, gain=gain)
11 | bias_init(module.bias.data)
12 | return module
13 |
14 | def get_clones(module, N):
15 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
16 |
17 | def check(input):
18 | output = torch.from_numpy(input) if type(input) == np.ndarray else input
19 | return output
20 |
--------------------------------------------------------------------------------
/onpolicy/algorithms/utils/vit.py:
--------------------------------------------------------------------------------
1 | from os import XATTR_SIZE_MAX
2 | import torch
3 | from torch import nn, einsum
4 |
5 | from einops import rearrange, repeat
6 | from einops.layers.torch import Rearrange
7 |
8 | # https://github.com/lucidrains/vit-pytorch
9 |
10 | # helpers
11 |
12 | def pair(t):
13 | return t if isinstance(t, tuple) else (t, t)
14 |
15 | # classes
16 |
17 | class PreNorm(nn.Module):
18 | def __init__(self, dim, fn):
19 | super().__init__()
20 | self.norm = nn.LayerNorm(dim)
21 | self.fn = fn
22 | def forward(self, x, **kwargs):
23 | return self.fn(self.norm(x), **kwargs)
24 |
25 | class FeedForward(nn.Module):
26 | def __init__(self, dim, hidden_dim, dropout = 0.):
27 | super().__init__()
28 | self.net = nn.Sequential(
29 | nn.Linear(dim, hidden_dim),
30 | nn.GELU(),
31 | nn.Dropout(dropout),
32 | nn.Linear(hidden_dim, dim),
33 | nn.Dropout(dropout)
34 | )
35 | def forward(self, x):
36 | return self.net(x)
37 |
38 | class Attention(nn.Module):
39 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
40 | super().__init__()
41 | inner_dim = dim_head * heads
42 | project_out = not (heads == 1 and dim_head == dim)
43 |
44 | self.heads = heads
45 | self.scale = dim_head ** -0.5
46 |
47 | self.attend = nn.Softmax(dim = -1)
48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
49 |
50 | self.to_out = nn.Sequential(
51 | nn.Linear(inner_dim, dim),
52 | nn.Dropout(dropout)
53 | ) if project_out else nn.Identity()
54 |
55 | def forward(self, x):
56 | b, n, _, h = *x.shape, self.heads
57 | qkv = self.to_qkv(x).chunk(3, dim = -1)
58 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
59 |
60 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
61 |
62 | attn = self.attend(dots)
63 |
64 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
65 | out = rearrange(out, 'b h n d -> b n (h d)')
66 | return self.to_out(out)
67 |
68 | class CrossAttention(nn.Module):
69 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
70 | super().__init__()
71 |
72 | inner_dim = dim_head * heads
73 | project_out = not (heads == 1 and dim_head == dim)
74 |
75 | self.heads = heads
76 | self.scale = dim_head ** -0.5
77 |
78 | self.attend = nn.Softmax(dim = -1)
79 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
80 |
81 | self.to_out = nn.Sequential(
82 | nn.Linear(inner_dim, dim),
83 | nn.Dropout(dropout)
84 | ) if project_out else nn.Identity()
85 |
86 | def forward(self, x, k, v):
87 | b, n, _, h = *x.shape, self.heads
88 | q = self.to_q(x)
89 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q, k, v])
90 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
91 | attn = self.attend(dots)
92 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
93 | out = rearrange(out, 'b h n d -> b n (h d)')
94 | return self.to_out(out)
95 |
96 | class Transformer(nn.Module):
97 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
98 | super().__init__()
99 | self.layers = nn.ModuleList([])
100 | for _ in range(depth):
101 | self.layers.append(nn.ModuleList([
102 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
103 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
104 | ]))
105 | def forward(self, x):
106 | for attn, ff in self.layers:
107 | x = attn(x) + x
108 | x = ff(x) + x
109 | return x
110 |
111 | class ViT(nn.Module):
112 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
113 | super().__init__()
114 | image_height, image_width = pair(image_size)
115 | patch_height, patch_width = pair(patch_size)
116 | self.r, self.c = image_height // patch_height, image_width // patch_width
117 |
118 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
119 |
120 | num_patches = (image_height // patch_height) * (image_width // patch_width)
121 | patch_dim = channels * patch_height * patch_width
122 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
123 |
124 | self.to_patch_embedding = nn.Sequential(
125 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
126 | nn.Linear(patch_dim, dim),
127 | )
128 |
129 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
130 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
131 | self.dropout = nn.Dropout(emb_dropout)
132 |
133 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
134 |
135 | self.pool = pool
136 | self.to_latent = nn.Identity()
137 |
138 | self.mlp_head = nn.Sequential(
139 | nn.LayerNorm(dim),
140 | nn.Linear(dim, num_classes)
141 | )
142 |
143 | self.recover = Rearrange('b (h w) c -> b c h w', h = self.r, w = self.c)
144 |
145 | def forward(self, img):
146 | x = self.to_patch_embedding(img)
147 | b, n, _ = x.shape
148 |
149 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
150 | x = torch.cat((cls_tokens, x), dim=1)
151 | x += self.pos_embedding[:, :(n + 1)]
152 | x = self.dropout(x)
153 |
154 | x = self.transformer(x)
155 |
156 | # x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
157 |
158 | # x = self.to_latent(x)
159 | x = self.recover(x[:, 1:])
160 |
161 | return x
--------------------------------------------------------------------------------
/onpolicy/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = onpolicy
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/onpolicy/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 | set SPHINXPROJ=onpolicy
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/onpolicy/docs/source/_templates/module.rst_t:
--------------------------------------------------------------------------------
1 | {%- if show_headings %}
2 | {{- [basename, "module"] | join(' ') | e | heading }}
3 |
4 | {% endif -%}
5 | .. automodule:: {{ qualname }}
6 | :members:
7 | :undoc-members:
--------------------------------------------------------------------------------
/onpolicy/docs/source/_templates/package.rst_t:
--------------------------------------------------------------------------------
1 | {%- macro automodule(modname, options) -%}
2 | .. automodule:: {{ modname }}
3 | {%- for option in options %}
4 | :{{ option }}:
5 | {%- endfor %}
6 | {%- endmacro %}
7 |
8 | {%- macro toctree(docnames) -%}
9 | .. toctree::
10 | :maxdepth: {{ maxdepth }}
11 | {% for docname in docnames %}
12 | {{ docname }}
13 | {%- endfor %}
14 | {%- endmacro %}
15 |
16 | {%- if is_namespace %}
17 | {{- [pkgname, "namespace"] | join(" ") | e | heading }}
18 | {% else %}
19 | {{- [pkgname, "package"] | join(" ") | e | heading }}
20 | {% endif %}
21 |
22 | {%- if modulefirst and not is_namespace %}
23 | {{ automodule(pkgname, automodule_options) }}
24 | {% endif %}
25 |
26 | {%- if subpackages %}
27 | {{ toctree(subpackages) }}
28 | {% endif %}
29 |
30 | {%- if submodules %}
31 | {% if separatemodules %}
32 | {{ toctree(submodules) }}
33 | {%- else %}
34 | {%- for submodule in submodules %}
35 | {% if show_headings %}
36 | {{- [submodule, "module"] | join(" ") | e | heading(2) }}
37 | {% endif %}
38 | {{ automodule(submodule, automodule_options) }}
39 | {% endfor %}
40 | {%- endif %}
41 | {% endif %}
42 |
43 | {%- if not modulefirst and not is_namespace %}
44 | Module contents
45 | ---------------
46 |
47 | {{ automodule(pkgname, automodule_options) }}
48 | {% endif %}
--------------------------------------------------------------------------------
/onpolicy/docs/source/_templates/toc.rst_t:
--------------------------------------------------------------------------------
1 | {{ header | heading }}
2 |
3 | .. toctree::
4 | :maxdepth: {{ maxdepth }}
5 | {% for docname in docnames %}
6 | {{ docname }}
7 | {%- endfor %}
--------------------------------------------------------------------------------
/onpolicy/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Configuration file for the Sphinx documentation builder.
4 | #
5 | # This file does only contain a selection of the most common options. For a
6 | # full list see the documentation:
7 | # http://www.sphinx-doc.org/en/master/config
8 |
9 | # -- Path setup --------------------------------------------------------------
10 |
11 | # If extensions (or modules to document with autodoc) are in another directory,
12 | # add these directories to sys.path here. If the directory is relative to the
13 | # documentation root, use os.path.abspath to make it absolute, like shown here.
14 | #
15 | import os
16 | import sys
17 | sys.path.insert(0, os.path.abspath("../.."))
18 |
19 | import sphinx_rtd_theme
20 | import sphinxcontrib
21 | from recommonmark.parser import CommonMarkParser
22 |
23 |
24 | # -- Project information -----------------------------------------------------
25 |
26 | project = 'onpolicy'
27 | copyright = '2020, Chao Yu'
28 | author = 'Chao Yu'
29 |
30 | # The short X.Y version
31 | version = ''
32 | # The full version, including alpha/beta/rc tags
33 | release = '0.1.0'
34 |
35 |
36 | # -- General configuration ---------------------------------------------------
37 |
38 | # If your documentation needs a minimal Sphinx version, state it here.
39 | #
40 | # needs_sphinx = '1.0'
41 |
42 | # Add any Sphinx extension module names here, as strings. They can be
43 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
44 | # ones.
45 | extensions = [
46 | "sphinx_rtd_theme", # Read The Docs theme
47 | 'sphinx.ext.autodoc', # Automatically extract docs from docstrings
48 | 'sphinx.ext.doctest',
49 | 'sphinx.ext.intersphinx',
50 | 'sphinx.ext.todo',
51 | 'sphinx.ext.coverage', # make coverage generates documentation coverage reports
52 | 'sphinx.ext.mathjax',
53 | 'sphinx.ext.ifconfig',
54 | 'sphinx.ext.viewcode', # link to sourcecode from docs
55 | 'sphinx.ext.githubpages',
56 | "sphinx.ext.napoleon", # support Numpy and Google doc style
57 | "sphinxcontrib.apidoc",
58 | ]
59 |
60 | # configuring automated generation of api documentation
61 | # See: https://github.com/sphinx-contrib/apidoc
62 |
63 | apidoc_module_dir = "../.."
64 | apidoc_excluded_paths = ["envs", "setup.py"]
65 | apidoc_module_first = True
66 | apidoc_extra_args = [
67 | "--force",
68 | "--separate",
69 | "--ext-viewcode",
70 | "--doc-project=onpolicy",
71 | "--maxdepth=2",
72 | "--templatedir=_templates/apidoc",
73 | ]
74 |
75 |
76 | # Add any paths that contain templates here, relative to this directory.
77 | templates_path = ['_templates']
78 |
79 | # The suffix(es) of source filenames.
80 | # You can specify multiple suffix as a list of string:
81 | #
82 | # source_suffix = ['.rst', '.md']
83 | source_suffix = ['.rst', '.md', '.MD']
84 |
85 |
86 | # The master toctree document.
87 | master_doc = 'index'
88 |
89 | # The language for content autogenerated by Sphinx. Refer to documentation
90 | # for a list of supported languages.
91 | #
92 | # This is also used if you do content translation via gettext catalogs.
93 | # Usually you set "language" from the command line for these cases.
94 | language = None
95 |
96 | # List of patterns, relative to source directory, that match files and
97 | # directories to ignore when looking for source files.
98 | # This pattern also affects html_static_path and html_extra_path .
99 | exclude_patterns = []
100 |
101 | # The name of the Pygments (syntax highlighting) style to use.
102 | pygments_style = 'sphinx'
103 |
104 |
105 | # -- Options for HTML output -------------------------------------------------
106 |
107 | # The theme to use for HTML and HTML Help pages. See the documentation for
108 | # a list of builtin themes.
109 | #
110 | html_theme = 'sphinx_rtd_theme'
111 |
112 | # Theme options are theme-specific and customize the look and feel of a theme
113 | # further. For a list of options available for each theme, see the
114 | # documentation.
115 | #
116 | # html_theme_options = {}
117 |
118 | # Add any paths that contain custom static files (such as style sheets) here,
119 | # relative to this directory. They are copied after the builtin static files,
120 | # so a file named "default.css" will overwrite the builtin "default.css".
121 | html_static_path = ['_static']
122 |
123 | # Custom sidebar templates, must be a dictionary that maps document names
124 | # to template names.
125 | #
126 | # The default sidebars (for documents that don't match any pattern) are
127 | # defined by theme itself. Builtin themes are using these templates by
128 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
129 | # 'searchbox.html']``.
130 | #
131 | # html_sidebars = {}
132 |
133 |
134 | # -- Options for HTMLHelp output ---------------------------------------------
135 |
136 | # Output file base name for HTML help builder.
137 | htmlhelp_basename = 'onpolicydoc'
138 |
139 |
140 | # -- Options for LaTeX output ------------------------------------------------
141 |
142 | latex_elements = {
143 | # The paper size ('letterpaper' or 'a4paper').
144 | #
145 | # 'papersize': 'letterpaper',
146 |
147 | # The font size ('10pt', '11pt' or '12pt').
148 | #
149 | # 'pointsize': '10pt',
150 |
151 | # Additional stuff for the LaTeX preamble.
152 | #
153 | # 'preamble': '',
154 |
155 | # Latex figure (float) alignment
156 | #
157 | # 'figure_align': 'htbp',
158 | }
159 |
160 | # Grouping the document tree into LaTeX files. List of tuples
161 | # (source start file, target name, title,
162 | # author, documentclass [howto, manual, or own class]).
163 | latex_documents = [
164 | (master_doc, 'onpolicy.tex', 'onpolicy Documentation',
165 | 'Chao Yu', 'manual'),
166 | ]
167 |
168 |
169 | # -- Options for manual page output ------------------------------------------
170 |
171 | # One entry per manual page. List of tuples
172 | # (source start file, name, description, authors, manual section).
173 | man_pages = [
174 | (master_doc, 'onpolicy', 'onpolicy Documentation',
175 | [author], 1)
176 | ]
177 |
178 |
179 | # -- Options for Texinfo output ----------------------------------------------
180 |
181 | # Grouping the document tree into Texinfo files. List of tuples
182 | # (source start file, target name, title, author,
183 | # dir menu entry, description, category)
184 | texinfo_documents = [
185 | (master_doc, 'onpolicy', 'onpolicy Documentation',
186 | author, 'onpolicy', 'One line description of project.',
187 | 'Miscellaneous'),
188 | ]
189 |
190 |
191 | # -- Extension configuration -------------------------------------------------
192 |
193 | # -- Options for intersphinx extension ---------------------------------------
194 |
195 | # Example configuration for intersphinx: refer to the Python standard library.
196 | intersphinx_mapping = {'https://docs.python.org/': None}
197 |
198 | # -- Options for todo extension ----------------------------------------------
199 |
200 | # If true, `todo` and `todoList` produce output, else they produce nothing.
201 | todo_include_todos = True
202 |
203 | source_parsers = {
204 | '.md': CommonMarkParser,
205 | '.MD': CommonMarkParser,
206 | }
--------------------------------------------------------------------------------
/onpolicy/docs/source/full.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/docs/source/full.gif
--------------------------------------------------------------------------------
/onpolicy/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. onpolicy documentation master file, created by
2 | sphinx-quickstart on Sun Dec 6 23:54:05 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to onpolicy's documentation!
7 | ====================================
8 |
9 | .. image:: full.gif
10 |
11 | .. include:: setup.rst
12 |
13 | .. include:: quickstart.rst
14 |
15 | .. toctree::
16 | :maxdepth: 2
17 | :caption: API Reference:
18 |
19 | api/modules.rst
20 |
21 |
22 |
23 | Indices and tables
24 | ==================
25 |
26 | * :ref:`genindex`
27 | * :ref:`modindex`
28 | * :ref:`search`
29 |
--------------------------------------------------------------------------------
/onpolicy/docs/source/quickstart.rst:
--------------------------------------------------------------------------------
1 | Quickstart:
2 | =========================================
3 |
4 |
5 | Training the Highway Environment with MAPPO Algorithm
6 | --------------------
7 |
8 | One line to start training the agent in Highway Environment with MAPPO Algorithm:
9 | 1. change the root path into `scripts/train/`
10 | 2. run the following code
11 |
12 | .. code-block:: bash
13 |
14 | ./train_highway.sh
15 |
16 |
17 |
18 | Hyperparameters
19 | --------------------
20 |
21 | Within `train_highway.sh` file, `train_highway.py` under `scripts/train/` is called with part hyperparameters specified.
22 |
23 | .. code-block:: bash
24 |
25 | CUDA_VISIBLE_DEVICES=0 python train/train_highway.py --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --scenario_name ${scenario} --task_type ${task} --n_attackers ${n_attackers} --n_defenders ${n_defenders} --n_dummies ${n_dummies} --seed ${seed} --n_training_threads 2 --n_rollout_threads 2 --n_eval_rollout_threads 2 --horizon 40 --episode_length 40 --log_interval 1 --use_wandb
26 |
27 |
28 | Hyperparameters contain two types
29 |
30 | - common hyperparameters used in all environments. Such parameters are parserd in ./config.py
31 | - private hyperparameters used in specified environment itself, which are parsered by ./scripts/train/train_.python
32 |
33 | Take highway env as an example,
34 | - the common hyperparameters are following:
35 |
36 | .. automodule:: config
37 | :members:
38 |
39 | - the private hyperparameters are following:
40 |
41 | Take highway environment as an example:
42 |
43 | .. automodule:: scripts.train.train_highway.make_train_env
44 | :members:
45 |
46 | Hierarchical structure
47 | --------------------
--------------------------------------------------------------------------------
/onpolicy/docs/source/setup.rst:
--------------------------------------------------------------------------------
1 | Quickstart:
2 | =========================================
3 | .. code-block:: bash
4 |
5 | python3.7 -m venv ~/venv/python37_onpolicy
6 | source ~/venv/python37_onpolicy/bin/activate
7 | python3 -m pip install --upgrade pip
8 | pip install setuptools==51.0.0
9 | pip install wheel
10 | pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
11 | pip install wandb setproctitle absl-py gym matplotlib pandas pygame imageio tensorboardX numba
12 | # change to the root directory of the onpolicy library
13 | pip install -e .
14 |
15 |
--------------------------------------------------------------------------------
/onpolicy/envs/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | import socket
3 | from absl import flags
4 | FLAGS = flags.FLAGS
5 | FLAGS(['train_sc.py'])
6 |
7 |
8 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "3.5"
4 |
5 | # command to install dependencies
6 | install:
7 | - pip3 install -e .
8 |
9 | # command to run tests
10 | script: ./run_tests.py
11 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/GridWorld_Env.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from .gym_minigrid.envs.human import HumanEnv
3 | from onpolicy.envs.gridworld.gym_minigrid.register import register
4 | import numpy as np
5 | from icecream import ic
6 | from onpolicy.utils.multi_discrete import MultiDiscrete
7 |
8 | class GridWorldEnv(object):
9 | def __init__(self, args):
10 |
11 | self.num_agents = args.num_agents
12 | self.scenario_name = args.scenario_name
13 | self.use_random_pos = args.use_random_pos
14 | self.agent_pos = None if self.use_random_pos else args.agent_pos
15 | self.num_obstacles = args.num_obstacles
16 | self.use_single_reward = args.use_single_reward
17 | self.use_discrect = args.use_discrect
18 |
19 | register(
20 | id=self.scenario_name,
21 | grid_size=args.grid_size,
22 | max_steps=args.max_steps,
23 | local_step_num=args.local_step_num,
24 | agent_view_size=args.agent_view_size,
25 | num_agents=self.num_agents,
26 | num_obstacles=self.num_obstacles,
27 | agent_pos=self.agent_pos,
28 | use_merge_plan=args.use_merge_plan,
29 | use_merge=args.use_merge,
30 | use_constrict_map=args.use_constrict_map,
31 | use_fc_net=args.use_fc_net,
32 | use_agent_id=args.use_agent_id,
33 | use_stack=args.use_stack,
34 | use_orientation=args.use_orientation,
35 | use_same_location=args.use_same_location,
36 | use_complete_reward=args.use_complete_reward,
37 | use_agent_obstacle=args.use_agent_obstacle,
38 | use_multiroom=args.use_multiroom,
39 | use_irregular_room=args.use_irregular_room,
40 | use_time_penalty=args.use_time_penalty,
41 | use_overlap_penalty=args.use_overlap_penalty,
42 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MultiExplorationEnv',
43 | astar_cost_mode=args.astar_cost_mode
44 | )
45 |
46 | self.env = gym.make(self.scenario_name)
47 | self.max_steps = self.env.max_steps
48 | # print("max step is {}".format(self.max_steps))
49 |
50 | self.observation_space = self.env.observation_space
51 | self.share_observation_space = self.env.observation_space
52 |
53 | if self.use_discrect:
54 | self.action_space = [
55 | MultiDiscrete([[0, args.grid_size - 1],[0, args.grid_size - 1]])
56 | for _ in range(self.num_agents)
57 | ]
58 | else:
59 | self.action_space = [
60 | gym.spaces.Box(low=0.0, high=1.0, shape=(2,), dtype=np.float32)
61 | for _ in range(self.num_agents)
62 | ]
63 |
64 | def seed(self, seed=None):
65 | if seed is None:
66 | self.env.seed(1)
67 | else:
68 | self.env.seed(seed)
69 |
70 | def reset(self, choose=True):
71 | if choose:
72 | obs, info = self.env.reset()
73 | else:
74 | obs = [
75 | {
76 | 'image': np.zeros((self.env.width, self.env.height, 3), dtype='uint8'),
77 | 'direction': 0,
78 | 'mission': " "
79 | } for agent_id in range(self.num_agents)
80 | ]
81 | info = {}
82 | return obs, info
83 |
84 | def step(self, actions):
85 | if not np.all(actions == np.ones((self.num_agents, 1)).astype(np.int) * (-1.0)):
86 | obs, rewards, done, infos = self.env.step(actions)
87 | dones = np.array([done for agent_id in range(self.num_agents)])
88 | if self.use_single_reward:
89 | rewards = 0.3 * np.expand_dims(infos['agent_explored_reward'], axis=1) + 0.7 * np.expand_dims(
90 | np.array([infos['merge_explored_reward'] for _ in range(self.num_agents)]), axis=1)
91 | else:
92 | rewards = np.expand_dims(
93 | np.array([infos['merge_explored_reward'] for _ in range(self.num_agents)]), axis=1)
94 | else:
95 | obs = [
96 | {
97 | 'image': np.zeros((self.env.width, self.env.height, 3), dtype='uint8'),
98 | 'direction': 0,
99 | 'mission': " "
100 | } for agent_id in range(self.num_agents)
101 | ]
102 | rewards = np.zeros((self.num_agents, 1))
103 | dones = np.array([None for agent_id in range(self.num_agents)])
104 | infos = {}
105 |
106 | return obs, rewards, dones, infos
107 |
108 | def close(self):
109 | self.env.close()
110 |
111 | def get_short_term_action(self, input):
112 | outputs = self.env.get_short_term_action(input)
113 | return outputs
114 |
115 | def render(self, mode="human", short_goal_pos=None):
116 | if mode == "human":
117 | self.env.render(mode=mode, short_goal_pos=short_goal_pos)
118 | else:
119 | return self.env.render(mode=mode, short_goal_pos=short_goal_pos)
120 |
121 | def ft_get_short_term_goals(self, args, mode=""):
122 | mode_list = ['apf', 'utility', 'nearest', 'rrt', 'voronoi']
123 | assert mode in mode_list, (f"frontier global mode should be in {mode_list}")
124 | return self.env.ft_get_short_term_goals(args, mode=mode)
125 |
126 | def ft_get_short_term_actions(self, *args):
127 | return self.env.ft_get_short_term_actions(*args)
128 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/benchmark.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import time
4 | import argparse
5 | import gym_minigrid
6 | import gym
7 | from gym_minigrid.wrappers import *
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument(
11 | "--env-name",
12 | dest="env_name",
13 | help="gym environment to load",
14 | default='MiniGrid-LavaGapS7-v0'
15 | )
16 | parser.add_argument("--num_resets", default=200)
17 | parser.add_argument("--num_frames", default=5000)
18 | args = parser.parse_args()
19 |
20 | env = gym.make(args.env_name)
21 |
22 | # Benchmark env.reset
23 | t0 = time.time()
24 | for i in range(args.num_resets):
25 | env.reset()
26 | t1 = time.time()
27 | dt = t1 - t0
28 | reset_time = (1000 * dt) / args.num_resets
29 |
30 | # Benchmark rendering
31 | t0 = time.time()
32 | for i in range(args.num_frames):
33 | env.render('rgb_array')
34 | t1 = time.time()
35 | dt = t1 - t0
36 | frames_per_sec = args.num_frames / dt
37 |
38 | # Create an environment with an RGB agent observation
39 | env = gym.make(args.env_name)
40 | env = RGBImgPartialObsWrapper(env)
41 | env = ImgObsWrapper(env)
42 |
43 | # Benchmark rendering
44 | t0 = time.time()
45 | for i in range(args.num_frames):
46 | obs, reward, done, info = env.step(0)
47 | t1 = time.time()
48 | dt = t1 - t0
49 | agent_view_fps = args.num_frames / dt
50 |
51 | print('Env reset time: {:.1f} ms'.format(reset_time))
52 | print('Rendering FPS : {:.0f}'.format(frames_per_sec))
53 | print('Agent view FPS: {:.0f}'.format(agent_view_fps))
54 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/BlockedUnlockPickup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/BlockedUnlockPickup.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/DistShift1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/DistShift1.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/DistShift2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/DistShift2.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/KeyCorridorS3R1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS3R1.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/KeyCorridorS3R2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS3R2.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/KeyCorridorS3R3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS3R3.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/KeyCorridorS4R3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS4R3.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/KeyCorridorS5R3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS5R3.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/KeyCorridorS6R3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS6R3.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/LavaCrossingS11N5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/LavaCrossingS11N5.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/LavaCrossingS9N1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/LavaCrossingS9N1.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/LavaCrossingS9N2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/LavaCrossingS9N2.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/LavaCrossingS9N3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/LavaCrossingS9N3.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/LavaGapS6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/LavaGapS6.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/ObstructedMaze-1Dl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-1Dl.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/ObstructedMaze-1Dlh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-1Dlh.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/ObstructedMaze-1Dlhb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-1Dlhb.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/ObstructedMaze-1Q.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-1Q.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/ObstructedMaze-2Dl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-2Dl.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/ObstructedMaze-2Dlh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-2Dlh.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/ObstructedMaze-2Dlhb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-2Dlhb.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/ObstructedMaze-2Q.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-2Q.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/ObstructedMaze-4Q.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-4Q.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/SimpleCrossingS11N5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/SimpleCrossingS11N5.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/SimpleCrossingS9N1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/SimpleCrossingS9N1.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/SimpleCrossingS9N2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/SimpleCrossingS9N2.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/SimpleCrossingS9N3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/SimpleCrossingS9N3.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/Unlock.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/Unlock.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/UnlockPickup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/UnlockPickup.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/door-key-curriculum.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/door-key-curriculum.gif
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/door-key-env.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/door-key-env.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/dynamic_obstacles.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/dynamic_obstacles.gif
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/empty-env.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/empty-env.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/fetch-env.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/fetch-env.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/four-rooms-env.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/four-rooms-env.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/gotodoor-6x6.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/gotodoor-6x6.mp4
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/gotodoor-6x6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/gotodoor-6x6.png
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/figures/multi-room.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/multi-room.gif
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/frontier/apf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from queue import deque
3 | import pyastar2d
4 |
5 | # class of APF(Artificial Potential Field)
6 |
7 |
8 | class APF(object):
9 | def __init__(self, args):
10 | self.args = args
11 |
12 | self.cluster_radius = args.apf_cluster_radius
13 | self.k_attract = args.apf_k_attract
14 | self.k_agents = args.apf_k_agents
15 | self.AGENT_INFERENCE_RADIUS = args.apf_AGENT_INFERENCE_RADIUS
16 | self.num_iters = args.apf_num_iters
17 | self.repeat_penalty = args.apf_repeat_penalty
18 | self.dis_type = args.apf_dis_type
19 |
20 | self.num_agents = args.num_agents
21 |
22 | def distance(self, a, b):
23 | a = np.array(a)
24 | b = np.array(b)
25 | if self.dis_type == "l2":
26 | return np.sqrt(((a-b)**2).sum())
27 | elif self.dis_type == "l1":
28 | return abs(a-b).sum()
29 |
30 | def schedule(self, map, locations, steps, agent_id, penalty, full_path=True):
31 | '''
32 | APF to schedule path for agent agent_id
33 | map: H x W
34 | - 0 for explored & available cell
35 | - 1 for obstacle
36 | - 2 for target (frontier)
37 | locations: num_agents x 2
38 | steps: available actions
39 | penalty: repeat penalty
40 | full_path: default True, False for single step (i.e., next cell)
41 | '''
42 | H, W = map.shape
43 |
44 | # find available targets
45 | vis = np.zeros((H, W), dtype=np.uint8)
46 | que = deque([])
47 | x, y = locations[agent_id]
48 | vis[x, y] = 1
49 | que.append((x, y))
50 | while len(que) > 0:
51 | x, y = que.popleft()
52 | for dx, dy in steps:
53 | x1 = x + dx
54 | y1 = y + dy
55 | if vis[x1, y1] == 0 and map[x1, y1] in [0, 2]:
56 | vis[x1, y1] = 1
57 | que.append((x1, y1))
58 |
59 | targets = []
60 | for i in range(H):
61 | for j in range(W):
62 | if map[i, j] == 2 and vis[i, j] == 1:
63 | targets.append((i, j))
64 | # clustering
65 | clusters = []
66 | num_targets = len(targets)
67 | valid = [True for _ in range(num_targets)]
68 | for i in range(num_targets):
69 | if valid[i]:
70 | # not clustered
71 | chosen_targets = []
72 | for j in range(num_targets):
73 | if valid[j] and self.distance(targets[i], targets[j]) <= self.cluster_radius:
74 | valid[j] = False
75 | chosen_targets.append(targets[j])
76 | min_r = 1e6
77 | center = None
78 | for a in chosen_targets:
79 | max_d = max([self.distance(a, b) for b in chosen_targets])
80 | if max_d < min_r:
81 | min_r = max_d
82 | center = a
83 | clusters.append({"center": center, "weight": len(chosen_targets)})
84 |
85 | # potential
86 | num_clusters = len(clusters)
87 | potential = np.zeros((H, W))
88 | potential[map == 1] = 1e6
89 |
90 | # potential of targets & obstacles (wave-front dist)
91 | for cluster in clusters:
92 | sx, sy = cluster["center"]
93 | w = cluster["weight"]
94 | dis = np.ones((H, W), dtype=np.int64) * 1e6
95 | dis[sx, sy] = 0
96 | que = deque([(sx, sy)])
97 | while len(que) > 0:
98 | (x, y) = que.popleft()
99 | for dx, dy in steps:
100 | x1 = x + dx
101 | y1 = y + dy
102 | if dis[x1, y1] == 1e6 and map[x1, y1] in [0, 2]:
103 | dis[x1, y1] = dis[x, y]+1
104 | que.append((x1, y1))
105 | dis[sx, sy] = 1e6
106 | dis = 1 / dis
107 | dis[sx, sy] = 0
108 | potential[map != 1] -= dis[map != 1] * self.k_attract * w
109 |
110 | # potential of agents
111 | for x in range(H):
112 | for y in range(W):
113 | for agent_loc in locations:
114 | d = self.distance(agent_loc, (x, y))
115 | if d <= self.AGENT_INFERENCE_RADIUS:
116 | potential[x, y] += self.k_agents * (self.AGENT_INFERENCE_RADIUS - d)
117 |
118 | # repeat penalty
119 | potential += penalty
120 |
121 | # schedule path
122 | it = 1
123 | current_loc = locations[agent_id]
124 | current_potential = 1e4
125 | minDis2Target = 1e6
126 | path = [(current_loc[0], current_loc[1])]
127 | while it <= self.num_iters and minDis2Target > 1:
128 | it = it + 1
129 | potential[current_loc[0], current_loc[1]] += self.repeat_penalty
130 | best_neigh = None
131 | min_potential = 1e6
132 | for dx, dy in steps:
133 | neighbor_loc = (current_loc[0] + dx, current_loc[1] + dy)
134 | if map[neighbor_loc[0], neighbor_loc[1]] == 1:
135 | continue
136 | if min_potential > potential[neighbor_loc[0], neighbor_loc[1]]:
137 | min_potential = potential[neighbor_loc[0], neighbor_loc[1]]
138 | best_neigh = neighbor_loc
139 | if current_potential > min_potential:
140 | current_potential = min_potential
141 | current_loc = best_neigh
142 | path.append(best_neigh)
143 | for tar in targets:
144 | l = self.distance(current_loc, tar)
145 | if l == 0:
146 | continue
147 | minDis2Target = min(minDis2Target, l)
148 | if l <= 1:
149 | path.append((tar[0], tar[1]))
150 | break
151 | if not full_path and len(path) > 1:
152 | return path[1] # next grid
153 | random_plan = False
154 | if minDis2Target > 1:
155 | random_plan = True
156 | for i in range(agent_id):
157 | if locations[i][0] == locations[agent_id][0] and locations[i][1] == locations[agent_id][1]:
158 | random_plan = True # two agents are at the same location, replan
159 | if random_plan:
160 | # if not reaching a frontier, randomly pick a traget as goal
161 | if num_targets == 0:
162 | targets = [(np.random.randint(0, H), np.random.randint(0, W))]
163 | num_targets = 1
164 | w = np.random.randint(0, num_targets)
165 | goal = targets[w]
166 | temp_map = np.ones((H, W), dtype=np.float32)
167 | temp_map[temp_map == 1] = 1000000
168 | temp_map[temp_map == 2] = 1.0
169 | path = pyastar2d.astar_path(temp_map, locations[agent_id], goal, allow_diagonal=False)
170 | if len(path) == 1:
171 | path = (locations[agent_id], goal)
172 | return path
173 | return path
174 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/frontier/nearest.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .utils import *
3 | import random
4 |
5 |
6 | def nearest_goal(map, loc, steps):
7 | dis, vis = bfs(map, loc, steps)
8 | H, W = map.shape
9 | frontiers = []
10 | for x in range(H):
11 | for y in range(W):
12 | if map[x, y] == 2 and vis[x, y]:
13 | frontiers.append((x, y))
14 | if len(frontiers) == 0:
15 | goal = random.randint(0, H-1), random.randint(0, W-1)
16 | return goal
17 | dist = [dis[x, y] for x, y in frontiers]
18 | mi = min(dist)
19 | candidates = [(x, y) for i, (x, y) in enumerate(frontiers) if dist[i] == mi]
20 | goal = random.choice(candidates)
21 | return goal
22 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/frontier/rrt.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.frontier.utils import generate_square
2 | from onpolicy.envs.gridworld.frontier.utils import find_rectangle_obstacles
3 | from onpolicy.utils.RRT.rrt import RRT
4 | import numpy as np
5 | from .utils import *
6 | import random
7 |
8 |
9 | class rrt_config(object):
10 | cluster_radius = 0
11 | rrt_max_iter = 10000
12 | rrt_num_targets = 500
13 | utility_edge_len = 7
14 |
15 |
16 | def rrt_goal(map, unexplored, loc):
17 | H, W = map.shape
18 | map = map.astype(np.int32)
19 |
20 | obstacles = find_rectangle_obstacles(map)
21 |
22 | rrt = RRT(start=(loc[0] + 0.5, loc[1] + 0.5),
23 | goals=[],
24 | rand_area=((0, H), (0, W)),
25 | obstacle_list=obstacles,
26 | expand_dis=0.5,
27 | goal_sample_rate=-1,
28 | max_iter=rrt_config.rrt_max_iter) # maybe more iterations?
29 |
30 | rrt_map = unexplored.copy().astype(np.int32)
31 | targets = rrt.select_frontiers(rrt_map, num_targets=rrt_config.rrt_num_targets)
32 | # print("targets: ", len(targets))
33 |
34 | clusters = get_frontier_cluster(targets, cluster_radius=rrt_config.cluster_radius)
35 | # print("clusters: ",len(clusters))
36 |
37 | if len(clusters) == 0:
38 | x, y = random.randint(0, H-1), random.randint(0, W-1)
39 | while map[x, y] == 1:
40 | x, y = random.randint(0, H-1), random.randint(0, W-1)
41 | goal = (x, y)
42 | return goal
43 | for cluster in clusters:
44 | center = cluster['center']
45 | # navigation cost
46 | nav_cost = l1distance(center, loc)
47 | # information gain
48 | mat = generate_square(H, W, center, rrt_config.utility_edge_len)
49 | area = mat.sum()
50 | info_gain = rrt_map[mat == 1].sum()
51 | info_gain /= area
52 | cluster['info_gain'] = info_gain
53 | cluster['nav_cost'] = nav_cost
54 | D = max([cluster['nav_cost'] for cluster in clusters])
55 | goal = None
56 | mx = -1e9
57 | for cluster in clusters:
58 | cluster['nav_cost'] /= D
59 | cluster['utility'] = cluster['info_gain'] - 1.0 * cluster['nav_cost']
60 | if mx < cluster['utility']:
61 | mx = cluster['utility']
62 | goal = cluster['center']
63 |
64 | return goal
65 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/frontier/utility.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.frontier.utils import generate_square
2 | import numpy as np
3 | from .utils import *
4 | import random
5 |
6 |
7 | def utility_goal(map, unexplored, loc, steps, edge_len=7):
8 | _, vis = bfs(map, loc, steps)
9 | H, W = map.shape
10 | utility = np.zeros((H, W), dtype=np.int32)
11 | frontiers = []
12 | for x in range(H):
13 | for y in range(W):
14 | if map[x, y] == 2 and vis[x, y]:
15 | mat = generate_square(H, W, (x, y), edge_len)
16 | utility[x, y] = unexplored[mat == 1].sum()
17 | frontiers.append((x, y))
18 | mx = utility.max()
19 | value = [utility[x, y] for x, y in frontiers]
20 | candidates = [(x, y) for i, (x, y) in enumerate(frontiers) if value[i] == mx]
21 | if len(candidates) > 0:
22 | goal = random.choice(candidates)
23 | else:
24 | goal = random.randint(0, H-1), random.randint(0, W-1)
25 | return goal
26 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/frontier/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from queue import deque
3 | import math
4 |
5 |
6 | def l1distance(x, y):
7 | return abs(x[0] - y[0]) + abs(x[1] - y[1])
8 |
9 |
10 | def l2distance(x, y):
11 | return math.hypot(x[0]-y[0], x[1]-y[1])
12 |
13 |
14 | def bfs(map, start, steps):
15 | sx, sy = start[0], start[1]
16 | assert map[sx, sy] != 1, ('start position should not be obstacle')
17 | que = deque([(sx, sy)])
18 | H, W = map.shape
19 | dis = np.ones((H, W), dtype=np.int32) * 1e9
20 | dis[sx, sy] = 0
21 | while len(que) > 0:
22 | x, y = que.popleft()
23 | neigh = [(x+dx, y+dy) for dx, dy in steps]
24 | neigh = [(x, y) for x, y in neigh if x >= 0 and x < H and y >
25 | 0 and y < W and map[x, y] != 1 and dis[x, y] == 1e9]
26 | for u, v in neigh:
27 | dis[u, v] = dis[x, y] + 1
28 | que.append((u, v))
29 | vis = (dis < 1e9)
30 | return dis, vis
31 |
32 |
33 | def generate_square(H, W, loc, d):
34 | mat = np.zeros((H, W), dtype=np.int32)
35 | for x in range(H):
36 | for y in range(W):
37 | if max(abs(x-loc[0]), abs(y-loc[1])) <= d:
38 | mat[x, y] = 1
39 | return mat
40 |
41 |
42 | def get_frontier_cluster(frontiers, cluster_radius=5.0):
43 | num_frontier = len(frontiers)
44 | clusters = []
45 | valid = [True for _ in range(num_frontier)]
46 | for i in range(num_frontier):
47 | if valid[i]:
48 | neigh = []
49 | for j in range(num_frontier):
50 | if valid[j] and l2distance(frontiers[i], frontiers[j]) <= cluster_radius:
51 | valid[j] = False
52 | neigh.append(frontiers[j])
53 | center = None
54 | min_r = 1e9
55 | for p in neigh:
56 | r = max([l2distance(p, q) for q in neigh])
57 | if r < min_r:
58 | min_r = r
59 | center = p
60 | if len(neigh) >= 5:
61 | clusters.append({'center': center, 'weight': len(neigh)})
62 | return clusters
63 |
64 |
65 | def find_rectangle_obstacles(map):
66 | map = map.copy().astype(np.int32)
67 | map[map == 2] = 0
68 | H, W = map.shape
69 | obstacles = []
70 | covered = np.zeros((H, W), dtype=np.int32)
71 | pad = 0.01
72 | for x in range(H):
73 | for y in range(W):
74 | if map[x, y] == 1 and covered[x, y] == 0:
75 | x1 = x
76 | x2 = x
77 | while x2 < H-1 and map[x2 + 1, y] == 1:
78 | x2 = x2 + 1
79 | y1 = y
80 | y2 = y
81 | while y2 < W-1 and map[x1: x2+1, y2 + 1].sum() == x2-x1+1:
82 | y2 = y2 + 1
83 | covered[x1: x2 + 1, y1: y2 + 1] = 1
84 | obstacles.append((x1-pad, y1-pad, x2 + 1 + pad, y2 + 1 + pad))
85 | return obstacles
86 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/frontier/voronoi.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.frontier.utils import generate_square
2 | import numpy as np
3 | from .utils import *
4 | import random
5 |
6 |
7 | def voronoi_goal(map, unexplored, locs, agent_id, steps, edge_len=7):
8 | num_agents = len(locs)
9 |
10 | dis = []
11 | vis = []
12 | for a, loc in enumerate(locs):
13 | _dis, _vis = bfs(map, loc, steps)
14 | dis.append(_dis)
15 | vis.append(_vis)
16 |
17 | H, W = map.shape
18 | my_grids = np.ones((H, W), dtype=np.int32)
19 | for a in range(num_agents):
20 | my_grids[dis[agent_id] > dis[a]] = 0
21 |
22 | utility = np.zeros((H, W), dtype=np.int32)
23 | frontiers = []
24 | for x in range(H):
25 | for y in range(W):
26 | if map[x, y] == 2 and vis[agent_id][x, y] and my_grids[x, y]:
27 | mat = generate_square(H, W, (x, y), edge_len)
28 | utility[x, y] = unexplored[mat == 1].sum()
29 | frontiers.append((x, y))
30 | mx = utility.max()
31 | value = [utility[x, y] for x, y in frontiers]
32 | candidates = [(x, y) for i, (x, y) in enumerate(frontiers) if value[i] == mx]
33 | if len(candidates) > 0:
34 | goal = random.choice(candidates)
35 | else:
36 | goal = random.randint(0, H-1), random.randint(0, W-1)
37 | return goal
38 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/__init__.py:
--------------------------------------------------------------------------------
1 | # Import the envs module so that envs register themselves
2 | import onpolicy.envs.gridworld.gym_minigrid.envs
3 |
4 | # Import wrappers so it's accessible when installing with pip
5 | import onpolicy.envs.gridworld.gym_minigrid.wrappers
6 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/__init__.py:
--------------------------------------------------------------------------------
1 | # from onpolicy.envs.gridworld.gym_minigrid.envs.empty import *
2 | # from onpolicy.envs.gridworld.gym_minigrid.envs.doorkey import *
3 | # from onpolicy.envs.gridworld.gym_minigrid.envs.multiroom import *
4 | # from onpolicy.envs.gridworld.gym_minigrid.envs.fetch import *
5 | # from onpolicy.envs.gridworld.gym_minigrid.envs.gotoobject import *
6 | # from onpolicy.envs.gridworld.gym_minigrid.envs.gotodoor import *
7 | # from onpolicy.envs.gridworld.gym_minigrid.envs.putnear import *
8 | # from onpolicy.envs.gridworld.gym_minigrid.envs.lockedroom import *
9 | # from onpolicy.envs.gridworld.gym_minigrid.envs.keycorridor import *
10 | # from onpolicy.envs.gridworld.gym_minigrid.envs.unlock import *
11 | # from onpolicy.envs.gridworld.gym_minigrid.envs.unlockpickup import *
12 | # from onpolicy.envs.gridworld.gym_minigrid.envs.blockedunlockpickup import *
13 | # from onpolicy.envs.gridworld.gym_minigrid.envs.playground_v0 import *
14 | # from onpolicy.envs.gridworld.gym_minigrid.envs.redbluedoors import *
15 | # from onpolicy.envs.gridworld.gym_minigrid.envs.obstructedmaze import *
16 | # from onpolicy.envs.gridworld.gym_minigrid.envs.memory import *
17 | # from onpolicy.envs.gridworld.gym_minigrid.envs.fourrooms import *
18 | # from onpolicy.envs.gridworld.gym_minigrid.envs.crossing import *
19 | # from onpolicy.envs.gridworld.gym_minigrid.envs.lavagap import *
20 | # from onpolicy.envs.gridworld.gym_minigrid.envs.dynamicobstacles import *
21 | # from onpolicy.envs.gridworld.gym_minigrid.envs.distshift import *
22 | from onpolicy.envs.gridworld.gym_minigrid.envs.human import *
23 | from onpolicy.envs.gridworld.gym_minigrid.envs.multiexploration import *
24 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/blockedunlockpickup.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import Ball
2 | from onpolicy.envs.gridworld.gym_minigrid.roomgrid import RoomGrid
3 | from onpolicy.envs.gridworld.gym_minigrid.register import register
4 |
5 | class BlockedUnlockPickup(RoomGrid):
6 | """
7 | Unlock a door blocked by a ball, then pick up a box
8 | in another room
9 | """
10 |
11 | def __init__(self, seed=None):
12 | room_size = 6
13 | super().__init__(
14 | num_rows=1,
15 | num_cols=2,
16 | room_size=room_size,
17 | max_steps=16*room_size**2,
18 | seed=seed
19 | )
20 |
21 | def _gen_grid(self, width, height):
22 | super()._gen_grid(width, height)
23 |
24 | # Add a box to the room on the right
25 | obj, _ = self.add_object(1, 0, kind="box")
26 | # Make sure the two rooms are directly connected by a locked door
27 | door, pos = self.add_door(0, 0, 0, locked=True)
28 | # Block the door with a ball
29 | color = self._rand_color()
30 | self.grid.set(pos[0]-1, pos[1], Ball(color))
31 | # Add a key to unlock the door
32 | self.add_object(0, 0, 'key', door.color)
33 |
34 | self.place_agent(0, 0)
35 |
36 | self.obj = obj
37 | self.mission = "pick up the %s %s" % (obj.color, obj.type)
38 |
39 | def step(self, action):
40 | obs, reward, done, info = super().step(action)
41 |
42 | if action == self.actions.pickup:
43 | if self.carrying and self.carrying == self.obj:
44 | reward = self._reward()
45 | done = True
46 |
47 | return obs, reward, done, info
48 |
49 | register(
50 | id='MiniGrid-BlockedUnlockPickup-v0',
51 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:BlockedUnlockPickup'
52 | )
53 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/crossing.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | import itertools as itt
5 |
6 |
7 | class CrossingEnv(MiniGridEnv):
8 | """
9 | Environment with wall or lava obstacles, sparse reward.
10 | """
11 |
12 | def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, seed=None):
13 | self.num_crossings = num_crossings
14 | self.obstacle_type = obstacle_type
15 | super().__init__(
16 | grid_size=size,
17 | max_steps=4*size*size,
18 | # Set this to True for maximum speed
19 | see_through_walls=False,
20 | seed=None
21 | )
22 |
23 | def _gen_grid(self, width, height):
24 | assert width % 2 == 1 and height % 2 == 1 # odd size
25 |
26 | # Create an empty grid
27 | self.grid = Grid(width, height)
28 |
29 | # Generate the surrounding walls
30 | self.grid.wall_rect(0, 0, width, height)
31 |
32 | # Place the agent in the top-left corner
33 | self.agent_pos = (1, 1)
34 | self.agent_dir = 0
35 |
36 | # Place a goal square in the bottom-right corner
37 | self.put_obj(Goal(), width - 2, height - 2)
38 |
39 | # Place obstacles (lava or walls)
40 | v, h = object(), object() # singleton `vertical` and `horizontal` objects
41 |
42 | # Lava rivers or walls specified by direction and position in grid
43 | rivers = [(v, i) for i in range(2, height - 2, 2)]
44 | rivers += [(h, j) for j in range(2, width - 2, 2)]
45 | self.np_random.shuffle(rivers)
46 | rivers = rivers[:self.num_crossings] # sample random rivers
47 | rivers_v = sorted([pos for direction, pos in rivers if direction is v])
48 | rivers_h = sorted([pos for direction, pos in rivers if direction is h])
49 | obstacle_pos = itt.chain(
50 | itt.product(range(1, width - 1), rivers_h),
51 | itt.product(rivers_v, range(1, height - 1)),
52 | )
53 | for i, j in obstacle_pos:
54 | self.put_obj(self.obstacle_type(), i, j)
55 |
56 | # Sample path to goal
57 | path = [h] * len(rivers_v) + [v] * len(rivers_h)
58 | self.np_random.shuffle(path)
59 |
60 | # Create openings
61 | limits_v = [0] + rivers_v + [height - 1]
62 | limits_h = [0] + rivers_h + [width - 1]
63 | room_i, room_j = 0, 0
64 | for direction in path:
65 | if direction is h:
66 | i = limits_v[room_i + 1]
67 | j = self.np_random.choice(
68 | range(limits_h[room_j] + 1, limits_h[room_j + 1]))
69 | room_i += 1
70 | elif direction is v:
71 | i = self.np_random.choice(
72 | range(limits_v[room_i] + 1, limits_v[room_i + 1]))
73 | j = limits_h[room_j + 1]
74 | room_j += 1
75 | else:
76 | assert False
77 | self.grid.set(i, j, None)
78 |
79 | self.mission = (
80 | "avoid the lava and get to the green goal square"
81 | if self.obstacle_type == Lava
82 | else "find the opening and get to the green goal square"
83 | )
84 |
85 | class LavaCrossingEnv(CrossingEnv):
86 | def __init__(self):
87 | super().__init__(size=9, num_crossings=1)
88 |
89 | class LavaCrossingS9N2Env(CrossingEnv):
90 | def __init__(self):
91 | super().__init__(size=9, num_crossings=2)
92 |
93 | class LavaCrossingS9N3Env(CrossingEnv):
94 | def __init__(self):
95 | super().__init__(size=9, num_crossings=3)
96 |
97 | class LavaCrossingS11N5Env(CrossingEnv):
98 | def __init__(self):
99 | super().__init__(size=11, num_crossings=5)
100 |
101 | register(
102 | id='MiniGrid-LavaCrossingS9N1-v0',
103 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaCrossingEnv'
104 | )
105 |
106 | register(
107 | id='MiniGrid-LavaCrossingS9N2-v0',
108 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaCrossingS9N2Env'
109 | )
110 |
111 | register(
112 | id='MiniGrid-LavaCrossingS9N3-v0',
113 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaCrossingS9N3Env'
114 | )
115 |
116 | register(
117 | id='MiniGrid-LavaCrossingS11N5-v0',
118 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaCrossingS11N5Env'
119 | )
120 |
121 | class SimpleCrossingEnv(CrossingEnv):
122 | def __init__(self):
123 | super().__init__(size=9, num_crossings=1, obstacle_type=Wall)
124 |
125 | class SimpleCrossingS9N2Env(CrossingEnv):
126 | def __init__(self):
127 | super().__init__(size=9, num_crossings=2, obstacle_type=Wall)
128 |
129 | class SimpleCrossingS9N3Env(CrossingEnv):
130 | def __init__(self):
131 | super().__init__(size=9, num_crossings=3, obstacle_type=Wall)
132 |
133 | class SimpleCrossingS11N5Env(CrossingEnv):
134 | def __init__(self):
135 | super().__init__(size=11, num_crossings=5, obstacle_type=Wall)
136 |
137 | register(
138 | id='MiniGrid-SimpleCrossingS9N1-v0',
139 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:SimpleCrossingEnv'
140 | )
141 |
142 | register(
143 | id='MiniGrid-SimpleCrossingS9N2-v0',
144 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:SimpleCrossingS9N2Env'
145 | )
146 |
147 | register(
148 | id='MiniGrid-SimpleCrossingS9N3-v0',
149 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:SimpleCrossingS9N3Env'
150 | )
151 |
152 | register(
153 | id='MiniGrid-SimpleCrossingS11N5-v0',
154 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:SimpleCrossingS11N5Env'
155 | )
156 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/distshift.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class DistShiftEnv(MiniGridEnv):
5 | """
6 | Distributional shift environment.
7 | """
8 |
9 | def __init__(
10 | self,
11 | width=9,
12 | height=7,
13 | agent_start_pos=(1,1),
14 | agent_start_dir=0,
15 | strip2_row=2
16 | ):
17 | self.agent_start_pos = agent_start_pos
18 | self.agent_start_dir = agent_start_dir
19 | self.goal_pos = (width-2, 1)
20 | self.strip2_row = strip2_row
21 |
22 | super().__init__(
23 | width=width,
24 | height=height,
25 | max_steps=4*width*height,
26 | # Set this to True for maximum speed
27 | see_through_walls=True
28 | )
29 |
30 | def _gen_grid(self, width, height):
31 | # Create an empty grid
32 | self.grid = Grid(width, height)
33 |
34 | # Generate the surrounding walls
35 | self.grid.wall_rect(0, 0, width, height)
36 |
37 | # Place a goal square in the bottom-right corner
38 | self.put_obj(Goal(), *self.goal_pos)
39 |
40 | # Place the lava rows
41 | for i in range(self.width - 6):
42 | self.grid.set(3+i, 1, Lava())
43 | self.grid.set(3+i, self.strip2_row, Lava())
44 |
45 | # Place the agent
46 | if self.agent_start_pos is not None:
47 | self.agent_pos = self.agent_start_pos
48 | self.agent_dir = self.agent_start_dir
49 | else:
50 | self.place_agent()
51 |
52 | self.mission = "get to the green goal square"
53 |
54 | class DistShift1(DistShiftEnv):
55 | def __init__(self):
56 | super().__init__(strip2_row=2)
57 |
58 | class DistShift2(DistShiftEnv):
59 | def __init__(self):
60 | super().__init__(strip2_row=5)
61 |
62 | register(
63 | id='MiniGrid-DistShift1-v0',
64 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DistShift1'
65 | )
66 |
67 | register(
68 | id='MiniGrid-DistShift2-v0',
69 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DistShift2'
70 | )
71 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/doorkey.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class DoorKeyEnv(MiniGridEnv):
5 | """
6 | Environment with a door and key, sparse reward
7 | """
8 |
9 | def __init__(self, size=8):
10 | super().__init__(
11 | grid_size=size,
12 | max_steps=10*size*size
13 | )
14 |
15 | def _gen_grid(self, width, height):
16 | # Create an empty grid
17 | self.grid = Grid(width, height)
18 |
19 | # Generate the surrounding walls
20 | self.grid.wall_rect(0, 0, width, height)
21 |
22 | # Place a goal in the bottom-right corner
23 | self.put_obj(Goal(), width - 2, height - 2)
24 |
25 | # Create a vertical splitting wall
26 | splitIdx = self._rand_int(2, width-2)
27 | self.grid.vert_wall(splitIdx, 0)
28 |
29 | # Place the agent at a random position and orientation
30 | # on the left side of the splitting wall
31 | self.place_agent(size=(splitIdx, height))
32 |
33 | # Place a door in the wall
34 | doorIdx = self._rand_int(1, width-2)
35 | self.put_obj(Door('yellow', is_locked=True), splitIdx, doorIdx)
36 |
37 | # Place a yellow key on the left side
38 | self.place_obj(
39 | obj=Key('yellow'),
40 | top=(0, 0),
41 | size=(splitIdx, height)
42 | )
43 |
44 | self.mission = "use the key to open the door and then get to the goal"
45 |
46 | class DoorKeyEnv5x5(DoorKeyEnv):
47 | def __init__(self):
48 | super().__init__(size=5)
49 |
50 | class DoorKeyEnv6x6(DoorKeyEnv):
51 | def __init__(self):
52 | super().__init__(size=6)
53 |
54 | class DoorKeyEnv16x16(DoorKeyEnv):
55 | def __init__(self):
56 | super().__init__(size=16)
57 |
58 | register(
59 | id='MiniGrid-DoorKey-5x5-v0',
60 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DoorKeyEnv5x5'
61 | )
62 |
63 | register(
64 | id='MiniGrid-DoorKey-6x6-v0',
65 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DoorKeyEnv6x6'
66 | )
67 |
68 | register(
69 | id='MiniGrid-DoorKey-8x8-v0',
70 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DoorKeyEnv'
71 | )
72 |
73 | register(
74 | id='MiniGrid-DoorKey-16x16-v0',
75 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DoorKeyEnv16x16'
76 | )
77 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/dynamicobstacles.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 | from operator import add
4 |
5 | class DynamicObstaclesEnv(MiniGridEnv):
6 | """
7 | Single-room square grid environment with moving obstacles
8 | """
9 |
10 | def __init__(
11 | self,
12 | size=8,
13 | agent_start_pos=(1, 1),
14 | agent_start_dir=0,
15 | n_obstacles=4
16 | ):
17 | self.agent_start_pos = agent_start_pos
18 | self.agent_start_dir = agent_start_dir
19 |
20 | # Reduce obstacles if there are too many
21 | if n_obstacles <= size/2 + 1:
22 | self.n_obstacles = int(n_obstacles)
23 | else:
24 | self.n_obstacles = int(size/2)
25 | super().__init__(
26 | grid_size=size,
27 | max_steps=4 * size * size,
28 | # Set this to True for maximum speed
29 | see_through_walls=True,
30 | )
31 | # Allow only 3 actions permitted: left, right, forward
32 | self.action_space = spaces.Discrete(self.actions.forward + 1)
33 | self.reward_range = (-1, 1)
34 |
35 | def _gen_grid(self, width, height):
36 | # Create an empty grid
37 | self.grid = Grid(width, height)
38 |
39 | # Generate the surrounding walls
40 | self.grid.wall_rect(0, 0, width, height)
41 |
42 | # Place a goal square in the bottom-right corner
43 | self.grid.set(width - 2, height - 2, Goal())
44 |
45 | # Place the agent
46 | if self.agent_start_pos is not None:
47 | self.agent_pos = self.agent_start_pos
48 | self.agent_dir = self.agent_start_dir
49 | else:
50 | self.place_agent()
51 |
52 | # Place obstacles
53 | self.obstacles = []
54 | for i_obst in range(self.n_obstacles):
55 | self.obstacles.append(Ball())
56 | self.place_obj(self.obstacles[i_obst], max_tries=100)
57 |
58 | self.mission = "get to the green goal square"
59 |
60 | def step(self, action):
61 | # Invalid action
62 | if action >= self.action_space.n:
63 | action = 0
64 |
65 | # Check if there is an obstacle in front of the agent
66 | front_cell = self.grid.get(*self.front_pos)
67 | not_clear = front_cell and front_cell.type != 'goal'
68 |
69 | # Update obstacle positions
70 | for i_obst in range(len(self.obstacles)):
71 | old_pos = self.obstacles[i_obst].cur_pos
72 | top = tuple(map(add, old_pos, (-1, -1)))
73 |
74 | try:
75 | self.place_obj(self.obstacles[i_obst], top=top, size=(3,3), max_tries=100)
76 | self.grid.set(*old_pos, None)
77 | except:
78 | pass
79 |
80 | # Update the agent's position/direction
81 | obs, reward, done, info = MiniGridEnv.step(self, action)
82 |
83 | # If the agent tried to walk over an obstacle or wall
84 | if action == self.actions.forward and not_clear:
85 | reward = -1
86 | done = True
87 | return obs, reward, done, info
88 |
89 | return obs, reward, done, info
90 |
91 | class DynamicObstaclesEnv5x5(DynamicObstaclesEnv):
92 | def __init__(self):
93 | super().__init__(size=5, n_obstacles=2)
94 |
95 | class DynamicObstaclesRandomEnv5x5(DynamicObstaclesEnv):
96 | def __init__(self):
97 | super().__init__(size=5, agent_start_pos=None, n_obstacles=2)
98 |
99 | class DynamicObstaclesEnv6x6(DynamicObstaclesEnv):
100 | def __init__(self):
101 | super().__init__(size=6, n_obstacles=3)
102 |
103 | class DynamicObstaclesRandomEnv6x6(DynamicObstaclesEnv):
104 | def __init__(self):
105 | super().__init__(size=6, agent_start_pos=None, n_obstacles=3)
106 |
107 | class DynamicObstaclesEnv16x16(DynamicObstaclesEnv):
108 | def __init__(self):
109 | super().__init__(size=16, n_obstacles=8)
110 |
111 | register(
112 | id='MiniGrid-Dynamic-Obstacles-5x5-v0',
113 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesEnv5x5'
114 | )
115 |
116 | register(
117 | id='MiniGrid-Dynamic-Obstacles-Random-5x5-v0',
118 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesRandomEnv5x5'
119 | )
120 |
121 | register(
122 | id='MiniGrid-Dynamic-Obstacles-6x6-v0',
123 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesEnv6x6'
124 | )
125 |
126 | register(
127 | id='MiniGrid-Dynamic-Obstacles-Random-6x6-v0',
128 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesRandomEnv6x6'
129 | )
130 |
131 | register(
132 | id='MiniGrid-Dynamic-Obstacles-8x8-v0',
133 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesEnv'
134 | )
135 |
136 | register(
137 | id='MiniGrid-Dynamic-Obstacles-16x16-v0',
138 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesEnv16x16'
139 | )
140 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/empty.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class EmptyEnv(MiniGridEnv):
5 | """
6 | Empty grid environment, no obstacles, sparse reward
7 | """
8 |
9 | def __init__(
10 | self,
11 | size=8,
12 | agent_start_pos=(1,1),
13 | agent_start_dir=0,
14 | ):
15 | self.agent_start_pos = agent_start_pos
16 | self.agent_start_dir = agent_start_dir
17 |
18 | super().__init__(
19 | grid_size=size,
20 | max_steps=4*size*size,
21 | # Set this to True for maximum speed
22 | see_through_walls=True
23 | )
24 |
25 | def _gen_grid(self, width, height):
26 | # Create an empty grid
27 | self.grid = Grid(width, height)
28 |
29 | # Generate the surrounding walls
30 | self.grid.wall_rect(0, 0, width, height)
31 |
32 | # Place a goal square in the bottom-right corner
33 | self.put_obj(Goal(), width - 2, height - 2)
34 |
35 | # Place the agent
36 | if self.agent_start_pos is not None:
37 | self.agent_pos = self.agent_start_pos
38 | self.agent_dir = self.agent_start_dir
39 | else:
40 | self.place_agent()
41 |
42 | self.mission = "get to the green goal square"
43 |
44 | class EmptyEnv5x5(EmptyEnv):
45 | def __init__(self, **kwargs):
46 | super().__init__(size=5, **kwargs)
47 |
48 | class EmptyRandomEnv5x5(EmptyEnv):
49 | def __init__(self):
50 | super().__init__(size=5, agent_start_pos=None)
51 |
52 | class EmptyEnv6x6(EmptyEnv):
53 | def __init__(self, **kwargs):
54 | super().__init__(size=6, **kwargs)
55 |
56 | class EmptyRandomEnv6x6(EmptyEnv):
57 | def __init__(self):
58 | super().__init__(size=6, agent_start_pos=None)
59 |
60 | class EmptyEnv16x16(EmptyEnv):
61 | def __init__(self, **kwargs):
62 | super().__init__(size=16, **kwargs)
63 |
64 | register(
65 | id='MiniGrid-Empty-5x5-v0',
66 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyEnv5x5'
67 | )
68 |
69 | register(
70 | id='MiniGrid-Empty-Random-5x5-v0',
71 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyRandomEnv5x5'
72 | )
73 |
74 | register(
75 | id='MiniGrid-Empty-6x6-v0',
76 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyEnv6x6'
77 | )
78 |
79 | register(
80 | id='MiniGrid-Empty-Random-6x6-v0',
81 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyRandomEnv6x6'
82 | )
83 |
84 | register(
85 | id='MiniGrid-Empty-8x8-v0',
86 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyEnv'
87 | )
88 |
89 | register(
90 | id='MiniGrid-Empty-16x16-v0',
91 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyEnv16x16'
92 | )
93 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/fetch.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class FetchEnv(MiniGridEnv):
5 | """
6 | Environment in which the agent has to fetch a random object
7 | named using English text strings
8 | """
9 |
10 | def __init__(
11 | self,
12 | size=8,
13 | numObjs=3
14 | ):
15 | self.numObjs = numObjs
16 |
17 | super().__init__(
18 | grid_size=size,
19 | max_steps=5*size**2,
20 | # Set this to True for maximum speed
21 | see_through_walls=True
22 | )
23 |
24 | def _gen_grid(self, width, height):
25 | self.grid = Grid(width, height)
26 |
27 | # Generate the surrounding walls
28 | self.grid.horz_wall(0, 0)
29 | self.grid.horz_wall(0, height-1)
30 | self.grid.vert_wall(0, 0)
31 | self.grid.vert_wall(width-1, 0)
32 |
33 | types = ['key', 'ball']
34 |
35 | objs = []
36 |
37 | # For each object to be generated
38 | while len(objs) < self.numObjs:
39 | objType = self._rand_elem(types)
40 | objColor = self._rand_elem(COLOR_NAMES)
41 |
42 | if objType == 'key':
43 | obj = Key(objColor)
44 | elif objType == 'ball':
45 | obj = Ball(objColor)
46 |
47 | self.place_obj(obj)
48 | objs.append(obj)
49 |
50 | # Randomize the player start position and orientation
51 | self.place_agent()
52 |
53 | # Choose a random object to be picked up
54 | target = objs[self._rand_int(0, len(objs))]
55 | self.targetType = target.type
56 | self.targetColor = target.color
57 |
58 | descStr = '%s %s' % (self.targetColor, self.targetType)
59 |
60 | # Generate the mission string
61 | idx = self._rand_int(0, 5)
62 | if idx == 0:
63 | self.mission = 'get a %s' % descStr
64 | elif idx == 1:
65 | self.mission = 'go get a %s' % descStr
66 | elif idx == 2:
67 | self.mission = 'fetch a %s' % descStr
68 | elif idx == 3:
69 | self.mission = 'go fetch a %s' % descStr
70 | elif idx == 4:
71 | self.mission = 'you must fetch a %s' % descStr
72 | assert hasattr(self, 'mission')
73 |
74 | def step(self, action):
75 | obs, reward, done, info = MiniGridEnv.step(self, action)
76 |
77 | if self.carrying:
78 | if self.carrying.color == self.targetColor and \
79 | self.carrying.type == self.targetType:
80 | reward = self._reward()
81 | done = True
82 | else:
83 | reward = 0
84 | done = True
85 |
86 | return obs, reward, done, info
87 |
88 | class FetchEnv5x5N2(FetchEnv):
89 | def __init__(self):
90 | super().__init__(size=5, numObjs=2)
91 |
92 | class FetchEnv6x6N2(FetchEnv):
93 | def __init__(self):
94 | super().__init__(size=6, numObjs=2)
95 |
96 | register(
97 | id='MiniGrid-Fetch-5x5-N2-v0',
98 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:FetchEnv5x5N2'
99 | )
100 |
101 | register(
102 | id='MiniGrid-Fetch-6x6-N2-v0',
103 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:FetchEnv6x6N2'
104 | )
105 |
106 | register(
107 | id='MiniGrid-Fetch-8x8-N3-v0',
108 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:FetchEnv'
109 | )
110 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/fourrooms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
5 | from onpolicy.envs.gridworld.gym_minigrid.register import register
6 |
7 |
8 | class FourRoomsEnv(MiniGridEnv):
9 | """
10 | Classic 4 rooms gridworld environment.
11 | Can specify agent and goal position, if not it set at random.
12 | """
13 |
14 | def __init__(self, agent_pos=None, goal_pos=None):
15 | self._agent_default_pos = agent_pos
16 | self._goal_default_pos = goal_pos
17 | super().__init__(grid_size=19, max_steps=100)
18 |
19 | def _gen_grid(self, width, height):
20 | # Create the grid
21 | self.grid = Grid(width, height)
22 |
23 | # Generate the surrounding walls
24 | self.grid.horz_wall(0, 0)
25 | self.grid.horz_wall(0, height - 1)
26 | self.grid.vert_wall(0, 0)
27 | self.grid.vert_wall(width - 1, 0)
28 |
29 | room_w = width // 2
30 | room_h = height // 2
31 |
32 | # For each row of rooms
33 | for j in range(0, 2):
34 |
35 | # For each column
36 | for i in range(0, 2):
37 | xL = i * room_w
38 | yT = j * room_h
39 | xR = xL + room_w
40 | yB = yT + room_h
41 |
42 | # Bottom wall and door
43 | if i + 1 < 2:
44 | self.grid.vert_wall(xR, yT, room_h)
45 | pos = (xR, self._rand_int(yT + 1, yB))
46 | self.grid.set(*pos, None)
47 |
48 | # Bottom wall and door
49 | if j + 1 < 2:
50 | self.grid.horz_wall(xL, yB, room_w)
51 | pos = (self._rand_int(xL + 1, xR), yB)
52 | self.grid.set(*pos, None)
53 |
54 | # Randomize the player start position and orientation
55 | if self._agent_default_pos is not None:
56 | self.agent_pos = self._agent_default_pos
57 | self.grid.set(*self._agent_default_pos, None)
58 | self.agent_dir = self._rand_int(0, 4) # assuming random start direction
59 | else:
60 | self.place_agent()
61 |
62 | if self._goal_default_pos is not None:
63 | goal = Goal()
64 | self.put_obj(goal, *self._goal_default_pos)
65 | goal.init_pos, goal.cur_pos = self._goal_default_pos
66 | else:
67 | self.place_obj(Goal())
68 |
69 | self.mission = 'Reach the goal'
70 |
71 | def step(self, action):
72 | obs, reward, done, info = MiniGridEnv.step(self, action)
73 | return obs, reward, done, info
74 |
75 | register(
76 | id='MiniGrid-FourRooms-v0',
77 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:FourRoomsEnv'
78 | )
79 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/gotodoor.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class GoToDoorEnv(MiniGridEnv):
5 | """
6 | Environment in which the agent is instructed to go to a given object
7 | named using an English text string
8 | """
9 |
10 | def __init__(
11 | self,
12 | size=5
13 | ):
14 | assert size >= 5
15 |
16 | super().__init__(
17 | grid_size=size,
18 | max_steps=5*size**2,
19 | # Set this to True for maximum speed
20 | see_through_walls=True
21 | )
22 |
23 | def _gen_grid(self, width, height):
24 | # Create the grid
25 | self.grid = Grid(width, height)
26 |
27 | # Randomly vary the room width and height
28 | width = self._rand_int(5, width+1)
29 | height = self._rand_int(5, height+1)
30 |
31 | # Generate the surrounding walls
32 | self.grid.wall_rect(0, 0, width, height)
33 |
34 | # Generate the 4 doors at random positions
35 | doorPos = []
36 | doorPos.append((self._rand_int(2, width-2), 0))
37 | doorPos.append((self._rand_int(2, width-2), height-1))
38 | doorPos.append((0, self._rand_int(2, height-2)))
39 | doorPos.append((width-1, self._rand_int(2, height-2)))
40 |
41 | # Generate the door colors
42 | doorColors = []
43 | while len(doorColors) < len(doorPos):
44 | color = self._rand_elem(COLOR_NAMES)
45 | if color in doorColors:
46 | continue
47 | doorColors.append(color)
48 |
49 | # Place the doors in the grid
50 | for idx, pos in enumerate(doorPos):
51 | color = doorColors[idx]
52 | self.grid.set(*pos, Door(color))
53 |
54 | # Randomize the agent start position and orientation
55 | self.place_agent(size=(width, height))
56 |
57 | # Select a random target door
58 | doorIdx = self._rand_int(0, len(doorPos))
59 | self.target_pos = doorPos[doorIdx]
60 | self.target_color = doorColors[doorIdx]
61 |
62 | # Generate the mission string
63 | self.mission = 'go to the %s door' % self.target_color
64 |
65 | def step(self, action):
66 | obs, reward, done, info = super().step(action)
67 |
68 | ax, ay = self.agent_pos
69 | tx, ty = self.target_pos
70 |
71 | # Don't let the agent open any of the doors
72 | if action == self.actions.toggle:
73 | done = True
74 |
75 | # Reward performing done action in front of the target door
76 | if action == self.actions.done:
77 | if (ax == tx and abs(ay - ty) == 1) or (ay == ty and abs(ax - tx) == 1):
78 | reward = self._reward()
79 | done = True
80 |
81 | return obs, reward, done, info
82 |
83 | class GoToDoor8x8Env(GoToDoorEnv):
84 | def __init__(self):
85 | super().__init__(size=8)
86 |
87 | class GoToDoor6x6Env(GoToDoorEnv):
88 | def __init__(self):
89 | super().__init__(size=6)
90 |
91 | register(
92 | id='MiniGrid-GoToDoor-5x5-v0',
93 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:GoToDoorEnv'
94 | )
95 |
96 | register(
97 | id='MiniGrid-GoToDoor-6x6-v0',
98 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:GoToDoor6x6Env'
99 | )
100 |
101 | register(
102 | id='MiniGrid-GoToDoor-8x8-v0',
103 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:GoToDoor8x8Env'
104 | )
105 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/gotoobject.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class GoToObjectEnv(MiniGridEnv):
5 | """
6 | Environment in which the agent is instructed to go to a given object
7 | named using an English text string
8 | """
9 |
10 | def __init__(
11 | self,
12 | size=6,
13 | numObjs=2
14 | ):
15 | self.numObjs = numObjs
16 |
17 | super().__init__(
18 | grid_size=size,
19 | max_steps=5*size**2,
20 | # Set this to True for maximum speed
21 | see_through_walls=True
22 | )
23 |
24 | def _gen_grid(self, width, height):
25 | self.grid = Grid(width, height)
26 |
27 | # Generate the surrounding walls
28 | self.grid.wall_rect(0, 0, width, height)
29 |
30 | # Types and colors of objects we can generate
31 | types = ['key', 'ball', 'box']
32 |
33 | objs = []
34 | objPos = []
35 |
36 | # Until we have generated all the objects
37 | while len(objs) < self.numObjs:
38 | objType = self._rand_elem(types)
39 | objColor = self._rand_elem(COLOR_NAMES)
40 |
41 | # If this object already exists, try again
42 | if (objType, objColor) in objs:
43 | continue
44 |
45 | if objType == 'key':
46 | obj = Key(objColor)
47 | elif objType == 'ball':
48 | obj = Ball(objColor)
49 | elif objType == 'box':
50 | obj = Box(objColor)
51 |
52 | pos = self.place_obj(obj)
53 | objs.append((objType, objColor))
54 | objPos.append(pos)
55 |
56 | # Randomize the agent start position and orientation
57 | self.place_agent()
58 |
59 | # Choose a random object to be picked up
60 | objIdx = self._rand_int(0, len(objs))
61 | self.targetType, self.target_color = objs[objIdx]
62 | self.target_pos = objPos[objIdx]
63 |
64 | descStr = '%s %s' % (self.target_color, self.targetType)
65 | self.mission = 'go to the %s' % descStr
66 | #print(self.mission)
67 |
68 | def step(self, action):
69 | obs, reward, done, info = MiniGridEnv.step(self, action)
70 |
71 | ax, ay = self.agent_pos
72 | tx, ty = self.target_pos
73 |
74 | # Toggle/pickup action terminates the episode
75 | if action == self.actions.toggle:
76 | done = True
77 |
78 | # Reward performing the done action next to the target object
79 | if action == self.actions.done:
80 | if abs(ax - tx) <= 1 and abs(ay - ty) <= 1:
81 | reward = self._reward()
82 | done = True
83 |
84 | return obs, reward, done, info
85 |
86 | class GotoEnv8x8N2(GoToObjectEnv):
87 | def __init__(self):
88 | super().__init__(size=8, numObjs=2)
89 |
90 | register(
91 | id='MiniGrid-GoToObject-6x6-N2-v0',
92 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:GoToObjectEnv'
93 | )
94 |
95 | register(
96 | id='MiniGrid-GoToObject-8x8-N2-v0',
97 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:GotoEnv8x8N2'
98 | )
99 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/human.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from icecream import ic
3 | import collections
4 | import math
5 | from copy import deepcopy
6 |
7 | class HumanEnv(MiniGridEnv):
8 | """
9 | Environment in which the agent is instructed to go to a given object
10 | named using an English text string
11 | """
12 |
13 | def __init__(
14 | self,
15 | num_agents=2,
16 | num_preies=2,
17 | num_obstacles=4,
18 | direction_alpha=0.5,
19 | use_direction_reward = False,
20 | use_human_command=False,
21 | coverage_discounter=0.1,
22 | size=19
23 | ):
24 | self.size = size
25 | self.num_preies = num_preies
26 | self.use_direction_reward = use_direction_reward
27 | self.direction_alpha = direction_alpha
28 | self.use_human_command = use_human_command
29 | self.coverage_discounter = coverage_discounter
30 | # initial the covering rate
31 | self.covering_rate = 0
32 | # Reduce obstacles if there are too many
33 | if num_obstacles <= size/2 + 1:
34 | self.num_obstacles = int(num_obstacles)
35 | else:
36 | self.num_obstacles = int(size/2)
37 |
38 | super().__init__(
39 | num_agents=num_agents,
40 | grid_size=size,
41 | max_steps=math.floor(((size-2)**2) / num_agents * 2),
42 | # Set this to True for maximum speed
43 | see_through_walls=True
44 | )
45 |
46 | def _gen_grid(self, width, height):
47 | # Create the grid
48 | self.grid = Grid(width, height)
49 |
50 | # Generate the surrounding walls
51 | self.grid.horz_wall(0, 0)
52 | self.grid.horz_wall(0, height - 1)
53 | self.grid.vert_wall(0, 0)
54 | self.grid.vert_wall(width - 1, 0)
55 |
56 | room_w = width // 2
57 | room_h = height // 2
58 |
59 | # For each row of rooms
60 | for j in range(0, 2):
61 |
62 | # For each column
63 | for i in range(0, 2):
64 | xL = i * room_w
65 | yT = j * room_h
66 | xR = xL + room_w
67 | yB = yT + room_h
68 |
69 | # Bottom wall and door
70 | if i + 1 < 2:
71 | self.grid.vert_wall(xR, yT, room_h)
72 | pos = (xR, self._rand_int(yT + 1, yB))
73 | self.grid.set(*pos, None)
74 |
75 | # Bottom wall and door
76 | if j + 1 < 2:
77 | self.grid.horz_wall(xL, yB, room_w)
78 | pos = (self._rand_int(xL + 1, xR), yB)
79 | self.grid.set(*pos, None)
80 |
81 | # initial the cover_grid
82 | self.cover_grid = np.zeros([width,height])
83 | for j in range(0, height):
84 | for i in range(0, width):
85 | if self.grid.get(i,j) != None and self.grid.get(i,j).type == 'wall':
86 | self.cover_grid[j,i] = 1.0
87 | self.cover_grid_initial = self.cover_grid.copy()
88 | self.num_none = collections.Counter(self.cover_grid_initial.flatten())[0.]
89 | # import pdb; pdb.set_trace()
90 |
91 | # Types and colors of objects we can generate
92 | types = ['key']
93 |
94 | objs = []
95 | objPos = []
96 |
97 | # Until we have generated all the objects
98 | while len(objs) < self.num_preies:
99 | objType = self._rand_elem(types)
100 | objColor = self._rand_elem(COLOR_NAMES)
101 |
102 | # If this object already exists, try again
103 | if (objType, objColor) in objs:
104 | continue
105 |
106 | if objType == 'key':
107 | obj = Key(objColor)
108 | elif objType == 'box':
109 | obj = Box(objColor)
110 | elif objType == 'ball':
111 | obj = Ball(objColor)
112 |
113 | pos = self.place_obj(obj)
114 | objs.append((objType, objColor))
115 | objPos.append(pos)
116 |
117 | # Place obstacles
118 | self.obstacles = []
119 | for i_obst in range(self.num_obstacles):
120 | self.obstacles.append(Obstacle())
121 | pos = self.place_obj(self.obstacles[i_obst], max_tries=100)
122 |
123 | self.occupy_grid = self.grid.copy()
124 | # Randomize the agent start position and orientation
125 | self.place_agent()
126 |
127 | # Choose a random object to be picked up
128 | objIdx = self._rand_int(0, len(objs))
129 | self.targetType, self.target_color = objs[objIdx]
130 | self.target_pos = objPos[objIdx]
131 |
132 | # direction
133 | array_direction = np.array([[0,1], [0,-1], [1,0], [-1,0], [1,1], [1,-1], [-1,1], [-1,-1]])
134 | self.direction = []
135 | self.direction_encoder = []
136 | self.direction_index = []
137 | for agent_id in range(self.num_agents):
138 | center_pos = np.array([int((self.size-1)/2),int((self.size-1)/2)])
139 | direction = np.sign(center_pos - self.agent_pos[agent_id])
140 | direction_index = np.argmax(np.all(np.where(array_direction == direction, True, False), axis=1))
141 | direction_encoder = np.eye(8)[direction_index]
142 | self.direction_index.append(direction_index)
143 | self.direction.append(direction)
144 | self.direction_encoder.append(direction_encoder)
145 |
146 | # text
147 | descStr = '%s %s' % (self.target_color, self.targetType)
148 | self.mission = 'go to the %s' % descStr
149 | # print(self.mission)
150 |
151 | def step(self, action):
152 | obs, reward, done, info = MiniGridEnv.step(self, action)
153 |
154 | rewards = []
155 |
156 | for agent_id in range(self.num_agents):
157 | ax, ay = self.agent_pos[agent_id]
158 | tx, ty = self.target_pos
159 | if self.cover_grid[ay,ax] == 0:
160 | reward += self.coverage_discounter
161 | self.cover_grid[ay, ax] = 1.0
162 | self.covering_rate = collections.Counter((self.cover_grid - self.cover_grid_initial).flatten())[1] / self.num_none
163 |
164 | # if abs(ax - tx) < 1 and abs(ay - ty) < 1:
165 | # reward += 1.0
166 | # self.num_reach_goal += 1
167 | # # done = True
168 |
169 | rewards.append(reward)
170 |
171 | rewards = [[np.sum(rewards)]] * self.num_agents
172 |
173 | dones = [done for agent_id in range(self.num_agents)]
174 |
175 | info['num_reach_goal'] = self.num_reach_goal
176 | info['covering_rate'] = self.covering_rate
177 | info['num_same_direction'] = self.num_same_direction
178 |
179 | return obs, rewards, dones, info
180 |
181 |
182 |
183 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/keycorridor.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.roomgrid import RoomGrid
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class KeyCorridor(RoomGrid):
5 | """
6 | A ball is behind a locked door, the key is placed in a
7 | random room.
8 | """
9 |
10 | def __init__(
11 | self,
12 | num_rows=3,
13 | obj_type="ball",
14 | room_size=6,
15 | seed=None
16 | ):
17 | self.obj_type = obj_type
18 |
19 | super().__init__(
20 | room_size=room_size,
21 | num_rows=num_rows,
22 | max_steps=30*room_size**2,
23 | seed=seed,
24 | )
25 |
26 | def _gen_grid(self, width, height):
27 | super()._gen_grid(width, height)
28 |
29 | # Connect the middle column rooms into a hallway
30 | for j in range(1, self.num_rows):
31 | self.remove_wall(1, j, 3)
32 |
33 | # Add a locked door on the bottom right
34 | # Add an object behind the locked door
35 | room_idx = self._rand_int(0, self.num_rows)
36 | door, _ = self.add_door(2, room_idx, 2, locked=True)
37 | obj, _ = self.add_object(2, room_idx, kind=self.obj_type)
38 |
39 | # Add a key in a random room on the left side
40 | self.add_object(0, self._rand_int(0, self.num_rows), 'key', door.color)
41 |
42 | # Place the agent in the middle
43 | self.place_agent(1, self.num_rows // 2)
44 |
45 | # Make sure all rooms are accessible
46 | self.connect_all()
47 |
48 | self.obj = obj
49 | self.mission = "pick up the %s %s" % (obj.color, obj.type)
50 |
51 | def step(self, action):
52 | obs, reward, done, info = super().step(action)
53 |
54 | if action == self.actions.pickup:
55 | if self.carrying and self.carrying == self.obj:
56 | reward = self._reward()
57 | done = True
58 |
59 | return obs, reward, done, info
60 |
61 | class KeyCorridorS3R1(KeyCorridor):
62 | def __init__(self, seed=None):
63 | super().__init__(
64 | room_size=3,
65 | num_rows=1,
66 | seed=seed
67 | )
68 |
69 | class KeyCorridorS3R2(KeyCorridor):
70 | def __init__(self, seed=None):
71 | super().__init__(
72 | room_size=3,
73 | num_rows=2,
74 | seed=seed
75 | )
76 |
77 | class KeyCorridorS3R3(KeyCorridor):
78 | def __init__(self, seed=None):
79 | super().__init__(
80 | room_size=3,
81 | num_rows=3,
82 | seed=seed
83 | )
84 |
85 | class KeyCorridorS4R3(KeyCorridor):
86 | def __init__(self, seed=None):
87 | super().__init__(
88 | room_size=4,
89 | num_rows=3,
90 | seed=seed
91 | )
92 |
93 | class KeyCorridorS5R3(KeyCorridor):
94 | def __init__(self, seed=None):
95 | super().__init__(
96 | room_size=5,
97 | num_rows=3,
98 | seed=seed
99 | )
100 |
101 | class KeyCorridorS6R3(KeyCorridor):
102 | def __init__(self, seed=None):
103 | super().__init__(
104 | room_size=6,
105 | num_rows=3,
106 | seed=seed
107 | )
108 |
109 | register(
110 | id='MiniGrid-KeyCorridorS3R1-v0',
111 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS3R1'
112 | )
113 |
114 | register(
115 | id='MiniGrid-KeyCorridorS3R2-v0',
116 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS3R2'
117 | )
118 |
119 | register(
120 | id='MiniGrid-KeyCorridorS3R3-v0',
121 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS3R3'
122 | )
123 |
124 | register(
125 | id='MiniGrid-KeyCorridorS4R3-v0',
126 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS4R3'
127 | )
128 |
129 | register(
130 | id='MiniGrid-KeyCorridorS5R3-v0',
131 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS5R3'
132 | )
133 |
134 | register(
135 | id='MiniGrid-KeyCorridorS6R3-v0',
136 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS6R3'
137 | )
138 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/lavagap.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class LavaGapEnv(MiniGridEnv):
5 | """
6 | Environment with one wall of lava with a small gap to cross through
7 | This environment is similar to LavaCrossing but simpler in structure.
8 | """
9 |
10 | def __init__(self, size, obstacle_type=Lava, seed=None):
11 | self.obstacle_type = obstacle_type
12 | super().__init__(
13 | grid_size=size,
14 | max_steps=4*size*size,
15 | # Set this to True for maximum speed
16 | see_through_walls=False,
17 | seed=None
18 | )
19 |
20 | def _gen_grid(self, width, height):
21 | assert width >= 5 and height >= 5
22 |
23 | # Create an empty grid
24 | self.grid = Grid(width, height)
25 |
26 | # Generate the surrounding walls
27 | self.grid.wall_rect(0, 0, width, height)
28 |
29 | # Place the agent in the top-left corner
30 | self.agent_pos = (1, 1)
31 | self.agent_dir = 0
32 |
33 | # Place a goal square in the bottom-right corner
34 | self.goal_pos = np.array((width - 2, height - 2))
35 | self.put_obj(Goal(), *self.goal_pos)
36 |
37 | # Generate and store random gap position
38 | self.gap_pos = np.array((
39 | self._rand_int(2, width - 2),
40 | self._rand_int(1, height - 1),
41 | ))
42 |
43 | # Place the obstacle wall
44 | self.grid.vert_wall(self.gap_pos[0], 1, height - 2, self.obstacle_type)
45 |
46 | # Put a hole in the wall
47 | self.grid.set(*self.gap_pos, None)
48 |
49 | self.mission = (
50 | "avoid the lava and get to the green goal square"
51 | if self.obstacle_type == Lava
52 | else "find the opening and get to the green goal square"
53 | )
54 |
55 | class LavaGapS5Env(LavaGapEnv):
56 | def __init__(self):
57 | super().__init__(size=5)
58 |
59 | class LavaGapS6Env(LavaGapEnv):
60 | def __init__(self):
61 | super().__init__(size=6)
62 |
63 | class LavaGapS7Env(LavaGapEnv):
64 | def __init__(self):
65 | super().__init__(size=7)
66 |
67 | register(
68 | id='MiniGrid-LavaGapS5-v0',
69 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaGapS5Env'
70 | )
71 |
72 | register(
73 | id='MiniGrid-LavaGapS6-v0',
74 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaGapS6Env'
75 | )
76 |
77 | register(
78 | id='MiniGrid-LavaGapS7-v0',
79 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaGapS7Env'
80 | )
81 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/lockedroom.py:
--------------------------------------------------------------------------------
1 | from gym import spaces
2 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
3 | from onpolicy.envs.gridworld.gym_minigrid.register import register
4 |
5 | class Room:
6 | def __init__(self,
7 | top,
8 | size,
9 | doorPos
10 | ):
11 | self.top = top
12 | self.size = size
13 | self.doorPos = doorPos
14 | self.color = None
15 | self.locked = False
16 |
17 | def rand_pos(self, env):
18 | topX, topY = self.top
19 | sizeX, sizeY = self.size
20 | return env._rand_pos(
21 | topX + 1, topX + sizeX - 1,
22 | topY + 1, topY + sizeY - 1
23 | )
24 |
25 | class LockedRoom(MiniGridEnv):
26 | """
27 | Environment in which the agent is instructed to go to a given object
28 | named using an English text string
29 | """
30 |
31 | def __init__(
32 | self,
33 | size=19
34 | ):
35 | super().__init__(grid_size=size, max_steps=10*size)
36 |
37 | def _gen_grid(self, width, height):
38 | # Create the grid
39 | self.grid = Grid(width, height)
40 |
41 | # Generate the surrounding walls
42 | for i in range(0, width):
43 | self.grid.set(i, 0, Wall())
44 | self.grid.set(i, height-1, Wall())
45 | for j in range(0, height):
46 | self.grid.set(0, j, Wall())
47 | self.grid.set(width-1, j, Wall())
48 |
49 | # Hallway walls
50 | lWallIdx = width // 2 - 2
51 | rWallIdx = width // 2 + 2
52 | for j in range(0, height):
53 | self.grid.set(lWallIdx, j, Wall())
54 | self.grid.set(rWallIdx, j, Wall())
55 |
56 | self.rooms = []
57 |
58 | # Room splitting walls
59 | for n in range(0, 3):
60 | j = n * (height // 3)
61 | for i in range(0, lWallIdx):
62 | self.grid.set(i, j, Wall())
63 | for i in range(rWallIdx, width):
64 | self.grid.set(i, j, Wall())
65 |
66 | roomW = lWallIdx + 1
67 | roomH = height // 3 + 1
68 | self.rooms.append(Room(
69 | (0, j),
70 | (roomW, roomH),
71 | (lWallIdx, j + 3)
72 | ))
73 | self.rooms.append(Room(
74 | (rWallIdx, j),
75 | (roomW, roomH),
76 | (rWallIdx, j + 3)
77 | ))
78 |
79 | # Choose one random room to be locked
80 | lockedRoom = self._rand_elem(self.rooms)
81 | lockedRoom.locked = True
82 | goalPos = lockedRoom.rand_pos(self)
83 | self.grid.set(*goalPos, Goal())
84 |
85 | # Assign the door colors
86 | colors = set(COLOR_NAMES)
87 | for room in self.rooms:
88 | color = self._rand_elem(sorted(colors))
89 | colors.remove(color)
90 | room.color = color
91 | if room.locked:
92 | self.grid.set(*room.doorPos, Door(color, is_locked=True))
93 | else:
94 | self.grid.set(*room.doorPos, Door(color))
95 |
96 | # Select a random room to contain the key
97 | while True:
98 | keyRoom = self._rand_elem(self.rooms)
99 | if keyRoom != lockedRoom:
100 | break
101 | keyPos = keyRoom.rand_pos(self)
102 | self.grid.set(*keyPos, Key(lockedRoom.color))
103 |
104 | # Randomize the player start position and orientation
105 | self.agent_pos = self.place_agent(
106 | top=(lWallIdx, 0),
107 | size=(rWallIdx-lWallIdx, height)
108 | )
109 |
110 | # Generate the mission string
111 | self.mission = (
112 | 'get the %s key from the %s room, '
113 | 'unlock the %s door and '
114 | 'go to the goal'
115 | ) % (lockedRoom.color, keyRoom.color, lockedRoom.color)
116 |
117 | def step(self, action):
118 | obs, reward, done, info = MiniGridEnv.step(self, action)
119 | return obs, reward, done, info
120 |
121 | register(
122 | id='MiniGrid-LockedRoom-v0',
123 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LockedRoom'
124 | )
125 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/memory.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class MemoryEnv(MiniGridEnv):
5 | """
6 | This environment is a memory test. The agent starts in a small room
7 | where it sees an object. It then has to go through a narrow hallway
8 | which ends in a split. At each end of the split there is an object,
9 | one of which is the same as the object in the starting room. The
10 | agent has to remember the initial object, and go to the matching
11 | object at split.
12 | """
13 |
14 | def __init__(
15 | self,
16 | seed,
17 | size=8,
18 | random_length=False,
19 | ):
20 | self.random_length = random_length
21 | super().__init__(
22 | seed=seed,
23 | grid_size=size,
24 | max_steps=5*size**2,
25 | # Set this to True for maximum speed
26 | see_through_walls=False,
27 | )
28 |
29 | def _gen_grid(self, width, height):
30 | self.grid = Grid(width, height)
31 |
32 | # Generate the surrounding walls
33 | self.grid.horz_wall(0, 0)
34 | self.grid.horz_wall(0, height-1)
35 | self.grid.vert_wall(0, 0)
36 | self.grid.vert_wall(width - 1, 0)
37 |
38 | assert height % 2 == 1
39 | upper_room_wall = height // 2 - 2
40 | lower_room_wall = height // 2 + 2
41 | if self.random_length:
42 | hallway_end = self._rand_int(4, width - 2)
43 | else:
44 | hallway_end = width - 3
45 |
46 | # Start room
47 | for i in range(1, 5):
48 | self.grid.set(i, upper_room_wall, Wall())
49 | self.grid.set(i, lower_room_wall, Wall())
50 | self.grid.set(4, upper_room_wall + 1, Wall())
51 | self.grid.set(4, lower_room_wall - 1, Wall())
52 |
53 | # Horizontal hallway
54 | for i in range(5, hallway_end):
55 | self.grid.set(i, upper_room_wall + 1, Wall())
56 | self.grid.set(i, lower_room_wall - 1, Wall())
57 |
58 | # Vertical hallway
59 | for j in range(0, height):
60 | if j != height // 2:
61 | self.grid.set(hallway_end, j, Wall())
62 | self.grid.set(hallway_end + 2, j, Wall())
63 |
64 | # Fix the player's start position and orientation
65 | self.agent_pos = (self._rand_int(1, hallway_end + 1), height // 2)
66 | self.agent_dir = 0
67 |
68 | # Place objects
69 | start_room_obj = self._rand_elem([Key, Ball])
70 | self.grid.set(1, height // 2 - 1, start_room_obj('green'))
71 |
72 | other_objs = self._rand_elem([[Ball, Key], [Key, Ball]])
73 | pos0 = (hallway_end + 1, height // 2 - 2)
74 | pos1 = (hallway_end + 1, height // 2 + 2)
75 | self.grid.set(*pos0, other_objs[0]('green'))
76 | self.grid.set(*pos1, other_objs[1]('green'))
77 |
78 | # Choose the target objects
79 | if start_room_obj == other_objs[0]:
80 | self.success_pos = (pos0[0], pos0[1] + 1)
81 | self.failure_pos = (pos1[0], pos1[1] - 1)
82 | else:
83 | self.success_pos = (pos1[0], pos1[1] - 1)
84 | self.failure_pos = (pos0[0], pos0[1] + 1)
85 |
86 | self.mission = 'go to the matching object at the end of the hallway'
87 |
88 | def step(self, action):
89 | if action == MiniGridEnv.Actions.pickup:
90 | action = MiniGridEnv.Actions.toggle
91 | obs, reward, done, info = MiniGridEnv.step(self, action)
92 |
93 | if tuple(self.agent_pos) == self.success_pos:
94 | reward = self._reward()
95 | done = True
96 | if tuple(self.agent_pos) == self.failure_pos:
97 | reward = 0
98 | done = True
99 |
100 | return obs, reward, done, info
101 |
102 | class MemoryS17Random(MemoryEnv):
103 | def __init__(self, seed=None):
104 | super().__init__(seed=seed, size=17, random_length=True)
105 |
106 | register(
107 | id='MiniGrid-MemoryS17Random-v0',
108 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS17Random',
109 | )
110 |
111 | class MemoryS13Random(MemoryEnv):
112 | def __init__(self, seed=None):
113 | super().__init__(seed=seed, size=13, random_length=True)
114 |
115 | register(
116 | id='MiniGrid-MemoryS13Random-v0',
117 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS13Random',
118 | )
119 |
120 | class MemoryS13(MemoryEnv):
121 | def __init__(self, seed=None):
122 | super().__init__(seed=seed, size=13)
123 |
124 | register(
125 | id='MiniGrid-MemoryS13-v0',
126 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS13',
127 | )
128 |
129 | class MemoryS11(MemoryEnv):
130 | def __init__(self, seed=None):
131 | super().__init__(seed=seed, size=11)
132 |
133 | register(
134 | id='MiniGrid-MemoryS11-v0',
135 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS11',
136 | )
137 |
138 | class MemoryS9(MemoryEnv):
139 | def __init__(self, seed=None):
140 | super().__init__(seed=seed, size=9)
141 |
142 | register(
143 | id='MiniGrid-MemoryS9-v0',
144 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS9',
145 | )
146 |
147 | class MemoryS7(MemoryEnv):
148 | def __init__(self, seed=None):
149 | super().__init__(seed=seed, size=7)
150 |
151 | register(
152 | id='MiniGrid-MemoryS7-v0',
153 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS7',
154 | )
155 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/obstructedmaze.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.roomgrid import RoomGrid
3 | from onpolicy.envs.gridworld.gym_minigrid.register import register
4 |
5 | class ObstructedMazeEnv(RoomGrid):
6 | """
7 | A blue ball is hidden in the maze. Doors may be locked,
8 | doors may be obstructed by a ball and keys may be hidden in boxes.
9 | """
10 |
11 | def __init__(self,
12 | num_rows,
13 | num_cols,
14 | num_rooms_visited,
15 | seed=None
16 | ):
17 | room_size = 6
18 | max_steps = 4*num_rooms_visited*room_size**2
19 |
20 | super().__init__(
21 | room_size=room_size,
22 | num_rows=num_rows,
23 | num_cols=num_cols,
24 | max_steps=max_steps,
25 | seed=seed
26 | )
27 |
28 | def _gen_grid(self, width, height):
29 | super()._gen_grid(width, height)
30 |
31 | # Define all possible colors for doors
32 | self.door_colors = self._rand_subset(COLOR_NAMES, len(COLOR_NAMES))
33 | # Define the color of the ball to pick up
34 | self.ball_to_find_color = COLOR_NAMES[0]
35 | # Define the color of the balls that obstruct doors
36 | self.blocking_ball_color = COLOR_NAMES[1]
37 | # Define the color of boxes in which keys are hidden
38 | self.box_color = COLOR_NAMES[2]
39 |
40 | self.mission = "pick up the %s ball" % self.ball_to_find_color
41 |
42 | def step(self, action):
43 | obs, reward, done, info = super().step(action)
44 |
45 | if action == self.actions.pickup:
46 | if self.carrying and self.carrying == self.obj:
47 | reward = self._reward()
48 | done = True
49 |
50 | return obs, reward, done, info
51 |
52 | def add_door(self, i, j, door_idx=0, color=None, locked=False, key_in_box=False, blocked=False):
53 | """
54 | Add a door. If the door must be locked, it also adds the key.
55 | If the key must be hidden, it is put in a box. If the door must
56 | be obstructed, it adds a ball in front of the door.
57 | """
58 |
59 | door, door_pos = super().add_door(i, j, door_idx, color, locked=locked)
60 |
61 | if blocked:
62 | vec = DIR_TO_VEC[door_idx]
63 | blocking_ball = Ball(self.blocking_ball_color) if blocked else None
64 | self.grid.set(door_pos[0]-vec[0], door_pos[1]-vec[1], blocking_ball)
65 |
66 | if locked:
67 | obj = Key(door.color)
68 | if key_in_box:
69 | box = Box(self.box_color) if key_in_box else None
70 | box.contains = obj
71 | obj = box
72 | self.place_in_room(i, j, obj)
73 |
74 | return door, door_pos
75 |
76 | class ObstructedMaze_1Dlhb(ObstructedMazeEnv):
77 | """
78 | A blue ball is hidden in a 2x1 maze. A locked door separates
79 | rooms. Doors are obstructed by a ball and keys are hidden in boxes.
80 | """
81 |
82 | def __init__(self, key_in_box=True, blocked=True, seed=None):
83 | self.key_in_box = key_in_box
84 | self.blocked = blocked
85 |
86 | super().__init__(
87 | num_rows=1,
88 | num_cols=2,
89 | num_rooms_visited=2,
90 | seed=seed
91 | )
92 |
93 | def _gen_grid(self, width, height):
94 | super()._gen_grid(width, height)
95 |
96 | self.add_door(0, 0, door_idx=0, color=self.door_colors[0],
97 | locked=True,
98 | key_in_box=self.key_in_box,
99 | blocked=self.blocked)
100 |
101 | self.obj, _ = self.add_object(1, 0, "ball", color=self.ball_to_find_color)
102 | self.place_agent(0, 0)
103 |
104 | class ObstructedMaze_1Dl(ObstructedMaze_1Dlhb):
105 | def __init__(self, seed=None):
106 | super().__init__(False, False, seed)
107 |
108 | class ObstructedMaze_1Dlh(ObstructedMaze_1Dlhb):
109 | def __init__(self, seed=None):
110 | super().__init__(True, False, seed)
111 |
112 | class ObstructedMaze_Full(ObstructedMazeEnv):
113 | """
114 | A blue ball is hidden in one of the 4 corners of a 3x3 maze. Doors
115 | are locked, doors are obstructed by a ball and keys are hidden in
116 | boxes.
117 | """
118 |
119 | def __init__(self, agent_room=(1, 1), key_in_box=True, blocked=True,
120 | num_quarters=4, num_rooms_visited=25, seed=None):
121 | self.agent_room = agent_room
122 | self.key_in_box = key_in_box
123 | self.blocked = blocked
124 | self.num_quarters = num_quarters
125 |
126 | super().__init__(
127 | num_rows=3,
128 | num_cols=3,
129 | num_rooms_visited=num_rooms_visited,
130 | seed=seed
131 | )
132 |
133 | def _gen_grid(self, width, height):
134 | super()._gen_grid(width, height)
135 |
136 | middle_room = (1, 1)
137 | # Define positions of "side rooms" i.e. rooms that are neither
138 | # corners nor the center.
139 | side_rooms = [(2, 1), (1, 2), (0, 1), (1, 0)][:self.num_quarters]
140 | for i in range(len(side_rooms)):
141 | side_room = side_rooms[i]
142 |
143 | # Add a door between the center room and the side room
144 | self.add_door(*middle_room, door_idx=i, color=self.door_colors[i], locked=False)
145 |
146 | for k in [-1, 1]:
147 | # Add a door to each side of the side room
148 | self.add_door(*side_room, locked=True,
149 | door_idx=(i+k)%4,
150 | color=self.door_colors[(i+k)%len(self.door_colors)],
151 | key_in_box=self.key_in_box,
152 | blocked=self.blocked)
153 |
154 | corners = [(2, 0), (2, 2), (0, 2), (0, 0)][:self.num_quarters]
155 | ball_room = self._rand_elem(corners)
156 |
157 | self.obj, _ = self.add_object(*ball_room, "ball", color=self.ball_to_find_color)
158 | self.place_agent(*self.agent_room)
159 |
160 | class ObstructedMaze_2Dl(ObstructedMaze_Full):
161 | def __init__(self, seed=None):
162 | super().__init__((2, 1), False, False, 1, 4, seed)
163 |
164 | class ObstructedMaze_2Dlh(ObstructedMaze_Full):
165 | def __init__(self, seed=None):
166 | super().__init__((2, 1), True, False, 1, 4, seed)
167 |
168 |
169 | class ObstructedMaze_2Dlhb(ObstructedMaze_Full):
170 | def __init__(self, seed=None):
171 | super().__init__((2, 1), True, True, 1, 4, seed)
172 |
173 | class ObstructedMaze_1Q(ObstructedMaze_Full):
174 | def __init__(self, seed=None):
175 | super().__init__((1, 1), True, True, 1, 5, seed)
176 |
177 | class ObstructedMaze_2Q(ObstructedMaze_Full):
178 | def __init__(self, seed=None):
179 | super().__init__((1, 1), True, True, 2, 11, seed)
180 |
181 | register(
182 | id="MiniGrid-ObstructedMaze-1Dl-v0",
183 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_1Dl"
184 | )
185 |
186 | register(
187 | id="MiniGrid-ObstructedMaze-1Dlh-v0",
188 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_1Dlh"
189 | )
190 |
191 | register(
192 | id="MiniGrid-ObstructedMaze-1Dlhb-v0",
193 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_1Dlhb"
194 | )
195 |
196 | register(
197 | id="MiniGrid-ObstructedMaze-2Dl-v0",
198 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_2Dl"
199 | )
200 |
201 | register(
202 | id="MiniGrid-ObstructedMaze-2Dlh-v0",
203 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_2Dlh"
204 | )
205 |
206 | register(
207 | id="MiniGrid-ObstructedMaze-2Dlhb-v0",
208 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_2Dlhb"
209 | )
210 |
211 | register(
212 | id="MiniGrid-ObstructedMaze-1Q-v0",
213 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_1Q"
214 | )
215 |
216 | register(
217 | id="MiniGrid-ObstructedMaze-2Q-v0",
218 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_2Q"
219 | )
220 |
221 | register(
222 | id="MiniGrid-ObstructedMaze-Full-v0",
223 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_Full"
224 | )
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/playground_v0.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class PlaygroundV0(MiniGridEnv):
5 | """
6 | Environment with multiple rooms and random objects.
7 | This environment has no specific goals or rewards.
8 | """
9 |
10 | def __init__(self):
11 | super().__init__(grid_size=19, max_steps=100)
12 |
13 | def _gen_grid(self, width, height):
14 | # Create the grid
15 | self.grid = Grid(width, height)
16 |
17 | # Generate the surrounding walls
18 | self.grid.horz_wall(0, 0)
19 | self.grid.horz_wall(0, height-1)
20 | self.grid.vert_wall(0, 0)
21 | self.grid.vert_wall(width-1, 0)
22 |
23 | roomW = width // 3
24 | roomH = height // 3
25 |
26 | # For each row of rooms
27 | for j in range(0, 3):
28 |
29 | # For each column
30 | for i in range(0, 3):
31 | xL = i * roomW
32 | yT = j * roomH
33 | xR = xL + roomW
34 | yB = yT + roomH
35 |
36 | # Bottom wall and door
37 | if i+1 < 3:
38 | self.grid.vert_wall(xR, yT, roomH)
39 | pos = (xR, self._rand_int(yT+1, yB-1))
40 | color = self._rand_elem(COLOR_NAMES)
41 | self.grid.set(*pos, Door(color))
42 |
43 | # Bottom wall and door
44 | if j+1 < 3:
45 | self.grid.horz_wall(xL, yB, roomW)
46 | pos = (self._rand_int(xL+1, xR-1), yB)
47 | color = self._rand_elem(COLOR_NAMES)
48 | self.grid.set(*pos, Door(color))
49 |
50 | # Randomize the player start position and orientation
51 | self.place_agent()
52 |
53 | # Place random objects in the world
54 | types = ['key', 'ball', 'box']
55 | for i in range(0, 12):
56 | objType = self._rand_elem(types)
57 | objColor = self._rand_elem(COLOR_NAMES)
58 | if objType == 'key':
59 | obj = Key(objColor)
60 | elif objType == 'ball':
61 | obj = Ball(objColor)
62 | elif objType == 'box':
63 | obj = Box(objColor)
64 | self.place_obj(obj)
65 |
66 | # No explicit mission in this environment
67 | self.mission = ''
68 |
69 | def step(self, action):
70 | obs, reward, done, info = MiniGridEnv.step(self, action)
71 | return obs, reward, done, info
72 |
73 | register(
74 | id='MiniGrid-Playground-v0',
75 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:PlaygroundV0'
76 | )
77 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/putnear.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class PutNearEnv(MiniGridEnv):
5 | """
6 | Environment in which the agent is instructed to place an object near
7 | another object through a natural language string.
8 | """
9 |
10 | def __init__(
11 | self,
12 | size=6,
13 | numObjs=2
14 | ):
15 | self.numObjs = numObjs
16 |
17 | super().__init__(
18 | grid_size=size,
19 | max_steps=5*size,
20 | # Set this to True for maximum speed
21 | see_through_walls=True
22 | )
23 |
24 | def _gen_grid(self, width, height):
25 | self.grid = Grid(width, height)
26 |
27 | # Generate the surrounding walls
28 | self.grid.horz_wall(0, 0)
29 | self.grid.horz_wall(0, height-1)
30 | self.grid.vert_wall(0, 0)
31 | self.grid.vert_wall(width-1, 0)
32 |
33 | # Types and colors of objects we can generate
34 | types = ['key', 'ball', 'box']
35 |
36 | objs = []
37 | objPos = []
38 |
39 | def near_obj(env, p1):
40 | for p2 in objPos:
41 | dx = p1[0] - p2[0]
42 | dy = p1[1] - p2[1]
43 | if abs(dx) <= 1 and abs(dy) <= 1:
44 | return True
45 | return False
46 |
47 | # Until we have generated all the objects
48 | while len(objs) < self.numObjs:
49 | objType = self._rand_elem(types)
50 | objColor = self._rand_elem(COLOR_NAMES)
51 |
52 | # If this object already exists, try again
53 | if (objType, objColor) in objs:
54 | continue
55 |
56 | if objType == 'key':
57 | obj = Key(objColor)
58 | elif objType == 'ball':
59 | obj = Ball(objColor)
60 | elif objType == 'box':
61 | obj = Box(objColor)
62 |
63 | pos = self.place_obj(obj, reject_fn=near_obj)
64 |
65 | objs.append((objType, objColor))
66 | objPos.append(pos)
67 |
68 | # Randomize the agent start position and orientation
69 | self.place_agent()
70 |
71 | # Choose a random object to be moved
72 | objIdx = self._rand_int(0, len(objs))
73 | self.move_type, self.moveColor = objs[objIdx]
74 | self.move_pos = objPos[objIdx]
75 |
76 | # Choose a target object (to put the first object next to)
77 | while True:
78 | targetIdx = self._rand_int(0, len(objs))
79 | if targetIdx != objIdx:
80 | break
81 | self.target_type, self.target_color = objs[targetIdx]
82 | self.target_pos = objPos[targetIdx]
83 |
84 | self.mission = 'put the %s %s near the %s %s' % (
85 | self.moveColor,
86 | self.move_type,
87 | self.target_color,
88 | self.target_type
89 | )
90 |
91 | def step(self, action):
92 | preCarrying = self.carrying
93 |
94 | obs, reward, done, info = super().step(action)
95 |
96 | u, v = self.dir_vec
97 | ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
98 | tx, ty = self.target_pos
99 |
100 | # If we picked up the wrong object, terminate the episode
101 | if action == self.actions.pickup and self.carrying:
102 | if self.carrying.type != self.move_type or self.carrying.color != self.moveColor:
103 | done = True
104 |
105 | # If successfully dropping an object near the target
106 | if action == self.actions.drop and preCarrying:
107 | if self.grid.get(ox, oy) is preCarrying:
108 | if abs(ox - tx) <= 1 and abs(oy - ty) <= 1:
109 | reward = self._reward()
110 | done = True
111 |
112 | return obs, reward, done, info
113 |
114 | class PutNear8x8N3(PutNearEnv):
115 | def __init__(self):
116 | super().__init__(size=8, numObjs=3)
117 |
118 | register(
119 | id='MiniGrid-PutNear-6x6-N2-v0',
120 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:PutNearEnv'
121 | )
122 |
123 | register(
124 | id='MiniGrid-PutNear-8x8-N3-v0',
125 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:PutNear8x8N3'
126 | )
127 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/redbluedoors.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import *
2 | from onpolicy.envs.gridworld.gym_minigrid.register import register
3 |
4 | class RedBlueDoorEnv(MiniGridEnv):
5 | """
6 | Single room with red and blue doors on opposite sides.
7 | The red door must be opened before the blue door to
8 | obtain a reward.
9 | """
10 |
11 | def __init__(self, size=8):
12 | self.size = size
13 |
14 | super().__init__(
15 | width=2*size,
16 | height=size,
17 | max_steps=20*size*size
18 | )
19 |
20 | def _gen_grid(self, width, height):
21 | # Create an empty grid
22 | self.grid = Grid(width, height)
23 |
24 | # Generate the grid walls
25 | self.grid.wall_rect(0, 0, 2*self.size, self.size)
26 | self.grid.wall_rect(self.size//2, 0, self.size, self.size)
27 |
28 | # Place the agent in the top-left corner
29 | self.place_agent(top=(self.size//2, 0), size=(self.size, self.size))
30 |
31 | # Add a red door at a random position in the left wall
32 | pos = self._rand_int(1, self.size - 1)
33 | self.red_door = Door("red")
34 | self.grid.set(self.size//2, pos, self.red_door)
35 |
36 | # Add a blue door at a random position in the right wall
37 | pos = self._rand_int(1, self.size - 1)
38 | self.blue_door = Door("blue")
39 | self.grid.set(self.size//2 + self.size - 1, pos, self.blue_door)
40 |
41 | # Generate the mission string
42 | self.mission = "open the red door then the blue door"
43 |
44 | def step(self, action):
45 | red_door_opened_before = self.red_door.is_open
46 | blue_door_opened_before = self.blue_door.is_open
47 |
48 | obs, reward, done, info = MiniGridEnv.step(self, action)
49 |
50 | red_door_opened_after = self.red_door.is_open
51 | blue_door_opened_after = self.blue_door.is_open
52 |
53 | if blue_door_opened_after:
54 | if red_door_opened_before:
55 | reward = self._reward()
56 | done = True
57 | else:
58 | reward = 0
59 | done = True
60 |
61 | elif red_door_opened_after:
62 | if blue_door_opened_before:
63 | reward = 0
64 | done = True
65 |
66 | return obs, reward, done, info
67 |
68 | class RedBlueDoorEnv6x6(RedBlueDoorEnv):
69 | def __init__(self):
70 | super().__init__(size=6)
71 |
72 | register(
73 | id='MiniGrid-RedBlueDoors-6x6-v0',
74 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:RedBlueDoorEnv6x6'
75 | )
76 |
77 | register(
78 | id='MiniGrid-RedBlueDoors-8x8-v0',
79 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:RedBlueDoorEnv'
80 | )
81 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/unlock.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import Ball
2 | from onpolicy.envs.gridworld.gym_minigrid.roomgrid import RoomGrid
3 | from onpolicy.envs.gridworld.gym_minigrid.register import register
4 |
5 | class Unlock(RoomGrid):
6 | """
7 | Unlock a door
8 | """
9 |
10 | def __init__(self, seed=None):
11 | room_size = 6
12 | super().__init__(
13 | num_rows=1,
14 | num_cols=2,
15 | room_size=room_size,
16 | max_steps=8*room_size**2,
17 | seed=seed
18 | )
19 |
20 | def _gen_grid(self, width, height):
21 | super()._gen_grid(width, height)
22 |
23 | # Make sure the two rooms are directly connected by a locked door
24 | door, _ = self.add_door(0, 0, 0, locked=True)
25 | # Add a key to unlock the door
26 | self.add_object(0, 0, 'key', door.color)
27 |
28 | self.place_agent(0, 0)
29 |
30 | self.door = door
31 | self.mission = "open the door"
32 |
33 | def step(self, action):
34 | obs, reward, done, info = super().step(action)
35 |
36 | if action == self.actions.toggle:
37 | if self.door.is_open:
38 | reward = self._reward()
39 | done = True
40 |
41 | return obs, reward, done, info
42 |
43 | register(
44 | id='MiniGrid-Unlock-v0',
45 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:Unlock'
46 | )
47 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/envs/unlockpickup.py:
--------------------------------------------------------------------------------
1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import Ball
2 | from onpolicy.envs.gridworld.gym_minigrid.roomgrid import RoomGrid
3 | from onpolicy.envs.gridworld.gym_minigrid.register import register
4 |
5 | class UnlockPickup(RoomGrid):
6 | """
7 | Unlock a door, then pick up a box in another room
8 | """
9 |
10 | def __init__(self, seed=None):
11 | room_size = 6
12 | super().__init__(
13 | num_rows=1,
14 | num_cols=2,
15 | room_size=room_size,
16 | max_steps=8*room_size**2,
17 | seed=seed
18 | )
19 |
20 | def _gen_grid(self, width, height):
21 | super()._gen_grid(width, height)
22 |
23 | # Add a box to the room on the right
24 | obj, _ = self.add_object(1, 0, kind="box")
25 | # Make sure the two rooms are directly connected by a locked door
26 | door, _ = self.add_door(0, 0, 0, locked=True)
27 | # Add a key to unlock the door
28 | self.add_object(0, 0, 'key', door.color)
29 |
30 | self.place_agent(0, 0)
31 |
32 | self.obj = obj
33 | self.mission = "pick up the %s %s" % (obj.color, obj.type)
34 |
35 | def step(self, action):
36 | obs, reward, done, info = super().step(action)
37 |
38 | if action == self.actions.pickup:
39 | if self.carrying and self.carrying == self.obj:
40 | reward = self._reward()
41 | done = True
42 |
43 | return obs, reward, done, info
44 |
45 | register(
46 | id='MiniGrid-UnlockPickup-v0',
47 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:UnlockPickup'
48 | )
49 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/register.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register as gym_register
2 |
3 | env_list = []
4 |
5 | def register(
6 | id,
7 | grid_size,
8 | max_steps,
9 | local_step_num,
10 | agent_view_size,
11 | num_obstacles,
12 | num_agents,
13 | agent_pos,
14 | entry_point,
15 | reward_threshold=0.95,
16 | use_merge = True,
17 | use_merge_plan = True,
18 | use_constrict_map = True,
19 | use_fc_net = False,
20 | use_agent_id = False,
21 | use_stack = False,
22 | use_orientation = False,
23 | use_same_location = True,
24 | use_complete_reward = True,
25 | use_agent_obstacle = False,
26 | use_multiroom = False,
27 | use_irregular_room = False,
28 | use_time_penalty = False,
29 | use_overlap_penalty = False,
30 | astar_cost_mode = 'normal'
31 | ):
32 | assert id.startswith("MiniGrid-")
33 | assert id not in env_list
34 |
35 | # Register the environment with OpenAI gym
36 | gym_register(
37 | id=id,
38 | entry_point=entry_point,
39 | kwargs={
40 | 'grid_size': grid_size,
41 | 'max_steps': max_steps,
42 | 'local_step_num': local_step_num,
43 | 'agent_view_size': agent_view_size,
44 | 'num_obstacles': num_obstacles,
45 | 'num_agents': num_agents,
46 | 'agent_pos': agent_pos,
47 | 'use_merge': use_merge,
48 | 'use_merge_plan': use_merge_plan,
49 | 'use_constrict_map': use_constrict_map,
50 | 'use_fc_net':use_fc_net,
51 | 'use_agent_id':use_agent_id,
52 | 'use_stack':use_stack,
53 | 'use_orientation':use_orientation,
54 | 'use_same_location': use_same_location,
55 | 'use_complete_reward': use_complete_reward,
56 | 'use_agent_obstacle': use_agent_obstacle,
57 | 'use_multiroom': use_multiroom,
58 | 'use_irregular_room': use_irregular_room,
59 | 'use_time_penalty': use_time_penalty,
60 | 'use_overlap_penalty': use_overlap_penalty,
61 | 'astar_cost_mode': astar_cost_mode
62 | },
63 | reward_threshold=reward_threshold
64 | )
65 |
66 | # Add the environment to the set
67 | env_list.append(id)
68 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/rendering.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | def downsample(img, factor):
5 | """
6 | Downsample an image along both dimensions by some factor
7 | """
8 |
9 | assert img.shape[0] % factor == 0
10 | assert img.shape[1] % factor == 0
11 |
12 | img = img.reshape([img.shape[0]//factor, factor, img.shape[1]//factor, factor, 3])
13 | img = img.mean(axis=3)
14 | img = img.mean(axis=1)
15 |
16 | return img
17 |
18 | def fill_coords(img, fn, color):
19 | """
20 | Fill pixels of an image with coordinates matching a filter function
21 | """
22 |
23 | for y in range(img.shape[0]):
24 | for x in range(img.shape[1]):
25 | yf = (y + 0.5) / img.shape[0]
26 | xf = (x + 0.5) / img.shape[1]
27 | if fn(xf, yf):
28 | img[y, x] = color
29 |
30 | return img
31 |
32 | def rotate_fn(fin, cx, cy, theta):
33 | def fout(x, y):
34 | x = x - cx
35 | y = y - cy
36 |
37 | x2 = cx + x * math.cos(-theta) - y * math.sin(-theta)
38 | y2 = cy + y * math.cos(-theta) + x * math.sin(-theta)
39 |
40 | return fin(x2, y2)
41 |
42 | return fout
43 |
44 | def point_in_line(x0, y0, x1, y1, r): #rounded rectangle
45 | p0 = np.array([x0, y0])
46 | p1 = np.array([x1, y1])
47 | dir = p1 - p0
48 | dist = np.linalg.norm(dir)
49 | dir = dir / dist
50 |
51 | xmin = min(x0, x1) - r
52 | xmax = max(x0, x1) + r
53 | ymin = min(y0, y1) - r
54 | ymax = max(y0, y1) + r
55 |
56 | def fn(x, y):
57 | # Fast, early escape test
58 | if x < xmin or x > xmax or y < ymin or y > ymax:
59 | return False
60 |
61 | q = np.array([x, y])
62 | pq = q - p0
63 |
64 | # Closest point on line
65 | a = np.dot(pq, dir)
66 | a = np.clip(a, 0, dist)
67 | p = p0 + a * dir
68 |
69 | dist_to_line = np.linalg.norm(q - p)
70 | return dist_to_line <= r
71 |
72 | return fn
73 |
74 | def point_in_circle(cx, cy, r):
75 | def fn(x, y):
76 | return (x-cx)*(x-cx) + (y-cy)*(y-cy) <= r * r
77 | return fn
78 |
79 | def point_in_rect(xmin, xmax, ymin, ymax):
80 | def fn(x, y):
81 | return x >= xmin and x <= xmax and y >= ymin and y <= ymax
82 | return fn
83 |
84 | def point_in_triangle(a, b, c):
85 | a = np.array(a)
86 | b = np.array(b)
87 | c = np.array(c)
88 |
89 | def fn(x, y):
90 | v0 = c - a
91 | v1 = b - a
92 | v2 = np.array((x, y)) - a
93 |
94 | # Compute dot products
95 | dot00 = np.dot(v0, v0)
96 | dot01 = np.dot(v0, v1)
97 | dot02 = np.dot(v0, v2)
98 | dot11 = np.dot(v1, v1)
99 | dot12 = np.dot(v1, v2)
100 |
101 | # Compute barycentric coordinates
102 | inv_denom = 1 / (dot00 * dot11 - dot01 * dot01)
103 | u = (dot11 * dot02 - dot01 * dot12) * inv_denom
104 | v = (dot00 * dot12 - dot01 * dot02) * inv_denom
105 |
106 | # Check if point is in triangle
107 | return (u >= 0) and (v >= 0) and (u + v) < 1
108 |
109 | return fn
110 |
111 | def highlight_img(img, color=(255, 255, 255), alpha=0.30):
112 | """
113 | Add highlighting to an image
114 | """
115 |
116 | blend_img = img + alpha * (np.array(color, dtype=np.uint8) - img)
117 | blend_img = blend_img.clip(0, 255).astype(np.uint8)
118 | img[:, :, :] = blend_img
119 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/gym_minigrid/window.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 |
4 | # Only ask users to install matplotlib if they actually need it
5 | try:
6 | import matplotlib.pyplot as plt
7 | except:
8 | print('To display the environment in a window, please install matplotlib, eg:')
9 | print('pip3 install --user matplotlib')
10 | sys.exit(-1)
11 |
12 | class Window:
13 | """
14 | Window to draw a gridworld instance using Matplotlib
15 | """
16 |
17 | def __init__(self, title):
18 | self.fig = None
19 |
20 | self.imshow_obj = None
21 | self.local_imshow_obj = None
22 |
23 | # Create the figure and axes
24 | self.fig, self.ax = plt.subplots(1,2)
25 |
26 | # Show the env name in the window title
27 | self.fig.canvas.set_window_title(title)
28 |
29 | # Turn off x/y axis numbering/ticks
30 | for ax in self.ax:
31 | ax.xaxis.set_ticks_position('none')
32 | ax.yaxis.set_ticks_position('none')
33 | _ = ax.set_xticklabels([])
34 | _ = ax.set_yticklabels([])
35 |
36 | # Flag indicating the window was closed
37 | self.closed = False
38 |
39 | def close_handler(evt):
40 | self.closed = True
41 |
42 | self.fig.canvas.mpl_connect('close_event', close_handler)
43 |
44 | def show_img(self, img, local_img):
45 | """
46 | Show an image or update the image being shown
47 | """
48 |
49 | # Show the first image of the environment
50 | if self.imshow_obj is None:
51 | self.imshow_obj = self.ax[0].imshow(img, interpolation='bilinear')
52 | if self.local_imshow_obj is None:
53 | self.local_imshow_obj = self.ax[1].imshow(local_img, interpolation='bilinear')
54 |
55 | self.imshow_obj.set_data(img)
56 | self.local_imshow_obj.set_data(local_img)
57 |
58 | self.fig.canvas.draw()
59 |
60 | # Let matplotlib process UI events
61 | # This is needed for interactive mode to work properly
62 | plt.pause(0.001)
63 |
64 | def set_caption(self, text):
65 | """
66 | Set/update the caption text below the image
67 | """
68 |
69 | plt.xlabel(text)
70 |
71 | def reg_key_handler(self, key_handler):
72 | """
73 | Register a keyboard event handler
74 | """
75 |
76 | # Keyboard handler
77 | self.fig.canvas.mpl_connect('key_press_event', key_handler)
78 |
79 | def show(self, block=True):
80 | """
81 | Show the window, and start an event loop
82 | """
83 |
84 | # If not blocking, trigger interactive mode
85 | if not block:
86 | plt.ion()
87 |
88 | # Show the plot
89 | # In non-interative mode, this enters the matplotlib event loop
90 | # In interactive mode, this call does not block
91 | plt.show()
92 |
93 | def close(self):
94 | """
95 | Close the window
96 | """
97 |
98 | plt.close()
99 | self.closed = True
100 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/manual_control.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import time
4 | import argparse
5 | import numpy as np
6 | import gym
7 | import gym_minigrid
8 | from gym_minigrid.wrappers import *
9 | from gym_minigrid.window import Window
10 |
11 | def redraw(img):
12 | if not args.agent_view:
13 | img = env.render('rgb_array', tile_size=args.tile_size)
14 |
15 | window.show_img(img)
16 |
17 | def reset():
18 | if args.seed != -1:
19 | env.seed(args.seed)
20 |
21 | obs = env.reset()
22 |
23 | if hasattr(env, 'mission'):
24 | print('Mission: %s' % env.mission)
25 | window.set_caption(env.mission)
26 |
27 | redraw(obs)
28 |
29 | def step(action):
30 | obs, reward, done, info = env.step(action)
31 | print('step=%s, reward=%.2f' % (env.step_count, reward))
32 |
33 | if done:
34 | print('done!')
35 | reset()
36 | else:
37 | redraw(obs)
38 |
39 | def key_handler(event):
40 | print('pressed', event.key)
41 |
42 | if event.key == 'escape':
43 | window.close()
44 | return
45 |
46 | if event.key == 'backspace':
47 | reset()
48 | return
49 |
50 | if event.key == 'left':
51 | step(env.actions.left)
52 | return
53 | if event.key == 'right':
54 | step(env.actions.right)
55 | return
56 | if event.key == 'up':
57 | step(env.actions.forward)
58 | return
59 |
60 | # Spacebar
61 | if event.key == ' ':
62 | step(env.actions.toggle)
63 | return
64 | if event.key == 'pageup':
65 | step(env.actions.pickup)
66 | return
67 | if event.key == 'pagedown':
68 | step(env.actions.drop)
69 | return
70 |
71 | if event.key == 'enter':
72 | step(env.actions.done)
73 | return
74 |
75 | parser = argparse.ArgumentParser()
76 | parser.add_argument(
77 | "--env",
78 | help="gym environment to load",
79 | default='MiniGrid-MultiRoom-N6-v0'
80 | )
81 | parser.add_argument(
82 | "--seed",
83 | type=int,
84 | help="random seed to generate the environment with",
85 | default=-1
86 | )
87 | parser.add_argument(
88 | "--tile_size",
89 | type=int,
90 | help="size at which to render tiles",
91 | default=32
92 | )
93 | parser.add_argument(
94 | '--agent_view',
95 | default=False,
96 | help="draw the agent sees (partially observable view)",
97 | action='store_true'
98 | )
99 |
100 | args = parser.parse_args()
101 |
102 | env = gym.make(args.env)
103 |
104 | if args.agent_view:
105 | env = RGBImgPartialObsWrapper(env)
106 | env = ImgObsWrapper(env)
107 |
108 | window = Window('gym_minigrid - ' + args.env)
109 | window.reg_key_handler(key_handler)
110 |
111 | reset()
112 |
113 | # Blocking event loop
114 | window.show(block=True)
115 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/run_tests.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import random
4 | import numpy as np
5 | import gym
6 | from gym_minigrid.register import env_list
7 | from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX
8 |
9 | # Test specifically importing a specific environment
10 | from gym_minigrid.envs import DoorKeyEnv
11 |
12 | # Test importing wrappers
13 | from gym_minigrid.wrappers import *
14 |
15 | ##############################################################################
16 |
17 | print('%d environments registered' % len(env_list))
18 |
19 | for env_idx, env_name in enumerate(env_list):
20 | print('testing {} ({}/{})'.format(env_name, env_idx+1, len(env_list)))
21 |
22 | # Load the gym environment
23 | env = gym.make(env_name)
24 | env.max_steps = min(env.max_steps, 200)
25 | env.reset()
26 | env.render('rgb_array')
27 |
28 | # Verify that the same seed always produces the same environment
29 | for i in range(0, 5):
30 | seed = 1337 + i
31 | env.seed(seed)
32 | grid1 = env.grid
33 | env.seed(seed)
34 | grid2 = env.grid
35 | assert grid1 == grid2
36 |
37 | env.reset()
38 |
39 | # Run for a few episodes
40 | num_episodes = 0
41 | while num_episodes < 5:
42 | # Pick a random action
43 | action = random.randint(0, env.action_space.n - 1)
44 |
45 | obs, reward, done, info = env.step(action)
46 |
47 | # Validate the agent position
48 | assert env.agent_pos[0] < env.width
49 | assert env.agent_pos[1] < env.height
50 |
51 | # Test observation encode/decode roundtrip
52 | img = obs['image']
53 | grid, vis_mask = Grid.decode(img)
54 | img2 = grid.encode(vis_mask=vis_mask)
55 | assert np.array_equal(img, img2)
56 |
57 | # Test the env to string function
58 | str(env)
59 |
60 | # Check that the reward is within the specified range
61 | assert reward >= env.reward_range[0], reward
62 | assert reward <= env.reward_range[1], reward
63 |
64 | if done:
65 | num_episodes += 1
66 | env.reset()
67 |
68 | env.render('rgb_array')
69 |
70 | # Test the close method
71 | env.close()
72 |
73 | env = gym.make(env_name)
74 | env = ReseedWrapper(env)
75 | for _ in range(10):
76 | env.reset()
77 | env.step(0)
78 | env.close()
79 |
80 | env = gym.make(env_name)
81 | env = ImgObsWrapper(env)
82 | env.reset()
83 | env.step(0)
84 | env.close()
85 |
86 | # Test the fully observable wrapper
87 | env = gym.make(env_name)
88 | env = FullyObsWrapper(env)
89 | env.reset()
90 | obs, _, _, _ = env.step(0)
91 | assert obs['image'].shape == env.observation_space.spaces['image'].shape
92 | env.close()
93 |
94 | # RGB image observation wrapper
95 | env = gym.make(env_name)
96 | env = RGBImgPartialObsWrapper(env)
97 | env.reset()
98 | obs, _, _, _ = env.step(0)
99 | assert obs['image'].mean() > 0
100 | env.close()
101 |
102 | env = gym.make(env_name)
103 | env = FlatObsWrapper(env)
104 | env.reset()
105 | env.step(0)
106 | env.close()
107 |
108 | env = gym.make(env_name)
109 | env = ViewSizeWrapper(env, 5)
110 | env.reset()
111 | env.step(0)
112 | env.close()
113 |
114 | # Test the wrappers return proper observation spaces.
115 | wrappers = [
116 | RGBImgObsWrapper,
117 | RGBImgPartialObsWrapper,
118 | OneHotPartialObsWrapper
119 | ]
120 | for wrapper in wrappers:
121 | env = wrapper(gym.make(env_name))
122 | obs_space, wrapper_name = env.observation_space, wrapper.__name__
123 | assert isinstance(
124 | obs_space, spaces.Dict
125 | ), "Observation space for {0} is not a Dict: {1}.".format(
126 | wrapper_name, obs_space
127 | )
128 | # This should not fail either
129 | ImgObsWrapper(env)
130 | env.reset()
131 | env.step(0)
132 | env.close()
133 |
134 | ##############################################################################
135 |
136 | print('testing agent_sees method')
137 | env = gym.make('MiniGrid-DoorKey-6x6-v0')
138 | goal_pos = (env.grid.width - 2, env.grid.height - 2)
139 |
140 | # Test the "in" operator on grid objects
141 | assert ('green', 'goal') in env.grid
142 | assert ('blue', 'key') not in env.grid
143 |
144 | # Test the env.agent_sees() function
145 | env.reset()
146 | for i in range(0, 500):
147 | action = random.randint(0, env.action_space.n - 1)
148 | obs, reward, done, info = env.step(action)
149 |
150 | grid, _ = Grid.decode(obs['image'])
151 | goal_visible = ('green', 'goal') in grid
152 |
153 | agent_sees_goal = env.agent_sees(*goal_pos)
154 | assert agent_sees_goal == goal_visible
155 | if done:
156 | env.reset()
157 |
--------------------------------------------------------------------------------
/onpolicy/envs/gridworld/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name='gym_minigrid',
5 | version='1.0.2',
6 | keywords='memory, environment, agent, rl, openaigym, openai-gym, gym',
7 | url='https://github.com/maximecb/gym-minigrid',
8 | description='Minimalistic gridworld package for OpenAI Gym',
9 | packages=['gym_minigrid', 'gym_minigrid.envs'],
10 | install_requires=[
11 | 'gym>=0.9.6',
12 | 'numpy>=1.15.0'
13 | ]
14 | )
15 |
--------------------------------------------------------------------------------
/onpolicy/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/scripts/__init__.py
--------------------------------------------------------------------------------
/onpolicy/scripts/render/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/scripts/render/__init__.py
--------------------------------------------------------------------------------
/onpolicy/scripts/render_gridworld.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | env="GridWorld"
3 | scenario="MiniGrid-MultiExploration-v0"
4 | num_agents=2
5 | num_obstacles=0
6 | algo="mappo"
7 | exp="async_global_new_attn_para_single_to_async(1-5)"
8 | seed_max=3
9 |
10 | echo "env is ${env}"
11 | for seed in `seq ${seed_max}`
12 | do
13 | CUDA_VISIBLE_DEVICES=2 python render/render_gridworld.py\
14 | --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --scenario_name ${scenario} \
15 | --num_agents ${num_agents} --num_obstacles ${num_obstacles} \
16 | --seed ${seed} --n_training_threads 1 --n_rollout_threads 1 --render_episodes 100 \
17 | --cnn_layers_params '16,7,2,1 32,5,2,1 16,3,1,1' \
18 | --model_dir "./results/GridWorld/MiniGrid-MultiExploration-v0/mappo/async_global_new_attn_para_no_agent_id_overlap_penalty_single_normal50/wandb/run-20220224_140614-3vhpd2ge/files/" \
19 | --max_steps 200 --use_complete_reward --agent_view_size 7 --local_step_num 1 --use_random_pos \
20 | --astar_cost_mode utility --grid_goal --goal_grid_size 5 --cnn_trans_layer 1,3,1,1 \
21 | --use_stack --grid_size 25 --use_recurrent_policy --use_stack --use_global_goal --use_overlap_penalty --use_eval --wandb_name "mapping" --user_name "yang-xy20" --asynch &
22 | done
23 |
--------------------------------------------------------------------------------
/onpolicy/scripts/render_gridworld_ft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | env="GridWorld"
3 | scenario="MiniGrid-MultiExploration-v0"
4 | num_agents=1
5 | grid_size=25
6 | num_obstacles=0
7 | local_step_num=1
8 | seed_max=3
9 | algo='ft_rrt'
10 |
11 | echo "env is ${env}"
12 | for seed in `seq ${seed_max}`
13 | do
14 | echo "seed is ${seed}"
15 | exp=new_async_${algo}_grid${grid_size}_stepgoal_${local_step_num}_merge_normal
16 | CUDA_VISIBLE_DEVICES=3 python render/render_gridworld_ft.py\
17 | --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --scenario_name ${scenario} \
18 | --num_agents ${num_agents} --num_obstacles ${num_obstacles} \
19 | --seed ${seed} --n_training_threads 1 --n_rollout_threads 1 --render_episodes 100 \
20 | --cnn_layers_params '16,3,1,1 32,3,1,1 16,3,1,1' \
21 | --ifi 0.5 --max_steps 300 --grid_size ${grid_size} --local_step_num ${local_step_num} --use_random_pos \
22 | --agent_view_size 7 --use_merge --use_merge_plan --use_eval \
23 | --astar_cost_mode "normal" --wandb_name "mapping" --user_name "yang-xy20" --asynch
24 | done
25 |
26 |
--------------------------------------------------------------------------------
/onpolicy/scripts/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/scripts/train/__init__.py
--------------------------------------------------------------------------------
/onpolicy/scripts/train_gridworld.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | env="GridWorld"
3 | scenario="MiniGrid-MultiExploration-v0"
4 | num_agents=3
5 | num_obstacles=0
6 | algo="mappo"
7 | exp="async_global_new_attn_para_disc_single_overlap_normal50"
8 | seed_max=1
9 |
10 | echo "env is ${env}, scenario is ${scenario}, algo is ${algo}, exp is ${exp}, max seed is ${seed_max}"
11 | for seed in `seq ${seed_max}`;
12 | do
13 | echo "seed is ${seed}:"
14 | CUDA_VISIBLE_DEVICES=2 python train/train_gridworld.py \
15 | --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --scenario_name ${scenario} \
16 | --log_interval 1 --wandb_name "mapping" --user_name "yang-xy20" --num_agents ${num_agents} \
17 | --num_obstacles ${num_obstacles} --cnn_layers_params '16,7,2,1 32,5,2,1 16,3,1,1' --hidden_size 64 --seed 1 --n_training_threads 1 \
18 | --n_rollout_threads 50 --num_mini_batch 1 --num_env_steps 80000000 --ppo_epoch 3 --gain 0.01 \
19 | --lr 5e-4 --critic_lr 5e-4 --max_steps 150 --use_complete_reward --agent_view_size 7 --local_step_num 5 --use_random_pos \
20 | --astar_cost_mode normal --cnn_trans_layer 1,3,1,1 --grid_size 25 --use_recurrent_policy \
21 | --use_global_goal --use_overlap_penalty --use_stack --goal_grid_size 5 --use_discrect --asynch
22 | done
--------------------------------------------------------------------------------
/onpolicy/utils/RRT/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/utils/RRT/__init__.py
--------------------------------------------------------------------------------
/onpolicy/utils/RRT/rrt_with_pathsmoothing.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | Path planning Sample Code with RRT with path smoothing
4 |
5 | @author: AtsushiSakai(@Atsushi_twi)
6 |
7 | """
8 |
9 | import math
10 | import os
11 | import random
12 | import sys
13 |
14 | import matplotlib.pyplot as plt
15 |
16 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
17 |
18 | try:
19 | from rrt import RRT
20 | except ImportError:
21 | raise
22 |
23 | show_animation = True
24 |
25 |
26 | def get_path_length(path):
27 | le = 0
28 | for i in range(len(path) - 1):
29 | dx = path[i + 1][0] - path[i][0]
30 | dy = path[i + 1][1] - path[i][1]
31 | d = math.sqrt(dx * dx + dy * dy)
32 | le += d
33 |
34 | return le
35 |
36 |
37 | def get_target_point(path, targetL):
38 | le = 0
39 | ti = 0
40 | lastPairLen = 0
41 | for i in range(len(path) - 1):
42 | dx = path[i + 1][0] - path[i][0]
43 | dy = path[i + 1][1] - path[i][1]
44 | d = math.sqrt(dx * dx + dy * dy)
45 | le += d
46 | if le >= targetL:
47 | ti = i - 1
48 | lastPairLen = d
49 | break
50 |
51 | partRatio = (le - targetL) / lastPairLen
52 |
53 | x = path[ti][0] + (path[ti + 1][0] - path[ti][0]) * partRatio
54 | y = path[ti][1] + (path[ti + 1][1] - path[ti][1]) * partRatio
55 |
56 | return [x, y, ti]
57 |
58 |
59 | def line_collision_check(first, second, obstacleList):
60 | # Line Equation
61 |
62 | x1 = first[0]
63 | y1 = first[1]
64 | x2 = second[0]
65 | y2 = second[1]
66 |
67 | try:
68 | a = y2 - y1
69 | b = -(x2 - x1)
70 | c = y2 * (x2 - x1) - x2 * (y2 - y1)
71 | except ZeroDivisionError:
72 | return False
73 |
74 | for (ox, oy, size) in obstacleList:
75 | d = abs(a * ox + b * oy + c) / (math.sqrt(a * a + b * b))
76 | if d <= size:
77 | return False
78 |
79 | return True # OK
80 |
81 |
82 | def path_smoothing(path, max_iter, obstacle_list):
83 | le = get_path_length(path)
84 |
85 | for i in range(max_iter):
86 | # Sample two points
87 | pickPoints = [random.uniform(0, le), random.uniform(0, le)]
88 | pickPoints.sort()
89 | first = get_target_point(path, pickPoints[0])
90 | second = get_target_point(path, pickPoints[1])
91 |
92 | if first[2] <= 0 or second[2] <= 0:
93 | continue
94 |
95 | if (second[2] + 1) > len(path):
96 | continue
97 |
98 | if second[2] == first[2]:
99 | continue
100 |
101 | # collision check
102 | if not line_collision_check(first, second, obstacle_list):
103 | continue
104 |
105 | # Create New path
106 | newPath = []
107 | newPath.extend(path[:first[2] + 1])
108 | newPath.append([first[0], first[1]])
109 | newPath.append([second[0], second[1]])
110 | newPath.extend(path[second[2] + 1:])
111 | path = newPath
112 | le = get_path_length(path)
113 |
114 | return path
115 |
116 |
117 | def main():
118 | # ====Search Path with RRT====
119 | # Parameter
120 | obstacleList = [
121 | (5, 5, 1),
122 | (3, 6, 2),
123 | (3, 8, 2),
124 | (3, 10, 2),
125 | (7, 5, 2),
126 | (9, 5, 2)
127 | ] # [x,y,size]
128 | rrt = RRT(start=[0, 0], goal=[6, 10],
129 | rand_area=[-2, 15], obstacle_list=obstacleList)
130 | path = rrt.planning(animation=show_animation)
131 |
132 | # Path smoothing
133 | maxIter = 1000
134 | smoothedPath = path_smoothing(path, maxIter, obstacleList)
135 |
136 | # Draw final path
137 | if show_animation:
138 | rrt.draw_graph()
139 | plt.plot([x for (x, y) in path], [y for (x, y) in path], '-r')
140 |
141 | plt.plot([x for (x, y) in smoothedPath], [
142 | y for (x, y) in smoothedPath], '-c')
143 |
144 | plt.grid(True)
145 | plt.pause(0.01) # Need for Mac
146 | plt.show()
147 |
148 |
149 | if __name__ == '__main__':
150 | main()
151 |
--------------------------------------------------------------------------------
/onpolicy/utils/RRT/sobol/__init__.py:
--------------------------------------------------------------------------------
1 | from .sobol import i4_sobol as sobol_quasirand
2 |
--------------------------------------------------------------------------------
/onpolicy/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/utils/__init__.py
--------------------------------------------------------------------------------
/onpolicy/utils/multi_discrete.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 |
4 | # An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates)
5 | # (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py)
6 | class MultiDiscrete(gym.Space):
7 | """
8 | - The multi-discrete action space consists of a series of discrete action spaces with different parameters
9 | - It can be adapted to both a Discrete action space or a continuous (Box) action space
10 | - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space
11 | - It is parametrized by passing an array of arrays containing [min, max] for each discrete action space where the discrete action space can take any integers from `min` to `max` (both inclusive)
12 | Note: A value of 0 always need to represent the NOOP action.
13 | e.g. Nintendo Game Controller
14 | - Can be conceptualized as 3 discrete action spaces:
15 | 1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4
16 | 2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
17 | 3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
18 | - Can be initialized as
19 | MultiDiscrete([ [0,4], [0,1], [0,1] ])
20 | """
21 |
22 | def __init__(self, array_of_param_array):
23 | self.low = np.array([x[0] for x in array_of_param_array])
24 | self.high = np.array([x[1] for x in array_of_param_array])
25 | self.num_discrete_space = self.low.shape[0]
26 | self.n = np.sum(self.high) + 2
27 |
28 | def sample(self):
29 | """ Returns a array with one sample from each discrete action space """
30 | # For each row: round(random .* (max - min) + min, 0)
31 | random_array = np.random.rand(self.num_discrete_space)
32 | return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)]
33 |
34 | def contains(self, x):
35 | return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all()
36 |
37 | @property
38 | def shape(self):
39 | return self.num_discrete_space
40 |
41 | def __repr__(self):
42 | return "MultiDiscrete" + str(self.num_discrete_space)
43 |
44 | def __eq__(self, other):
45 | return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high)
46 |
--------------------------------------------------------------------------------
/onpolicy/utils/util.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import numpy as np
4 | import math
5 | import torch
6 | import torch.nn as nn
7 |
8 | class AsynchControl:
9 | def __init__(self, num_envs, num_agents, limit, random_fn, min_length, max_length):
10 | self.num_envs = num_envs
11 | self.num_agents = num_agents
12 | self.limit = limit
13 | self.random_fn = random_fn
14 | self.min_length = min_length
15 | self.max_length = max_length
16 |
17 | self.reset()
18 |
19 | def reset(self):
20 | self.cnt = np.zeros((self.num_envs, self.num_agents), dtype=np.int32)
21 | self.rest = np.zeros((self.num_envs, self.num_agents), dtype=np.int32)
22 | self.active = np.ones((self.num_envs, self.num_agents), dtype=np.int32)
23 | for e in range(self.num_envs):
24 | for a in range(self.num_agents):
25 | self.rest[e, a] = self.random_fn() # the first step is unlimited
26 |
27 | def step(self):
28 | for e in range(self.num_envs):
29 | for a in range(self.num_agents):
30 | self.rest[e, a] -= 1
31 | self.active[e, a] = 0
32 | if self.rest[e, a] <= 0:
33 | if self.cnt[e, a] < self.limit:
34 | self.cnt[e, a] += 1
35 | self.active[e, a] = 1
36 | self.rest[e, a] = min(max(self.random_fn(), self.min_length), self.max_length)
37 |
38 | def active_agents(self):
39 | ret = []
40 | for e in range(self.num_envs):
41 | for a in range(self.num_agents):
42 | if self.active[e, a]:
43 | ret.append((e, a, self.cnt[e, a]))
44 | return ret
45 |
46 | def check(input):
47 | if type(input) == np.ndarray:
48 | return torch.from_numpy(input)
49 |
50 | def get_gard_norm(it):
51 | sum_grad = 0
52 | for x in it:
53 | if x.grad is None:
54 | continue
55 | sum_grad += x.grad.norm() ** 2
56 | return math.sqrt(sum_grad)
57 |
58 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
59 | """Decreases the learning rate linearly"""
60 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
61 | for param_group in optimizer.param_groups:
62 | param_group['lr'] = lr
63 |
64 | def huber_loss(e, d):
65 | a = (abs(e) <= d).float()
66 | b = (e > d).float()
67 | return a*e**2/2 + b*d*(abs(e)-d/2)
68 |
69 | def mse_loss(e):
70 | return e**2/2
71 |
72 | def get_shape_from_obs_space(obs_space):
73 | if obs_space.__class__.__name__ == 'Box':
74 | obs_shape = obs_space.shape
75 | elif obs_space.__class__.__name__ == 'list':
76 | obs_shape = obs_space
77 | elif obs_space.__class__.__name__ == 'Dict':
78 | obs_shape = obs_space.spaces
79 | else:
80 | raise NotImplementedError
81 | return obs_shape
82 |
83 | def get_shape_from_act_space(act_space):
84 | if act_space.__class__.__name__ == 'Discrete':
85 | act_shape = 1
86 | elif act_space.__class__.__name__ == "MultiDiscrete":
87 | act_shape = act_space.shape
88 | elif act_space.__class__.__name__ == "Box":
89 | act_shape = act_space.shape[0]
90 | elif act_space.__class__.__name__ == "MultiBinary":
91 | act_shape = act_space.shape[0]
92 | else: # agar
93 | act_shape = act_space[0].shape[0] + 1
94 | return act_shape
95 |
96 |
97 | def tile_images(img_nhwc):
98 | """
99 | Tile N images into one big PxQ image
100 | (P,Q) are chosen to be as close as possible, and if N
101 | is square, then P=Q.
102 | input: img_nhwc, list or array of images, ndim=4 once turned into array
103 | n = batch index, h = height, w = width, c = channel
104 | returns:
105 | bigim_HWc, ndarray with ndim=3
106 | """
107 | img_nhwc = np.asarray(img_nhwc)
108 | N, h, w, c = img_nhwc.shape
109 | H = int(np.ceil(np.sqrt(N)))
110 | W = int(np.ceil(float(N)/H))
111 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])
112 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
113 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
114 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)
115 | return img_Hh_Ww_c
--------------------------------------------------------------------------------
/onpolicy/utils/valuenorm.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class ValueNorm(nn.Module):
9 | """ Normalize a vector of observations - across the first norm_axes dimensions"""
10 |
11 | def __init__(self, input_shape, norm_axes=1, beta=0.99999, per_element_update=False, epsilon=1e-5, device=torch.device("cpu")):
12 | super(ValueNorm, self).__init__()
13 |
14 | self.input_shape = input_shape
15 | self.norm_axes = norm_axes
16 | self.epsilon = epsilon
17 | self.beta = beta
18 | self.per_element_update = per_element_update
19 | self.tpdv = dict(dtype=torch.float32, device=device)
20 |
21 | self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv)
22 | self.running_mean_sq = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv)
23 | self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(**self.tpdv)
24 |
25 | self.reset_parameters()
26 |
27 | def reset_parameters(self):
28 | self.running_mean.zero_()
29 | self.running_mean_sq.zero_()
30 | self.debiasing_term.zero_()
31 |
32 | def running_mean_var(self):
33 | debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon)
34 | debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon)
35 | debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2)
36 | return debiased_mean, debiased_var
37 |
38 | @torch.no_grad()
39 | def update(self, input_vector):
40 | if type(input_vector) == np.ndarray:
41 | input_vector = torch.from_numpy(input_vector)
42 | input_vector = input_vector.to(**self.tpdv)
43 |
44 | batch_mean = input_vector.mean(dim=tuple(range(self.norm_axes)))
45 | batch_sq_mean = (input_vector ** 2).mean(dim=tuple(range(self.norm_axes)))
46 |
47 | if self.per_element_update:
48 | batch_size = np.prod(input_vector.size()[:self.norm_axes])
49 | weight = self.beta ** batch_size
50 | else:
51 | weight = self.beta
52 |
53 | self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight))
54 | self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight))
55 | self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight))
56 |
57 | def normalize(self, input_vector):
58 | # Make sure input is float32
59 | if type(input_vector) == np.ndarray:
60 | input_vector = torch.from_numpy(input_vector)
61 | input_vector = input_vector.to(**self.tpdv)
62 |
63 | mean, var = self.running_mean_var()
64 | out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes]
65 |
66 | return out
67 |
68 | def denormalize(self, input_vector):
69 | """ Transform normalized data back into original distribution """
70 | if type(input_vector) == np.ndarray:
71 | input_vector = torch.from_numpy(input_vector)
72 | input_vector = input_vector.to(**self.tpdv)
73 |
74 | mean, var = self.running_mean_var()
75 | out = input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes]
76 |
77 | out = out.cpu().numpy()
78 |
79 | return out
80 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.9.0
2 | aiohttp==3.6.2
3 | aioredis==1.3.1
4 | astor==0.8.0
5 | astunparse==1.6.3
6 | async-timeout==3.0.1
7 | atari-py==0.2.6
8 | atomicwrites==1.2.1
9 | attrs==18.2.0
10 | beautifulsoup4==4.9.1
11 | blessings==1.7
12 | cachetools==4.1.1
13 | certifi==2020.4.5.2
14 | cffi==1.14.1
15 | chardet==3.0.4
16 | click==7.1.2
17 | cloudpickle==1.3.0
18 | colorama==0.4.3
19 | colorful==0.5.4
20 | configparser==5.0.1
21 | contextvars==2.4
22 | cycler==0.10.0
23 | Cython==0.29.21
24 | deepdiff==4.3.2
25 | dill==0.3.2
26 | docker-pycreds==0.4.0
27 | docopt==0.6.2
28 | fasteners==0.15
29 | filelock==3.0.12
30 | funcsigs==1.0.2
31 | future==0.16.0
32 | gast==0.2.2
33 | gin==0.1.6
34 | gin-config==0.3.0
35 | gitdb==4.0.5
36 | GitPython==3.1.9
37 | glfw==1.12.0
38 | google==3.0.0
39 | google-api-core==1.22.1
40 | google-auth==1.21.0
41 | google-auth-oauthlib==0.4.1
42 | google-pasta==0.2.0
43 | googleapis-common-protos==1.52.0
44 | gpustat==0.6.0
45 | gql==0.2.0
46 | graphql-core==1.1
47 | grpcio==1.31.0
48 | gym==0.17.2
49 | h5py==2.10.0
50 | hiredis==1.1.0
51 | idna==2.7
52 | idna-ssl==1.1.0
53 | imageio==2.4.1
54 | immutables==0.14
55 | importlib-metadata==1.7.0
56 | joblib==0.16.0
57 | jsonnet==0.16.0
58 | jsonpickle==0.9.6
59 | jsonschema==3.2.0
60 | Keras-Applications==1.0.8
61 | Keras-Preprocessing==1.1.2
62 | kiwisolver==1.0.1
63 | lockfile==0.12.2
64 | Markdown==3.1.1
65 | matplotlib==3.0.0
66 | mkl-fft==1.1.0
67 | mkl-random==1.1.1
68 | mkl-service==2.3.0
69 | mock==2.0.0
70 | monotonic==1.5
71 | more-itertools==4.3.0
72 | mpi4py==3.0.3
73 | mpyq==0.2.5
74 | msgpack==1.0.0
75 | mujoco-py==2.0.2.13
76 | multidict==4.7.6
77 | munch==2.3.2
78 | numpy==1.18.5
79 | nvidia-ml-py3==7.352.0
80 | oauthlib==3.1.0
81 | opencensus==0.7.10
82 | opencensus-context==0.1.1
83 | opencv-python==4.2.0.34
84 | opt-einsum==3.1.0
85 | ordered-set==4.0.2
86 | packaging==20.4
87 | pandas==1.1.1
88 | pathlib2==2.3.2
89 | pathtools==0.1.2
90 | pbr==4.3.0
91 | Pillow==5.3.0
92 | pluggy==0.7.1
93 | portpicker==1.2.0
94 | probscale==0.2.3
95 | progressbar2==3.53.1
96 | prometheus-client==0.8.0
97 | promise==2.3
98 | protobuf==3.12.3
99 | psutil==5.7.2
100 | py==1.6.0
101 | py-spy==0.3.3
102 | pyasn1==0.4.8
103 | pyasn1-modules==0.2.8
104 | pycparser==2.20
105 | pygame==1.9.4
106 | pyglet==1.5.0
107 | PyOpenGL==3.1.5
108 | PyOpenGL-accelerate==3.1.5
109 | pyparsing==2.2.2
110 | pyrsistent==0.16.0
111 | PySC2==3.0.0
112 | pytest==3.8.2
113 | python-dateutil==2.7.3
114 | python-utils==2.4.0
115 | pytz==2020.1
116 | PyYAML==3.13
117 | pyzmq==19.0.2
118 | ray==0.8.0
119 | redis==3.4.1
120 | requests==2.24.0
121 | requests-oauthlib==1.3.0
122 | rsa==4.6
123 | s2clientprotocol==4.10.1.75800.0
124 | s2protocol==4.11.4.78285.0
125 | sacred==0.7.2
126 | scipy==1.4.1
127 | seaborn==0.10.1
128 | sentry-sdk==0.18.0
129 | setproctitle==1.1.10
130 | shortuuid==1.0.1
131 | six==1.15.0
132 | sk-video==1.1.10
133 | smmap==3.0.4
134 | snakeviz==1.0.0
135 | soupsieve==2.0.1
136 | subprocess32==3.5.4
137 | tabulate==0.8.7
138 | tensorboard==2.0.2
139 | tensorboard-logger==0.1.0
140 | tensorboard-plugin-wit==1.7.0
141 | tensorboardX==2.0
142 | tensorflow==2.0.0
143 | tensorflow-estimator==2.0.0
144 | termcolor==1.1.0
145 | torch==1.5.1+cu101
146 | torchvision==0.6.1+cu101
147 | tornado==5.1.1
148 | tqdm==4.48.2
149 | typing-extensions==3.7.4.3
150 | urllib3==1.23
151 | wandb==0.10.5
152 | watchdog==0.10.3
153 | websocket-client==0.53.0
154 | Werkzeug==0.16.1
155 | whichcraft==0.5.2
156 | wrapt==1.12.1
157 | xmltodict==0.12.0
158 | yarl==1.5.1
159 | zipp==3.1.0
160 | zmq==0.0.0
161 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | from setuptools import setup, find_packages
6 | import setuptools
7 |
8 | def get_version() -> str:
9 | # https://packaging.python.org/guides/single-sourcing-package-version/
10 | init = open(os.path.join("onpolicy", "__init__.py"), "r").read().split()
11 | return init[init.index("__version__") + 2][1:-1]
12 |
13 | setup(
14 | name="onpolicy", # Replace with your own username
15 | version=get_version(),
16 | description="on-policy algorithms of marlbenchmark",
17 | long_description=open("README.md", encoding="utf8").read(),
18 | long_description_content_type="text/markdown",
19 | author="yuchao",
20 | author_email="zoeyuchao@gmail.com",
21 | packages=setuptools.find_packages(),
22 | classifiers=[
23 | "Development Status :: 3 - Alpha",
24 | "Intended Audience :: Science/Research",
25 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
26 | "Topic :: Software Development :: Libraries :: Python Modules",
27 | "Programming Language :: Python :: 3",
28 | "License :: OSI Approved :: MIT License",
29 | "Operating System :: OS Independent",
30 | ],
31 | keywords="multi-agent reinforcement learning platform pytorch",
32 | python_requires='>=3.6',
33 | )
34 |
--------------------------------------------------------------------------------